gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21 pub fn new(
26 api_key: impl Into<String>,
27 model: impl Into<String>,
28 sys_prompt: Option<SystemInstruction>,
29 ) -> Self {
30 Self {
31 client: Client::builder()
32 .timeout(Duration::from_secs(60))
33 .build()
34 .unwrap(),
35 api_key: api_key.into(),
36 model: model.into(),
37 sys_prompt,
38 generation_config: None,
39 tools: None,
40 }
41 }
42 pub fn new_with_timeout(
44 api_key: impl Into<String>,
45 model: impl Into<String>,
46 sys_prompt: Option<SystemInstruction>,
47 api_timeout: Duration,
48 ) -> Self {
49 Self {
50 client: Client::builder().timeout(api_timeout).build().unwrap(),
51 api_key: api_key.into(),
52 model: model.into(),
53 sys_prompt,
54 generation_config: None,
55 tools: None,
56 }
57 }
58 pub fn set_generation_config(mut self, generation_config: Value) -> Self {
60 self.generation_config = Some(generation_config);
61 self
62 }
63 pub fn set_model(mut self, model: impl Into<String>) -> Self {
64 self.model = model.into();
65 self
66 }
67 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
68 self.api_key = api_key.into();
69 self
70 }
71 pub fn set_json_mode(mut self, schema: Value) -> Self {
77 if let None = self.generation_config {
78 self.generation_config = Some(json!({
79 "response_mime_type": "application/json",
80 "response_schema":schema
81 }))
82 } else if let Some(config) = self.generation_config.as_mut() {
83 config["response_mime_type"] = "application/json".into();
84 config["response_schema"] = schema.into();
85 }
86 self
87 }
88 pub fn unset_json_mode(mut self) -> Self {
89 if let Some(ref mut generation_config) = self.generation_config {
90 generation_config["response_schema"] = None::<Value>.into();
91 generation_config["response_mime_type"] = None::<Value>.into();
92 }
93 self
94 }
95 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
96 self.tools = Some(tools);
97 self
98 }
99 pub fn unset_tools(mut self) -> Self {
100 self.tools = None;
101 self
102 }
103
104 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
105 let req_url = format!(
106 "{BASE_URL}/{}:generateContent?key={}",
107 self.model, self.api_key
108 );
109
110 let response = self
111 .client
112 .post(req_url)
113 .json(&GeminiRequestBody::new(
114 self.sys_prompt.as_ref(),
115 self.tools.as_deref(),
116 &session.get_history().as_slice(),
117 self.generation_config.as_ref(),
118 ))
119 .send()
120 .await
121 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
122
123 if !response.status().is_success() {
124 let text = response
125 .text()
126 .await
127 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
128 return Err(GeminiResponseError::StatusNotOk(text));
129 }
130
131 let reply = GeminiResponse::new(response)
132 .await
133 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
134 session.update(&reply);
135 Ok(reply)
136 }
137 pub async fn ask_as_stream_with_extractor<F, StreamType>(
154 &self,
155 session: Session,
156 data_extractor: F,
157 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
158 where
159 F: FnMut(&Session, GeminiResponse) -> StreamType,
160 {
161 let req_url = format!(
162 "{BASE_URL}/{}:streamGenerateContent?key={}",
163 self.model, self.api_key
164 );
165
166 let request = self
167 .client
168 .post(req_url)
169 .json(&GeminiRequestBody::new(
170 self.sys_prompt.as_ref(),
171 self.tools.as_deref(),
172 session.get_history().as_slice(),
173 self.generation_config.as_ref(),
174 ))
175 .send()
176 .await;
177 let response = match request {
178 Ok(response) => response,
179 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
180 };
181
182 if !response.status().is_success() {
183 let text = match response.text().await {
184 Ok(response) => response,
185 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
186 };
187 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
188 }
189
190 Ok(ResponseStream::new(
191 Box::new(response.bytes_stream()),
192 session,
193 data_extractor,
194 ))
195 }
196 pub async fn ask_as_stream(
210 &self,
211 session: Session,
212 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
213 self.ask_as_stream_with_extractor(
214 session,
215 (|_, gemini_response| gemini_response)
216 as fn(&Session, GeminiResponse) -> GeminiResponse,
217 )
218 .await
219 }
220}