ricecoder_providers/providers/
google.rs

1//! Google Gemini provider implementation
2//!
3//! Supports Gemini models via the Google AI API.
4
5use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tracing::{debug, error, warn};
10
11use crate::error::ProviderError;
12use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
13use crate::provider::Provider;
14use crate::token_counter::TokenCounter;
15
16/// Google Gemini provider implementation
17pub struct GoogleProvider {
18    api_key: String,
19    client: Arc<Client>,
20    base_url: String,
21    token_counter: Arc<TokenCounter>,
22}
23
24impl GoogleProvider {
25    /// Create a new Google provider instance
26    pub fn new(api_key: String) -> Result<Self, ProviderError> {
27        if api_key.is_empty() {
28            return Err(ProviderError::ConfigError(
29                "Google API key is required".to_string(),
30            ));
31        }
32
33        Ok(Self {
34            api_key,
35            client: Arc::new(Client::new()),
36            base_url: "https://generativelanguage.googleapis.com/v1beta/models".to_string(),
37            token_counter: Arc::new(TokenCounter::new()),
38        })
39    }
40
41    /// Create a new Google provider with a custom base URL
42    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self, ProviderError> {
43        if api_key.is_empty() {
44            return Err(ProviderError::ConfigError(
45                "Google API key is required".to_string(),
46            ));
47        }
48
49        Ok(Self {
50            api_key,
51            client: Arc::new(Client::new()),
52            base_url,
53            token_counter: Arc::new(TokenCounter::new()),
54        })
55    }
56
57    /// Convert Google API response to our ChatResponse
58    fn convert_response(
59        response: GoogleChatResponse,
60        model: String,
61    ) -> Result<ChatResponse, ProviderError> {
62        let content = response
63            .candidates
64            .first()
65            .and_then(|c| c.content.as_ref())
66            .and_then(|c| c.parts.first())
67            .map(|p| p.text.clone())
68            .ok_or_else(|| ProviderError::ProviderError("No content in response".to_string()))?;
69
70        let finish_reason = response
71            .candidates
72            .first()
73            .and_then(|c| c.finish_reason.as_deref())
74            .map(|reason| match reason {
75                "STOP" => FinishReason::Stop,
76                "MAX_TOKENS" => FinishReason::Length,
77                "ERROR" => FinishReason::Error,
78                _ => FinishReason::Stop,
79            })
80            .unwrap_or(FinishReason::Stop);
81
82        // Google API doesn't always return usage info, so we estimate
83        let total_tokens = response
84            .usage_metadata
85            .as_ref()
86            .map(|u| u.total_token_count)
87            .unwrap_or(0);
88
89        let prompt_tokens = response
90            .usage_metadata
91            .as_ref()
92            .map(|u| u.prompt_token_count)
93            .unwrap_or(0);
94
95        let completion_tokens = response
96            .usage_metadata
97            .as_ref()
98            .map(|u| u.candidates_token_count)
99            .unwrap_or(0);
100
101        Ok(ChatResponse {
102            content,
103            model,
104            usage: TokenUsage {
105                prompt_tokens,
106                completion_tokens,
107                total_tokens,
108            },
109            finish_reason,
110        })
111    }
112}
113
114#[async_trait]
115impl Provider for GoogleProvider {
116    fn id(&self) -> &str {
117        "google"
118    }
119
120    fn name(&self) -> &str {
121        "Google"
122    }
123
124    fn models(&self) -> Vec<ModelInfo> {
125        vec![
126            ModelInfo {
127                id: "gemini-2.0-flash".to_string(),
128                name: "Gemini 2.0 Flash".to_string(),
129                provider: "google".to_string(),
130                context_window: 1000000,
131                capabilities: vec![
132                    Capability::Chat,
133                    Capability::Code,
134                    Capability::Vision,
135                    Capability::Streaming,
136                ],
137                pricing: Some(crate::models::Pricing {
138                    input_per_1k_tokens: 0.075,
139                    output_per_1k_tokens: 0.3,
140                }),
141            },
142            ModelInfo {
143                id: "gemini-1.5-pro".to_string(),
144                name: "Gemini 1.5 Pro".to_string(),
145                provider: "google".to_string(),
146                context_window: 2000000,
147                capabilities: vec![
148                    Capability::Chat,
149                    Capability::Code,
150                    Capability::Vision,
151                    Capability::Streaming,
152                ],
153                pricing: Some(crate::models::Pricing {
154                    input_per_1k_tokens: 1.25,
155                    output_per_1k_tokens: 5.0,
156                }),
157            },
158            ModelInfo {
159                id: "gemini-1.5-flash".to_string(),
160                name: "Gemini 1.5 Flash".to_string(),
161                provider: "google".to_string(),
162                context_window: 1000000,
163                capabilities: vec![
164                    Capability::Chat,
165                    Capability::Code,
166                    Capability::Vision,
167                    Capability::Streaming,
168                ],
169                pricing: Some(crate::models::Pricing {
170                    input_per_1k_tokens: 0.075,
171                    output_per_1k_tokens: 0.3,
172                }),
173            },
174            ModelInfo {
175                id: "gemini-1.0-pro".to_string(),
176                name: "Gemini 1.0 Pro".to_string(),
177                provider: "google".to_string(),
178                context_window: 32000,
179                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
180                pricing: Some(crate::models::Pricing {
181                    input_per_1k_tokens: 0.5,
182                    output_per_1k_tokens: 1.5,
183                }),
184            },
185        ]
186    }
187
188    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
189        // Validate model
190        let model_id = &request.model;
191        if !self.models().iter().any(|m| m.id == *model_id) {
192            return Err(ProviderError::InvalidModel(model_id.clone()));
193        }
194
195        let google_request = GoogleChatRequest {
196            contents: vec![GoogleContent {
197                role: "user".to_string(),
198                parts: request
199                    .messages
200                    .iter()
201                    .map(|m| GooglePart {
202                        text: m.content.clone(),
203                    })
204                    .collect(),
205            }],
206            generation_config: Some(GoogleGenerationConfig {
207                temperature: request.temperature,
208                max_output_tokens: request.max_tokens,
209            }),
210        };
211
212        debug!(
213            "Sending chat request to Google for model: {}",
214            request.model
215        );
216
217        let url = format!("{}:generateContent?key={}", self.base_url, self.api_key);
218
219        let response = self
220            .client
221            .post(&url)
222            .header("Content-Type", "application/json")
223            .json(&google_request)
224            .send()
225            .await
226            .map_err(|e| {
227                error!("Google API request failed: {}", e);
228                ProviderError::from(e)
229            })?;
230
231        let status = response.status();
232        if !status.is_success() {
233            let error_text = response.text().await.unwrap_or_default();
234            error!("Google API error ({}): {}", status, error_text);
235
236            return match status.as_u16() {
237                401 | 403 => Err(ProviderError::AuthError),
238                429 => Err(ProviderError::RateLimited(60)),
239                _ => Err(ProviderError::ProviderError(format!(
240                    "Google API error: {}",
241                    status
242                ))),
243            };
244        }
245
246        let google_response: GoogleChatResponse = response.json().await?;
247        Self::convert_response(google_response, request.model)
248    }
249
250    async fn chat_stream(
251        &self,
252        _request: ChatRequest,
253    ) -> Result<crate::provider::ChatStream, ProviderError> {
254        // Streaming support will be implemented in a future iteration
255        Err(ProviderError::ProviderError(
256            "Streaming not yet implemented for Google".to_string(),
257        ))
258    }
259
260    fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
261        // Validate model
262        if !self.models().iter().any(|m| m.id == model) {
263            return Err(ProviderError::InvalidModel(model.to_string()));
264        }
265
266        // Use token counter with caching for performance
267        let tokens = self.token_counter.count_tokens_openai(content, model);
268        Ok(tokens)
269    }
270
271    async fn health_check(&self) -> Result<bool, ProviderError> {
272        debug!("Performing health check for Google provider");
273
274        // Try to list models as a health check
275        let url = format!("{}?key={}", self.base_url, self.api_key);
276
277        let response = self.client.get(&url).send().await.map_err(|e| {
278            warn!("Google health check failed: {}", e);
279            ProviderError::from(e)
280        })?;
281
282        match response.status().as_u16() {
283            200 => {
284                debug!("Google health check passed");
285                Ok(true)
286            }
287            401 | 403 => {
288                error!("Google health check failed: authentication error");
289                Err(ProviderError::AuthError)
290            }
291            _ => {
292                warn!(
293                    "Google health check failed with status: {}",
294                    response.status()
295                );
296                Ok(false)
297            }
298        }
299    }
300}
301
302/// Google API request format
303#[derive(Debug, Serialize)]
304struct GoogleChatRequest {
305    contents: Vec<GoogleContent>,
306    #[serde(skip_serializing_if = "Option::is_none")]
307    generation_config: Option<GoogleGenerationConfig>,
308}
309
310/// Google API content format
311#[derive(Debug, Serialize, Deserialize)]
312struct GoogleContent {
313    role: String,
314    parts: Vec<GooglePart>,
315}
316
317/// Google API part format
318#[derive(Debug, Serialize, Deserialize)]
319struct GooglePart {
320    text: String,
321}
322
323/// Google API generation config
324#[derive(Debug, Serialize)]
325struct GoogleGenerationConfig {
326    #[serde(skip_serializing_if = "Option::is_none")]
327    temperature: Option<f32>,
328    #[serde(skip_serializing_if = "Option::is_none")]
329    max_output_tokens: Option<usize>,
330}
331
332/// Google API response format
333#[derive(Debug, Deserialize)]
334struct GoogleChatResponse {
335    candidates: Vec<GoogleCandidate>,
336    #[serde(default)]
337    usage_metadata: Option<GoogleUsageMetadata>,
338}
339
340/// Google API candidate format
341#[derive(Debug, Deserialize)]
342struct GoogleCandidate {
343    content: Option<GoogleContent>,
344    finish_reason: Option<String>,
345}
346
347/// Google API usage metadata
348#[derive(Debug, Deserialize)]
349struct GoogleUsageMetadata {
350    prompt_token_count: usize,
351    candidates_token_count: usize,
352    total_token_count: usize,
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_google_provider_creation() {
361        let provider = GoogleProvider::new("test-key".to_string());
362        assert!(provider.is_ok());
363    }
364
365    #[test]
366    fn test_google_provider_creation_empty_key() {
367        let provider = GoogleProvider::new("".to_string());
368        assert!(provider.is_err());
369    }
370
371    #[test]
372    fn test_google_provider_id() {
373        let provider = GoogleProvider::new("test-key".to_string()).unwrap();
374        assert_eq!(provider.id(), "google");
375    }
376
377    #[test]
378    fn test_google_provider_name() {
379        let provider = GoogleProvider::new("test-key".to_string()).unwrap();
380        assert_eq!(provider.name(), "Google");
381    }
382
383    #[test]
384    fn test_google_models() {
385        let provider = GoogleProvider::new("test-key".to_string()).unwrap();
386        let models = provider.models();
387        assert_eq!(models.len(), 4);
388        assert!(models.iter().any(|m| m.id == "gemini-2.0-flash"));
389        assert!(models.iter().any(|m| m.id == "gemini-1.5-pro"));
390        assert!(models.iter().any(|m| m.id == "gemini-1.5-flash"));
391        assert!(models.iter().any(|m| m.id == "gemini-1.0-pro"));
392    }
393
394    #[test]
395    fn test_token_counting() {
396        let provider = GoogleProvider::new("test-key".to_string()).unwrap();
397        let tokens = provider
398            .count_tokens("Hello, world!", "gemini-1.5-pro")
399            .unwrap();
400        assert!(tokens > 0);
401    }
402
403    #[test]
404    fn test_token_counting_invalid_model() {
405        let provider = GoogleProvider::new("test-key".to_string()).unwrap();
406        let result = provider.count_tokens("Hello, world!", "invalid-model");
407        assert!(result.is_err());
408    }
409}