Skip to main content

agent_io/llm/
google.rs

1//! Google Gemini Chat Model implementation
2
3use async_trait::async_trait;
4use derive_builder::Builder;
5use futures::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10use crate::llm::{
11    BaseChatModel, ChatCompletion, ChatStream, ContentPart, LlmError, Message, StopReason,
12    ToolChoice, ToolDefinition, Usage,
13};
14
15const GOOGLE_API_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
16
17/// Google Gemini Chat Model
18#[derive(Builder, Clone)]
19#[builder(pattern = "owned", build_fn(skip))]
20pub struct ChatGoogle {
21    /// Model identifier
22    #[builder(setter(into))]
23    model: String,
24    /// API key
25    api_key: String,
26    /// Base URL
27    #[builder(setter(into, strip_option), default = "None")]
28    base_url: Option<String>,
29    /// Maximum output tokens
30    #[builder(default = "8192")]
31    max_tokens: u64,
32    /// Temperature for sampling
33    #[builder(default = "0.2")]
34    temperature: f32,
35    /// Thinking budget (for thinking models)
36    #[builder(default = "None")]
37    thinking_budget: Option<u64>,
38    /// HTTP client
39    #[builder(setter(skip))]
40    client: Client,
41    /// Context window
42    #[builder(setter(skip))]
43    context_window: u64,
44}
45
46impl ChatGoogle {
47    /// Create a new Google chat model
48    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
49        let api_key = std::env::var("GOOGLE_API_KEY")
50            .or_else(|_| std::env::var("GEMINI_API_KEY"))
51            .map_err(|_| LlmError::Config("GOOGLE_API_KEY or GEMINI_API_KEY not set".into()))?;
52
53        Self::builder().model(model).api_key(api_key).build()
54    }
55
56    /// Create a builder for configuration
57    pub fn builder() -> ChatGoogleBuilder {
58        ChatGoogleBuilder::default()
59    }
60
61    /// Get the API URL for the model
62    fn api_url(&self, stream: bool) -> String {
63        let base = self.base_url.as_deref().unwrap_or(GOOGLE_API_URL);
64        let method = if stream {
65            "streamGenerateContent"
66        } else {
67            "generateContent"
68        };
69        format!("{}/{}:{}?key={}", base, self.model, method, self.api_key)
70    }
71
72    /// Build the HTTP client
73    fn build_client() -> Client {
74        Client::builder()
75            .timeout(Duration::from_secs(120))
76            .build()
77            .expect("Failed to create HTTP client")
78    }
79
80    /// Get context window for model
81    fn get_context_window(model: &str) -> u64 {
82        let model_lower = model.to_lowercase();
83
84        if model_lower.contains("gemini-1.5-pro") {
85            2_097_152 // 2M tokens
86        } else {
87            1_048_576 // 1M tokens - default for most Gemini models
88        }
89    }
90
91    /// Check if this is a thinking model
92    fn is_thinking_model(&self) -> bool {
93        let model_lower = self.model.to_lowercase();
94        model_lower.contains("gemini-2.5")
95            || model_lower.contains("thinking")
96            || model_lower.contains("gemini-exp")
97    }
98}
99
100impl ChatGoogleBuilder {
101    pub fn build(&self) -> Result<ChatGoogle, LlmError> {
102        let model = self
103            .model
104            .clone()
105            .ok_or_else(|| LlmError::Config("model is required".into()))?;
106        let api_key = self
107            .api_key
108            .clone()
109            .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
110
111        Ok(ChatGoogle {
112            context_window: ChatGoogle::get_context_window(&model),
113            client: ChatGoogle::build_client(),
114            model,
115            api_key,
116            base_url: self.base_url.clone().flatten(),
117            max_tokens: self.max_tokens.unwrap_or(8192),
118            temperature: self.temperature.unwrap_or(0.2),
119            thinking_budget: self.thinking_budget.flatten(),
120        })
121    }
122}
123
124#[async_trait]
125impl BaseChatModel for ChatGoogle {
126    fn model(&self) -> &str {
127        &self.model
128    }
129
130    fn provider(&self) -> &str {
131        "google"
132    }
133
134    fn context_window(&self) -> Option<u64> {
135        Some(self.context_window)
136    }
137
138    async fn invoke(
139        &self,
140        messages: Vec<Message>,
141        tools: Option<Vec<ToolDefinition>>,
142        tool_choice: Option<ToolChoice>,
143    ) -> Result<ChatCompletion, LlmError> {
144        let request = self.build_request(messages, tools, tool_choice)?;
145
146        let response = self
147            .client
148            .post(self.api_url(false))
149            .header("Content-Type", "application/json")
150            .json(&request)
151            .send()
152            .await?;
153
154        if !response.status().is_success() {
155            let status = response.status();
156            let body = response.text().await.unwrap_or_default();
157            return Err(LlmError::Api(format!(
158                "Google API error ({}): {}",
159                status, body
160            )));
161        }
162
163        let completion: GeminiResponse = response.json().await?;
164        Ok(self.parse_response(completion))
165    }
166
167    async fn invoke_stream(
168        &self,
169        messages: Vec<Message>,
170        tools: Option<Vec<ToolDefinition>>,
171        tool_choice: Option<ToolChoice>,
172    ) -> Result<ChatStream, LlmError> {
173        let request = self.build_request(messages, tools, tool_choice)?;
174
175        let response = self
176            .client
177            .post(self.api_url(true))
178            .header("Content-Type", "application/json")
179            .json(&request)
180            .send()
181            .await?;
182
183        if !response.status().is_success() {
184            let status = response.status();
185            let body = response.text().await.unwrap_or_default();
186            return Err(LlmError::Api(format!(
187                "Google API error ({}): {}",
188                status, body
189            )));
190        }
191
192        // Google returns JSON lines for streaming
193        let stream = response.bytes_stream().filter_map(|result| async move {
194            match result {
195                Ok(bytes) => {
196                    let text = String::from_utf8_lossy(&bytes);
197                    Self::parse_stream_chunk(&text)
198                }
199                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
200            }
201        });
202
203        Ok(Box::pin(stream))
204    }
205
206    fn supports_vision(&self) -> bool {
207        // All Gemini models support vision
208        true
209    }
210}
211
212// =============================================================================
213// Request/Response Types
214// =============================================================================
215
216#[derive(Serialize)]
217struct GeminiRequest {
218    contents: Vec<GeminiContent>,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    system_instruction: Option<GeminiContent>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    tools: Option<GeminiTools>,
223    #[serde(skip_serializing_if = "Option::is_none")]
224    generation_config: Option<GeminiGenerationConfig>,
225}
226
227#[derive(Serialize)]
228struct GeminiContent {
229    role: String,
230    parts: Vec<GeminiPart>,
231}
232
233#[derive(Serialize)]
234#[serde(untagged)]
235enum GeminiPart {
236    Text {
237        text: String,
238    },
239    InlineData {
240        inline_data: GeminiInlineData,
241    },
242    FunctionCall {
243        function_call: GeminiFunctionCall,
244    },
245    FunctionResponse {
246        function_response: GeminiFunctionResponse,
247    },
248    Thought {
249        thought: String,
250    },
251}
252
253#[derive(Serialize)]
254struct GeminiInlineData {
255    mime_type: String,
256    data: String,
257}
258
259#[derive(Serialize)]
260struct GeminiFunctionCall {
261    name: String,
262    args: serde_json::Value,
263}
264
265#[derive(Serialize)]
266struct GeminiFunctionResponse {
267    name: String,
268    response: GeminiToolResult,
269}
270
271#[derive(Serialize)]
272struct GeminiToolResult {
273    name: String,
274    content: String,
275}
276
277#[derive(Serialize)]
278struct GeminiTools {
279    function_declarations: Vec<GeminiFunctionDeclaration>,
280}
281
282#[derive(Serialize)]
283struct GeminiFunctionDeclaration {
284    name: String,
285    description: String,
286    parameters: serde_json::Map<String, serde_json::Value>,
287}
288
289#[derive(Serialize)]
290struct GeminiGenerationConfig {
291    temperature: f32,
292    max_output_tokens: u64,
293    #[serde(skip_serializing_if = "Option::is_none")]
294    thinking_config: Option<GeminiThinkingConfig>,
295}
296
297#[derive(Serialize)]
298struct GeminiThinkingConfig {
299    thinking_budget: u64,
300}
301
302#[derive(Deserialize)]
303struct GeminiResponse {
304    candidates: Vec<GeminiCandidate>,
305    usage_metadata: Option<GeminiUsage>,
306}
307
308#[derive(Deserialize)]
309struct GeminiCandidate {
310    content: GeminiResponseContent,
311    finish_reason: Option<String>,
312}
313
314#[derive(Deserialize)]
315struct GeminiResponseContent {
316    parts: Vec<GeminiResponsePart>,
317}
318
319#[derive(Deserialize)]
320#[serde(untagged)]
321enum GeminiResponsePart {
322    Text {
323        text: String,
324    },
325    Thought {
326        thought: String,
327    },
328    FunctionCall {
329        function_call: GeminiFunctionCallResponse,
330    },
331}
332
333#[derive(Deserialize)]
334struct GeminiFunctionCallResponse {
335    name: String,
336    args: serde_json::Value,
337    #[serde(default)]
338    id: Option<String>,
339}
340
341#[derive(Deserialize)]
342struct GeminiUsage {
343    prompt_token_count: u64,
344    candidates_token_count: u64,
345    total_token_count: u64,
346    #[serde(default)]
347    cached_content_token_count: u64,
348}
349
350impl ChatGoogle {
351    fn build_request(
352        &self,
353        messages: Vec<Message>,
354        tools: Option<Vec<ToolDefinition>>,
355        _tool_choice: Option<ToolChoice>,
356    ) -> Result<GeminiRequest, LlmError> {
357        let mut system_instruction: Option<GeminiContent> = None;
358        let mut contents: Vec<GeminiContent> = Vec::new();
359
360        for message in messages {
361            match message {
362                Message::System(s) => {
363                    system_instruction = Some(GeminiContent {
364                        role: "user".to_string(),
365                        parts: vec![GeminiPart::Text { text: s.content }],
366                    });
367                }
368                Message::User(u) => {
369                    let parts: Vec<GeminiPart> = u
370                        .content
371                        .into_iter()
372                        .map(|c| match c {
373                            ContentPart::Text(t) => GeminiPart::Text { text: t.text },
374                            ContentPart::Image(img) => {
375                                let (mime_type, data) = if img.image_url.url.starts_with("data:") {
376                                    let parts: Vec<&str> =
377                                        img.image_url.url.splitn(2, ',').collect();
378                                    let mime = parts[0]
379                                        .strip_prefix("data:")
380                                        .and_then(|s| s.strip_suffix(";base64"))
381                                        .unwrap_or("image/png");
382                                    (mime.to_string(), parts.get(1).unwrap_or(&"").to_string())
383                                } else {
384                                    ("image/png".to_string(), img.image_url.url.clone())
385                                };
386                                GeminiPart::InlineData {
387                                    inline_data: GeminiInlineData { mime_type, data },
388                                }
389                            }
390                            _ => GeminiPart::Text {
391                                text: "[Unsupported content]".to_string(),
392                            },
393                        })
394                        .collect();
395
396                    contents.push(GeminiContent {
397                        role: "user".to_string(),
398                        parts,
399                    });
400                }
401                Message::Assistant(a) => {
402                    let mut parts = Vec::new();
403
404                    if let Some(t) = a.thinking {
405                        parts.push(GeminiPart::Thought { thought: t });
406                    }
407
408                    if let Some(c) = a.content {
409                        parts.push(GeminiPart::Text { text: c });
410                    }
411
412                    for tc in a.tool_calls {
413                        let args: serde_json::Value = serde_json::from_str(&tc.function.arguments)
414                            .unwrap_or(serde_json::json!({}));
415                        parts.push(GeminiPart::FunctionCall {
416                            function_call: GeminiFunctionCall {
417                                name: tc.function.name,
418                                args,
419                            },
420                        });
421                    }
422
423                    contents.push(GeminiContent {
424                        role: "model".to_string(),
425                        parts,
426                    });
427                }
428                Message::Tool(t) => {
429                    contents.push(GeminiContent {
430                        role: "user".to_string(),
431                        parts: vec![GeminiPart::FunctionResponse {
432                            function_response: GeminiFunctionResponse {
433                                name: "function_result".to_string(),
434                                response: GeminiToolResult {
435                                    name: "result".to_string(),
436                                    content: t.content,
437                                },
438                            },
439                        }],
440                    });
441                }
442                Message::Developer(d) => {
443                    system_instruction = Some(GeminiContent {
444                        role: "user".to_string(),
445                        parts: vec![GeminiPart::Text { text: d.content }],
446                    });
447                }
448            }
449        }
450
451        let gemini_tools = tools.map(|ts| GeminiTools {
452            function_declarations: ts
453                .into_iter()
454                .map(|t| GeminiFunctionDeclaration {
455                    name: t.name,
456                    description: t.description,
457                    parameters: t.parameters,
458                })
459                .collect(),
460        });
461
462        let thinking_config = if self.is_thinking_model() {
463            self.thinking_budget.map(|budget| GeminiThinkingConfig {
464                thinking_budget: budget,
465            })
466        } else {
467            None
468        };
469
470        Ok(GeminiRequest {
471            contents,
472            system_instruction,
473            tools: gemini_tools,
474            generation_config: Some(GeminiGenerationConfig {
475                temperature: self.temperature,
476                max_output_tokens: self.max_tokens,
477                thinking_config,
478            }),
479        })
480    }
481
482    fn parse_response(&self, response: GeminiResponse) -> ChatCompletion {
483        let stop_reason = response
484            .candidates
485            .first()
486            .and_then(|c| c.finish_reason.as_ref())
487            .and_then(|r| match r.as_str() {
488                "STOP" => Some(StopReason::EndTurn),
489                "MAX_TOKENS" => Some(StopReason::MaxTokens),
490                "TOOL_CODE" => Some(StopReason::ToolUse),
491                _ => None,
492            });
493
494        let candidate = response.candidates.into_iter().next();
495
496        let (content, thinking, tool_calls) = candidate
497            .map(|c| {
498                let mut text: Option<String> = None;
499                let mut think: Option<String> = None;
500                let mut calls = Vec::new();
501
502                for part in c.content.parts {
503                    match part {
504                        GeminiResponsePart::Text { text: t } => {
505                            text = Some(t);
506                        }
507                        GeminiResponsePart::Thought { thought: t } => {
508                            think = Some(t);
509                        }
510                        GeminiResponsePart::FunctionCall { function_call: fc } => {
511                            let id = fc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
512                            calls.push(crate::llm::ToolCall::new(
513                                id,
514                                fc.name,
515                                serde_json::to_string(&fc.args).unwrap_or_default(),
516                            ));
517                        }
518                    }
519                }
520
521                (text, think, calls)
522            })
523            .unwrap_or((None, None, Vec::new()));
524
525        let usage = response.usage_metadata.map(|u| Usage {
526            prompt_tokens: u.prompt_token_count,
527            completion_tokens: u.candidates_token_count,
528            total_tokens: u.total_token_count,
529            prompt_cached_tokens: Some(u.cached_content_token_count),
530            ..Default::default()
531        });
532
533        ChatCompletion {
534            content,
535            thinking,
536            redacted_thinking: None,
537            tool_calls,
538            usage,
539            stop_reason,
540        }
541    }
542
543    fn parse_stream_chunk(text: &str) -> Option<Result<ChatCompletion, LlmError>> {
544        // Google returns JSON array chunks
545        for line in text.lines() {
546            let line = line.trim();
547            if line.is_empty() {
548                continue;
549            }
550
551            // Handle array wrapping
552            let line = line.trim_start_matches('[').trim_end_matches(']');
553            if line.is_empty() {
554                continue;
555            }
556
557            // Handle comma-separated chunks
558            for chunk_str in line.split("},") {
559                let chunk_str = if !chunk_str.ends_with('}') {
560                    format!("{}{}", chunk_str, "}")
561                } else {
562                    chunk_str.to_string()
563                };
564
565                let chunk: serde_json::Value = match serde_json::from_str(&chunk_str) {
566                    Ok(v) => v,
567                    Err(_) => continue,
568                };
569
570                let parts = chunk
571                    .get("candidates")?
572                    .as_array()?
573                    .first()?
574                    .get("content")?
575                    .get("parts")?
576                    .as_array()?;
577
578                for part in parts {
579                    if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
580                        return Some(Ok(ChatCompletion::text(text)));
581                    }
582
583                    if let Some(thought) = part.get("thought").and_then(|t| t.as_str()) {
584                        let mut completion = ChatCompletion::text("");
585                        completion.thinking = Some(thought.to_string());
586                        return Some(Ok(completion));
587                    }
588
589                    if let Some(fc) = part.get("function_call") {
590                        let name = fc.get("name")?.as_str()?.to_string();
591                        let args = fc.get("args").cloned().unwrap_or(serde_json::json!({}));
592                        let id = fc.get("id").and_then(|i| i.as_str()).unwrap_or("pending");
593
594                        return Some(Ok(ChatCompletion {
595                            content: None,
596                            thinking: None,
597                            redacted_thinking: None,
598                            tool_calls: vec![crate::llm::ToolCall::new(
599                                id,
600                                name,
601                                serde_json::to_string(&args).unwrap_or_default(),
602                            )],
603                            usage: None,
604                            stop_reason: Some(StopReason::ToolUse),
605                        }));
606                    }
607                }
608            }
609        }
610
611        None
612    }
613}