Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12/// The main client for interacting with the Gemini API.
13#[derive(Clone, Default, Debug)]
14pub struct Gemini {
15    client: Client,
16    api_key: String,
17    model: String,
18    sys_prompt: Option<SystemInstruction>,
19    generation_config: Option<Value>,
20    safety_settings: Option<Vec<SafetySetting>>,
21    tools: Option<Vec<Tool>>,
22    tool_config: Option<ToolConfig>,
23    cached_content: Option<String>,
24}
25
26impl Gemini {
27    /// Creates a new `Gemini` client.
28    ///
29    /// # Arguments
30    /// * `api_key` - Your Gemini API key. Get one from [Google AI studio](https://aistudio.google.com/app/apikey).
31    /// * `model` - The model variation to use (e.g., "gemini-2.5-flash"). See [model variations](https://ai.google.dev/gemini-api/docs/models#model-variations).
32    /// * `sys_prompt` - Optional system instructions. See [system instructions](https://ai.google.dev/gemini-api/docs/text-generation#image-input).
33    pub fn new(
34        api_key: impl Into<String>,
35        model: impl Into<String>,
36        sys_prompt: Option<SystemInstruction>,
37    ) -> Self {
38        Self {
39            client: Client::builder()
40                .timeout(Duration::from_secs(60))
41                .build()
42                .unwrap(),
43            api_key: api_key.into(),
44            model: model.into(),
45            sys_prompt,
46            generation_config: None,
47            safety_settings: None,
48            tools: None,
49            tool_config: None,
50            cached_content: None,
51        }
52    }
53    /// Creates a new `Gemini` client with a custom API timeout.
54    ///
55    /// # Arguments
56    /// * `api_key` - Your Gemini API key.
57    /// * `model` - The model variation to use.
58    /// * `sys_prompt` - Optional system instructions.
59    /// * `api_timeout` - Custom duration for request timeouts.
60    #[deprecated]
61    pub fn new_with_timeout(
62        api_key: impl Into<String>,
63        model: impl Into<String>,
64        sys_prompt: Option<SystemInstruction>,
65        api_timeout: Duration,
66    ) -> Self {
67        Self {
68            client: Client::builder().timeout(api_timeout).build().unwrap(),
69            api_key: api_key.into(),
70            model: model.into(),
71            sys_prompt,
72            generation_config: None,
73            safety_settings: None,
74            tools: None,
75            tool_config: None,
76            cached_content: None,
77        }
78    }
79    /// Creates a new `Gemini` client with a custom API reqwest::Client.
80    ///
81    /// # Arguments
82    /// * `api_key` - Your Gemini API key.
83    /// * `model` - The model variation to use.
84    /// * `sys_prompt` - Optional system instructions.
85    /// * `client` - reqwest::Client to request gemini API.
86    pub fn new_with_client(
87        api_key: impl Into<String>,
88        model: impl Into<String>,
89        sys_prompt: Option<SystemInstruction>,
90        client: Client,
91    ) -> Self {
92        Self {
93            client,
94            api_key: api_key.into(),
95            model: model.into(),
96            sys_prompt,
97            generation_config: None,
98            safety_settings: None,
99            tools: None,
100            tool_config: None,
101            cached_content: None,
102        }
103    }
104    /// Returns a mutable reference to the generation configuration.
105    /// If not already set, initializes it to an empty object.
106    ///
107    /// See [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig) for schema details.
108    pub fn set_generation_config(&mut self) -> &mut Value {
109        if let None = self.generation_config {
110            self.generation_config = Some(json!({}));
111        }
112        self.generation_config.as_mut().unwrap()
113    }
114    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
115        self.tool_config = Some(config);
116        self
117    }
118    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
119        if let Value::Object(map) = self.set_generation_config() {
120            if let Ok(thinking_value) = serde_json::to_value(config) {
121                map.insert("thinking_config".to_string(), thinking_value);
122            }
123        }
124        self
125    }
126    pub fn set_model(mut self, model: impl Into<String>) -> Self {
127        self.model = model.into();
128        self
129    }
130    /// # Warning
131    /// Changing sys_prompt in middle of a conversation can confuse the model.
132    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
133        self.sys_prompt = sys_prompt;
134        self
135    }
136    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
137        self.safety_settings = settings;
138        self
139    }
140    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
141        self.api_key = api_key.into();
142        self
143    }
144    /// Sets the response format to JSON mode with a specific schema.
145    ///
146    /// To use a Rust struct as a schema, decorate it with `#[gemini_schema]` and pass
147    /// `StructName::gemini_schema()`.
148    ///
149    /// # Arguments
150    /// * `schema` - The JSON schema for the response. See [Gemini Schema docs](https://ai.google.dev/api/caching#Schema).
151    pub fn set_json_mode(mut self, schema: Value) -> Self {
152        let config = self.set_generation_config();
153        config["response_mime_type"] = "application/json".into();
154        config["response_schema"] = schema.into();
155        self
156    }
157    pub fn remove_json_mode(mut self) -> Self {
158        if let Some(ref mut generation_config) = self.generation_config {
159            generation_config["response_schema"] = None::<Value>.into();
160            generation_config["response_mime_type"] = None::<Value>.into();
161        }
162        self
163    }
164    /// Sets the tools (functions) available to the model.
165    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
166        self.tools = Some(tools);
167        self
168    }
169    /// Removes all tools.
170    pub fn remove_tools(mut self) -> Self {
171        self.tools = None;
172        self
173    }
174    pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
175        self.cached_content = Some(name.into());
176        self
177    }
178    pub fn remove_cached_content(mut self) -> Self {
179        self.cached_content = None;
180        self
181    }
182
183    // Cache management methods
184
185    pub async fn create_cache(
186        &self,
187        cached_content: &CachedContent,
188    ) -> Result<CachedContent, GeminiResponseError> {
189        let req_url = format!(
190            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
191            self.api_key
192        );
193
194        let response = self
195            .client
196            .post(req_url)
197            .json(cached_content)
198            .send()
199            .await
200            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
201
202        if !response.status().is_success() {
203            let error = response
204                .json()
205                .await
206                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
207            return Err(GeminiResponseError::StatusNotOk(error));
208        }
209
210        let cached_content: CachedContent = response
211            .json()
212            .await
213            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
214        Ok(cached_content)
215    }
216
217    pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
218        let req_url = format!(
219            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
220            self.api_key
221        );
222
223        let response = self
224            .client
225            .get(req_url)
226            .send()
227            .await
228            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
229
230        if !response.status().is_success() {
231            let error = response
232                .json()
233                .await
234                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
235            return Err(GeminiResponseError::StatusNotOk(error));
236        }
237
238        let list: CachedContentList = response
239            .json()
240            .await
241            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
242        Ok(list)
243    }
244
245    pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
246        let req_url = format!(
247            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
248            name, self.api_key
249        );
250
251        let response = self
252            .client
253            .get(req_url)
254            .send()
255            .await
256            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
257
258        if !response.status().is_success() {
259            let error = response
260                .json()
261                .await
262                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
263            return Err(GeminiResponseError::StatusNotOk(error));
264        }
265
266        let cached_content: CachedContent = response
267            .json()
268            .await
269            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
270        Ok(cached_content)
271    }
272
273    pub async fn update_cache(
274        &self,
275        name: &str,
276        update: &CachedContentUpdate,
277    ) -> Result<CachedContent, GeminiResponseError> {
278        let req_url = format!(
279            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
280            name, self.api_key
281        );
282
283        let response = self
284            .client
285            .patch(req_url)
286            .json(update)
287            .send()
288            .await
289            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
290
291        if !response.status().is_success() {
292            let error = response
293                .json()
294                .await
295                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
296            return Err(GeminiResponseError::StatusNotOk(error));
297        }
298
299        let cached_content: CachedContent = response
300            .json()
301            .await
302            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
303        Ok(cached_content)
304    }
305
306    pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
307        let req_url = format!(
308            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
309            name, self.api_key
310        );
311
312        let response = self
313            .client
314            .delete(req_url)
315            .send()
316            .await
317            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
318
319        if !response.status().is_success() {
320            let error = response
321                .json()
322                .await
323                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
324            return Err(GeminiResponseError::StatusNotOk(error));
325        }
326
327        Ok(())
328    }
329
330    /// Sends a prompt to the model and waits for the full response.
331    ///
332    /// Updates the `session` history with the model's reply.
333    ///
334    /// # Errors
335    /// Returns `GeminiResponseError::NothingToRespond` if the last message in history is from the model.
336    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
337        if session
338            .get_last_chat()
339            .is_some_and(|chat| *chat.role() == Role::Model)
340        {
341            return Err(GeminiResponseError::NothingToRespond);
342        }
343        let req_url = format!(
344            "{BASE_URL}/{}:generateContent?key={}",
345            self.model, self.api_key
346        );
347
348        let response = self
349            .client
350            .post(req_url)
351            .json(&GeminiRequestBody::new(
352                self.sys_prompt.as_ref(),
353                self.tools.as_deref(),
354                &session.get_history().as_slice(),
355                self.generation_config.as_ref(),
356                self.safety_settings.as_deref(),
357                self.tool_config.as_ref(),
358                self.cached_content.clone(),
359            ))
360            .send()
361            .await
362            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
363
364        if !response.status().is_success() {
365            let error = response
366                .json()
367                .await
368                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
369            return Err(GeminiResponseError::StatusNotOk(error));
370        }
371
372        let reply = GeminiResponse::new(response)
373            .await
374            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
375        session.update(&reply);
376        Ok(reply)
377    }
378    /// # Warning
379    /// You must read the response stream to get reply stored context in `session`.
380    /// `data_extractor` is used to extract data that you get as a stream of futures.
381    /// # Example
382    ///```ignore
383    ///use futures::StreamExt
384    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
385    ///     |session, _gemini_response| session.get_last_chat().unwrap().get_text_no_think("\n"))
386    ///    .await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
387    ///while let Some(response) = response_stream.next().await {
388    ///    println!("{}", response);
389    ///}
390    ///```
391    pub async fn ask_as_stream_with_extractor<F, StreamType>(
392        &self,
393        session: Session,
394        data_extractor: F,
395    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
396    where
397        F: FnMut(&Session, GeminiResponse) -> StreamType,
398    {
399        if session
400            .get_last_chat()
401            .is_some_and(|chat| *chat.role() == Role::Model)
402        {
403            return Err((session, GeminiResponseError::NothingToRespond));
404        }
405        let req_url = format!(
406            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
407            self.model, self.api_key
408        );
409
410        let request = self
411            .client
412            .post(req_url)
413            .json(&GeminiRequestBody::new(
414                self.sys_prompt.as_ref(),
415                self.tools.as_deref(),
416                session.get_history().as_slice(),
417                self.generation_config.as_ref(),
418                self.safety_settings.as_deref(),
419                self.tool_config.as_ref(),
420                self.cached_content.clone(),
421            ))
422            .send()
423            .await;
424        let response = match request {
425            Ok(response) => response,
426            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
427        };
428
429        if !response.status().is_success() {
430            let error = match response.json().await {
431                Ok(response) => response,
432                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
433            };
434            return Err((session, GeminiResponseError::StatusNotOk(error)));
435        }
436
437        Ok(ResponseStream::new(
438            Box::new(response.bytes_stream()),
439            session,
440            data_extractor,
441        ))
442    }
443    /// Sends a prompt to the model and returns a stream of responses.
444    ///
445    /// # Warning
446    /// You must exhaust the response stream to ensure the `session` history is correctly updated.
447    ///
448    /// # Example
449    /// ```no_run
450    /// use futures::StreamExt;
451    /// # async fn run(gemini: gemini_client_api::gemini::ask::Gemini, session: gemini_client_api::gemini::types::sessions::Session) {
452    /// let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
453    ///
454    /// while let Some(response) = response_stream.next().await {
455    ///     if let Ok(response) = response {
456    ///         println!("{}", response.get_chat().get_text_no_think("\n"));
457    ///     }
458    /// }
459    /// # }
460    /// ```
461    pub async fn ask_as_stream(
462        &self,
463        session: Session,
464    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
465        self.ask_as_stream_with_extractor(
466            session,
467            (|_, gemini_response| gemini_response)
468                as fn(&Session, GeminiResponse) -> GeminiResponse,
469        )
470        .await
471    }
472}