Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5#[cfg(feature = "reqwest")]
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12/// The main client for interacting with the Gemini API.
13///
14/// Use `Gemini::new` or `Gemini::new_with_timeout` to create an instance.
15/// You can configure various aspects of the request like model, system instructions,
16/// generation config, safety settings, and tools using the provided builder-like methods.
17#[derive(Clone, Default, Debug)]
18pub struct Gemini {
19    #[cfg(feature = "reqwest")]
20    client: Client,
21    api_key: String,
22    model: String,
23    sys_prompt: Option<SystemInstruction>,
24    generation_config: Option<Value>,
25    safety_settings: Option<Vec<SafetySetting>>,
26    tools: Option<Vec<Tool>>,
27    tool_config: Option<ToolConfig>,
28}
29
30impl Gemini {
31    /// Creates a new `Gemini` client.
32    ///
33    /// # Arguments
34    /// * `api_key` - Your Gemini API key. Get one from [Google AI studio](https://aistudio.google.com/app/apikey).
35    /// * `model` - The model variation to use (e.g., "gemini-1.5-flash"). See [model variations](https://ai.google.dev/gemini-api/docs/models#model-variations).
36    /// * `sys_prompt` - Optional system instructions. See [system instructions](https://ai.google.dev/gemini-api/docs/text-generation#image-input).
37    #[cfg(feature = "reqwest")]
38    pub fn new(
39        api_key: impl Into<String>,
40        model: impl Into<String>,
41        sys_prompt: Option<SystemInstruction>,
42    ) -> Self {
43        Self {
44            client: Client::builder()
45                .timeout(Duration::from_secs(60))
46                .build()
47                .unwrap(),
48            api_key: api_key.into(),
49            model: model.into(),
50            sys_prompt,
51            generation_config: None,
52            safety_settings: None,
53            tools: None,
54            tool_config: None,
55        }
56    }
57    /// Creates a new `Gemini` client with a custom API timeout.
58    ///
59    /// # Arguments
60    /// * `api_key` - Your Gemini API key.
61    /// * `model` - The model variation to use.
62    /// * `sys_prompt` - Optional system instructions.
63    /// * `api_timeout` - Custom duration for request timeouts.
64    #[cfg(feature = "reqwest")]
65    pub fn new_with_timeout(
66        api_key: impl Into<String>,
67        model: impl Into<String>,
68        sys_prompt: Option<SystemInstruction>,
69        api_timeout: Duration,
70    ) -> Self {
71        Self {
72            client: Client::builder().timeout(api_timeout).build().unwrap(),
73            api_key: api_key.into(),
74            model: model.into(),
75            sys_prompt,
76            generation_config: None,
77            safety_settings: None,
78            tools: None,
79            tool_config: None,
80        }
81    }
82    /// Returns a mutable reference to the generation configuration.
83    /// If not already set, initializes it to an empty object.
84    ///
85    /// See [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig) for schema details.
86    pub fn set_generation_config(&mut self) -> &mut Value {
87        if let None = self.generation_config {
88            self.generation_config = Some(json!({}));
89        }
90        self.generation_config.as_mut().unwrap()
91    }
92    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
93        self.tool_config = Some(config);
94        self
95    }
96    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
97        if let Value::Object(map) = self.set_generation_config() {
98            if let Ok(thinking_value) = serde_json::to_value(config) {
99                map.insert("thinking_config".to_string(), thinking_value);
100            }
101        }
102        self
103    }
104    pub fn set_model(mut self, model: impl Into<String>) -> Self {
105        self.model = model.into();
106        self
107    }
108    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
109        self.sys_prompt = sys_prompt;
110        self
111    }
112    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
113        self.safety_settings = settings;
114        self
115    }
116    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
117        self.api_key = api_key.into();
118        self
119    }
120    /// Sets the response format to JSON mode with a specific schema.
121    ///
122    /// To use a Rust struct as a schema, decorate it with `#[gemini_schema]` and pass
123    /// `StructName::gemini_schema()`.
124    ///
125    /// # Arguments
126    /// * `schema` - The JSON schema for the response. See [Gemini Schema docs](https://ai.google.dev/api/caching#Schema).
127    pub fn set_json_mode(mut self, schema: Value) -> Self {
128        let config = self.set_generation_config();
129        config["response_mime_type"] = "application/json".into();
130        config["response_schema"] = schema.into();
131        self
132    }
133    pub fn unset_json_mode(mut self) -> Self {
134        if let Some(ref mut generation_config) = self.generation_config {
135            generation_config["response_schema"] = None::<Value>.into();
136            generation_config["response_mime_type"] = None::<Value>.into();
137        }
138        self
139    }
140    /// Sets the tools (functions) available to the model.
141    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
142        self.tools = Some(tools);
143        self
144    }
145    /// Removes all tools.
146    pub fn unset_tools(mut self) -> Self {
147        self.tools = None;
148        self
149    }
150
151    /// Sends a prompt to the model and waits for the full response.
152    ///
153    /// Updates the `session` history with the model's reply.
154    ///
155    /// # Errors
156    /// Returns `GeminiResponseError::NothingToRespond` if the last message in history is from the model.
157    #[cfg(feature = "reqwest")]
158    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
159        if session
160            .get_last_chat()
161            .is_some_and(|chat| *chat.role() == Role::Model)
162        {
163            return Err(GeminiResponseError::NothingToRespond);
164        }
165        let req_url = format!(
166            "{BASE_URL}/{}:generateContent?key={}",
167            self.model, self.api_key
168        );
169
170        let response = self
171            .client
172            .post(req_url)
173            .json(&GeminiRequestBody::new(
174                self.sys_prompt.as_ref(),
175                self.tools.as_deref(),
176                &session.get_history().as_slice(),
177                self.generation_config.as_ref(),
178                self.safety_settings.as_deref(),
179                self.tool_config.as_ref(),
180            ))
181            .send()
182            .await
183            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
184
185        if !response.status().is_success() {
186            let text = response
187                .text()
188                .await
189                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
190            return Err(GeminiResponseError::StatusNotOk(text));
191        }
192
193        let reply = GeminiResponse::new(response)
194            .await
195            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
196        session.update(&reply);
197        Ok(reply)
198    }
199    /// # Warning
200    /// You must read the response stream to get reply stored context in `session`.
201    /// `data_extractor` is used to extract data that you get as a stream of futures.
202    /// # Example
203    ///```ignore
204    ///use futures::StreamExt
205    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
206    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
207    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
208    ///
209    ///while let Some(response) = response_stream.next().await {
210    ///    if let Ok(response) = response {
211    ///        println!("{}", response);
212    ///    }
213    ///}
214    ///```
215    #[cfg(feature = "reqwest")]
216    pub async fn ask_as_stream_with_extractor<F, StreamType>(
217        &self,
218        session: Session,
219        data_extractor: F,
220    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
221    where
222        F: FnMut(&Session, GeminiResponse) -> StreamType,
223    {
224        if session
225            .get_last_chat()
226            .is_some_and(|chat| *chat.role() == Role::Model)
227        {
228            return Err((session, GeminiResponseError::NothingToRespond));
229        }
230        let req_url = format!(
231            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
232            self.model, self.api_key
233        );
234
235        let request = self
236            .client
237            .post(req_url)
238            .json(&GeminiRequestBody::new(
239                self.sys_prompt.as_ref(),
240                self.tools.as_deref(),
241                session.get_history().as_slice(),
242                self.generation_config.as_ref(),
243                self.safety_settings.as_deref(),
244                self.tool_config.as_ref(),
245            ))
246            .send()
247            .await;
248        let response = match request {
249            Ok(response) => response,
250            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
251        };
252
253        if !response.status().is_success() {
254            let text = match response.text().await {
255                Ok(response) => response,
256                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
257            };
258            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
259        }
260
261        Ok(ResponseStream::new(
262            Box::new(response.bytes_stream()),
263            session,
264            data_extractor,
265        ))
266    }
267    /// Sends a prompt to the model and returns a stream of responses.
268    ///
269    /// # Warning
270    /// You must exhaust the response stream to ensure the `session` history is correctly updated.
271    ///
272    /// # Example
273    /// ```no_run
274    /// use futures::StreamExt;
275    /// # async fn run(gemini: gemini_client_api::gemini::ask::Gemini, session: gemini_client_api::gemini::types::sessions::Session) {
276    /// let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
277    ///
278    /// while let Some(response) = response_stream.next().await {
279    ///     if let Ok(response) = response {
280    ///         println!("{}", response.get_text(""));
281    ///     }
282    /// }
283    /// # }
284    /// ```
285    #[cfg(feature = "reqwest")]
286    pub async fn ask_as_stream(
287        &self,
288        session: Session,
289    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
290        self.ask_as_stream_with_extractor(
291            session,
292            (|_, gemini_response| gemini_response)
293                as fn(&Session, GeminiResponse) -> GeminiResponse,
294        )
295        .await
296    }
297}