Skip to main content

codetether_agent/provider/
google.rs

1//! Google Gemini provider implementation
2//!
3//! Uses the Google AI Gemini OpenAI-compatible endpoint for simplicity.
4//! Reference: https://ai.google.dev/gemini-api/docs/openai
5
6use super::util;
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
18
19pub struct GoogleProvider {
20    client: Client,
21    api_key: String,
22}
23
24impl std::fmt::Debug for GoogleProvider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("GoogleProvider")
27            .field("api_key", &"<REDACTED>")
28            .field("api_key_len", &self.api_key.len())
29            .finish()
30    }
31}
32
33impl GoogleProvider {
34    pub fn new(api_key: String) -> Result<Self> {
35        tracing::debug!(
36            provider = "google",
37            api_key_len = api_key.len(),
38            "Creating Google Gemini provider"
39        );
40        Ok(Self {
41            client: crate::provider::shared_http::shared_client().clone(),
42            api_key,
43        })
44    }
45
46    fn validate_api_key(&self) -> Result<()> {
47        if self.api_key.is_empty() {
48            anyhow::bail!("Google API key is empty");
49        }
50        Ok(())
51    }
52
53    fn convert_messages(messages: &[Message]) -> Vec<Value> {
54        messages
55            .iter()
56            .map(|msg| {
57                let role = match msg.role {
58                    Role::System => "system",
59                    Role::User => "user",
60                    Role::Assistant => "assistant",
61                    Role::Tool => "tool",
62                };
63
64                // For tool messages, we need to produce one message per tool result
65                if msg.role == Role::Tool {
66                    let mut content_parts: Vec<Value> = Vec::new();
67                    let mut tool_call_id = None;
68                    for part in &msg.content {
69                        match part {
70                            ContentPart::ToolResult {
71                                tool_call_id: id,
72                                content,
73                            } => {
74                                tool_call_id = Some(id.clone());
75                                content_parts.push(json!(content));
76                            }
77                            ContentPart::Text { text } => {
78                                content_parts.push(json!(text));
79                            }
80                            _ => {}
81                        }
82                    }
83                    let content_str = content_parts
84                        .iter()
85                        .filter_map(|v| v.as_str())
86                        .collect::<Vec<_>>()
87                        .join("\n");
88                    let mut m = json!({
89                        "role": "tool",
90                        "content": content_str,
91                    });
92                    if let Some(id) = tool_call_id {
93                        m["tool_call_id"] = json!(id);
94                    }
95                    return m;
96                }
97
98                // For assistant messages with tool calls
99                if msg.role == Role::Assistant {
100                    let mut text_parts = Vec::new();
101                    let mut tool_calls = Vec::new();
102                    for part in &msg.content {
103                        match part {
104                            ContentPart::Text { text } => {
105                                if !text.is_empty() {
106                                    text_parts.push(text.clone());
107                                }
108                            }
109                            ContentPart::ToolCall {
110                                id,
111                                name,
112                                arguments,
113                                thought_signature,
114                            } => {
115                                let mut tc = json!({
116                                    "id": id,
117                                    "type": "function",
118                                    "function": {
119                                        "name": name,
120                                        "arguments": arguments
121                                    }
122                                });
123                                // Include thought signature for Gemini 3.x models
124                                if let Some(sig) = thought_signature {
125                                    tc["extra_content"] = json!({
126                                        "google": {
127                                            "thought_signature": sig
128                                        }
129                                    });
130                                }
131                                tool_calls.push(tc);
132                            }
133                            _ => {}
134                        }
135                    }
136                    let content = text_parts.join("\n");
137                    let mut m = json!({"role": "assistant"});
138                    if !content.is_empty() || tool_calls.is_empty() {
139                        m["content"] = json!(content);
140                    }
141                    if !tool_calls.is_empty() {
142                        m["tool_calls"] = json!(tool_calls);
143                    }
144                    return m;
145                }
146
147                let text: String = msg
148                    .content
149                    .iter()
150                    .filter_map(|p| match p {
151                        ContentPart::Text { text } => Some(text.clone()),
152                        _ => None,
153                    })
154                    .collect::<Vec<_>>()
155                    .join("\n");
156
157                json!({
158                    "role": role,
159                    "content": text
160                })
161            })
162            .collect()
163    }
164
165    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
166        tools
167            .iter()
168            .map(|t| {
169                json!({
170                    "type": "function",
171                    "function": {
172                        "name": t.name,
173                        "description": t.description,
174                        "parameters": t.parameters
175                    }
176                })
177            })
178            .collect()
179    }
180}
181
182/// OpenAI-compatible types for parsing Google's response
183
184#[derive(Debug, Deserialize)]
185struct ChatCompletion {
186    #[allow(dead_code)]
187    id: Option<String>,
188    choices: Vec<Choice>,
189    #[serde(default)]
190    usage: Option<ApiUsage>,
191}
192
193#[derive(Debug, Deserialize)]
194struct Choice {
195    message: ChoiceMessage,
196    #[serde(default)]
197    finish_reason: Option<String>,
198}
199
200#[derive(Debug, Deserialize)]
201struct ChoiceMessage {
202    #[allow(dead_code)]
203    role: Option<String>,
204    #[serde(default)]
205    content: Option<String>,
206    #[serde(default)]
207    tool_calls: Option<Vec<ToolCall>>,
208}
209
210#[derive(Debug, Deserialize)]
211struct ToolCall {
212    id: String,
213    function: FunctionCall,
214    /// Thought signature for Gemini 3.x models
215    #[serde(default)]
216    extra_content: Option<ExtraContent>,
217}
218
219#[derive(Debug, Deserialize)]
220struct ExtraContent {
221    google: Option<GoogleExtra>,
222}
223
224#[derive(Debug, Deserialize)]
225struct GoogleExtra {
226    thought_signature: Option<String>,
227}
228
229#[derive(Debug, Deserialize)]
230struct FunctionCall {
231    name: String,
232    arguments: String,
233}
234
235#[derive(Debug, Deserialize)]
236struct ApiUsage {
237    #[serde(default)]
238    prompt_tokens: usize,
239    #[serde(default)]
240    completion_tokens: usize,
241    #[serde(default)]
242    total_tokens: usize,
243    /// Gemini `usageMetadata.cachedContentTokenCount` exposed via the
244    /// OpenAI-compat shim. Counted within `prompt_tokens`; we subtract
245    /// it when populating [`Usage`] so the cost estimator prices it at
246    /// the discounted rate (see [`crate::provider::pricing::cache_read_multiplier`]).
247    #[serde(default, rename = "cached_tokens")]
248    cached_tokens: Option<usize>,
249    #[serde(default, rename = "prompt_tokens_details")]
250    prompt_tokens_details: Option<PromptTokenDetails>,
251}
252
253#[derive(Debug, Deserialize, Default)]
254struct PromptTokenDetails {
255    #[serde(default)]
256    cached_tokens: usize,
257}
258
259impl ApiUsage {
260    fn cached_input_tokens(&self) -> usize {
261        self.cached_tokens.unwrap_or_else(|| {
262            self.prompt_tokens_details
263                .as_ref()
264                .map(|d| d.cached_tokens)
265                .unwrap_or(0)
266        })
267    }
268}
269
270#[derive(Debug, Deserialize)]
271struct ApiError {
272    error: ApiErrorDetail,
273}
274
275#[derive(Debug, Deserialize)]
276struct ApiErrorDetail {
277    message: String,
278}
279
280#[async_trait]
281impl Provider for GoogleProvider {
282    fn name(&self) -> &str {
283        "google"
284    }
285
286    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
287        self.validate_api_key()?;
288
289        Ok(vec![
290            // Gemini 3.x models (require thought signatures for tool calls)
291            ModelInfo {
292                id: "gemini-3.1-pro-preview".to_string(),
293                name: "Gemini 3.1 Pro Preview".to_string(),
294                provider: "google".to_string(),
295                context_window: 1_048_576,
296                max_output_tokens: Some(65_536),
297                supports_vision: true,
298                supports_tools: true,
299                supports_streaming: true,
300                input_cost_per_million: Some(2.0),
301                output_cost_per_million: Some(12.0),
302            },
303            ModelInfo {
304                id: "gemini-3.1-pro-preview-customtools".to_string(),
305                name: "Gemini 3.1 Pro Preview (Custom Tools)".to_string(),
306                provider: "google".to_string(),
307                context_window: 1_048_576,
308                max_output_tokens: Some(65_536),
309                supports_vision: true,
310                supports_tools: true,
311                supports_streaming: true,
312                input_cost_per_million: Some(2.0),
313                output_cost_per_million: Some(12.0),
314            },
315            ModelInfo {
316                id: "gemini-3-pro-preview".to_string(),
317                name: "Gemini 3 Pro Preview".to_string(),
318                provider: "google".to_string(),
319                context_window: 1_048_576,
320                max_output_tokens: Some(65_536),
321                supports_vision: true,
322                supports_tools: true,
323                supports_streaming: true,
324                input_cost_per_million: Some(2.0),
325                output_cost_per_million: Some(12.0),
326            },
327            ModelInfo {
328                id: "gemini-3-flash-preview".to_string(),
329                name: "Gemini 3 Flash Preview".to_string(),
330                provider: "google".to_string(),
331                context_window: 1_048_576,
332                max_output_tokens: Some(65_536),
333                supports_vision: true,
334                supports_tools: true,
335                supports_streaming: true,
336                input_cost_per_million: Some(0.50),
337                output_cost_per_million: Some(3.0),
338            },
339            ModelInfo {
340                id: "gemini-3-pro-image-preview".to_string(),
341                name: "Gemini 3 Pro Image Preview".to_string(),
342                provider: "google".to_string(),
343                context_window: 65_536,
344                max_output_tokens: Some(32_768),
345                supports_vision: true,
346                supports_tools: false,
347                supports_streaming: false,
348                input_cost_per_million: Some(2.0),
349                output_cost_per_million: Some(134.0),
350            },
351            // Gemini 2.5 models
352            ModelInfo {
353                id: "gemini-2.5-pro".to_string(),
354                name: "Gemini 2.5 Pro".to_string(),
355                provider: "google".to_string(),
356                context_window: 1_048_576,
357                max_output_tokens: Some(65_536),
358                supports_vision: true,
359                supports_tools: true,
360                supports_streaming: true,
361                input_cost_per_million: Some(1.25),
362                output_cost_per_million: Some(10.0),
363            },
364            ModelInfo {
365                id: "gemini-2.5-flash".to_string(),
366                name: "Gemini 2.5 Flash".to_string(),
367                provider: "google".to_string(),
368                context_window: 1_048_576,
369                max_output_tokens: Some(65_536),
370                supports_vision: true,
371                supports_tools: true,
372                supports_streaming: true,
373                input_cost_per_million: Some(0.15),
374                output_cost_per_million: Some(0.60),
375            },
376            ModelInfo {
377                id: "gemini-2.0-flash".to_string(),
378                name: "Gemini 2.0 Flash".to_string(),
379                provider: "google".to_string(),
380                context_window: 1_048_576,
381                max_output_tokens: Some(8_192),
382                supports_vision: true,
383                supports_tools: true,
384                supports_streaming: true,
385                input_cost_per_million: Some(0.10),
386                output_cost_per_million: Some(0.40),
387            },
388        ])
389    }
390
391    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
392        tracing::debug!(
393            provider = "google",
394            model = %request.model,
395            message_count = request.messages.len(),
396            tool_count = request.tools.len(),
397            "Starting Google Gemini completion request"
398        );
399
400        self.validate_api_key()?;
401
402        let messages = Self::convert_messages(&request.messages);
403        let tools = Self::convert_tools(&request.tools);
404
405        let mut body = json!({
406            "model": request.model,
407            "messages": messages,
408        });
409
410        if let Some(max_tokens) = request.max_tokens {
411            body["max_tokens"] = json!(max_tokens);
412        }
413        if !tools.is_empty() {
414            body["tools"] = json!(tools);
415        }
416        if let Some(temp) = request.temperature {
417            body["temperature"] = json!(temp);
418        }
419        if let Some(top_p) = request.top_p {
420            body["top_p"] = json!(top_p);
421        }
422
423        tracing::debug!("Google Gemini request to model {}", request.model);
424
425        // Google AI Studio OpenAI-compatible endpoint uses Bearer token auth
426        let url = format!("{}/chat/completions", GOOGLE_OPENAI_BASE);
427        let response = self
428            .client
429            .post(&url)
430            .header("content-type", "application/json")
431            .header("Authorization", format!("Bearer {}", self.api_key))
432            .json(&body)
433            .send()
434            .await
435            .context("Failed to send request to Google Gemini")?;
436
437        let status = response.status();
438        let text = response
439            .text()
440            .await
441            .context("Failed to read Google Gemini response")?;
442
443        if !status.is_success() {
444            if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
445                anyhow::bail!("Google Gemini API error: {}", err.error.message);
446            }
447            anyhow::bail!("Google Gemini API error: {} {}", status, text);
448        }
449
450        let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
451            "Failed to parse Google Gemini response: {}",
452            util::truncate_bytes_safe(&text, 200)
453        ))?;
454
455        let choice = completion
456            .choices
457            .into_iter()
458            .next()
459            .context("No choices in Google Gemini response")?;
460
461        let mut content_parts = Vec::new();
462        let mut has_tool_calls = false;
463
464        if let Some(text) = choice.message.content
465            && !text.is_empty()
466        {
467            content_parts.push(ContentPart::Text { text });
468        }
469
470        if let Some(tool_calls) = choice.message.tool_calls {
471            has_tool_calls = !tool_calls.is_empty();
472            for tc in tool_calls {
473                // Extract thought signature from extra_content.google.thought_signature
474                let thought_signature = tc
475                    .extra_content
476                    .as_ref()
477                    .and_then(|ec| ec.google.as_ref())
478                    .and_then(|g| g.thought_signature.clone());
479
480                content_parts.push(ContentPart::ToolCall {
481                    id: tc.id,
482                    name: tc.function.name,
483                    arguments: tc.function.arguments,
484                    thought_signature,
485                });
486            }
487        }
488
489        let finish_reason = if has_tool_calls {
490            FinishReason::ToolCalls
491        } else {
492            match choice.finish_reason.as_deref() {
493                Some("stop") => FinishReason::Stop,
494                Some("length") => FinishReason::Length,
495                Some("tool_calls") => FinishReason::ToolCalls,
496                Some("content_filter") => FinishReason::ContentFilter,
497                _ => FinishReason::Stop,
498            }
499        };
500
501        let usage = completion.usage.as_ref();
502
503        Ok(CompletionResponse {
504            message: Message {
505                role: Role::Assistant,
506                content: content_parts,
507            },
508            usage: Usage {
509                prompt_tokens: usage
510                    .map(|u| u.prompt_tokens.saturating_sub(u.cached_input_tokens()))
511                    .unwrap_or(0),
512                completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
513                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
514                cache_read_tokens: usage.map(ApiUsage::cached_input_tokens).filter(|&n| n > 0),
515                cache_write_tokens: None,
516            },
517            finish_reason,
518        })
519    }
520
521    async fn complete_stream(
522        &self,
523        request: CompletionRequest,
524    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
525        // Fall back to non-streaming for now
526        let response = self.complete(request).await?;
527        let text = response
528            .message
529            .content
530            .iter()
531            .filter_map(|p| match p {
532                ContentPart::Text { text } => Some(text.clone()),
533                _ => None,
534            })
535            .collect::<Vec<_>>()
536            .join("");
537
538        Ok(Box::pin(futures::stream::once(async move {
539            StreamChunk::Text(text)
540        })))
541    }
542}