Skip to main content

codetether_agent/provider/
google.rs

1//! Google Gemini provider implementation
2//!
3//! Uses the Google AI Gemini OpenAI-compatible endpoint for simplicity.
4//! Reference: https://ai.google.dev/gemini-api/docs/openai
5
6use super::{
7    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8    Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{Value, json};
15
16const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
17
18pub struct GoogleProvider {
19    client: Client,
20    api_key: String,
21}
22
23impl std::fmt::Debug for GoogleProvider {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("GoogleProvider")
26            .field("api_key", &"<REDACTED>")
27            .field("api_key_len", &self.api_key.len())
28            .finish()
29    }
30}
31
32impl GoogleProvider {
33    pub fn new(api_key: String) -> Result<Self> {
34        tracing::debug!(
35            provider = "google",
36            api_key_len = api_key.len(),
37            "Creating Google Gemini provider"
38        );
39        Ok(Self {
40            client: Client::new(),
41            api_key,
42        })
43    }
44
45    fn validate_api_key(&self) -> Result<()> {
46        if self.api_key.is_empty() {
47            anyhow::bail!("Google API key is empty");
48        }
49        Ok(())
50    }
51
52    fn convert_messages(messages: &[Message]) -> Vec<Value> {
53        messages
54            .iter()
55            .map(|msg| {
56                let role = match msg.role {
57                    Role::System => "system",
58                    Role::User => "user",
59                    Role::Assistant => "assistant",
60                    Role::Tool => "tool",
61                };
62
63                // For tool messages, we need to produce one message per tool result
64                if msg.role == Role::Tool {
65                    let mut content_parts: Vec<Value> = Vec::new();
66                    let mut tool_call_id = None;
67                    for part in &msg.content {
68                        match part {
69                            ContentPart::ToolResult {
70                                tool_call_id: id,
71                                content,
72                            } => {
73                                tool_call_id = Some(id.clone());
74                                content_parts.push(json!(content));
75                            }
76                            ContentPart::Text { text } => {
77                                content_parts.push(json!(text));
78                            }
79                            _ => {}
80                        }
81                    }
82                    let content_str = content_parts
83                        .iter()
84                        .filter_map(|v| v.as_str())
85                        .collect::<Vec<_>>()
86                        .join("\n");
87                    let mut m = json!({
88                        "role": "tool",
89                        "content": content_str,
90                    });
91                    if let Some(id) = tool_call_id {
92                        m["tool_call_id"] = json!(id);
93                    }
94                    return m;
95                }
96
97                // For assistant messages with tool calls
98                if msg.role == Role::Assistant {
99                    let mut text_parts = Vec::new();
100                    let mut tool_calls = Vec::new();
101                    for part in &msg.content {
102                        match part {
103                            ContentPart::Text { text } => {
104                                if !text.is_empty() {
105                                    text_parts.push(text.clone());
106                                }
107                            }
108                            ContentPart::ToolCall {
109                                id,
110                                name,
111                                arguments,
112                            } => {
113                                tool_calls.push(json!({
114                                    "id": id,
115                                    "type": "function",
116                                    "function": {
117                                        "name": name,
118                                        "arguments": arguments
119                                    }
120                                }));
121                            }
122                            _ => {}
123                        }
124                    }
125                    let content = text_parts.join("\n");
126                    let mut m = json!({"role": "assistant"});
127                    if !content.is_empty() || tool_calls.is_empty() {
128                        m["content"] = json!(content);
129                    }
130                    if !tool_calls.is_empty() {
131                        m["tool_calls"] = json!(tool_calls);
132                    }
133                    return m;
134                }
135
136                let text: String = msg
137                    .content
138                    .iter()
139                    .filter_map(|p| match p {
140                        ContentPart::Text { text } => Some(text.clone()),
141                        _ => None,
142                    })
143                    .collect::<Vec<_>>()
144                    .join("\n");
145
146                json!({
147                    "role": role,
148                    "content": text
149                })
150            })
151            .collect()
152    }
153
154    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
155        tools
156            .iter()
157            .map(|t| {
158                json!({
159                    "type": "function",
160                    "function": {
161                        "name": t.name,
162                        "description": t.description,
163                        "parameters": t.parameters
164                    }
165                })
166            })
167            .collect()
168    }
169}
170
171/// OpenAI-compatible types for parsing Google's response
172
173#[derive(Debug, Deserialize)]
174struct ChatCompletion {
175    #[allow(dead_code)]
176    id: Option<String>,
177    choices: Vec<Choice>,
178    #[serde(default)]
179    usage: Option<ApiUsage>,
180}
181
182#[derive(Debug, Deserialize)]
183struct Choice {
184    message: ChoiceMessage,
185    #[serde(default)]
186    finish_reason: Option<String>,
187}
188
189#[derive(Debug, Deserialize)]
190struct ChoiceMessage {
191    #[allow(dead_code)]
192    role: Option<String>,
193    #[serde(default)]
194    content: Option<String>,
195    #[serde(default)]
196    tool_calls: Option<Vec<ToolCall>>,
197}
198
199#[derive(Debug, Deserialize)]
200struct ToolCall {
201    id: String,
202    function: FunctionCall,
203}
204
205#[derive(Debug, Deserialize)]
206struct FunctionCall {
207    name: String,
208    arguments: String,
209}
210
211#[derive(Debug, Deserialize)]
212struct ApiUsage {
213    #[serde(default)]
214    prompt_tokens: usize,
215    #[serde(default)]
216    completion_tokens: usize,
217    #[serde(default)]
218    total_tokens: usize,
219}
220
221#[derive(Debug, Deserialize)]
222struct ApiError {
223    error: ApiErrorDetail,
224}
225
226#[derive(Debug, Deserialize)]
227struct ApiErrorDetail {
228    message: String,
229}
230
231#[async_trait]
232impl Provider for GoogleProvider {
233    fn name(&self) -> &str {
234        "google"
235    }
236
237    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
238        self.validate_api_key()?;
239
240        Ok(vec![
241            ModelInfo {
242                id: "gemini-2.5-pro".to_string(),
243                name: "Gemini 2.5 Pro".to_string(),
244                provider: "google".to_string(),
245                context_window: 1_048_576,
246                max_output_tokens: Some(65_536),
247                supports_vision: true,
248                supports_tools: true,
249                supports_streaming: true,
250                input_cost_per_million: Some(1.25),
251                output_cost_per_million: Some(10.0),
252            },
253            ModelInfo {
254                id: "gemini-2.5-flash".to_string(),
255                name: "Gemini 2.5 Flash".to_string(),
256                provider: "google".to_string(),
257                context_window: 1_048_576,
258                max_output_tokens: Some(65_536),
259                supports_vision: true,
260                supports_tools: true,
261                supports_streaming: true,
262                input_cost_per_million: Some(0.15),
263                output_cost_per_million: Some(0.60),
264            },
265            ModelInfo {
266                id: "gemini-2.0-flash".to_string(),
267                name: "Gemini 2.0 Flash".to_string(),
268                provider: "google".to_string(),
269                context_window: 1_048_576,
270                max_output_tokens: Some(8_192),
271                supports_vision: true,
272                supports_tools: true,
273                supports_streaming: true,
274                input_cost_per_million: Some(0.10),
275                output_cost_per_million: Some(0.40),
276            },
277        ])
278    }
279
280    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
281        tracing::debug!(
282            provider = "google",
283            model = %request.model,
284            message_count = request.messages.len(),
285            tool_count = request.tools.len(),
286            "Starting Google Gemini completion request"
287        );
288
289        self.validate_api_key()?;
290
291        let messages = Self::convert_messages(&request.messages);
292        let tools = Self::convert_tools(&request.tools);
293
294        let mut body = json!({
295            "model": request.model,
296            "messages": messages,
297        });
298
299        if let Some(max_tokens) = request.max_tokens {
300            body["max_tokens"] = json!(max_tokens);
301        }
302        if !tools.is_empty() {
303            body["tools"] = json!(tools);
304        }
305        if let Some(temp) = request.temperature {
306            body["temperature"] = json!(temp);
307        }
308        if let Some(top_p) = request.top_p {
309            body["top_p"] = json!(top_p);
310        }
311
312        tracing::debug!("Google Gemini request to model {}", request.model);
313
314        // Google AI Studio keys use query param auth, not Bearer tokens
315        let url = format!(
316            "{}/chat/completions?key={}",
317            GOOGLE_OPENAI_BASE, self.api_key
318        );
319        let response = self
320            .client
321            .post(&url)
322            .header("content-type", "application/json")
323            .json(&body)
324            .send()
325            .await
326            .context("Failed to send request to Google Gemini")?;
327
328        let status = response.status();
329        let text = response
330            .text()
331            .await
332            .context("Failed to read Google Gemini response")?;
333
334        if !status.is_success() {
335            if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
336                anyhow::bail!("Google Gemini API error: {}", err.error.message);
337            }
338            anyhow::bail!("Google Gemini API error: {} {}", status, text);
339        }
340
341        let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
342            "Failed to parse Google Gemini response: {}",
343            &text[..text.len().min(200)]
344        ))?;
345
346        let choice = completion
347            .choices
348            .into_iter()
349            .next()
350            .context("No choices in Google Gemini response")?;
351
352        let mut content_parts = Vec::new();
353        let mut has_tool_calls = false;
354
355        if let Some(text) = choice.message.content {
356            if !text.is_empty() {
357                content_parts.push(ContentPart::Text { text });
358            }
359        }
360
361        if let Some(tool_calls) = choice.message.tool_calls {
362            has_tool_calls = !tool_calls.is_empty();
363            for tc in tool_calls {
364                content_parts.push(ContentPart::ToolCall {
365                    id: tc.id,
366                    name: tc.function.name,
367                    arguments: tc.function.arguments,
368                });
369            }
370        }
371
372        let finish_reason = if has_tool_calls {
373            FinishReason::ToolCalls
374        } else {
375            match choice.finish_reason.as_deref() {
376                Some("stop") => FinishReason::Stop,
377                Some("length") => FinishReason::Length,
378                Some("tool_calls") => FinishReason::ToolCalls,
379                Some("content_filter") => FinishReason::ContentFilter,
380                _ => FinishReason::Stop,
381            }
382        };
383
384        let usage = completion.usage.as_ref();
385
386        Ok(CompletionResponse {
387            message: Message {
388                role: Role::Assistant,
389                content: content_parts,
390            },
391            usage: Usage {
392                prompt_tokens: usage.map(|u| u.prompt_tokens).unwrap_or(0),
393                completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
394                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
395                cache_read_tokens: None,
396                cache_write_tokens: None,
397            },
398            finish_reason,
399        })
400    }
401
402    async fn complete_stream(
403        &self,
404        request: CompletionRequest,
405    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
406        // Fall back to non-streaming for now
407        let response = self.complete(request).await?;
408        let text = response
409            .message
410            .content
411            .iter()
412            .filter_map(|p| match p {
413                ContentPart::Text { text } => Some(text.clone()),
414                _ => None,
415            })
416            .collect::<Vec<_>>()
417            .join("");
418
419        Ok(Box::pin(futures::stream::once(async move {
420            StreamChunk::Text(text)
421        })))
422    }
423}