Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    safety_settings: Option<Vec<SafetySetting>>,
19    tools: Option<Vec<Tool>>,
20    tool_config: Option<ToolConfig>,
21}
22impl Gemini {
23    /// # Arguments
24    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
25    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
26    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
27    pub fn new(
28        api_key: impl Into<String>,
29        model: impl Into<String>,
30        sys_prompt: Option<SystemInstruction>,
31    ) -> Self {
32        Self {
33            client: Client::builder()
34                .timeout(Duration::from_secs(60))
35                .build()
36                .unwrap(),
37            api_key: api_key.into(),
38            model: model.into(),
39            sys_prompt,
40            generation_config: None,
41            safety_settings: None,
42            tools: None,
43            tool_config: None,
44        }
45    }
46    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
47    pub fn new_with_timeout(
48        api_key: impl Into<String>,
49        model: impl Into<String>,
50        sys_prompt: Option<SystemInstruction>,
51        api_timeout: Duration,
52    ) -> Self {
53        Self {
54            client: Client::builder().timeout(api_timeout).build().unwrap(),
55            api_key: api_key.into(),
56            model: model.into(),
57            sys_prompt,
58            generation_config: None,
59            safety_settings: None,
60            tools: None,
61            tool_config: None,
62        }
63    }
64    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig)
65    pub fn set_generation_config(&mut self) -> &mut Value {
66        if let None = self.generation_config {
67            self.generation_config = Some(json!({}));
68        }
69        self.generation_config.as_mut().unwrap()
70    }
71    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
72        self.tool_config = Some(config);
73        self
74    }
75    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
76        if let Value::Object(map) = self.set_generation_config() {
77            if let Ok(thinking_value) = serde_json::to_value(config) {
78                map.insert("thinking_config".to_string(), thinking_value);
79            }
80        }
81        self
82    }
83    pub fn set_model(mut self, model: impl Into<String>) -> Self {
84        self.model = model.into();
85        self
86    }
87    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
88        self.sys_prompt = sys_prompt;
89        self
90    }
91    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
92        self.safety_settings = settings;
93        self
94    }
95    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
96        self.api_key = api_key.into();
97        self
98    }
99    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
100    /// To verify your schema visit [here](https://aistudio.google.com/prompts/new_chat):
101    /// - Under tools, toggle on Structured output
102    /// - Click Edit
103    /// - Here you can create schema with `Visual Editor` or `Code Editor` with error detection
104    pub fn set_json_mode(mut self, schema: Value) -> Self {
105        let config = self.set_generation_config();
106        config["response_mime_type"] = "application/json".into();
107        config["response_schema"] = schema.into();
108        self
109    }
110    pub fn unset_json_mode(mut self) -> Self {
111        if let Some(ref mut generation_config) = self.generation_config {
112            generation_config["response_schema"] = None::<Value>.into();
113            generation_config["response_mime_type"] = None::<Value>.into();
114        }
115        self
116    }
117    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
118        self.tools = Some(tools);
119        self
120    }
121    pub fn unset_tools(mut self) -> Self {
122        self.tools = None;
123        self
124    }
125
126    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
127        if !session
128            .get_last_chat()
129            .is_some_and(|chat| *chat.role() != Role::Model)
130        {
131            return Err(GeminiResponseError::NothingToRespond);
132        }
133        let req_url = format!(
134            "{BASE_URL}/{}:generateContent?key={}",
135            self.model, self.api_key
136        );
137
138        let response = self
139            .client
140            .post(req_url)
141            .json(&GeminiRequestBody::new(
142                self.sys_prompt.as_ref(),
143                self.tools.as_deref(),
144                &session.get_history().as_slice(),
145                self.generation_config.as_ref(),
146                self.safety_settings.as_deref(),
147                self.tool_config.as_ref(),
148            ))
149            .send()
150            .await
151            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
152
153        if !response.status().is_success() {
154            let text = response
155                .text()
156                .await
157                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
158            return Err(GeminiResponseError::StatusNotOk(text));
159        }
160
161        let reply = GeminiResponse::new(response)
162            .await
163            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
164        session.update(&reply);
165        Ok(reply)
166    }
167    /// # Warning
168    /// You must read the response stream to get reply stored context in `session`.
169    /// `data_extractor` is used to extract data that you get as a stream of futures.
170    /// # Example
171    ///```ignore
172    ///use futures::StreamExt
173    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
174    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
175    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
176    ///
177    ///while let Some(response) = response_stream.next().await {
178    ///    if let Ok(response) = response {
179    ///        println!("{}", response);
180    ///    }
181    ///}
182    ///```
183    pub async fn ask_as_stream_with_extractor<F, StreamType>(
184        &self,
185        session: Session,
186        data_extractor: F,
187    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
188    where
189        F: FnMut(&Session, GeminiResponse) -> StreamType,
190    {
191        if !session
192            .get_last_chat()
193            .is_some_and(|chat| *chat.role() != Role::Model)
194        {
195            return Err((session, GeminiResponseError::NothingToRespond));
196        }
197        let req_url = format!(
198            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
199            self.model, self.api_key
200        );
201
202        let request = self
203            .client
204            .post(req_url)
205            .json(&GeminiRequestBody::new(
206                self.sys_prompt.as_ref(),
207                self.tools.as_deref(),
208                session.get_history().as_slice(),
209                self.generation_config.as_ref(),
210                self.safety_settings.as_deref(),
211                self.tool_config.as_ref(),
212            ))
213            .send()
214            .await;
215        let response = match request {
216            Ok(response) => response,
217            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
218        };
219
220        if !response.status().is_success() {
221            let text = match response.text().await {
222                Ok(response) => response,
223                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
224            };
225            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
226        }
227
228        Ok(ResponseStream::new(
229            Box::new(response.bytes_stream()),
230            session,
231            data_extractor,
232        ))
233    }
234    /// # Warning
235    /// You must read the response stream to get reply stored context in `session`.
236    /// # Example
237    ///```ignore
238    ///use futures::StreamExt
239    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
240    ///
241    ///while let Some(response) = response_stream.next().await {
242    ///    if let Ok(response) = response {
243    ///        println!("{}", response.get_text(""));
244    ///    }
245    ///}
246    ///```
247    pub async fn ask_as_stream(
248        &self,
249        session: Session,
250    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
251        self.ask_as_stream_with_extractor(
252            session,
253            (|_, gemini_response| gemini_response)
254                as fn(&Session, GeminiResponse) -> GeminiResponse,
255        )
256        .await
257    }
258}