gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 safety_settings: Option<Vec<SafetySetting>>,
19 tools: Option<Vec<Tool>>,
20 tool_config: Option<ToolConfig>,
21}
22impl Gemini {
23 pub fn new(
28 api_key: impl Into<String>,
29 model: impl Into<String>,
30 sys_prompt: Option<SystemInstruction>,
31 ) -> Self {
32 Self {
33 client: Client::builder()
34 .timeout(Duration::from_secs(60))
35 .build()
36 .unwrap(),
37 api_key: api_key.into(),
38 model: model.into(),
39 sys_prompt,
40 generation_config: None,
41 safety_settings: None,
42 tools: None,
43 tool_config: None,
44 }
45 }
46 pub fn new_with_timeout(
48 api_key: impl Into<String>,
49 model: impl Into<String>,
50 sys_prompt: Option<SystemInstruction>,
51 api_timeout: Duration,
52 ) -> Self {
53 Self {
54 client: Client::builder().timeout(api_timeout).build().unwrap(),
55 api_key: api_key.into(),
56 model: model.into(),
57 sys_prompt,
58 generation_config: None,
59 safety_settings: None,
60 tools: None,
61 tool_config: None,
62 }
63 }
64 pub fn set_generation_config(&mut self) -> &mut Value {
66 if let None = self.generation_config {
67 self.generation_config = Some(json!({}));
68 }
69 self.generation_config.as_mut().unwrap()
70 }
71 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
72 self.tool_config = Some(config);
73 self
74 }
75 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
76 if let Value::Object(map) = self.set_generation_config() {
77 if let Ok(thinking_value) = serde_json::to_value(config) {
78 map.insert("thinking_config".to_string(), thinking_value);
79 }
80 }
81 self
82 }
83 pub fn set_model(mut self, model: impl Into<String>) -> Self {
84 self.model = model.into();
85 self
86 }
87 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
88 self.sys_prompt = sys_prompt;
89 self
90 }
91 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
92 self.safety_settings = settings;
93 self
94 }
95 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
96 self.api_key = api_key.into();
97 self
98 }
99 pub fn set_json_mode(mut self, schema: Value) -> Self {
108 let config = self.set_generation_config();
109 config["response_mime_type"] = "application/json".into();
110 config["response_schema"] = schema.into();
111 self
112 }
113 pub fn unset_json_mode(mut self) -> Self {
114 if let Some(ref mut generation_config) = self.generation_config {
115 generation_config["response_schema"] = None::<Value>.into();
116 generation_config["response_mime_type"] = None::<Value>.into();
117 }
118 self
119 }
120 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
121 self.tools = Some(tools);
122 self
123 }
124 pub fn unset_tools(mut self) -> Self {
125 self.tools = None;
126 self
127 }
128
129 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
130 if !session
131 .get_last_chat()
132 .is_some_and(|chat| *chat.role() != Role::Model)
133 {
134 return Err(GeminiResponseError::NothingToRespond);
135 }
136 let req_url = format!(
137 "{BASE_URL}/{}:generateContent?key={}",
138 self.model, self.api_key
139 );
140
141 let response = self
142 .client
143 .post(req_url)
144 .json(&GeminiRequestBody::new(
145 self.sys_prompt.as_ref(),
146 self.tools.as_deref(),
147 &session.get_history().as_slice(),
148 self.generation_config.as_ref(),
149 self.safety_settings.as_deref(),
150 self.tool_config.as_ref(),
151 ))
152 .send()
153 .await
154 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
155
156 if !response.status().is_success() {
157 let text = response
158 .text()
159 .await
160 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
161 return Err(GeminiResponseError::StatusNotOk(text));
162 }
163
164 let reply = GeminiResponse::new(response)
165 .await
166 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
167 session.update(&reply);
168 Ok(reply)
169 }
170 pub async fn ask_as_stream_with_extractor<F, StreamType>(
187 &self,
188 session: Session,
189 data_extractor: F,
190 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
191 where
192 F: FnMut(&Session, GeminiResponse) -> StreamType,
193 {
194 if !session
195 .get_last_chat()
196 .is_some_and(|chat| *chat.role() != Role::Model)
197 {
198 return Err((session, GeminiResponseError::NothingToRespond));
199 }
200 let req_url = format!(
201 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
202 self.model, self.api_key
203 );
204
205 let request = self
206 .client
207 .post(req_url)
208 .json(&GeminiRequestBody::new(
209 self.sys_prompt.as_ref(),
210 self.tools.as_deref(),
211 session.get_history().as_slice(),
212 self.generation_config.as_ref(),
213 self.safety_settings.as_deref(),
214 self.tool_config.as_ref(),
215 ))
216 .send()
217 .await;
218 let response = match request {
219 Ok(response) => response,
220 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
221 };
222
223 if !response.status().is_success() {
224 let text = match response.text().await {
225 Ok(response) => response,
226 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
227 };
228 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
229 }
230
231 Ok(ResponseStream::new(
232 Box::new(response.bytes_stream()),
233 session,
234 data_extractor,
235 ))
236 }
237 pub async fn ask_as_stream(
251 &self,
252 session: Session,
253 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
254 self.ask_as_stream_with_extractor(
255 session,
256 (|_, gemini_response| gemini_response)
257 as fn(&Session, GeminiResponse) -> GeminiResponse,
258 )
259 .await
260 }
261}