gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    safety_settings: Option<Vec<SafetySetting>>,
19    tools: Option<Vec<Tool>>,
20}
21impl Gemini {
22    /// # Arguments
23    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
24    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
25    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
26    pub fn new(
27        api_key: impl Into<String>,
28        model: impl Into<String>,
29        sys_prompt: Option<SystemInstruction>,
30    ) -> Self {
31        Self {
32            client: Client::builder()
33                .timeout(Duration::from_secs(60))
34                .build()
35                .unwrap(),
36            api_key: api_key.into(),
37            model: model.into(),
38            sys_prompt,
39            generation_config: None,
40            safety_settings: None,
41            tools: None,
42        }
43    }
44    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
45    pub fn new_with_timeout(
46        api_key: impl Into<String>,
47        model: impl Into<String>,
48        sys_prompt: Option<SystemInstruction>,
49        api_timeout: Duration,
50    ) -> Self {
51        Self {
52            client: Client::builder().timeout(api_timeout).build().unwrap(),
53            api_key: api_key.into(),
54            model: model.into(),
55            sys_prompt,
56            generation_config: None,
57            safety_settings: None,
58            tools: None,
59        }
60    }
61    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig)
62    pub fn set_generation_config(&mut self) -> &mut Value {
63        if let None = self.generation_config {
64            self.generation_config = Some(json!({}));
65        }
66        self.generation_config.as_mut().unwrap()
67    }
68    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
69        if let Value::Object(map) = self.set_generation_config() {
70            if let Ok(thinking_value) = serde_json::to_value(config) {
71                map.insert("thinking_config".to_string(), thinking_value);
72            }
73        }
74        self
75    }
76    pub fn set_model(mut self, model: impl Into<String>) -> Self {
77        self.model = model.into();
78        self
79    }
80    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
81        self.sys_prompt = sys_prompt;
82        self
83    }
84    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
85        self.safety_settings = settings;
86        self
87    }
88    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
89        self.api_key = api_key.into();
90        self
91    }
92    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
93    /// To verify your schema visit [here](https://aistudio.google.com/prompts/new_chat):
94    /// - Under tools, toggle on Structured output
95    /// - Click Edit
96    /// - Here you can create schema with `Visual Editor` or `Code Editor` with error detection
97    pub fn set_json_mode(mut self, schema: Value) -> Self {
98        let config = self.set_generation_config();
99        config["response_mime_type"] = "application/json".into();
100        config["response_schema"] = schema.into();
101        self
102    }
103    pub fn unset_json_mode(mut self) -> Self {
104        if let Some(ref mut generation_config) = self.generation_config {
105            generation_config["response_schema"] = None::<Value>.into();
106            generation_config["response_mime_type"] = None::<Value>.into();
107        }
108        self
109    }
110    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
111        self.tools = Some(tools);
112        self
113    }
114    pub fn unset_tools(mut self) -> Self {
115        self.tools = None;
116        self
117    }
118
119    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
120        let req_url = format!(
121            "{BASE_URL}/{}:generateContent?key={}",
122            self.model, self.api_key
123        );
124
125        let response = self
126            .client
127            .post(req_url)
128            .json(&GeminiRequestBody::new(
129                self.sys_prompt.as_ref(),
130                self.tools.as_deref(),
131                &session.get_history().as_slice(),
132                self.generation_config.as_ref(),
133                self.safety_settings.as_deref(),
134            ))
135            .send()
136            .await
137            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
138
139        if !response.status().is_success() {
140            let text = response
141                .text()
142                .await
143                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
144            return Err(GeminiResponseError::StatusNotOk(text));
145        }
146
147        let reply = GeminiResponse::new(response)
148            .await
149            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
150        session.update(&reply);
151        Ok(reply)
152    }
153    /// # Warning
154    /// You must read the response stream to get reply stored context in `session`.
155    /// `data_extractor` is used to extract data that you get as a stream of futures.
156    /// # Example
157    ///```ignore
158    ///use futures::StreamExt
159    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
160    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
161    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
162    ///
163    ///while let Some(response) = response_stream.next().await {
164    ///    if let Ok(response) = response {
165    ///        println!("{}", response);
166    ///    }
167    ///}
168    ///```
169    pub async fn ask_as_stream_with_extractor<F, StreamType>(
170        &self,
171        session: Session,
172        data_extractor: F,
173    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
174    where
175        F: FnMut(&Session, GeminiResponse) -> StreamType,
176    {
177        let req_url = format!(
178            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
179            self.model, self.api_key
180        );
181
182        let request = self
183            .client
184            .post(req_url)
185            .json(&GeminiRequestBody::new(
186                self.sys_prompt.as_ref(),
187                self.tools.as_deref(),
188                session.get_history().as_slice(),
189                self.generation_config.as_ref(),
190                self.safety_settings.as_deref(),
191            ))
192            .send()
193            .await;
194        let response = match request {
195            Ok(response) => response,
196            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
197        };
198
199        if !response.status().is_success() {
200            let text = match response.text().await {
201                Ok(response) => response,
202                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
203            };
204            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
205        }
206
207        Ok(ResponseStream::new(
208            Box::new(response.bytes_stream()),
209            session,
210            data_extractor,
211        ))
212    }
213    /// # Warning
214    /// You must read the response stream to get reply stored context in `session`.
215    /// # Example
216    ///```ignore
217    ///use futures::StreamExt
218    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
219    ///
220    ///while let Some(response) = response_stream.next().await {
221    ///    if let Ok(response) = response {
222    ///        println!("{}", response.get_text(""));
223    ///    }
224    ///}
225    ///```
226    pub async fn ask_as_stream(
227        &self,
228        session: Session,
229    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
230        self.ask_as_stream_with_extractor(
231            session,
232            (|_, gemini_response| gemini_response)
233                as fn(&Session, GeminiResponse) -> GeminiResponse,
234        )
235        .await
236    }
237}