gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12#[derive(Clone, Default, Debug)]
18pub struct Gemini {
19 client: Client,
20 api_key: String,
21 model: String,
22 sys_prompt: Option<SystemInstruction>,
23 generation_config: Option<Value>,
24 safety_settings: Option<Vec<SafetySetting>>,
25 tools: Option<Vec<Tool>>,
26 tool_config: Option<ToolConfig>,
27 cached_content: Option<String>,
28}
29
30impl Gemini {
31 pub fn new(
38 api_key: impl Into<String>,
39 model: impl Into<String>,
40 sys_prompt: Option<SystemInstruction>,
41 ) -> Self {
42 Self {
43 client: Client::builder()
44 .timeout(Duration::from_secs(60))
45 .build()
46 .unwrap(),
47 api_key: api_key.into(),
48 model: model.into(),
49 sys_prompt,
50 generation_config: None,
51 safety_settings: None,
52 tools: None,
53 tool_config: None,
54 cached_content: None,
55 }
56 }
57 #[deprecated]
65 pub fn new_with_timeout(
66 api_key: impl Into<String>,
67 model: impl Into<String>,
68 sys_prompt: Option<SystemInstruction>,
69 api_timeout: Duration,
70 ) -> Self {
71 Self {
72 client: Client::builder().timeout(api_timeout).build().unwrap(),
73 api_key: api_key.into(),
74 model: model.into(),
75 sys_prompt,
76 generation_config: None,
77 safety_settings: None,
78 tools: None,
79 tool_config: None,
80 cached_content: None,
81 }
82 }
83 pub fn new_with_client(
91 api_key: impl Into<String>,
92 model: impl Into<String>,
93 sys_prompt: Option<SystemInstruction>,
94 client: Client,
95 ) -> Self {
96 Self {
97 client,
98 api_key: api_key.into(),
99 model: model.into(),
100 sys_prompt,
101 generation_config: None,
102 safety_settings: None,
103 tools: None,
104 tool_config: None,
105 cached_content: None,
106 }
107 }
108 pub fn set_generation_config(&mut self) -> &mut Value {
113 if let None = self.generation_config {
114 self.generation_config = Some(json!({}));
115 }
116 self.generation_config.as_mut().unwrap()
117 }
118 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
119 self.tool_config = Some(config);
120 self
121 }
122 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
123 if let Value::Object(map) = self.set_generation_config() {
124 if let Ok(thinking_value) = serde_json::to_value(config) {
125 map.insert("thinking_config".to_string(), thinking_value);
126 }
127 }
128 self
129 }
130 pub fn set_model(mut self, model: impl Into<String>) -> Self {
131 self.model = model.into();
132 self
133 }
134 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
137 self.sys_prompt = sys_prompt;
138 self
139 }
140 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
141 self.safety_settings = settings;
142 self
143 }
144 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
145 self.api_key = api_key.into();
146 self
147 }
148 pub fn set_json_mode(mut self, schema: Value) -> Self {
156 let config = self.set_generation_config();
157 config["response_mime_type"] = "application/json".into();
158 config["response_schema"] = schema.into();
159 self
160 }
161 pub fn remove_json_mode(mut self) -> Self {
162 if let Some(ref mut generation_config) = self.generation_config {
163 generation_config["response_schema"] = None::<Value>.into();
164 generation_config["response_mime_type"] = None::<Value>.into();
165 }
166 self
167 }
168 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
170 self.tools = Some(tools);
171 self
172 }
173 pub fn remove_tools(mut self) -> Self {
175 self.tools = None;
176 self
177 }
178 pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
179 self.cached_content = Some(name.into());
180 self
181 }
182 pub fn remove_cached_content(mut self) -> Self {
183 self.cached_content = None;
184 self
185 }
186
187 pub async fn create_cache(
190 &self,
191 cached_content: &CachedContent,
192 ) -> Result<CachedContent, GeminiResponseError> {
193 let req_url = format!(
194 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
195 self.api_key
196 );
197
198 let response = self
199 .client
200 .post(req_url)
201 .json(cached_content)
202 .send()
203 .await
204 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
205
206 if !response.status().is_success() {
207 let error = response
208 .json()
209 .await
210 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
211 return Err(GeminiResponseError::StatusNotOk(error));
212 }
213
214 let cached_content: CachedContent = response
215 .json()
216 .await
217 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
218 Ok(cached_content)
219 }
220
221 pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
222 let req_url = format!(
223 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
224 self.api_key
225 );
226
227 let response = self
228 .client
229 .get(req_url)
230 .send()
231 .await
232 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
233
234 if !response.status().is_success() {
235 let error = response
236 .json()
237 .await
238 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
239 return Err(GeminiResponseError::StatusNotOk(error));
240 }
241
242 let list: CachedContentList = response
243 .json()
244 .await
245 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
246 Ok(list)
247 }
248
249 pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
250 let req_url = format!(
251 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
252 name, self.api_key
253 );
254
255 let response = self
256 .client
257 .get(req_url)
258 .send()
259 .await
260 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
261
262 if !response.status().is_success() {
263 let error = response
264 .json()
265 .await
266 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
267 return Err(GeminiResponseError::StatusNotOk(error));
268 }
269
270 let cached_content: CachedContent = response
271 .json()
272 .await
273 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
274 Ok(cached_content)
275 }
276
277 pub async fn update_cache(
278 &self,
279 name: &str,
280 update: &CachedContentUpdate,
281 ) -> Result<CachedContent, GeminiResponseError> {
282 let req_url = format!(
283 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
284 name, self.api_key
285 );
286
287 let response = self
288 .client
289 .patch(req_url)
290 .json(update)
291 .send()
292 .await
293 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
294
295 if !response.status().is_success() {
296 let error = response
297 .json()
298 .await
299 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
300 return Err(GeminiResponseError::StatusNotOk(error));
301 }
302
303 let cached_content: CachedContent = response
304 .json()
305 .await
306 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
307 Ok(cached_content)
308 }
309
310 pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
311 let req_url = format!(
312 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
313 name, self.api_key
314 );
315
316 let response = self
317 .client
318 .delete(req_url)
319 .send()
320 .await
321 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
322
323 if !response.status().is_success() {
324 let error = response
325 .json()
326 .await
327 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
328 return Err(GeminiResponseError::StatusNotOk(error));
329 }
330
331 Ok(())
332 }
333
334 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
341 if session
342 .get_last_chat()
343 .is_some_and(|chat| *chat.role() == Role::Model)
344 {
345 return Err(GeminiResponseError::NothingToRespond);
346 }
347 let req_url = format!(
348 "{BASE_URL}/{}:generateContent?key={}",
349 self.model, self.api_key
350 );
351
352 let response = self
353 .client
354 .post(req_url)
355 .json(&GeminiRequestBody::new(
356 self.sys_prompt.as_ref(),
357 self.tools.as_deref(),
358 &session.get_history().as_slice(),
359 self.generation_config.as_ref(),
360 self.safety_settings.as_deref(),
361 self.tool_config.as_ref(),
362 self.cached_content.clone(),
363 ))
364 .send()
365 .await
366 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
367
368 if !response.status().is_success() {
369 let error = response
370 .json()
371 .await
372 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
373 return Err(GeminiResponseError::StatusNotOk(error));
374 }
375
376 let reply = GeminiResponse::new(response)
377 .await
378 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
379 session.update(&reply);
380 Ok(reply)
381 }
382 pub async fn ask_as_stream_with_extractor<F, StreamType>(
396 &self,
397 session: Session,
398 data_extractor: F,
399 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
400 where
401 F: FnMut(&Session, GeminiResponse) -> StreamType,
402 {
403 if session
404 .get_last_chat()
405 .is_some_and(|chat| *chat.role() == Role::Model)
406 {
407 return Err((session, GeminiResponseError::NothingToRespond));
408 }
409 let req_url = format!(
410 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
411 self.model, self.api_key
412 );
413
414 let request = self
415 .client
416 .post(req_url)
417 .json(&GeminiRequestBody::new(
418 self.sys_prompt.as_ref(),
419 self.tools.as_deref(),
420 session.get_history().as_slice(),
421 self.generation_config.as_ref(),
422 self.safety_settings.as_deref(),
423 self.tool_config.as_ref(),
424 self.cached_content.clone(),
425 ))
426 .send()
427 .await;
428 let response = match request {
429 Ok(response) => response,
430 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
431 };
432
433 if !response.status().is_success() {
434 let error = match response.json().await {
435 Ok(response) => response,
436 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
437 };
438 return Err((session, GeminiResponseError::StatusNotOk(error)));
439 }
440
441 Ok(ResponseStream::new(
442 Box::new(response.bytes_stream()),
443 session,
444 data_extractor,
445 ))
446 }
447 pub async fn ask_as_stream(
466 &self,
467 session: Session,
468 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
469 self.ask_as_stream_with_extractor(
470 session,
471 (|_, gemini_response| gemini_response)
472 as fn(&Session, GeminiResponse) -> GeminiResponse,
473 )
474 .await
475 }
476}