Skip to main content

codetether_agent/provider/
google.rs

1//! Google Gemini provider implementation
2//!
3//! Uses the Google AI Gemini OpenAI-compatible endpoint for simplicity.
4//! Reference: https://ai.google.dev/gemini-api/docs/openai
5
6use super::util;
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
18
19pub struct GoogleProvider {
20    client: Client,
21    api_key: String,
22}
23
24impl std::fmt::Debug for GoogleProvider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("GoogleProvider")
27            .field("api_key", &"<REDACTED>")
28            .field("api_key_len", &self.api_key.len())
29            .finish()
30    }
31}
32
33impl GoogleProvider {
34    pub fn new(api_key: String) -> Result<Self> {
35        tracing::debug!(
36            provider = "google",
37            api_key_len = api_key.len(),
38            "Creating Google Gemini provider"
39        );
40        Ok(Self {
41            client: Client::new(),
42            api_key,
43        })
44    }
45
46    fn validate_api_key(&self) -> Result<()> {
47        if self.api_key.is_empty() {
48            anyhow::bail!("Google API key is empty");
49        }
50        Ok(())
51    }
52
53    fn convert_messages(messages: &[Message]) -> Vec<Value> {
54        messages
55            .iter()
56            .map(|msg| {
57                let role = match msg.role {
58                    Role::System => "system",
59                    Role::User => "user",
60                    Role::Assistant => "assistant",
61                    Role::Tool => "tool",
62                };
63
64                // For tool messages, we need to produce one message per tool result
65                if msg.role == Role::Tool {
66                    let mut content_parts: Vec<Value> = Vec::new();
67                    let mut tool_call_id = None;
68                    for part in &msg.content {
69                        match part {
70                            ContentPart::ToolResult {
71                                tool_call_id: id,
72                                content,
73                            } => {
74                                tool_call_id = Some(id.clone());
75                                content_parts.push(json!(content));
76                            }
77                            ContentPart::Text { text } => {
78                                content_parts.push(json!(text));
79                            }
80                            _ => {}
81                        }
82                    }
83                    let content_str = content_parts
84                        .iter()
85                        .filter_map(|v| v.as_str())
86                        .collect::<Vec<_>>()
87                        .join("\n");
88                    let mut m = json!({
89                        "role": "tool",
90                        "content": content_str,
91                    });
92                    if let Some(id) = tool_call_id {
93                        m["tool_call_id"] = json!(id);
94                    }
95                    return m;
96                }
97
98                // For assistant messages with tool calls
99                if msg.role == Role::Assistant {
100                    let mut text_parts = Vec::new();
101                    let mut tool_calls = Vec::new();
102                    for part in &msg.content {
103                        match part {
104                            ContentPart::Text { text } => {
105                                if !text.is_empty() {
106                                    text_parts.push(text.clone());
107                                }
108                            }
109                            ContentPart::ToolCall {
110                                id,
111                                name,
112                                arguments,
113                                thought_signature,
114                            } => {
115                                let mut tc = json!({
116                                    "id": id,
117                                    "type": "function",
118                                    "function": {
119                                        "name": name,
120                                        "arguments": arguments
121                                    }
122                                });
123                                // Include thought signature for Gemini 3.x models
124                                if let Some(sig) = thought_signature {
125                                    tc["extra_content"] = json!({
126                                        "google": {
127                                            "thought_signature": sig
128                                        }
129                                    });
130                                }
131                                tool_calls.push(tc);
132                            }
133                            _ => {}
134                        }
135                    }
136                    let content = text_parts.join("\n");
137                    let mut m = json!({"role": "assistant"});
138                    if !content.is_empty() || tool_calls.is_empty() {
139                        m["content"] = json!(content);
140                    }
141                    if !tool_calls.is_empty() {
142                        m["tool_calls"] = json!(tool_calls);
143                    }
144                    return m;
145                }
146
147                let text: String = msg
148                    .content
149                    .iter()
150                    .filter_map(|p| match p {
151                        ContentPart::Text { text } => Some(text.clone()),
152                        _ => None,
153                    })
154                    .collect::<Vec<_>>()
155                    .join("\n");
156
157                json!({
158                    "role": role,
159                    "content": text
160                })
161            })
162            .collect()
163    }
164
165    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
166        tools
167            .iter()
168            .map(|t| {
169                json!({
170                    "type": "function",
171                    "function": {
172                        "name": t.name,
173                        "description": t.description,
174                        "parameters": t.parameters
175                    }
176                })
177            })
178            .collect()
179    }
180}
181
182/// OpenAI-compatible types for parsing Google's response
183
184#[derive(Debug, Deserialize)]
185struct ChatCompletion {
186    #[allow(dead_code)]
187    id: Option<String>,
188    choices: Vec<Choice>,
189    #[serde(default)]
190    usage: Option<ApiUsage>,
191}
192
193#[derive(Debug, Deserialize)]
194struct Choice {
195    message: ChoiceMessage,
196    #[serde(default)]
197    finish_reason: Option<String>,
198}
199
200#[derive(Debug, Deserialize)]
201struct ChoiceMessage {
202    #[allow(dead_code)]
203    role: Option<String>,
204    #[serde(default)]
205    content: Option<String>,
206    #[serde(default)]
207    tool_calls: Option<Vec<ToolCall>>,
208}
209
210#[derive(Debug, Deserialize)]
211struct ToolCall {
212    id: String,
213    function: FunctionCall,
214    /// Thought signature for Gemini 3.x models
215    #[serde(default)]
216    extra_content: Option<ExtraContent>,
217}
218
219#[derive(Debug, Deserialize)]
220struct ExtraContent {
221    google: Option<GoogleExtra>,
222}
223
224#[derive(Debug, Deserialize)]
225struct GoogleExtra {
226    thought_signature: Option<String>,
227}
228
229#[derive(Debug, Deserialize)]
230struct FunctionCall {
231    name: String,
232    arguments: String,
233}
234
235#[derive(Debug, Deserialize)]
236struct ApiUsage {
237    #[serde(default)]
238    prompt_tokens: usize,
239    #[serde(default)]
240    completion_tokens: usize,
241    #[serde(default)]
242    total_tokens: usize,
243}
244
245#[derive(Debug, Deserialize)]
246struct ApiError {
247    error: ApiErrorDetail,
248}
249
250#[derive(Debug, Deserialize)]
251struct ApiErrorDetail {
252    message: String,
253}
254
255#[async_trait]
256impl Provider for GoogleProvider {
257    fn name(&self) -> &str {
258        "google"
259    }
260
261    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
262        self.validate_api_key()?;
263
264        Ok(vec![
265            // Gemini 3.x models (require thought signatures for tool calls)
266            ModelInfo {
267                id: "gemini-3.1-pro-preview".to_string(),
268                name: "Gemini 3.1 Pro Preview".to_string(),
269                provider: "google".to_string(),
270                context_window: 1_048_576,
271                max_output_tokens: Some(65_536),
272                supports_vision: true,
273                supports_tools: true,
274                supports_streaming: true,
275                input_cost_per_million: Some(2.0),
276                output_cost_per_million: Some(12.0),
277            },
278            ModelInfo {
279                id: "gemini-3.1-pro-preview-customtools".to_string(),
280                name: "Gemini 3.1 Pro Preview (Custom Tools)".to_string(),
281                provider: "google".to_string(),
282                context_window: 1_048_576,
283                max_output_tokens: Some(65_536),
284                supports_vision: true,
285                supports_tools: true,
286                supports_streaming: true,
287                input_cost_per_million: Some(2.0),
288                output_cost_per_million: Some(12.0),
289            },
290            ModelInfo {
291                id: "gemini-3-pro-preview".to_string(),
292                name: "Gemini 3 Pro Preview".to_string(),
293                provider: "google".to_string(),
294                context_window: 1_048_576,
295                max_output_tokens: Some(65_536),
296                supports_vision: true,
297                supports_tools: true,
298                supports_streaming: true,
299                input_cost_per_million: Some(2.0),
300                output_cost_per_million: Some(12.0),
301            },
302            ModelInfo {
303                id: "gemini-3-flash-preview".to_string(),
304                name: "Gemini 3 Flash Preview".to_string(),
305                provider: "google".to_string(),
306                context_window: 1_048_576,
307                max_output_tokens: Some(65_536),
308                supports_vision: true,
309                supports_tools: true,
310                supports_streaming: true,
311                input_cost_per_million: Some(0.50),
312                output_cost_per_million: Some(3.0),
313            },
314            ModelInfo {
315                id: "gemini-3-pro-image-preview".to_string(),
316                name: "Gemini 3 Pro Image Preview".to_string(),
317                provider: "google".to_string(),
318                context_window: 65_536,
319                max_output_tokens: Some(32_768),
320                supports_vision: true,
321                supports_tools: false,
322                supports_streaming: false,
323                input_cost_per_million: Some(2.0),
324                output_cost_per_million: Some(134.0),
325            },
326            // Gemini 2.5 models
327            ModelInfo {
328                id: "gemini-2.5-pro".to_string(),
329                name: "Gemini 2.5 Pro".to_string(),
330                provider: "google".to_string(),
331                context_window: 1_048_576,
332                max_output_tokens: Some(65_536),
333                supports_vision: true,
334                supports_tools: true,
335                supports_streaming: true,
336                input_cost_per_million: Some(1.25),
337                output_cost_per_million: Some(10.0),
338            },
339            ModelInfo {
340                id: "gemini-2.5-flash".to_string(),
341                name: "Gemini 2.5 Flash".to_string(),
342                provider: "google".to_string(),
343                context_window: 1_048_576,
344                max_output_tokens: Some(65_536),
345                supports_vision: true,
346                supports_tools: true,
347                supports_streaming: true,
348                input_cost_per_million: Some(0.15),
349                output_cost_per_million: Some(0.60),
350            },
351            ModelInfo {
352                id: "gemini-2.0-flash".to_string(),
353                name: "Gemini 2.0 Flash".to_string(),
354                provider: "google".to_string(),
355                context_window: 1_048_576,
356                max_output_tokens: Some(8_192),
357                supports_vision: true,
358                supports_tools: true,
359                supports_streaming: true,
360                input_cost_per_million: Some(0.10),
361                output_cost_per_million: Some(0.40),
362            },
363        ])
364    }
365
366    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
367        tracing::debug!(
368            provider = "google",
369            model = %request.model,
370            message_count = request.messages.len(),
371            tool_count = request.tools.len(),
372            "Starting Google Gemini completion request"
373        );
374
375        self.validate_api_key()?;
376
377        let messages = Self::convert_messages(&request.messages);
378        let tools = Self::convert_tools(&request.tools);
379
380        let mut body = json!({
381            "model": request.model,
382            "messages": messages,
383        });
384
385        if let Some(max_tokens) = request.max_tokens {
386            body["max_tokens"] = json!(max_tokens);
387        }
388        if !tools.is_empty() {
389            body["tools"] = json!(tools);
390        }
391        if let Some(temp) = request.temperature {
392            body["temperature"] = json!(temp);
393        }
394        if let Some(top_p) = request.top_p {
395            body["top_p"] = json!(top_p);
396        }
397
398        tracing::debug!("Google Gemini request to model {}", request.model);
399
400        // Google AI Studio OpenAI-compatible endpoint uses Bearer token auth
401        let url = format!("{}/chat/completions", GOOGLE_OPENAI_BASE);
402        let response = self
403            .client
404            .post(&url)
405            .header("content-type", "application/json")
406            .header("Authorization", format!("Bearer {}", self.api_key))
407            .json(&body)
408            .send()
409            .await
410            .context("Failed to send request to Google Gemini")?;
411
412        let status = response.status();
413        let text = response
414            .text()
415            .await
416            .context("Failed to read Google Gemini response")?;
417
418        if !status.is_success() {
419            if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
420                anyhow::bail!("Google Gemini API error: {}", err.error.message);
421            }
422            anyhow::bail!("Google Gemini API error: {} {}", status, text);
423        }
424
425        let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
426            "Failed to parse Google Gemini response: {}",
427            util::truncate_bytes_safe(&text, 200)
428        ))?;
429
430        let choice = completion
431            .choices
432            .into_iter()
433            .next()
434            .context("No choices in Google Gemini response")?;
435
436        let mut content_parts = Vec::new();
437        let mut has_tool_calls = false;
438
439        if let Some(text) = choice.message.content
440            && !text.is_empty()
441        {
442            content_parts.push(ContentPart::Text { text });
443        }
444
445        if let Some(tool_calls) = choice.message.tool_calls {
446            has_tool_calls = !tool_calls.is_empty();
447            for tc in tool_calls {
448                // Extract thought signature from extra_content.google.thought_signature
449                let thought_signature = tc
450                    .extra_content
451                    .as_ref()
452                    .and_then(|ec| ec.google.as_ref())
453                    .and_then(|g| g.thought_signature.clone());
454
455                content_parts.push(ContentPart::ToolCall {
456                    id: tc.id,
457                    name: tc.function.name,
458                    arguments: tc.function.arguments,
459                    thought_signature,
460                });
461            }
462        }
463
464        let finish_reason = if has_tool_calls {
465            FinishReason::ToolCalls
466        } else {
467            match choice.finish_reason.as_deref() {
468                Some("stop") => FinishReason::Stop,
469                Some("length") => FinishReason::Length,
470                Some("tool_calls") => FinishReason::ToolCalls,
471                Some("content_filter") => FinishReason::ContentFilter,
472                _ => FinishReason::Stop,
473            }
474        };
475
476        let usage = completion.usage.as_ref();
477
478        Ok(CompletionResponse {
479            message: Message {
480                role: Role::Assistant,
481                content: content_parts,
482            },
483            usage: Usage {
484                prompt_tokens: usage.map(|u| u.prompt_tokens).unwrap_or(0),
485                completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
486                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
487                cache_read_tokens: None,
488                cache_write_tokens: None,
489            },
490            finish_reason,
491        })
492    }
493
494    async fn complete_stream(
495        &self,
496        request: CompletionRequest,
497    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
498        // Fall back to non-streaming for now
499        let response = self.complete(request).await?;
500        let text = response
501            .message
502            .content
503            .iter()
504            .filter_map(|p| match p {
505                ContentPart::Text { text } => Some(text.clone()),
506                _ => None,
507            })
508            .collect::<Vec<_>>()
509            .join("");
510
511        Ok(Box::pin(futures::stream::once(async move {
512            StreamChunk::Text(text)
513        })))
514    }
515}