gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 safety_settings: Option<Vec<SafetySetting>>,
19 tools: Option<Vec<Tool>>,
20 tool_config: Option<ToolConfig>,
21}
22impl Gemini {
23 pub fn new(
28 api_key: impl Into<String>,
29 model: impl Into<String>,
30 sys_prompt: Option<SystemInstruction>,
31 ) -> Self {
32 Self {
33 client: Client::builder()
34 .timeout(Duration::from_secs(60))
35 .build()
36 .unwrap(),
37 api_key: api_key.into(),
38 model: model.into(),
39 sys_prompt,
40 generation_config: None,
41 safety_settings: None,
42 tools: None,
43 tool_config: None,
44 }
45 }
46 pub fn new_with_timeout(
48 api_key: impl Into<String>,
49 model: impl Into<String>,
50 sys_prompt: Option<SystemInstruction>,
51 api_timeout: Duration,
52 ) -> Self {
53 Self {
54 client: Client::builder().timeout(api_timeout).build().unwrap(),
55 api_key: api_key.into(),
56 model: model.into(),
57 sys_prompt,
58 generation_config: None,
59 safety_settings: None,
60 tools: None,
61 tool_config: None,
62 }
63 }
64 pub fn set_generation_config(&mut self) -> &mut Value {
66 if let None = self.generation_config {
67 self.generation_config = Some(json!({}));
68 }
69 self.generation_config.as_mut().unwrap()
70 }
71 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
72 self.tool_config = Some(config);
73 self
74 }
75 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
76 if let Value::Object(map) = self.set_generation_config() {
77 if let Ok(thinking_value) = serde_json::to_value(config) {
78 map.insert("thinking_config".to_string(), thinking_value);
79 }
80 }
81 self
82 }
83 pub fn set_model(mut self, model: impl Into<String>) -> Self {
84 self.model = model.into();
85 self
86 }
87 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
88 self.sys_prompt = sys_prompt;
89 self
90 }
91 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
92 self.safety_settings = settings;
93 self
94 }
95 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
96 self.api_key = api_key.into();
97 self
98 }
99 pub fn set_json_mode(mut self, schema: Value) -> Self {
105 let config = self.set_generation_config();
106 config["response_mime_type"] = "application/json".into();
107 config["response_schema"] = schema.into();
108 self
109 }
110 pub fn unset_json_mode(mut self) -> Self {
111 if let Some(ref mut generation_config) = self.generation_config {
112 generation_config["response_schema"] = None::<Value>.into();
113 generation_config["response_mime_type"] = None::<Value>.into();
114 }
115 self
116 }
117 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
118 self.tools = Some(tools);
119 self
120 }
121 pub fn unset_tools(mut self) -> Self {
122 self.tools = None;
123 self
124 }
125
126 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
127 if !session
128 .get_last_chat()
129 .is_some_and(|chat| *chat.role() != Role::Model)
130 {
131 return Err(GeminiResponseError::NothingToRespond);
132 }
133 let req_url = format!(
134 "{BASE_URL}/{}:generateContent?key={}",
135 self.model, self.api_key
136 );
137
138 let response = self
139 .client
140 .post(req_url)
141 .json(&GeminiRequestBody::new(
142 self.sys_prompt.as_ref(),
143 self.tools.as_deref(),
144 &session.get_history().as_slice(),
145 self.generation_config.as_ref(),
146 self.safety_settings.as_deref(),
147 self.tool_config.as_ref(),
148 ))
149 .send()
150 .await
151 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
152
153 if !response.status().is_success() {
154 let text = response
155 .text()
156 .await
157 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
158 return Err(GeminiResponseError::StatusNotOk(text));
159 }
160
161 let reply = GeminiResponse::new(response)
162 .await
163 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
164 session.update(&reply);
165 Ok(reply)
166 }
167 pub async fn ask_as_stream_with_extractor<F, StreamType>(
184 &self,
185 session: Session,
186 data_extractor: F,
187 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
188 where
189 F: FnMut(&Session, GeminiResponse) -> StreamType,
190 {
191 if !session
192 .get_last_chat()
193 .is_some_and(|chat| *chat.role() != Role::Model)
194 {
195 return Err((session, GeminiResponseError::NothingToRespond));
196 }
197 let req_url = format!(
198 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
199 self.model, self.api_key
200 );
201
202 let request = self
203 .client
204 .post(req_url)
205 .json(&GeminiRequestBody::new(
206 self.sys_prompt.as_ref(),
207 self.tools.as_deref(),
208 session.get_history().as_slice(),
209 self.generation_config.as_ref(),
210 self.safety_settings.as_deref(),
211 self.tool_config.as_ref(),
212 ))
213 .send()
214 .await;
215 let response = match request {
216 Ok(response) => response,
217 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
218 };
219
220 if !response.status().is_success() {
221 let text = match response.text().await {
222 Ok(response) => response,
223 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
224 };
225 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
226 }
227
228 Ok(ResponseStream::new(
229 Box::new(response.bytes_stream()),
230 session,
231 data_extractor,
232 ))
233 }
234 pub async fn ask_as_stream(
248 &self,
249 session: Session,
250 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
251 self.ask_as_stream_with_extractor(
252 session,
253 (|_, gemini_response| gemini_response)
254 as fn(&Session, GeminiResponse) -> GeminiResponse,
255 )
256 .await
257 }
258}