gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21 pub fn new(
26 api_key: impl Into<String>,
27 model: impl Into<String>,
28 sys_prompt: Option<SystemInstruction>,
29 ) -> Self {
30 Self {
31 client: Client::builder()
32 .timeout(Duration::from_secs(60))
33 .build()
34 .unwrap(),
35 api_key: api_key.into(),
36 model: model.into(),
37 sys_prompt,
38 generation_config: None,
39 tools: None,
40 }
41 }
42 pub fn new_with_timeout(
44 api_key: impl Into<String>,
45 model: impl Into<String>,
46 sys_prompt: Option<SystemInstruction>,
47 api_timeout: Duration,
48 ) -> Self {
49 Self {
50 client: Client::builder().timeout(api_timeout).build().unwrap(),
51 api_key: api_key.into(),
52 model: model.into(),
53 sys_prompt,
54 generation_config: None,
55 tools: None,
56 }
57 }
58 pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
60 self.generation_config = Some(generation_config);
61 self
62 }
63 pub fn set_model(&mut self, model: impl Into<String>) -> &mut Self {
64 self.model = model.into();
65 self
66 }
67 pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
68 self.api_key = api_key.into();
69 self
70 }
71 pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
73 if let None = self.generation_config {
74 self.generation_config = Some(json!({
75 "response_mime_type": "application/json",
76 "response_schema":schema
77 }))
78 } else if let Some(config) = self.generation_config.as_mut() {
79 config["response_mime_type"] = "application/json".into();
80 config["response_schema"] = schema.into();
81 }
82 self
83 }
84 pub fn unset_json_mode(&mut self) -> &mut Self {
85 if let Some(ref mut generation_config) = self.generation_config {
86 generation_config["response_schema"] = None::<Value>.into();
87 generation_config["response_mime_type"] = None::<Value>.into();
88 }
89 self
90 }
91 pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
94 self.tools = tools;
95 self
96 }
97 pub fn unset_code_execution_mode(&mut self) -> &mut Self {
98 self.tools.take();
99 self
100 }
101
102 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
103 let req_url = format!(
104 "{BASE_URL}/{}:generateContent?key={}",
105 self.model, self.api_key
106 );
107
108 let response = self
109 .client
110 .post(req_url)
111 .json(&GeminiRequestBody::new(
112 self.sys_prompt.as_ref(),
113 self.tools.as_deref(),
114 &session.get_history().as_slice(),
115 self.generation_config.as_ref(),
116 ))
117 .send()
118 .await?;
119
120 if !response.status().is_success() {
121 let text = response.text().await?;
122 return Err(text.into());
123 }
124
125 let reply = GeminiResponse::new(response).await?;
126 session.update(&reply);
127 Ok(reply)
128 }
129 pub async fn ask_as_stream_with_extractor<F, StreamType>(
146 &self,
147 session: Session,
148 data_extractor: F,
149 ) -> Result<ResponseStream<F, StreamType>, GeminiResponseError>
150 where
151 F: FnMut(&Session, GeminiResponse) -> StreamType,
152 {
153 let req_url = format!(
154 "{BASE_URL}/{}:streamGenerateContent?key={}",
155 self.model, self.api_key
156 );
157
158 let response = self
159 .client
160 .post(req_url)
161 .json(&GeminiRequestBody::new(
162 self.sys_prompt.as_ref(),
163 self.tools.as_deref(),
164 session.get_history().as_slice(),
165 self.generation_config.as_ref(),
166 ))
167 .send()
168 .await?;
169
170 if !response.status().is_success() {
171 let text = response.text().await?;
172 return Err(text.into());
173 }
174
175 Ok(ResponseStream::new(
176 Box::new(response.bytes_stream()),
177 session,
178 data_extractor,
179 ))
180 }
181 pub async fn ask_as_stream(
195 &self,
196 session: Session,
197 ) -> Result<GeminiResponseStream, GeminiResponseError> {
198 self.ask_as_stream_with_extractor(
199 session,
200 (|_, gemini_response| gemini_response)
201 as fn(&Session, GeminiResponse) -> GeminiResponse,
202 )
203 .await
204 }
205}