Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    safety_settings: Option<Vec<SafetySetting>>,
19    tools: Option<Vec<Tool>>,
20    tool_config: Option<ToolConfig>,
21}
22impl Gemini {
23    /// # Arguments
24    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
25    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
26    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
27    pub fn new(
28        api_key: impl Into<String>,
29        model: impl Into<String>,
30        sys_prompt: Option<SystemInstruction>,
31    ) -> Self {
32        Self {
33            client: Client::builder()
34                .timeout(Duration::from_secs(60))
35                .build()
36                .unwrap(),
37            api_key: api_key.into(),
38            model: model.into(),
39            sys_prompt,
40            generation_config: None,
41            safety_settings: None,
42            tools: None,
43            tool_config: None,
44        }
45    }
46    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
47    pub fn new_with_timeout(
48        api_key: impl Into<String>,
49        model: impl Into<String>,
50        sys_prompt: Option<SystemInstruction>,
51        api_timeout: Duration,
52    ) -> Self {
53        Self {
54            client: Client::builder().timeout(api_timeout).build().unwrap(),
55            api_key: api_key.into(),
56            model: model.into(),
57            sys_prompt,
58            generation_config: None,
59            safety_settings: None,
60            tools: None,
61            tool_config: None,
62        }
63    }
64    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig)
65    pub fn set_generation_config(&mut self) -> &mut Value {
66        if let None = self.generation_config {
67            self.generation_config = Some(json!({}));
68        }
69        self.generation_config.as_mut().unwrap()
70    }
71    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
72        self.tool_config = Some(config);
73        self
74    }
75    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
76        if let Value::Object(map) = self.set_generation_config() {
77            if let Ok(thinking_value) = serde_json::to_value(config) {
78                map.insert("thinking_config".to_string(), thinking_value);
79            }
80        }
81        self
82    }
83    pub fn set_model(mut self, model: impl Into<String>) -> Self {
84        self.model = model.into();
85        self
86    }
87    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
88        self.sys_prompt = sys_prompt;
89        self
90    }
91    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
92        self.safety_settings = settings;
93        self
94    }
95    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
96        self.api_key = api_key.into();
97        self
98    }
99    ///To use struct as schema write #[gemini_schema] above struct then pass
100    /// `StructName::gemini_schema()`
101    /// OR
102    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
103    /// To verify your schema visit [here](https://aistudio.google.com/prompts/new_chat):
104    /// - Under tools, toggle on Structured output
105    /// - Click Edit
106    /// - Here you can create schema with `Visual Editor` or `Code Editor` with error detection
107    pub fn set_json_mode(mut self, schema: Value) -> Self {
108        let config = self.set_generation_config();
109        config["response_mime_type"] = "application/json".into();
110        config["response_schema"] = schema.into();
111        self
112    }
113    pub fn unset_json_mode(mut self) -> Self {
114        if let Some(ref mut generation_config) = self.generation_config {
115            generation_config["response_schema"] = None::<Value>.into();
116            generation_config["response_mime_type"] = None::<Value>.into();
117        }
118        self
119    }
120    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
121        self.tools = Some(tools);
122        self
123    }
124    pub fn unset_tools(mut self) -> Self {
125        self.tools = None;
126        self
127    }
128
129    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
130        if !session
131            .get_last_chat()
132            .is_some_and(|chat| *chat.role() != Role::Model)
133        {
134            return Err(GeminiResponseError::NothingToRespond);
135        }
136        let req_url = format!(
137            "{BASE_URL}/{}:generateContent?key={}",
138            self.model, self.api_key
139        );
140
141        let response = self
142            .client
143            .post(req_url)
144            .json(&GeminiRequestBody::new(
145                self.sys_prompt.as_ref(),
146                self.tools.as_deref(),
147                &session.get_history().as_slice(),
148                self.generation_config.as_ref(),
149                self.safety_settings.as_deref(),
150                self.tool_config.as_ref(),
151            ))
152            .send()
153            .await
154            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
155
156        if !response.status().is_success() {
157            let text = response
158                .text()
159                .await
160                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
161            return Err(GeminiResponseError::StatusNotOk(text));
162        }
163
164        let reply = GeminiResponse::new(response)
165            .await
166            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
167        session.update(&reply);
168        Ok(reply)
169    }
170    /// # Warning
171    /// You must read the response stream to get reply stored context in `session`.
172    /// `data_extractor` is used to extract data that you get as a stream of futures.
173    /// # Example
174    ///```ignore
175    ///use futures::StreamExt
176    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
177    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
178    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
179    ///
180    ///while let Some(response) = response_stream.next().await {
181    ///    if let Ok(response) = response {
182    ///        println!("{}", response);
183    ///    }
184    ///}
185    ///```
186    pub async fn ask_as_stream_with_extractor<F, StreamType>(
187        &self,
188        session: Session,
189        data_extractor: F,
190    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
191    where
192        F: FnMut(&Session, GeminiResponse) -> StreamType,
193    {
194        if !session
195            .get_last_chat()
196            .is_some_and(|chat| *chat.role() != Role::Model)
197        {
198            return Err((session, GeminiResponseError::NothingToRespond));
199        }
200        let req_url = format!(
201            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
202            self.model, self.api_key
203        );
204
205        let request = self
206            .client
207            .post(req_url)
208            .json(&GeminiRequestBody::new(
209                self.sys_prompt.as_ref(),
210                self.tools.as_deref(),
211                session.get_history().as_slice(),
212                self.generation_config.as_ref(),
213                self.safety_settings.as_deref(),
214                self.tool_config.as_ref(),
215            ))
216            .send()
217            .await;
218        let response = match request {
219            Ok(response) => response,
220            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
221        };
222
223        if !response.status().is_success() {
224            let text = match response.text().await {
225                Ok(response) => response,
226                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
227            };
228            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
229        }
230
231        Ok(ResponseStream::new(
232            Box::new(response.bytes_stream()),
233            session,
234            data_extractor,
235        ))
236    }
237    /// # Warning
238    /// You must read the response stream to get reply stored context in `session`.
239    /// # Example
240    ///```ignore
241    ///use futures::StreamExt
242    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
243    ///
244    ///while let Some(response) = response_stream.next().await {
245    ///    if let Ok(response) = response {
246    ///        println!("{}", response.get_text(""));
247    ///    }
248    ///}
249    ///```
250    pub async fn ask_as_stream(
251        &self,
252        session: Session,
253    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
254        self.ask_as_stream_with_extractor(
255            session,
256            (|_, gemini_response| gemini_response)
257                as fn(&Session, GeminiResponse) -> GeminiResponse,
258        )
259        .await
260    }
261}