1use crate::{
2 models::{
3 Content, FunctionCallingConfig, FunctionCallingMode, GenerateContentRequest,
4 GenerationConfig, GenerationResponse, Message, Role, ToolConfig,
5 },
6 tools::{FunctionDeclaration, Tool},
7 Error, Result,
8};
9use futures::stream::Stream;
10use futures_util::StreamExt;
11use reqwest::Client;
12use std::pin::Pin;
13use std::sync::Arc;
14use url::Url;
15
16const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
17const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
18
19pub struct ContentBuilder {
21 client: Arc<GeminiClient>,
22 pub contents: Vec<Content>,
23 generation_config: Option<GenerationConfig>,
24 tools: Option<Vec<Tool>>,
25 tool_config: Option<ToolConfig>,
26 system_instruction: Option<Content>,
27}
28
29impl ContentBuilder {
30 fn new(client: Arc<GeminiClient>) -> Self {
32 Self {
33 client,
34 contents: Vec::new(),
35 generation_config: None,
36 tools: None,
37 tool_config: None,
38 system_instruction: None,
39 }
40 }
41
42 pub fn with_system_prompt(self, text: impl Into<String>) -> Self {
44 self.with_system_instruction(text)
46 }
47
48 pub fn with_system_instruction(mut self, text: impl Into<String>) -> Self {
50 let content = Content::text(text);
52 self.system_instruction = Some(content);
53 self
54 }
55
56 pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
58 let message = Message::user(text);
59 let content = message.content;
60 self.contents.push(content);
61 self
62 }
63
64 pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
66 let message = Message::model(text);
67 let content = message.content;
68 self.contents.push(content);
69 self
70 }
71
72 pub fn with_function_response(
74 mut self,
75 name: impl Into<String>,
76 response: serde_json::Value,
77 ) -> Self {
78 let content = Content::function_response_json(name, response).with_role(Role::User);
79 self.contents.push(content);
80 self
81 }
82
83 pub fn with_function_response_str(
85 mut self,
86 name: impl Into<String>,
87 response: impl Into<String>,
88 ) -> std::result::Result<Self, serde_json::Error> {
89 let response_str = response.into();
90 let json = serde_json::from_str(&response_str)?;
91 let content = Content::function_response_json(name, json).with_role(Role::User);
92 self.contents.push(content);
93 Ok(self)
94 }
95
96 pub fn with_message(mut self, message: Message) -> Self {
98 let content = message.content.clone();
99 match &content.role {
100 Some(role) => {
101 let role_clone = role.clone();
102 self.contents.push(content.with_role(role_clone));
103 }
104 None => {
105 self.contents.push(content.with_role(message.role));
106 }
107 }
108 self
109 }
110
111 pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
113 for message in messages {
114 self = self.with_message(message);
115 }
116 self
117 }
118
119 pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
121 self.generation_config = Some(config);
122 self
123 }
124
125 pub fn with_temperature(mut self, temperature: f32) -> Self {
127 if self.generation_config.is_none() {
128 self.generation_config = Some(GenerationConfig::default());
129 }
130 if let Some(config) = &mut self.generation_config {
131 config.temperature = Some(temperature);
132 }
133 self
134 }
135
136 pub fn with_top_p(mut self, top_p: f32) -> Self {
138 if self.generation_config.is_none() {
139 self.generation_config = Some(GenerationConfig::default());
140 }
141 if let Some(config) = &mut self.generation_config {
142 config.top_p = Some(top_p);
143 }
144 self
145 }
146
147 pub fn with_top_k(mut self, top_k: i32) -> Self {
149 if self.generation_config.is_none() {
150 self.generation_config = Some(GenerationConfig::default());
151 }
152 if let Some(config) = &mut self.generation_config {
153 config.top_k = Some(top_k);
154 }
155 self
156 }
157
158 pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
160 if self.generation_config.is_none() {
161 self.generation_config = Some(GenerationConfig::default());
162 }
163 if let Some(config) = &mut self.generation_config {
164 config.max_output_tokens = Some(max_output_tokens);
165 }
166 self
167 }
168
169 pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
171 if self.generation_config.is_none() {
172 self.generation_config = Some(GenerationConfig::default());
173 }
174 if let Some(config) = &mut self.generation_config {
175 config.candidate_count = Some(candidate_count);
176 }
177 self
178 }
179
180 pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
182 if self.generation_config.is_none() {
183 self.generation_config = Some(GenerationConfig::default());
184 }
185 if let Some(config) = &mut self.generation_config {
186 config.stop_sequences = Some(stop_sequences);
187 }
188 self
189 }
190
191 pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
193 if self.generation_config.is_none() {
194 self.generation_config = Some(GenerationConfig::default());
195 }
196 if let Some(config) = &mut self.generation_config {
197 config.response_mime_type = Some(mime_type.into());
198 }
199 self
200 }
201
202 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
204 if self.generation_config.is_none() {
205 self.generation_config = Some(GenerationConfig::default());
206 }
207 if let Some(config) = &mut self.generation_config {
208 config.response_schema = Some(schema);
209 }
210 self
211 }
212
213 pub fn with_tool(mut self, tool: Tool) -> Self {
215 if self.tools.is_none() {
216 self.tools = Some(Vec::new());
217 }
218 if let Some(tools) = &mut self.tools {
219 tools.push(tool);
220 }
221 self
222 }
223
224 pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
226 let tool = Tool::new(function);
227 self = self.with_tool(tool);
228 self
229 }
230
231 pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
233 if self.tool_config.is_none() {
234 self.tool_config = Some(ToolConfig {
235 function_calling_config: Some(FunctionCallingConfig { mode }),
236 });
237 } else if let Some(tool_config) = &mut self.tool_config {
238 tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
239 }
240 self
241 }
242
243 pub async fn execute(self) -> Result<GenerationResponse> {
245 let request = GenerateContentRequest {
246 contents: self.contents,
247 generation_config: self.generation_config,
248 safety_settings: None,
249 tools: self.tools,
250 tool_config: self.tool_config,
251 system_instruction: self.system_instruction,
252 };
253
254 self.client.generate_content_raw(request).await
255 }
256
257 pub async fn execute_stream(
259 self,
260 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
261 let request = GenerateContentRequest {
262 contents: self.contents,
263 generation_config: self.generation_config,
264 safety_settings: None,
265 tools: self.tools,
266 tool_config: self.tool_config,
267 system_instruction: self.system_instruction,
268 };
269
270 self.client.generate_content_stream(request).await
271 }
272}
273
274struct GeminiClient {
276 http_client: Client,
277 api_key: String,
278 model: String,
279}
280
281impl GeminiClient {
282 fn new(api_key: impl Into<String>, model: String) -> Self {
284 Self {
285 http_client: Client::new(),
286 api_key: api_key.into(),
287 model,
288 }
289 }
290
291 async fn generate_content_raw(
293 &self,
294 request: GenerateContentRequest,
295 ) -> Result<GenerationResponse> {
296 let url = self.build_url("generateContent")?;
297
298 let response = self.http_client.post(url).json(&request).send().await?;
299
300 let status = response.status();
301 if !status.is_success() {
302 let error_text = response.text().await?;
303 return Err(Error::ApiError {
304 status_code: status.as_u16(),
305 message: error_text,
306 });
307 }
308
309 let response = response.json().await?;
310 Ok(response)
311 }
312
313 async fn generate_content_stream(
315 &self,
316 request: GenerateContentRequest,
317 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
318 let url = self.build_url("streamGenerateContent")?;
319
320 let response = self.http_client.post(url).json(&request).send().await?;
321
322 let status = response.status();
323 if !status.is_success() {
324 let error_text = response.text().await?;
325 return Err(Error::ApiError {
326 status_code: status.as_u16(),
327 message: error_text,
328 });
329 }
330
331 let stream = response
332 .bytes_stream()
333 .map(|result| {
334 match result {
335 Ok(bytes) => {
336 let text = String::from_utf8_lossy(&bytes);
337 let mut responses = Vec::new();
340 for line in text.lines() {
341 if let Some(json_str) = line.strip_prefix("data: ") {
342 if json_str == "[DONE]" {
343 continue;
344 }
345 match serde_json::from_str::<GenerationResponse>(json_str) {
346 Ok(response) => responses.push(Ok(response)),
347 Err(e) => responses.push(Err(Error::JsonError(e))),
348 }
349 }
350 }
351 futures::stream::iter(responses)
352 }
353 Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
354 }
355 })
356 .flatten();
357
358 Ok(Box::pin(stream))
359 }
360
361 fn build_url(&self, endpoint: &str) -> Result<Url> {
363 let url_str = format!(
366 "{}{}:{}?key={}",
367 BASE_URL, self.model, endpoint, self.api_key
368 );
369 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
370 }
371}
372
373#[derive(Clone)]
375pub struct Gemini {
376 client: Arc<GeminiClient>,
377}
378
379impl Gemini {
380 pub fn new(api_key: impl Into<String>) -> Self {
382 Self::with_model(api_key, DEFAULT_MODEL.to_string())
383 }
384
385 pub fn pro(api_key: impl Into<String>) -> Self {
387 Self::with_model(api_key, "models/gemini-2.0-pro-exp-02-05".to_string())
388 }
389
390 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
392 let client = GeminiClient::new(api_key, model);
393 Self {
394 client: Arc::new(client),
395 }
396 }
397
398 pub fn generate_content(&self) -> ContentBuilder {
400 ContentBuilder::new(self.client.clone())
401 }
402}