gemini_client_api/gemini/
ask.rs1use super::types::request::*;
2use super::types::response::*;
3use super::types::sessions::Session;
4use awc::Client;
5use serde_json::{Value, json};
6use std::time::Duration;
7
8const API_TIMEOUT: Duration = Duration::from_secs(30);
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11pub struct Gemini<'a> {
12 client: Client,
13 api_key: String,
14 model: String,
15 sys_prompt: Option<SystemInstruction<'a>>,
16 generation_config: Option<Value>,
17 tools: Option<Vec<Tool>>,
18}
19impl<'a> Gemini<'a> {
20 pub fn new(api_key: String, model: String, sys_prompt: Option<SystemInstruction<'a>>) -> Self {
22 Self {
23 client: Client::builder().timeout(API_TIMEOUT).finish(),
24 api_key,
25 model,
26 sys_prompt,
27 generation_config: None,
28 tools: None,
29 }
30 }
31 pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
33 self.generation_config = Some(generation_config);
34 self
35 }
36 pub fn set_model(&mut self, model: String) {
37 self.model = model;
38 }
39 pub fn set_api_key(&mut self, api_key: String) {
40 self.api_key = api_key;
41 }
42 pub fn set_json_mode(&mut self, schema: Value) -> &Self {
44 if let None = self.generation_config {
45 self.generation_config = Some(json!({
46 "response_mime_type": "application/json",
47 "response_schema":schema
48 }))
49 } else if let Some(config) = self.generation_config.as_mut() {
50 config["response_mime_type"] = "application/json".into();
51 config["response_schema"] = schema.into();
52 }
53 self
54 }
55 pub fn unset_json_mode(&mut self) -> &Self {
56 if let Some(ref mut generation_config) = self.generation_config {
57 generation_config["response_schema"] = None::<Value>.into();
58 generation_config["response_mime_type"] = None::<Value>.into();
59 }
60 self
61 }
62 pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &Self {
65 self.tools = tools;
66 self
67 }
68 pub fn unset_code_execution_mode(&mut self) -> &Self {
69 self.tools.take();
70 self
71 }
72
73 pub async fn ask<'b>(
74 &self,
75 session: &'b mut Session,
76 ) -> Result<GeminiResponse, Box<dyn std::error::Error>> {
77 let req_url = format!(
78 "{BASE_URL}/{}:generateContent?key={}",
79 self.model, self.api_key
80 );
81
82 let response = self
83 .client
84 .post(req_url)
85 .send_json(&GeminiBody::new(
86 self.sys_prompt.as_ref(),
87 self.tools.as_deref(),
88 &session.get_history().as_slice(),
89 self.generation_config.as_ref(),
90 ))
91 .await?;
92 let reply = GeminiResponse::new(response).await?;
93 session.update(&reply);
94 Ok(reply)
95 }
96 pub async fn ask_as_stream<'b>(
97 &self,
98 session: &'b mut Session,
99 ) -> Result<GeminiResponseStream<'b>, Box<dyn std::error::Error>> {
100 let req_url = format!(
101 "{BASE_URL}/{}:streamGenerateContent?key={}",
102 self.model, self.api_key
103 );
104
105 let response = self
106 .client
107 .post(req_url)
108 .send_json(&GeminiBody::new(
109 self.sys_prompt.as_ref(),
110 self.tools.as_deref(),
111 session.get_history().as_slice(),
112 self.generation_config.as_ref(),
113 ))
114 .await?;
115 if !response.status().is_success() {
116 return Err(format!(
117 "Found status due to {} from Gemini endpoint",
118 response.status()
119 )
120 .into());
121 }
122
123 Ok(GeminiResponseStream::new(response, session))
124 }
125}