1use async_trait::async_trait;
4use derive_builder::Builder;
5use futures::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10use crate::llm::{
11 BaseChatModel, ChatCompletion, ChatStream, ContentPart, LlmError, Message, StopReason,
12 ToolChoice, ToolDefinition, Usage,
13};
14
15const GOOGLE_API_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
16
17#[derive(Builder, Clone)]
19#[builder(pattern = "owned", build_fn(skip))]
20pub struct ChatGoogle {
21 #[builder(setter(into))]
23 model: String,
24 api_key: String,
26 #[builder(setter(into, strip_option), default = "None")]
28 base_url: Option<String>,
29 #[builder(default = "8192")]
31 max_tokens: u64,
32 #[builder(default = "0.2")]
34 temperature: f32,
35 #[builder(default = "None")]
37 thinking_budget: Option<u64>,
38 #[builder(setter(skip))]
40 client: Client,
41 #[builder(setter(skip))]
43 context_window: u64,
44}
45
46impl ChatGoogle {
47 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
49 let api_key = std::env::var("GOOGLE_API_KEY")
50 .or_else(|_| std::env::var("GEMINI_API_KEY"))
51 .map_err(|_| LlmError::Config("GOOGLE_API_KEY or GEMINI_API_KEY not set".into()))?;
52
53 Self::builder().model(model).api_key(api_key).build()
54 }
55
56 pub fn builder() -> ChatGoogleBuilder {
58 ChatGoogleBuilder::default()
59 }
60
61 fn api_url(&self, stream: bool) -> String {
63 let base = self.base_url.as_deref().unwrap_or(GOOGLE_API_URL);
64 let method = if stream {
65 "streamGenerateContent"
66 } else {
67 "generateContent"
68 };
69 format!("{}/{}:{}?key={}", base, self.model, method, self.api_key)
70 }
71
72 fn build_client() -> Client {
74 Client::builder()
75 .timeout(Duration::from_secs(120))
76 .build()
77 .expect("Failed to create HTTP client")
78 }
79
80 fn get_context_window(model: &str) -> u64 {
82 let model_lower = model.to_lowercase();
83
84 if model_lower.contains("gemini-1.5-pro") {
85 2_097_152 } else {
87 1_048_576 }
89 }
90
91 fn is_thinking_model(&self) -> bool {
93 let model_lower = self.model.to_lowercase();
94 model_lower.contains("gemini-2.5")
95 || model_lower.contains("thinking")
96 || model_lower.contains("gemini-exp")
97 }
98}
99
100impl ChatGoogleBuilder {
101 pub fn build(&self) -> Result<ChatGoogle, LlmError> {
102 let model = self
103 .model
104 .clone()
105 .ok_or_else(|| LlmError::Config("model is required".into()))?;
106 let api_key = self
107 .api_key
108 .clone()
109 .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
110
111 Ok(ChatGoogle {
112 context_window: ChatGoogle::get_context_window(&model),
113 client: ChatGoogle::build_client(),
114 model,
115 api_key,
116 base_url: self.base_url.clone().flatten(),
117 max_tokens: self.max_tokens.unwrap_or(8192),
118 temperature: self.temperature.unwrap_or(0.2),
119 thinking_budget: self.thinking_budget.flatten(),
120 })
121 }
122}
123
124#[async_trait]
125impl BaseChatModel for ChatGoogle {
126 fn model(&self) -> &str {
127 &self.model
128 }
129
130 fn provider(&self) -> &str {
131 "google"
132 }
133
134 fn context_window(&self) -> Option<u64> {
135 Some(self.context_window)
136 }
137
138 async fn invoke(
139 &self,
140 messages: Vec<Message>,
141 tools: Option<Vec<ToolDefinition>>,
142 tool_choice: Option<ToolChoice>,
143 ) -> Result<ChatCompletion, LlmError> {
144 let request = self.build_request(messages, tools, tool_choice)?;
145
146 let response = self
147 .client
148 .post(self.api_url(false))
149 .header("Content-Type", "application/json")
150 .json(&request)
151 .send()
152 .await?;
153
154 if !response.status().is_success() {
155 let status = response.status();
156 let body = response.text().await.unwrap_or_default();
157 return Err(LlmError::Api(format!(
158 "Google API error ({}): {}",
159 status, body
160 )));
161 }
162
163 let completion: GeminiResponse = response.json().await?;
164 Ok(self.parse_response(completion))
165 }
166
167 async fn invoke_stream(
168 &self,
169 messages: Vec<Message>,
170 tools: Option<Vec<ToolDefinition>>,
171 tool_choice: Option<ToolChoice>,
172 ) -> Result<ChatStream, LlmError> {
173 let request = self.build_request(messages, tools, tool_choice)?;
174
175 let response = self
176 .client
177 .post(self.api_url(true))
178 .header("Content-Type", "application/json")
179 .json(&request)
180 .send()
181 .await?;
182
183 if !response.status().is_success() {
184 let status = response.status();
185 let body = response.text().await.unwrap_or_default();
186 return Err(LlmError::Api(format!(
187 "Google API error ({}): {}",
188 status, body
189 )));
190 }
191
192 let stream = response.bytes_stream().filter_map(|result| async move {
194 match result {
195 Ok(bytes) => {
196 let text = String::from_utf8_lossy(&bytes);
197 Self::parse_stream_chunk(&text)
198 }
199 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
200 }
201 });
202
203 Ok(Box::pin(stream))
204 }
205
206 fn supports_vision(&self) -> bool {
207 true
209 }
210}
211
212#[derive(Serialize)]
217struct GeminiRequest {
218 contents: Vec<GeminiContent>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 system_instruction: Option<GeminiContent>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 tools: Option<GeminiTools>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 generation_config: Option<GeminiGenerationConfig>,
225}
226
227#[derive(Serialize)]
228struct GeminiContent {
229 role: String,
230 parts: Vec<GeminiPart>,
231}
232
233#[derive(Serialize)]
234#[serde(untagged)]
235enum GeminiPart {
236 Text {
237 text: String,
238 },
239 InlineData {
240 inline_data: GeminiInlineData,
241 },
242 FunctionCall {
243 function_call: GeminiFunctionCall,
244 },
245 FunctionResponse {
246 function_response: GeminiFunctionResponse,
247 },
248 Thought {
249 thought: String,
250 },
251}
252
253#[derive(Serialize)]
254struct GeminiInlineData {
255 mime_type: String,
256 data: String,
257}
258
259#[derive(Serialize)]
260struct GeminiFunctionCall {
261 name: String,
262 args: serde_json::Value,
263}
264
265#[derive(Serialize)]
266struct GeminiFunctionResponse {
267 name: String,
268 response: GeminiToolResult,
269}
270
271#[derive(Serialize)]
272struct GeminiToolResult {
273 name: String,
274 content: String,
275}
276
277#[derive(Serialize)]
278struct GeminiTools {
279 function_declarations: Vec<GeminiFunctionDeclaration>,
280}
281
282#[derive(Serialize)]
283struct GeminiFunctionDeclaration {
284 name: String,
285 description: String,
286 parameters: serde_json::Map<String, serde_json::Value>,
287}
288
289#[derive(Serialize)]
290struct GeminiGenerationConfig {
291 temperature: f32,
292 max_output_tokens: u64,
293 #[serde(skip_serializing_if = "Option::is_none")]
294 thinking_config: Option<GeminiThinkingConfig>,
295}
296
297#[derive(Serialize)]
298struct GeminiThinkingConfig {
299 thinking_budget: u64,
300}
301
302#[derive(Deserialize)]
303struct GeminiResponse {
304 candidates: Vec<GeminiCandidate>,
305 usage_metadata: Option<GeminiUsage>,
306}
307
308#[derive(Deserialize)]
309struct GeminiCandidate {
310 content: GeminiResponseContent,
311 finish_reason: Option<String>,
312}
313
314#[derive(Deserialize)]
315struct GeminiResponseContent {
316 parts: Vec<GeminiResponsePart>,
317}
318
319#[derive(Deserialize)]
320#[serde(untagged)]
321enum GeminiResponsePart {
322 Text {
323 text: String,
324 },
325 Thought {
326 thought: String,
327 },
328 FunctionCall {
329 function_call: GeminiFunctionCallResponse,
330 },
331}
332
333#[derive(Deserialize)]
334struct GeminiFunctionCallResponse {
335 name: String,
336 args: serde_json::Value,
337 #[serde(default)]
338 id: Option<String>,
339}
340
341#[derive(Deserialize)]
342struct GeminiUsage {
343 prompt_token_count: u64,
344 candidates_token_count: u64,
345 total_token_count: u64,
346 #[serde(default)]
347 cached_content_token_count: u64,
348}
349
350impl ChatGoogle {
351 fn build_request(
352 &self,
353 messages: Vec<Message>,
354 tools: Option<Vec<ToolDefinition>>,
355 _tool_choice: Option<ToolChoice>,
356 ) -> Result<GeminiRequest, LlmError> {
357 let mut system_instruction: Option<GeminiContent> = None;
358 let mut contents: Vec<GeminiContent> = Vec::new();
359
360 for message in messages {
361 match message {
362 Message::System(s) => {
363 system_instruction = Some(GeminiContent {
364 role: "user".to_string(),
365 parts: vec![GeminiPart::Text { text: s.content }],
366 });
367 }
368 Message::User(u) => {
369 let parts: Vec<GeminiPart> = u
370 .content
371 .into_iter()
372 .map(|c| match c {
373 ContentPart::Text(t) => GeminiPart::Text { text: t.text },
374 ContentPart::Image(img) => {
375 let (mime_type, data) = if img.image_url.url.starts_with("data:") {
376 let parts: Vec<&str> =
377 img.image_url.url.splitn(2, ',').collect();
378 let mime = parts[0]
379 .strip_prefix("data:")
380 .and_then(|s| s.strip_suffix(";base64"))
381 .unwrap_or("image/png");
382 (mime.to_string(), parts.get(1).unwrap_or(&"").to_string())
383 } else {
384 ("image/png".to_string(), img.image_url.url.clone())
385 };
386 GeminiPart::InlineData {
387 inline_data: GeminiInlineData { mime_type, data },
388 }
389 }
390 _ => GeminiPart::Text {
391 text: "[Unsupported content]".to_string(),
392 },
393 })
394 .collect();
395
396 contents.push(GeminiContent {
397 role: "user".to_string(),
398 parts,
399 });
400 }
401 Message::Assistant(a) => {
402 let mut parts = Vec::new();
403
404 if let Some(t) = a.thinking {
405 parts.push(GeminiPart::Thought { thought: t });
406 }
407
408 if let Some(c) = a.content {
409 parts.push(GeminiPart::Text { text: c });
410 }
411
412 for tc in a.tool_calls {
413 let args: serde_json::Value = serde_json::from_str(&tc.function.arguments)
414 .unwrap_or(serde_json::json!({}));
415 parts.push(GeminiPart::FunctionCall {
416 function_call: GeminiFunctionCall {
417 name: tc.function.name,
418 args,
419 },
420 });
421 }
422
423 contents.push(GeminiContent {
424 role: "model".to_string(),
425 parts,
426 });
427 }
428 Message::Tool(t) => {
429 contents.push(GeminiContent {
430 role: "user".to_string(),
431 parts: vec![GeminiPart::FunctionResponse {
432 function_response: GeminiFunctionResponse {
433 name: "function_result".to_string(),
434 response: GeminiToolResult {
435 name: "result".to_string(),
436 content: t.content,
437 },
438 },
439 }],
440 });
441 }
442 Message::Developer(d) => {
443 system_instruction = Some(GeminiContent {
444 role: "user".to_string(),
445 parts: vec![GeminiPart::Text { text: d.content }],
446 });
447 }
448 }
449 }
450
451 let gemini_tools = tools.map(|ts| GeminiTools {
452 function_declarations: ts
453 .into_iter()
454 .map(|t| GeminiFunctionDeclaration {
455 name: t.name,
456 description: t.description,
457 parameters: t.parameters,
458 })
459 .collect(),
460 });
461
462 let thinking_config = if self.is_thinking_model() {
463 self.thinking_budget.map(|budget| GeminiThinkingConfig {
464 thinking_budget: budget,
465 })
466 } else {
467 None
468 };
469
470 Ok(GeminiRequest {
471 contents,
472 system_instruction,
473 tools: gemini_tools,
474 generation_config: Some(GeminiGenerationConfig {
475 temperature: self.temperature,
476 max_output_tokens: self.max_tokens,
477 thinking_config,
478 }),
479 })
480 }
481
482 fn parse_response(&self, response: GeminiResponse) -> ChatCompletion {
483 let stop_reason = response
484 .candidates
485 .first()
486 .and_then(|c| c.finish_reason.as_ref())
487 .and_then(|r| match r.as_str() {
488 "STOP" => Some(StopReason::EndTurn),
489 "MAX_TOKENS" => Some(StopReason::MaxTokens),
490 "TOOL_CODE" => Some(StopReason::ToolUse),
491 _ => None,
492 });
493
494 let candidate = response.candidates.into_iter().next();
495
496 let (content, thinking, tool_calls) = candidate
497 .map(|c| {
498 let mut text: Option<String> = None;
499 let mut think: Option<String> = None;
500 let mut calls = Vec::new();
501
502 for part in c.content.parts {
503 match part {
504 GeminiResponsePart::Text { text: t } => {
505 text = Some(t);
506 }
507 GeminiResponsePart::Thought { thought: t } => {
508 think = Some(t);
509 }
510 GeminiResponsePart::FunctionCall { function_call: fc } => {
511 let id = fc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
512 calls.push(crate::llm::ToolCall::new(
513 id,
514 fc.name,
515 serde_json::to_string(&fc.args).unwrap_or_default(),
516 ));
517 }
518 }
519 }
520
521 (text, think, calls)
522 })
523 .unwrap_or((None, None, Vec::new()));
524
525 let usage = response.usage_metadata.map(|u| Usage {
526 prompt_tokens: u.prompt_token_count,
527 completion_tokens: u.candidates_token_count,
528 total_tokens: u.total_token_count,
529 prompt_cached_tokens: Some(u.cached_content_token_count),
530 ..Default::default()
531 });
532
533 ChatCompletion {
534 content,
535 thinking,
536 redacted_thinking: None,
537 tool_calls,
538 usage,
539 stop_reason,
540 }
541 }
542
543 fn parse_stream_chunk(text: &str) -> Option<Result<ChatCompletion, LlmError>> {
544 for line in text.lines() {
546 let line = line.trim();
547 if line.is_empty() {
548 continue;
549 }
550
551 let line = line.trim_start_matches('[').trim_end_matches(']');
553 if line.is_empty() {
554 continue;
555 }
556
557 for chunk_str in line.split("},") {
559 let chunk_str = if !chunk_str.ends_with('}') {
560 format!("{}{}", chunk_str, "}")
561 } else {
562 chunk_str.to_string()
563 };
564
565 let chunk: serde_json::Value = match serde_json::from_str(&chunk_str) {
566 Ok(v) => v,
567 Err(_) => continue,
568 };
569
570 let parts = chunk
571 .get("candidates")?
572 .as_array()?
573 .first()?
574 .get("content")?
575 .get("parts")?
576 .as_array()?;
577
578 for part in parts {
579 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
580 return Some(Ok(ChatCompletion::text(text)));
581 }
582
583 if let Some(thought) = part.get("thought").and_then(|t| t.as_str()) {
584 let mut completion = ChatCompletion::text("");
585 completion.thinking = Some(thought.to_string());
586 return Some(Ok(completion));
587 }
588
589 if let Some(fc) = part.get("function_call") {
590 let name = fc.get("name")?.as_str()?.to_string();
591 let args = fc.get("args").cloned().unwrap_or(serde_json::json!({}));
592 let id = fc.get("id").and_then(|i| i.as_str()).unwrap_or("pending");
593
594 return Some(Ok(ChatCompletion {
595 content: None,
596 thinking: None,
597 redacted_thinking: None,
598 tool_calls: vec![crate::llm::ToolCall::new(
599 id,
600 name,
601 serde_json::to_string(&args).unwrap_or_default(),
602 )],
603 usage: None,
604 stop_reason: Some(StopReason::ToolUse),
605 }));
606 }
607 }
608 }
609 }
610
611 None
612 }
613}