Skip to main content

agent_io/llm/google/
mod.rs

1//! Google Gemini Chat Model implementation
2
3mod request;
4mod response;
5mod types;
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use futures::StreamExt;
10use reqwest::Client;
11use std::time::Duration;
12
13use crate::llm::{
14    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
15};
16
17use types::*;
18
19const GOOGLE_API_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
20
21/// Google Gemini Chat Model
22#[derive(Builder, Clone)]
23#[builder(pattern = "owned", build_fn(skip))]
24pub struct ChatGoogle {
25    /// Model identifier
26    #[builder(setter(into))]
27    pub(super) model: String,
28    /// API key
29    pub(super) api_key: String,
30    /// Base URL
31    #[builder(setter(into, strip_option), default = "None")]
32    pub(super) base_url: Option<String>,
33    /// Maximum output tokens
34    #[builder(default = "8192")]
35    pub(super) max_tokens: u64,
36    /// Temperature for sampling
37    #[builder(default = "0.2")]
38    pub(super) temperature: f32,
39    /// Thinking budget (for thinking models)
40    #[builder(default = "None")]
41    pub(super) thinking_budget: Option<u64>,
42    /// HTTP client
43    #[builder(setter(skip))]
44    pub(super) client: Client,
45    /// Context window
46    #[builder(setter(skip))]
47    pub(super) context_window: u64,
48}
49
50impl ChatGoogle {
51    /// Create a new Google chat model
52    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
53        let api_key = std::env::var("GOOGLE_API_KEY")
54            .or_else(|_| std::env::var("GEMINI_API_KEY"))
55            .map_err(|_| LlmError::Config("GOOGLE_API_KEY or GEMINI_API_KEY not set".into()))?;
56
57        Self::builder().model(model).api_key(api_key).build()
58    }
59
60    /// Create a builder for configuration
61    pub fn builder() -> ChatGoogleBuilder {
62        ChatGoogleBuilder::default()
63    }
64
65    /// Get the API URL for the model
66    fn api_url(&self, stream: bool) -> String {
67        let base = self.base_url.as_deref().unwrap_or(GOOGLE_API_URL);
68        let method = if stream {
69            "streamGenerateContent"
70        } else {
71            "generateContent"
72        };
73        format!("{}/{}:{}?key={}", base, self.model, method, self.api_key)
74    }
75
76    /// Build the HTTP client
77    fn build_client() -> Client {
78        Client::builder()
79            .timeout(Duration::from_secs(120))
80            .build()
81            .expect("Failed to create HTTP client")
82    }
83
84    /// Get context window for model
85    fn get_context_window(model: &str) -> u64 {
86        let model_lower = model.to_lowercase();
87
88        if model_lower.contains("gemini-1.5-pro") {
89            2_097_152 // 2M tokens
90        } else {
91            1_048_576 // 1M tokens - default for most Gemini models
92        }
93    }
94
95    /// Check if this is a thinking model
96    fn is_thinking_model(&self) -> bool {
97        let model_lower = self.model.to_lowercase();
98        model_lower.contains("gemini-2.5")
99            || model_lower.contains("thinking")
100            || model_lower.contains("gemini-exp")
101    }
102}
103
104impl ChatGoogleBuilder {
105    pub fn build(&self) -> Result<ChatGoogle, LlmError> {
106        let model = self
107            .model
108            .clone()
109            .ok_or_else(|| LlmError::Config("model is required".into()))?;
110        let api_key = self
111            .api_key
112            .clone()
113            .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
114
115        Ok(ChatGoogle {
116            context_window: ChatGoogle::get_context_window(&model),
117            client: ChatGoogle::build_client(),
118            model,
119            api_key,
120            base_url: self.base_url.clone().flatten(),
121            max_tokens: self.max_tokens.unwrap_or(8192),
122            temperature: self.temperature.unwrap_or(0.2),
123            thinking_budget: self.thinking_budget.flatten(),
124        })
125    }
126}
127
128#[async_trait]
129impl BaseChatModel for ChatGoogle {
130    fn model(&self) -> &str {
131        &self.model
132    }
133
134    fn provider(&self) -> &str {
135        "google"
136    }
137
138    fn context_window(&self) -> Option<u64> {
139        Some(self.context_window)
140    }
141
142    async fn invoke(
143        &self,
144        messages: Vec<Message>,
145        tools: Option<Vec<ToolDefinition>>,
146        tool_choice: Option<ToolChoice>,
147    ) -> Result<ChatCompletion, LlmError> {
148        let request = self.build_request(messages, tools, tool_choice)?;
149
150        let response = self
151            .client
152            .post(self.api_url(false))
153            .header("Content-Type", "application/json")
154            .json(&request)
155            .send()
156            .await?;
157
158        if !response.status().is_success() {
159            let status = response.status();
160            let body = response.text().await.unwrap_or_default();
161            return Err(LlmError::Api(format!(
162                "Google API error ({}): {}",
163                status, body
164            )));
165        }
166
167        let completion: GeminiResponse = response.json().await?;
168        Ok(self.parse_response(completion))
169    }
170
171    async fn invoke_stream(
172        &self,
173        messages: Vec<Message>,
174        tools: Option<Vec<ToolDefinition>>,
175        tool_choice: Option<ToolChoice>,
176    ) -> Result<ChatStream, LlmError> {
177        let request = self.build_request(messages, tools, tool_choice)?;
178
179        let response = self
180            .client
181            .post(self.api_url(true))
182            .header("Content-Type", "application/json")
183            .json(&request)
184            .send()
185            .await?;
186
187        if !response.status().is_success() {
188            let status = response.status();
189            let body = response.text().await.unwrap_or_default();
190            return Err(LlmError::Api(format!(
191                "Google API error ({}): {}",
192                status, body
193            )));
194        }
195
196        // Google returns JSON lines for streaming
197        let stream = response.bytes_stream().filter_map(|result| async move {
198            match result {
199                Ok(bytes) => {
200                    let text = String::from_utf8_lossy(&bytes);
201                    Self::parse_stream_chunk(&text)
202                }
203                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
204            }
205        });
206
207        Ok(Box::pin(stream))
208    }
209
210    fn supports_vision(&self) -> bool {
211        // All Gemini models support vision
212        true
213    }
214}