Skip to main content

codetether_agent/provider/
google.rs

1//! Google Gemini provider implementation
2//!
3//! Uses the Google AI Gemini OpenAI-compatible endpoint for simplicity.
4//! Reference: https://ai.google.dev/gemini-api/docs/openai
5
6use super::{
7    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8    Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{Value, json};
15
16const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
17
18pub struct GoogleProvider {
19    client: Client,
20    api_key: String,
21}
22
23impl std::fmt::Debug for GoogleProvider {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("GoogleProvider")
26            .field("api_key", &"<REDACTED>")
27            .field("api_key_len", &self.api_key.len())
28            .finish()
29    }
30}
31
32impl GoogleProvider {
33    pub fn new(api_key: String) -> Result<Self> {
34        tracing::debug!(
35            provider = "google",
36            api_key_len = api_key.len(),
37            "Creating Google Gemini provider"
38        );
39        Ok(Self {
40            client: Client::new(),
41            api_key,
42        })
43    }
44
45    fn validate_api_key(&self) -> Result<()> {
46        if self.api_key.is_empty() {
47            anyhow::bail!("Google API key is empty");
48        }
49        Ok(())
50    }
51
52    fn convert_messages(messages: &[Message]) -> Vec<Value> {
53        messages
54            .iter()
55            .map(|msg| {
56                let role = match msg.role {
57                    Role::System => "system",
58                    Role::User => "user",
59                    Role::Assistant => "assistant",
60                    Role::Tool => "tool",
61                };
62
63                // For tool messages, we need to produce one message per tool result
64                if msg.role == Role::Tool {
65                    let mut content_parts: Vec<Value> = Vec::new();
66                    let mut tool_call_id = None;
67                    for part in &msg.content {
68                        match part {
69                            ContentPart::ToolResult {
70                                tool_call_id: id,
71                                content,
72                            } => {
73                                tool_call_id = Some(id.clone());
74                                content_parts.push(json!(content));
75                            }
76                            ContentPart::Text { text } => {
77                                content_parts.push(json!(text));
78                            }
79                            _ => {}
80                        }
81                    }
82                    let content_str = content_parts
83                        .iter()
84                        .filter_map(|v| v.as_str())
85                        .collect::<Vec<_>>()
86                        .join("\n");
87                    let mut m = json!({
88                        "role": "tool",
89                        "content": content_str,
90                    });
91                    if let Some(id) = tool_call_id {
92                        m["tool_call_id"] = json!(id);
93                    }
94                    return m;
95                }
96
97                // For assistant messages with tool calls
98                if msg.role == Role::Assistant {
99                    let mut text_parts = Vec::new();
100                    let mut tool_calls = Vec::new();
101                    for part in &msg.content {
102                        match part {
103                            ContentPart::Text { text } => {
104                                if !text.is_empty() {
105                                    text_parts.push(text.clone());
106                                }
107                            }
108                            ContentPart::ToolCall {
109                                id,
110                                name,
111                                arguments,
112                                thought_signature,
113                            } => {
114                                let mut tc = json!({
115                                    "id": id,
116                                    "type": "function",
117                                    "function": {
118                                        "name": name,
119                                        "arguments": arguments
120                                    }
121                                });
122                                // Include thought signature for Gemini 3.x models
123                                if let Some(sig) = thought_signature {
124                                    tc["extra_content"] = json!({
125                                        "google": {
126                                            "thought_signature": sig
127                                        }
128                                    });
129                                }
130                                tool_calls.push(tc);
131                            }
132                            _ => {}
133                        }
134                    }
135                    let content = text_parts.join("\n");
136                    let mut m = json!({"role": "assistant"});
137                    if !content.is_empty() || tool_calls.is_empty() {
138                        m["content"] = json!(content);
139                    }
140                    if !tool_calls.is_empty() {
141                        m["tool_calls"] = json!(tool_calls);
142                    }
143                    return m;
144                }
145
146                let text: String = msg
147                    .content
148                    .iter()
149                    .filter_map(|p| match p {
150                        ContentPart::Text { text } => Some(text.clone()),
151                        _ => None,
152                    })
153                    .collect::<Vec<_>>()
154                    .join("\n");
155
156                json!({
157                    "role": role,
158                    "content": text
159                })
160            })
161            .collect()
162    }
163
164    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
165        tools
166            .iter()
167            .map(|t| {
168                json!({
169                    "type": "function",
170                    "function": {
171                        "name": t.name,
172                        "description": t.description,
173                        "parameters": t.parameters
174                    }
175                })
176            })
177            .collect()
178    }
179}
180
181/// OpenAI-compatible types for parsing Google's response
182
183#[derive(Debug, Deserialize)]
184struct ChatCompletion {
185    #[allow(dead_code)]
186    id: Option<String>,
187    choices: Vec<Choice>,
188    #[serde(default)]
189    usage: Option<ApiUsage>,
190}
191
192#[derive(Debug, Deserialize)]
193struct Choice {
194    message: ChoiceMessage,
195    #[serde(default)]
196    finish_reason: Option<String>,
197}
198
199#[derive(Debug, Deserialize)]
200struct ChoiceMessage {
201    #[allow(dead_code)]
202    role: Option<String>,
203    #[serde(default)]
204    content: Option<String>,
205    #[serde(default)]
206    tool_calls: Option<Vec<ToolCall>>,
207}
208
209#[derive(Debug, Deserialize)]
210struct ToolCall {
211    id: String,
212    function: FunctionCall,
213    /// Thought signature for Gemini 3.x models
214    #[serde(default)]
215    extra_content: Option<ExtraContent>,
216}
217
218#[derive(Debug, Deserialize)]
219struct ExtraContent {
220    google: Option<GoogleExtra>,
221}
222
223#[derive(Debug, Deserialize)]
224struct GoogleExtra {
225    thought_signature: Option<String>,
226}
227
228#[derive(Debug, Deserialize)]
229struct FunctionCall {
230    name: String,
231    arguments: String,
232}
233
234#[derive(Debug, Deserialize)]
235struct ApiUsage {
236    #[serde(default)]
237    prompt_tokens: usize,
238    #[serde(default)]
239    completion_tokens: usize,
240    #[serde(default)]
241    total_tokens: usize,
242}
243
244#[derive(Debug, Deserialize)]
245struct ApiError {
246    error: ApiErrorDetail,
247}
248
249#[derive(Debug, Deserialize)]
250struct ApiErrorDetail {
251    message: String,
252}
253
254#[async_trait]
255impl Provider for GoogleProvider {
256    fn name(&self) -> &str {
257        "google"
258    }
259
260    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
261        self.validate_api_key()?;
262
263        Ok(vec![
264            // Gemini 3.x models (require thought signatures for tool calls)
265            ModelInfo {
266                id: "gemini-3.1-pro-preview".to_string(),
267                name: "Gemini 3.1 Pro Preview".to_string(),
268                provider: "google".to_string(),
269                context_window: 1_048_576,
270                max_output_tokens: Some(65_536),
271                supports_vision: true,
272                supports_tools: true,
273                supports_streaming: true,
274                input_cost_per_million: Some(2.0),
275                output_cost_per_million: Some(12.0),
276            },
277            ModelInfo {
278                id: "gemini-3.1-pro-preview-customtools".to_string(),
279                name: "Gemini 3.1 Pro Preview (Custom Tools)".to_string(),
280                provider: "google".to_string(),
281                context_window: 1_048_576,
282                max_output_tokens: Some(65_536),
283                supports_vision: true,
284                supports_tools: true,
285                supports_streaming: true,
286                input_cost_per_million: Some(2.0),
287                output_cost_per_million: Some(12.0),
288            },
289            ModelInfo {
290                id: "gemini-3-pro-preview".to_string(),
291                name: "Gemini 3 Pro Preview".to_string(),
292                provider: "google".to_string(),
293                context_window: 1_048_576,
294                max_output_tokens: Some(65_536),
295                supports_vision: true,
296                supports_tools: true,
297                supports_streaming: true,
298                input_cost_per_million: Some(2.0),
299                output_cost_per_million: Some(12.0),
300            },
301            ModelInfo {
302                id: "gemini-3-flash-preview".to_string(),
303                name: "Gemini 3 Flash Preview".to_string(),
304                provider: "google".to_string(),
305                context_window: 1_048_576,
306                max_output_tokens: Some(65_536),
307                supports_vision: true,
308                supports_tools: true,
309                supports_streaming: true,
310                input_cost_per_million: Some(0.50),
311                output_cost_per_million: Some(3.0),
312            },
313            ModelInfo {
314                id: "gemini-3-pro-image-preview".to_string(),
315                name: "Gemini 3 Pro Image Preview".to_string(),
316                provider: "google".to_string(),
317                context_window: 65_536,
318                max_output_tokens: Some(32_768),
319                supports_vision: true,
320                supports_tools: false,
321                supports_streaming: false,
322                input_cost_per_million: Some(2.0),
323                output_cost_per_million: Some(134.0),
324            },
325            // Gemini 2.5 models
326            ModelInfo {
327                id: "gemini-2.5-pro".to_string(),
328                name: "Gemini 2.5 Pro".to_string(),
329                provider: "google".to_string(),
330                context_window: 1_048_576,
331                max_output_tokens: Some(65_536),
332                supports_vision: true,
333                supports_tools: true,
334                supports_streaming: true,
335                input_cost_per_million: Some(1.25),
336                output_cost_per_million: Some(10.0),
337            },
338            ModelInfo {
339                id: "gemini-2.5-flash".to_string(),
340                name: "Gemini 2.5 Flash".to_string(),
341                provider: "google".to_string(),
342                context_window: 1_048_576,
343                max_output_tokens: Some(65_536),
344                supports_vision: true,
345                supports_tools: true,
346                supports_streaming: true,
347                input_cost_per_million: Some(0.15),
348                output_cost_per_million: Some(0.60),
349            },
350            ModelInfo {
351                id: "gemini-2.0-flash".to_string(),
352                name: "Gemini 2.0 Flash".to_string(),
353                provider: "google".to_string(),
354                context_window: 1_048_576,
355                max_output_tokens: Some(8_192),
356                supports_vision: true,
357                supports_tools: true,
358                supports_streaming: true,
359                input_cost_per_million: Some(0.10),
360                output_cost_per_million: Some(0.40),
361            },
362        ])
363    }
364
365    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
366        tracing::debug!(
367            provider = "google",
368            model = %request.model,
369            message_count = request.messages.len(),
370            tool_count = request.tools.len(),
371            "Starting Google Gemini completion request"
372        );
373
374        self.validate_api_key()?;
375
376        let messages = Self::convert_messages(&request.messages);
377        let tools = Self::convert_tools(&request.tools);
378
379        let mut body = json!({
380            "model": request.model,
381            "messages": messages,
382        });
383
384        if let Some(max_tokens) = request.max_tokens {
385            body["max_tokens"] = json!(max_tokens);
386        }
387        if !tools.is_empty() {
388            body["tools"] = json!(tools);
389        }
390        if let Some(temp) = request.temperature {
391            body["temperature"] = json!(temp);
392        }
393        if let Some(top_p) = request.top_p {
394            body["top_p"] = json!(top_p);
395        }
396
397        tracing::debug!("Google Gemini request to model {}", request.model);
398
399        // Google AI Studio OpenAI-compatible endpoint uses Bearer token auth
400        let url = format!("{}/chat/completions", GOOGLE_OPENAI_BASE);
401        let response = self
402            .client
403            .post(&url)
404            .header("content-type", "application/json")
405            .header("Authorization", format!("Bearer {}", self.api_key))
406            .json(&body)
407            .send()
408            .await
409            .context("Failed to send request to Google Gemini")?;
410
411        let status = response.status();
412        let text = response
413            .text()
414            .await
415            .context("Failed to read Google Gemini response")?;
416
417        if !status.is_success() {
418            if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
419                anyhow::bail!("Google Gemini API error: {}", err.error.message);
420            }
421            anyhow::bail!("Google Gemini API error: {} {}", status, text);
422        }
423
424        let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
425            "Failed to parse Google Gemini response: {}",
426            &text[..text.len().min(200)]
427        ))?;
428
429        let choice = completion
430            .choices
431            .into_iter()
432            .next()
433            .context("No choices in Google Gemini response")?;
434
435        let mut content_parts = Vec::new();
436        let mut has_tool_calls = false;
437
438        if let Some(text) = choice.message.content {
439            if !text.is_empty() {
440                content_parts.push(ContentPart::Text { text });
441            }
442        }
443
444        if let Some(tool_calls) = choice.message.tool_calls {
445            has_tool_calls = !tool_calls.is_empty();
446            for tc in tool_calls {
447                // Extract thought signature from extra_content.google.thought_signature
448                let thought_signature = tc
449                    .extra_content
450                    .as_ref()
451                    .and_then(|ec| ec.google.as_ref())
452                    .and_then(|g| g.thought_signature.clone());
453
454                content_parts.push(ContentPart::ToolCall {
455                    id: tc.id,
456                    name: tc.function.name,
457                    arguments: tc.function.arguments,
458                    thought_signature,
459                });
460            }
461        }
462
463        let finish_reason = if has_tool_calls {
464            FinishReason::ToolCalls
465        } else {
466            match choice.finish_reason.as_deref() {
467                Some("stop") => FinishReason::Stop,
468                Some("length") => FinishReason::Length,
469                Some("tool_calls") => FinishReason::ToolCalls,
470                Some("content_filter") => FinishReason::ContentFilter,
471                _ => FinishReason::Stop,
472            }
473        };
474
475        let usage = completion.usage.as_ref();
476
477        Ok(CompletionResponse {
478            message: Message {
479                role: Role::Assistant,
480                content: content_parts,
481            },
482            usage: Usage {
483                prompt_tokens: usage.map(|u| u.prompt_tokens).unwrap_or(0),
484                completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
485                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
486                cache_read_tokens: None,
487                cache_write_tokens: None,
488            },
489            finish_reason,
490        })
491    }
492
493    async fn complete_stream(
494        &self,
495        request: CompletionRequest,
496    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
497        // Fall back to non-streaming for now
498        let response = self.complete(request).await?;
499        let text = response
500            .message
501            .content
502            .iter()
503            .filter_map(|p| match p {
504                ContentPart::Text { text } => Some(text.clone()),
505                _ => None,
506            })
507            .collect::<Vec<_>>()
508            .join("");
509
510        Ok(Box::pin(futures::stream::once(async move {
511            StreamChunk::Text(text)
512        })))
513    }
514}