Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5#[cfg(feature = "reqwest")]
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12#[derive(Clone, Default, Debug)]
13pub struct Gemini {
14    #[cfg(feature = "reqwest")]
15    client: Client,
16    api_key: String,
17    model: String,
18    sys_prompt: Option<SystemInstruction>,
19    generation_config: Option<Value>,
20    safety_settings: Option<Vec<SafetySetting>>,
21    tools: Option<Vec<Tool>>,
22    tool_config: Option<ToolConfig>,
23}
24
25impl Gemini {
26    /// # Arguments
27    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
28    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
29    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
30    #[cfg(feature = "reqwest")]
31    pub fn new(
32        api_key: impl Into<String>,
33        model: impl Into<String>,
34        sys_prompt: Option<SystemInstruction>,
35    ) -> Self {
36        Self {
37            client: Client::builder()
38                .timeout(Duration::from_secs(60))
39                .build()
40                .unwrap(),
41            api_key: api_key.into(),
42            model: model.into(),
43            sys_prompt,
44            generation_config: None,
45            safety_settings: None,
46            tools: None,
47            tool_config: None,
48        }
49    }
50    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
51    #[cfg(feature = "reqwest")]
52    pub fn new_with_timeout(
53        api_key: impl Into<String>,
54        model: impl Into<String>,
55        sys_prompt: Option<SystemInstruction>,
56        api_timeout: Duration,
57    ) -> Self {
58        Self {
59            client: Client::builder().timeout(api_timeout).build().unwrap(),
60            api_key: api_key.into(),
61            model: model.into(),
62            sys_prompt,
63            generation_config: None,
64            safety_settings: None,
65            tools: None,
66            tool_config: None,
67        }
68    }
69    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig)
70    pub fn set_generation_config(&mut self) -> &mut Value {
71        if let None = self.generation_config {
72            self.generation_config = Some(json!({}));
73        }
74        self.generation_config.as_mut().unwrap()
75    }
76    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
77        self.tool_config = Some(config);
78        self
79    }
80    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
81        if let Value::Object(map) = self.set_generation_config() {
82            if let Ok(thinking_value) = serde_json::to_value(config) {
83                map.insert("thinking_config".to_string(), thinking_value);
84            }
85        }
86        self
87    }
88    pub fn set_model(mut self, model: impl Into<String>) -> Self {
89        self.model = model.into();
90        self
91    }
92    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
93        self.sys_prompt = sys_prompt;
94        self
95    }
96    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
97        self.safety_settings = settings;
98        self
99    }
100    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
101        self.api_key = api_key.into();
102        self
103    }
104    ///To use struct as schema write #[gemini_schema] above struct then pass
105    /// `StructName::gemini_schema()`
106    /// OR
107    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
108    /// To verify your schema visit [here](https://aistudio.google.com/prompts/new_chat):
109    /// - Under tools, toggle on Structured output
110    /// - Click Edit
111    /// - Here you can create schema with `Visual Editor` or `Code Editor` with error detection
112    pub fn set_json_mode(mut self, schema: Value) -> Self {
113        let config = self.set_generation_config();
114        config["response_mime_type"] = "application/json".into();
115        config["response_schema"] = schema.into();
116        self
117    }
118    pub fn unset_json_mode(mut self) -> Self {
119        if let Some(ref mut generation_config) = self.generation_config {
120            generation_config["response_schema"] = None::<Value>.into();
121            generation_config["response_mime_type"] = None::<Value>.into();
122        }
123        self
124    }
125    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
126        self.tools = Some(tools);
127        self
128    }
129    pub fn unset_tools(mut self) -> Self {
130        self.tools = None;
131        self
132    }
133
134    #[cfg(feature = "reqwest")]
135    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
136        if !session
137            .get_last_chat()
138            .is_some_and(|chat| *chat.role() != Role::Model)
139        {
140            return Err(GeminiResponseError::NothingToRespond);
141        }
142        let req_url = format!(
143            "{BASE_URL}/{}:generateContent?key={}",
144            self.model, self.api_key
145        );
146
147        let response = self
148            .client
149            .post(req_url)
150            .json(&GeminiRequestBody::new(
151                self.sys_prompt.as_ref(),
152                self.tools.as_deref(),
153                &session.get_history().as_slice(),
154                self.generation_config.as_ref(),
155                self.safety_settings.as_deref(),
156                self.tool_config.as_ref(),
157            ))
158            .send()
159            .await
160            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
161
162        if !response.status().is_success() {
163            let text = response
164                .text()
165                .await
166                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
167            return Err(GeminiResponseError::StatusNotOk(text));
168        }
169
170        let reply = GeminiResponse::new(response)
171            .await
172            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
173        session.update(&reply);
174        Ok(reply)
175    }
176    /// # Warning
177    /// You must read the response stream to get reply stored context in `session`.
178    /// `data_extractor` is used to extract data that you get as a stream of futures.
179    /// # Example
180    ///```ignore
181    ///use futures::StreamExt
182    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
183    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
184    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
185    ///
186    ///while let Some(response) = response_stream.next().await {
187    ///    if let Ok(response) = response {
188    ///        println!("{}", response);
189    ///    }
190    ///}
191    ///```
192    #[cfg(feature = "reqwest")]
193    pub async fn ask_as_stream_with_extractor<F, StreamType>(
194        &self,
195        session: Session,
196        data_extractor: F,
197    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
198    where
199        F: FnMut(&Session, GeminiResponse) -> StreamType,
200    {
201        if !session
202            .get_last_chat()
203            .is_some_and(|chat| *chat.role() != Role::Model)
204        {
205            return Err((session, GeminiResponseError::NothingToRespond));
206        }
207        let req_url = format!(
208            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
209            self.model, self.api_key
210        );
211
212        let request = self
213            .client
214            .post(req_url)
215            .json(&GeminiRequestBody::new(
216                self.sys_prompt.as_ref(),
217                self.tools.as_deref(),
218                session.get_history().as_slice(),
219                self.generation_config.as_ref(),
220                self.safety_settings.as_deref(),
221                self.tool_config.as_ref(),
222            ))
223            .send()
224            .await;
225        let response = match request {
226            Ok(response) => response,
227            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
228        };
229
230        if !response.status().is_success() {
231            let text = match response.text().await {
232                Ok(response) => response,
233                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
234            };
235            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
236        }
237
238        Ok(ResponseStream::new(
239            Box::new(response.bytes_stream()),
240            session,
241            data_extractor,
242        ))
243    }
244    /// # Warning
245    /// You must read the response stream to get reply stored context in `session`.
246    /// # Example
247    ///```ignore
248    ///use futures::StreamExt
249    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
250    ///
251    ///while let Some(response) = response_stream.next().await {
252    ///    if let Ok(response) = response {
253    ///        println!("{}", response.get_text(""));
254    ///    }
255    ///}
256    ///```
257    #[cfg(feature = "reqwest")]
258    pub async fn ask_as_stream(
259        &self,
260        session: Session,
261    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
262        self.ask_as_stream_with_extractor(
263            session,
264            (|_, gemini_response| gemini_response)
265                as fn(&Session, GeminiResponse) -> GeminiResponse,
266        )
267        .await
268    }
269}