1use crate::{
2 models::{
3 Content, FunctionCallingConfig, FunctionCallingMode, GenerateContentRequest,
4 GenerationConfig, GenerationResponse, Message, Role, ToolConfig,
5 },
6 tools::{FunctionDeclaration, Tool},
7 Error, Result,
8};
9use futures::stream::Stream;
10use futures_util::StreamExt;
11use reqwest::Client;
12use std::pin::Pin;
13use std::sync::Arc;
14use url::Url;
15
16const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
17const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
18
19pub struct ContentBuilder {
21 client: Arc<GeminiClient>,
22 pub contents: Vec<Content>,
23 generation_config: Option<GenerationConfig>,
24 tools: Option<Vec<Tool>>,
25 tool_config: Option<ToolConfig>,
26 system_instruction: Option<Content>,
27}
28
29impl ContentBuilder {
30 fn new(client: Arc<GeminiClient>) -> Self {
32 Self {
33 client,
34 contents: Vec::new(),
35 generation_config: None,
36 tools: None,
37 tool_config: None,
38 system_instruction: None,
39 }
40 }
41
42 pub fn with_system_prompt(self, text: impl Into<String>) -> Self {
44 self.with_system_instruction(text)
46 }
47
48 pub fn with_system_instruction(mut self, text: impl Into<String>) -> Self {
50 let content = Content::text(text);
52 self.system_instruction = Some(content);
53 self
54 }
55
56 pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
58 let content = Content::text(text).with_role(Role::User);
59 self.contents.push(content);
60 self
61 }
62
63 pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
65 let content = Content::text(text).with_role(Role::Model);
66 self.contents.push(content);
67 self
68 }
69
70 pub fn with_function_response(
72 mut self,
73 name: impl Into<String>,
74 response: serde_json::Value,
75 ) -> Self {
76 let content = Content::function_response_json(name, response).with_role(Role::Function);
77 self.contents.push(content);
78 self
79 }
80
81 pub fn with_function_response_str(
83 mut self,
84 name: impl Into<String>,
85 response: impl Into<String>,
86 ) -> std::result::Result<Self, serde_json::Error> {
87 let response_str = response.into();
88 let json = serde_json::from_str(&response_str)?;
89 let content = Content::function_response_json(name, json).with_role(Role::Function);
90 self.contents.push(content);
91 Ok(self)
92 }
93
94 pub fn with_message(mut self, message: Message) -> Self {
96 let content = message.content.clone();
97 match &content.role {
98 Some(role) => {
99 let role_clone = role.clone();
100 self.contents.push(content.with_role(role_clone));
101 }
102 None => {
103 self.contents.push(content.with_role(message.role));
104 }
105 }
106 self
107 }
108
109 pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
111 for message in messages {
112 self = self.with_message(message);
113 }
114 self
115 }
116
117 pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
119 self.generation_config = Some(config);
120 self
121 }
122
123 pub fn with_temperature(mut self, temperature: f32) -> Self {
125 if self.generation_config.is_none() {
126 self.generation_config = Some(GenerationConfig::default());
127 }
128 if let Some(config) = &mut self.generation_config {
129 config.temperature = Some(temperature);
130 }
131 self
132 }
133
134 pub fn with_top_p(mut self, top_p: f32) -> Self {
136 if self.generation_config.is_none() {
137 self.generation_config = Some(GenerationConfig::default());
138 }
139 if let Some(config) = &mut self.generation_config {
140 config.top_p = Some(top_p);
141 }
142 self
143 }
144
145 pub fn with_top_k(mut self, top_k: i32) -> Self {
147 if self.generation_config.is_none() {
148 self.generation_config = Some(GenerationConfig::default());
149 }
150 if let Some(config) = &mut self.generation_config {
151 config.top_k = Some(top_k);
152 }
153 self
154 }
155
156 pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
158 if self.generation_config.is_none() {
159 self.generation_config = Some(GenerationConfig::default());
160 }
161 if let Some(config) = &mut self.generation_config {
162 config.max_output_tokens = Some(max_output_tokens);
163 }
164 self
165 }
166
167 pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
169 if self.generation_config.is_none() {
170 self.generation_config = Some(GenerationConfig::default());
171 }
172 if let Some(config) = &mut self.generation_config {
173 config.candidate_count = Some(candidate_count);
174 }
175 self
176 }
177
178 pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
180 if self.generation_config.is_none() {
181 self.generation_config = Some(GenerationConfig::default());
182 }
183 if let Some(config) = &mut self.generation_config {
184 config.stop_sequences = Some(stop_sequences);
185 }
186 self
187 }
188
189 pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
191 if self.generation_config.is_none() {
192 self.generation_config = Some(GenerationConfig::default());
193 }
194 if let Some(config) = &mut self.generation_config {
195 config.response_mime_type = Some(mime_type.into());
196 }
197 self
198 }
199
200 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
202 if self.generation_config.is_none() {
203 self.generation_config = Some(GenerationConfig::default());
204 }
205 if let Some(config) = &mut self.generation_config {
206 config.response_schema = Some(schema);
207 }
208 self
209 }
210
211 pub fn with_tool(mut self, tool: Tool) -> Self {
213 if self.tools.is_none() {
214 self.tools = Some(Vec::new());
215 }
216 if let Some(tools) = &mut self.tools {
217 tools.push(tool);
218 }
219 self
220 }
221
222 pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
224 let tool = Tool::new(function);
225 self = self.with_tool(tool);
226 self
227 }
228
229 pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
231 if self.tool_config.is_none() {
232 self.tool_config = Some(ToolConfig {
233 function_calling_config: Some(FunctionCallingConfig { mode }),
234 });
235 } else if let Some(tool_config) = &mut self.tool_config {
236 tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
237 }
238 self
239 }
240
241 pub async fn execute(self) -> Result<GenerationResponse> {
243 let request = GenerateContentRequest {
244 contents: self.contents,
245 generation_config: self.generation_config,
246 safety_settings: None,
247 tools: self.tools,
248 tool_config: self.tool_config,
249 system_instruction: self.system_instruction,
250 };
251
252 self.client.generate_content_raw(request).await
253 }
254
255 pub async fn execute_stream(
257 self,
258 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
259 let request = GenerateContentRequest {
260 contents: self.contents,
261 generation_config: self.generation_config,
262 safety_settings: None,
263 tools: self.tools,
264 tool_config: self.tool_config,
265 system_instruction: self.system_instruction,
266 };
267
268 self.client.generate_content_stream(request).await
269 }
270}
271
272struct GeminiClient {
274 http_client: Client,
275 api_key: String,
276 model: String,
277}
278
279impl GeminiClient {
280 fn new(api_key: impl Into<String>, model: String) -> Self {
282 Self {
283 http_client: Client::new(),
284 api_key: api_key.into(),
285 model,
286 }
287 }
288
289 async fn generate_content_raw(
291 &self,
292 request: GenerateContentRequest,
293 ) -> Result<GenerationResponse> {
294 let url = self.build_url("generateContent")?;
295
296 let response = self.http_client.post(url).json(&request).send().await?;
297
298 let status = response.status();
299 if !status.is_success() {
300 let error_text = response.text().await?;
301 return Err(Error::ApiError {
302 status_code: status.as_u16(),
303 message: error_text,
304 });
305 }
306
307 let response = response.json().await?;
308 Ok(response)
309 }
310
311 async fn generate_content_stream(
313 &self,
314 request: GenerateContentRequest,
315 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
316 let url = self.build_url("streamGenerateContent")?;
317
318 let response = self.http_client.post(url).json(&request).send().await?;
319
320 let status = response.status();
321 if !status.is_success() {
322 let error_text = response.text().await?;
323 return Err(Error::ApiError {
324 status_code: status.as_u16(),
325 message: error_text,
326 });
327 }
328
329 let stream = response
330 .bytes_stream()
331 .map(|result| {
332 match result {
333 Ok(bytes) => {
334 let text = String::from_utf8_lossy(&bytes);
335 let mut responses = Vec::new();
338 for line in text.lines() {
339 if let Some(json_str) = line.strip_prefix("data: ") {
340 if json_str == "[DONE]" {
341 continue;
342 }
343 match serde_json::from_str::<GenerationResponse>(json_str) {
344 Ok(response) => responses.push(Ok(response)),
345 Err(e) => responses.push(Err(Error::JsonError(e))),
346 }
347 }
348 }
349 futures::stream::iter(responses)
350 }
351 Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
352 }
353 })
354 .flatten();
355
356 Ok(Box::pin(stream))
357 }
358
359 fn build_url(&self, endpoint: &str) -> Result<Url> {
361 let url_str = format!(
364 "{}{}:{}?key={}",
365 BASE_URL, self.model, endpoint, self.api_key
366 );
367 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
368 }
369}
370
371#[derive(Clone)]
373pub struct Gemini {
374 client: Arc<GeminiClient>,
375}
376
377impl Gemini {
378 pub fn new(api_key: impl Into<String>) -> Self {
380 Self::with_model(api_key, DEFAULT_MODEL.to_string())
381 }
382
383 pub fn pro(api_key: impl Into<String>) -> Self {
385 Self::with_model(api_key, "models/gemini-2.0-pro-exp-02-05".to_string())
386 }
387
388 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
390 let client = GeminiClient::new(api_key, model);
391 Self {
392 client: Arc::new(client),
393 }
394 }
395
396 pub fn generate_content(&self) -> ContentBuilder {
398 ContentBuilder::new(self.client.clone())
399 }
400}