ricecoder_providers/providers/
google.rs1use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tracing::{debug, error, warn};
10
11use crate::error::ProviderError;
12use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
13use crate::provider::Provider;
14use crate::token_counter::TokenCounter;
15
16pub struct GoogleProvider {
18 api_key: String,
19 client: Arc<Client>,
20 base_url: String,
21 token_counter: Arc<TokenCounter>,
22}
23
24impl GoogleProvider {
25 pub fn new(api_key: String) -> Result<Self, ProviderError> {
27 if api_key.is_empty() {
28 return Err(ProviderError::ConfigError(
29 "Google API key is required".to_string(),
30 ));
31 }
32
33 Ok(Self {
34 api_key,
35 client: Arc::new(Client::new()),
36 base_url: "https://generativelanguage.googleapis.com/v1beta/models".to_string(),
37 token_counter: Arc::new(TokenCounter::new()),
38 })
39 }
40
41 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self, ProviderError> {
43 if api_key.is_empty() {
44 return Err(ProviderError::ConfigError(
45 "Google API key is required".to_string(),
46 ));
47 }
48
49 Ok(Self {
50 api_key,
51 client: Arc::new(Client::new()),
52 base_url,
53 token_counter: Arc::new(TokenCounter::new()),
54 })
55 }
56
57 fn convert_response(
59 response: GoogleChatResponse,
60 model: String,
61 ) -> Result<ChatResponse, ProviderError> {
62 let content = response
63 .candidates
64 .first()
65 .and_then(|c| c.content.as_ref())
66 .and_then(|c| c.parts.first())
67 .map(|p| p.text.clone())
68 .ok_or_else(|| ProviderError::ProviderError("No content in response".to_string()))?;
69
70 let finish_reason = response
71 .candidates
72 .first()
73 .and_then(|c| c.finish_reason.as_deref())
74 .map(|reason| match reason {
75 "STOP" => FinishReason::Stop,
76 "MAX_TOKENS" => FinishReason::Length,
77 "ERROR" => FinishReason::Error,
78 _ => FinishReason::Stop,
79 })
80 .unwrap_or(FinishReason::Stop);
81
82 let total_tokens = response
84 .usage_metadata
85 .as_ref()
86 .map(|u| u.total_token_count)
87 .unwrap_or(0);
88
89 let prompt_tokens = response
90 .usage_metadata
91 .as_ref()
92 .map(|u| u.prompt_token_count)
93 .unwrap_or(0);
94
95 let completion_tokens = response
96 .usage_metadata
97 .as_ref()
98 .map(|u| u.candidates_token_count)
99 .unwrap_or(0);
100
101 Ok(ChatResponse {
102 content,
103 model,
104 usage: TokenUsage {
105 prompt_tokens,
106 completion_tokens,
107 total_tokens,
108 },
109 finish_reason,
110 })
111 }
112}
113
114#[async_trait]
115impl Provider for GoogleProvider {
116 fn id(&self) -> &str {
117 "google"
118 }
119
120 fn name(&self) -> &str {
121 "Google"
122 }
123
124 fn models(&self) -> Vec<ModelInfo> {
125 vec![
126 ModelInfo {
127 id: "gemini-2.0-flash".to_string(),
128 name: "Gemini 2.0 Flash".to_string(),
129 provider: "google".to_string(),
130 context_window: 1000000,
131 capabilities: vec![
132 Capability::Chat,
133 Capability::Code,
134 Capability::Vision,
135 Capability::Streaming,
136 ],
137 pricing: Some(crate::models::Pricing {
138 input_per_1k_tokens: 0.075,
139 output_per_1k_tokens: 0.3,
140 }),
141 },
142 ModelInfo {
143 id: "gemini-1.5-pro".to_string(),
144 name: "Gemini 1.5 Pro".to_string(),
145 provider: "google".to_string(),
146 context_window: 2000000,
147 capabilities: vec![
148 Capability::Chat,
149 Capability::Code,
150 Capability::Vision,
151 Capability::Streaming,
152 ],
153 pricing: Some(crate::models::Pricing {
154 input_per_1k_tokens: 1.25,
155 output_per_1k_tokens: 5.0,
156 }),
157 },
158 ModelInfo {
159 id: "gemini-1.5-flash".to_string(),
160 name: "Gemini 1.5 Flash".to_string(),
161 provider: "google".to_string(),
162 context_window: 1000000,
163 capabilities: vec![
164 Capability::Chat,
165 Capability::Code,
166 Capability::Vision,
167 Capability::Streaming,
168 ],
169 pricing: Some(crate::models::Pricing {
170 input_per_1k_tokens: 0.075,
171 output_per_1k_tokens: 0.3,
172 }),
173 },
174 ModelInfo {
175 id: "gemini-1.0-pro".to_string(),
176 name: "Gemini 1.0 Pro".to_string(),
177 provider: "google".to_string(),
178 context_window: 32000,
179 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
180 pricing: Some(crate::models::Pricing {
181 input_per_1k_tokens: 0.5,
182 output_per_1k_tokens: 1.5,
183 }),
184 },
185 ]
186 }
187
188 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
189 let model_id = &request.model;
191 if !self.models().iter().any(|m| m.id == *model_id) {
192 return Err(ProviderError::InvalidModel(model_id.clone()));
193 }
194
195 let google_request = GoogleChatRequest {
196 contents: vec![GoogleContent {
197 role: "user".to_string(),
198 parts: request
199 .messages
200 .iter()
201 .map(|m| GooglePart {
202 text: m.content.clone(),
203 })
204 .collect(),
205 }],
206 generation_config: Some(GoogleGenerationConfig {
207 temperature: request.temperature,
208 max_output_tokens: request.max_tokens,
209 }),
210 };
211
212 debug!(
213 "Sending chat request to Google for model: {}",
214 request.model
215 );
216
217 let url = format!("{}:generateContent?key={}", self.base_url, self.api_key);
218
219 let response = self
220 .client
221 .post(&url)
222 .header("Content-Type", "application/json")
223 .json(&google_request)
224 .send()
225 .await
226 .map_err(|e| {
227 error!("Google API request failed: {}", e);
228 ProviderError::from(e)
229 })?;
230
231 let status = response.status();
232 if !status.is_success() {
233 let error_text = response.text().await.unwrap_or_default();
234 error!("Google API error ({}): {}", status, error_text);
235
236 return match status.as_u16() {
237 401 | 403 => Err(ProviderError::AuthError),
238 429 => Err(ProviderError::RateLimited(60)),
239 _ => Err(ProviderError::ProviderError(format!(
240 "Google API error: {}",
241 status
242 ))),
243 };
244 }
245
246 let google_response: GoogleChatResponse = response.json().await?;
247 Self::convert_response(google_response, request.model)
248 }
249
250 async fn chat_stream(
251 &self,
252 _request: ChatRequest,
253 ) -> Result<crate::provider::ChatStream, ProviderError> {
254 Err(ProviderError::ProviderError(
256 "Streaming not yet implemented for Google".to_string(),
257 ))
258 }
259
260 fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
261 if !self.models().iter().any(|m| m.id == model) {
263 return Err(ProviderError::InvalidModel(model.to_string()));
264 }
265
266 let tokens = self.token_counter.count_tokens_openai(content, model);
268 Ok(tokens)
269 }
270
271 async fn health_check(&self) -> Result<bool, ProviderError> {
272 debug!("Performing health check for Google provider");
273
274 let url = format!("{}?key={}", self.base_url, self.api_key);
276
277 let response = self.client.get(&url).send().await.map_err(|e| {
278 warn!("Google health check failed: {}", e);
279 ProviderError::from(e)
280 })?;
281
282 match response.status().as_u16() {
283 200 => {
284 debug!("Google health check passed");
285 Ok(true)
286 }
287 401 | 403 => {
288 error!("Google health check failed: authentication error");
289 Err(ProviderError::AuthError)
290 }
291 _ => {
292 warn!(
293 "Google health check failed with status: {}",
294 response.status()
295 );
296 Ok(false)
297 }
298 }
299 }
300}
301
302#[derive(Debug, Serialize)]
304struct GoogleChatRequest {
305 contents: Vec<GoogleContent>,
306 #[serde(skip_serializing_if = "Option::is_none")]
307 generation_config: Option<GoogleGenerationConfig>,
308}
309
310#[derive(Debug, Serialize, Deserialize)]
312struct GoogleContent {
313 role: String,
314 parts: Vec<GooglePart>,
315}
316
317#[derive(Debug, Serialize, Deserialize)]
319struct GooglePart {
320 text: String,
321}
322
323#[derive(Debug, Serialize)]
325struct GoogleGenerationConfig {
326 #[serde(skip_serializing_if = "Option::is_none")]
327 temperature: Option<f32>,
328 #[serde(skip_serializing_if = "Option::is_none")]
329 max_output_tokens: Option<usize>,
330}
331
332#[derive(Debug, Deserialize)]
334struct GoogleChatResponse {
335 candidates: Vec<GoogleCandidate>,
336 #[serde(default)]
337 usage_metadata: Option<GoogleUsageMetadata>,
338}
339
340#[derive(Debug, Deserialize)]
342struct GoogleCandidate {
343 content: Option<GoogleContent>,
344 finish_reason: Option<String>,
345}
346
347#[derive(Debug, Deserialize)]
349struct GoogleUsageMetadata {
350 prompt_token_count: usize,
351 candidates_token_count: usize,
352 total_token_count: usize,
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_google_provider_creation() {
361 let provider = GoogleProvider::new("test-key".to_string());
362 assert!(provider.is_ok());
363 }
364
365 #[test]
366 fn test_google_provider_creation_empty_key() {
367 let provider = GoogleProvider::new("".to_string());
368 assert!(provider.is_err());
369 }
370
371 #[test]
372 fn test_google_provider_id() {
373 let provider = GoogleProvider::new("test-key".to_string()).unwrap();
374 assert_eq!(provider.id(), "google");
375 }
376
377 #[test]
378 fn test_google_provider_name() {
379 let provider = GoogleProvider::new("test-key".to_string()).unwrap();
380 assert_eq!(provider.name(), "Google");
381 }
382
383 #[test]
384 fn test_google_models() {
385 let provider = GoogleProvider::new("test-key".to_string()).unwrap();
386 let models = provider.models();
387 assert_eq!(models.len(), 4);
388 assert!(models.iter().any(|m| m.id == "gemini-2.0-flash"));
389 assert!(models.iter().any(|m| m.id == "gemini-1.5-pro"));
390 assert!(models.iter().any(|m| m.id == "gemini-1.5-flash"));
391 assert!(models.iter().any(|m| m.id == "gemini-1.0-pro"));
392 }
393
394 #[test]
395 fn test_token_counting() {
396 let provider = GoogleProvider::new("test-key".to_string()).unwrap();
397 let tokens = provider
398 .count_tokens("Hello, world!", "gemini-1.5-pro")
399 .unwrap();
400 assert!(tokens > 0);
401 }
402
403 #[test]
404 fn test_token_counting_invalid_model() {
405 let provider = GoogleProvider::new("test-key".to_string()).unwrap();
406 let result = provider.count_tokens("Hello, world!", "invalid-model");
407 assert!(result.is_err());
408 }
409}