gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 safety_settings: Option<Vec<SafetySetting>>,
19 tools: Option<Vec<Tool>>,
20 tool_config: Option<ToolConfig>,
21}
22impl Gemini {
23 pub fn new(
28 api_key: impl Into<String>,
29 model: impl Into<String>,
30 sys_prompt: Option<SystemInstruction>,
31 ) -> Self {
32 Self {
33 client: Client::builder()
34 .timeout(Duration::from_secs(60))
35 .build()
36 .unwrap(),
37 api_key: api_key.into(),
38 model: model.into(),
39 sys_prompt,
40 generation_config: None,
41 safety_settings: None,
42 tools: None,
43 tool_config: None,
44 }
45 }
46 pub fn new_with_timeout(
48 api_key: impl Into<String>,
49 model: impl Into<String>,
50 sys_prompt: Option<SystemInstruction>,
51 api_timeout: Duration,
52 ) -> Self {
53 Self {
54 client: Client::builder().timeout(api_timeout).build().unwrap(),
55 api_key: api_key.into(),
56 model: model.into(),
57 sys_prompt,
58 generation_config: None,
59 safety_settings: None,
60 tools: None,
61 tool_config: None,
62 }
63 }
64 pub fn set_generation_config(&mut self) -> &mut Value {
66 if let None = self.generation_config {
67 self.generation_config = Some(json!({}));
68 }
69 self.generation_config.as_mut().unwrap()
70 }
71 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
72 self.tool_config = Some(config);
73 self
74 }
75 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
76 if let Value::Object(map) = self.set_generation_config() {
77 if let Ok(thinking_value) = serde_json::to_value(config) {
78 map.insert("thinking_config".to_string(), thinking_value);
79 }
80 }
81 self
82 }
83 pub fn set_model(mut self, model: impl Into<String>) -> Self {
84 self.model = model.into();
85 self
86 }
87 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
88 self.sys_prompt = sys_prompt;
89 self
90 }
91 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
92 self.safety_settings = settings;
93 self
94 }
95 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
96 self.api_key = api_key.into();
97 self
98 }
99 pub fn set_json_mode(mut self, schema: Value) -> Self {
105 let config = self.set_generation_config();
106 config["response_mime_type"] = "application/json".into();
107 config["response_schema"] = schema.into();
108 self
109 }
110 pub fn unset_json_mode(mut self) -> Self {
111 if let Some(ref mut generation_config) = self.generation_config {
112 generation_config["response_schema"] = None::<Value>.into();
113 generation_config["response_mime_type"] = None::<Value>.into();
114 }
115 self
116 }
117 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
118 self.tools = Some(tools);
119 self
120 }
121 pub fn unset_tools(mut self) -> Self {
122 self.tools = None;
123 self
124 }
125
126 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
127 let req_url = format!(
128 "{BASE_URL}/{}:generateContent?key={}",
129 self.model, self.api_key
130 );
131
132 let response = self
133 .client
134 .post(req_url)
135 .json(&GeminiRequestBody::new(
136 self.sys_prompt.as_ref(),
137 self.tools.as_deref(),
138 &session.get_history().as_slice(),
139 self.generation_config.as_ref(),
140 self.safety_settings.as_deref(),
141 self.tool_config.as_ref(),
142 ))
143 .send()
144 .await
145 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
146
147 if !response.status().is_success() {
148 let text = response
149 .text()
150 .await
151 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
152 return Err(GeminiResponseError::StatusNotOk(text));
153 }
154
155 let reply = GeminiResponse::new(response)
156 .await
157 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
158 session.update(&reply);
159 Ok(reply)
160 }
161 pub async fn ask_as_stream_with_extractor<F, StreamType>(
178 &self,
179 session: Session,
180 data_extractor: F,
181 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
182 where
183 F: FnMut(&Session, GeminiResponse) -> StreamType,
184 {
185 let req_url = format!(
186 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
187 self.model, self.api_key
188 );
189
190 let request = self
191 .client
192 .post(req_url)
193 .json(&GeminiRequestBody::new(
194 self.sys_prompt.as_ref(),
195 self.tools.as_deref(),
196 session.get_history().as_slice(),
197 self.generation_config.as_ref(),
198 self.safety_settings.as_deref(),
199 self.tool_config.as_ref(),
200 ))
201 .send()
202 .await;
203 let response = match request {
204 Ok(response) => response,
205 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
206 };
207
208 if !response.status().is_success() {
209 let text = match response.text().await {
210 Ok(response) => response,
211 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
212 };
213 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
214 }
215
216 Ok(ResponseStream::new(
217 Box::new(response.bytes_stream()),
218 session,
219 data_extractor,
220 ))
221 }
222 pub async fn ask_as_stream(
236 &self,
237 session: Session,
238 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
239 self.ask_as_stream_with_extractor(
240 session,
241 (|_, gemini_response| gemini_response)
242 as fn(&Session, GeminiResponse) -> GeminiResponse,
243 )
244 .await
245 }
246}