Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12/// The main client for interacting with the Gemini API.
13///
14/// Use `Gemini::new` or `Gemini::new_with_timeout` to create an instance.
15/// You can configure various aspects of the request like model, system instructions,
16/// generation config, safety settings, and tools using the provided builder-like methods.
17#[derive(Clone, Default, Debug)]
18pub struct Gemini {
19    client: Client,
20    api_key: String,
21    model: String,
22    sys_prompt: Option<SystemInstruction>,
23    generation_config: Option<Value>,
24    safety_settings: Option<Vec<SafetySetting>>,
25    tools: Option<Vec<Tool>>,
26    tool_config: Option<ToolConfig>,
27    cached_content: Option<String>,
28}
29
30impl Gemini {
31    /// Creates a new `Gemini` client.
32    ///
33    /// # Arguments
34    /// * `api_key` - Your Gemini API key. Get one from [Google AI studio](https://aistudio.google.com/app/apikey).
35    /// * `model` - The model variation to use (e.g., "gemini-2.5-flash"). See [model variations](https://ai.google.dev/gemini-api/docs/models#model-variations).
36    /// * `sys_prompt` - Optional system instructions. See [system instructions](https://ai.google.dev/gemini-api/docs/text-generation#image-input).
37    pub fn new(
38        api_key: impl Into<String>,
39        model: impl Into<String>,
40        sys_prompt: Option<SystemInstruction>,
41    ) -> Self {
42        Self {
43            client: Client::builder()
44                .timeout(Duration::from_secs(60))
45                .build()
46                .unwrap(),
47            api_key: api_key.into(),
48            model: model.into(),
49            sys_prompt,
50            generation_config: None,
51            safety_settings: None,
52            tools: None,
53            tool_config: None,
54            cached_content: None,
55        }
56    }
57    /// Creates a new `Gemini` client with a custom API timeout.
58    ///
59    /// # Arguments
60    /// * `api_key` - Your Gemini API key.
61    /// * `model` - The model variation to use.
62    /// * `sys_prompt` - Optional system instructions.
63    /// * `api_timeout` - Custom duration for request timeouts.
64    #[deprecated]
65    pub fn new_with_timeout(
66        api_key: impl Into<String>,
67        model: impl Into<String>,
68        sys_prompt: Option<SystemInstruction>,
69        api_timeout: Duration,
70    ) -> Self {
71        Self {
72            client: Client::builder().timeout(api_timeout).build().unwrap(),
73            api_key: api_key.into(),
74            model: model.into(),
75            sys_prompt,
76            generation_config: None,
77            safety_settings: None,
78            tools: None,
79            tool_config: None,
80            cached_content: None,
81        }
82    }
83    /// Creates a new `Gemini` client with a custom API reqwest::Client.
84    ///
85    /// # Arguments
86    /// * `api_key` - Your Gemini API key.
87    /// * `model` - The model variation to use.
88    /// * `sys_prompt` - Optional system instructions.
89    /// * `client` - reqwest::Client to request gemini API.
90    pub fn new_with_client(
91        api_key: impl Into<String>,
92        model: impl Into<String>,
93        sys_prompt: Option<SystemInstruction>,
94        client: Client,
95    ) -> Self {
96        Self {
97            client,
98            api_key: api_key.into(),
99            model: model.into(),
100            sys_prompt,
101            generation_config: None,
102            safety_settings: None,
103            tools: None,
104            tool_config: None,
105            cached_content: None,
106        }
107    }
108    /// Returns a mutable reference to the generation configuration.
109    /// If not already set, initializes it to an empty object.
110    ///
111    /// See [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig) for schema details.
112    pub fn set_generation_config(&mut self) -> &mut Value {
113        if let None = self.generation_config {
114            self.generation_config = Some(json!({}));
115        }
116        self.generation_config.as_mut().unwrap()
117    }
118    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
119        self.tool_config = Some(config);
120        self
121    }
122    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
123        if let Value::Object(map) = self.set_generation_config() {
124            if let Ok(thinking_value) = serde_json::to_value(config) {
125                map.insert("thinking_config".to_string(), thinking_value);
126            }
127        }
128        self
129    }
130    pub fn set_model(mut self, model: impl Into<String>) -> Self {
131        self.model = model.into();
132        self
133    }
134    /// # Warning
135    /// Changing sys_prompt in middle of a conversation can confuse the model.
136    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
137        self.sys_prompt = sys_prompt;
138        self
139    }
140    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
141        self.safety_settings = settings;
142        self
143    }
144    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
145        self.api_key = api_key.into();
146        self
147    }
148    /// Sets the response format to JSON mode with a specific schema.
149    ///
150    /// To use a Rust struct as a schema, decorate it with `#[gemini_schema]` and pass
151    /// `StructName::gemini_schema()`.
152    ///
153    /// # Arguments
154    /// * `schema` - The JSON schema for the response. See [Gemini Schema docs](https://ai.google.dev/api/caching#Schema).
155    pub fn set_json_mode(mut self, schema: Value) -> Self {
156        let config = self.set_generation_config();
157        config["response_mime_type"] = "application/json".into();
158        config["response_schema"] = schema.into();
159        self
160    }
161    pub fn remove_json_mode(mut self) -> Self {
162        if let Some(ref mut generation_config) = self.generation_config {
163            generation_config["response_schema"] = None::<Value>.into();
164            generation_config["response_mime_type"] = None::<Value>.into();
165        }
166        self
167    }
168    /// Sets the tools (functions) available to the model.
169    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
170        self.tools = Some(tools);
171        self
172    }
173    /// Removes all tools.
174    pub fn remove_tools(mut self) -> Self {
175        self.tools = None;
176        self
177    }
178    pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
179        self.cached_content = Some(name.into());
180        self
181    }
182    pub fn remove_cached_content(mut self) -> Self {
183        self.cached_content = None;
184        self
185    }
186
187    // Cache management methods
188
189    pub async fn create_cache(
190        &self,
191        cached_content: &CachedContent,
192    ) -> Result<CachedContent, GeminiResponseError> {
193        let req_url = format!(
194            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
195            self.api_key
196        );
197
198        let response = self
199            .client
200            .post(req_url)
201            .json(cached_content)
202            .send()
203            .await
204            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
205
206        if !response.status().is_success() {
207            let error = response
208                .json()
209                .await
210                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
211            return Err(GeminiResponseError::StatusNotOk(error));
212        }
213
214        let cached_content: CachedContent = response
215            .json()
216            .await
217            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
218        Ok(cached_content)
219    }
220
221    pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
222        let req_url = format!(
223            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
224            self.api_key
225        );
226
227        let response = self
228            .client
229            .get(req_url)
230            .send()
231            .await
232            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
233
234        if !response.status().is_success() {
235            let error = response
236                .json()
237                .await
238                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
239            return Err(GeminiResponseError::StatusNotOk(error));
240        }
241
242        let list: CachedContentList = response
243            .json()
244            .await
245            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
246        Ok(list)
247    }
248
249    pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
250        let req_url = format!(
251            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
252            name, self.api_key
253        );
254
255        let response = self
256            .client
257            .get(req_url)
258            .send()
259            .await
260            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
261
262        if !response.status().is_success() {
263            let error = response
264                .json()
265                .await
266                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
267            return Err(GeminiResponseError::StatusNotOk(error));
268        }
269
270        let cached_content: CachedContent = response
271            .json()
272            .await
273            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
274        Ok(cached_content)
275    }
276
277    pub async fn update_cache(
278        &self,
279        name: &str,
280        update: &CachedContentUpdate,
281    ) -> Result<CachedContent, GeminiResponseError> {
282        let req_url = format!(
283            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
284            name, self.api_key
285        );
286
287        let response = self
288            .client
289            .patch(req_url)
290            .json(update)
291            .send()
292            .await
293            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
294
295        if !response.status().is_success() {
296            let error = response
297                .json()
298                .await
299                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
300            return Err(GeminiResponseError::StatusNotOk(error));
301        }
302
303        let cached_content: CachedContent = response
304            .json()
305            .await
306            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
307        Ok(cached_content)
308    }
309
310    pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
311        let req_url = format!(
312            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
313            name, self.api_key
314        );
315
316        let response = self
317            .client
318            .delete(req_url)
319            .send()
320            .await
321            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
322
323        if !response.status().is_success() {
324            let error = response
325                .json()
326                .await
327                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
328            return Err(GeminiResponseError::StatusNotOk(error));
329        }
330
331        Ok(())
332    }
333
334    /// Sends a prompt to the model and waits for the full response.
335    ///
336    /// Updates the `session` history with the model's reply.
337    ///
338    /// # Errors
339    /// Returns `GeminiResponseError::NothingToRespond` if the last message in history is from the model.
340    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
341        if session
342            .get_last_chat()
343            .is_some_and(|chat| *chat.role() == Role::Model)
344        {
345            return Err(GeminiResponseError::NothingToRespond);
346        }
347        let req_url = format!(
348            "{BASE_URL}/{}:generateContent?key={}",
349            self.model, self.api_key
350        );
351
352        let response = self
353            .client
354            .post(req_url)
355            .json(&GeminiRequestBody::new(
356                self.sys_prompt.as_ref(),
357                self.tools.as_deref(),
358                &session.get_history().as_slice(),
359                self.generation_config.as_ref(),
360                self.safety_settings.as_deref(),
361                self.tool_config.as_ref(),
362                self.cached_content.clone(),
363            ))
364            .send()
365            .await
366            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
367
368        if !response.status().is_success() {
369            let error = response
370                .json()
371                .await
372                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
373            return Err(GeminiResponseError::StatusNotOk(error));
374        }
375
376        let reply = GeminiResponse::new(response)
377            .await
378            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
379        session.update(&reply);
380        Ok(reply)
381    }
382    /// # Warning
383    /// You must read the response stream to get reply stored context in `session`.
384    /// `data_extractor` is used to extract data that you get as a stream of futures.
385    /// # Example
386    ///```ignore
387    ///use futures::StreamExt
388    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
389    ///     |session, _gemini_response| session.get_last_chat().unwrap().get_text_no_think("\n"))
390    ///    .await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
391    ///while let Some(response) = response_stream.next().await {
392    ///    println!("{}", response);
393    ///}
394    ///```
395    pub async fn ask_as_stream_with_extractor<F, StreamType>(
396        &self,
397        session: Session,
398        data_extractor: F,
399    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
400    where
401        F: FnMut(&Session, GeminiResponse) -> StreamType,
402    {
403        if session
404            .get_last_chat()
405            .is_some_and(|chat| *chat.role() == Role::Model)
406        {
407            return Err((session, GeminiResponseError::NothingToRespond));
408        }
409        let req_url = format!(
410            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
411            self.model, self.api_key
412        );
413
414        let request = self
415            .client
416            .post(req_url)
417            .json(&GeminiRequestBody::new(
418                self.sys_prompt.as_ref(),
419                self.tools.as_deref(),
420                session.get_history().as_slice(),
421                self.generation_config.as_ref(),
422                self.safety_settings.as_deref(),
423                self.tool_config.as_ref(),
424                self.cached_content.clone(),
425            ))
426            .send()
427            .await;
428        let response = match request {
429            Ok(response) => response,
430            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
431        };
432
433        if !response.status().is_success() {
434            let error = match response.json().await {
435                Ok(response) => response,
436                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
437            };
438            return Err((session, GeminiResponseError::StatusNotOk(error)));
439        }
440
441        Ok(ResponseStream::new(
442            Box::new(response.bytes_stream()),
443            session,
444            data_extractor,
445        ))
446    }
447    /// Sends a prompt to the model and returns a stream of responses.
448    ///
449    /// # Warning
450    /// You must exhaust the response stream to ensure the `session` history is correctly updated.
451    ///
452    /// # Example
453    /// ```no_run
454    /// use futures::StreamExt;
455    /// # async fn run(gemini: gemini_client_api::gemini::ask::Gemini, session: gemini_client_api::gemini::types::sessions::Session) {
456    /// let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
457    ///
458    /// while let Some(response) = response_stream.next().await {
459    ///     if let Ok(response) = response {
460    ///         println!("{}", response.get_chat().get_text_no_think("\n"));
461    ///     }
462    /// }
463    /// # }
464    /// ```
465    pub async fn ask_as_stream(
466        &self,
467        session: Session,
468    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
469        self.ask_as_stream_with_extractor(
470            session,
471            (|_, gemini_response| gemini_response)
472                as fn(&Session, GeminiResponse) -> GeminiResponse,
473        )
474        .await
475    }
476}