1use ceres_core::HttpConfig;
25use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
26use reqwest::Client;
27use serde::{Deserialize, Serialize};
28
29#[derive(Clone)]
53pub struct GeminiClient {
54 client: Client,
55 api_key: String,
56}
57
58#[derive(Serialize)]
60struct EmbeddingRequest {
61 model: String,
62 content: Content,
63 #[serde(skip_serializing_if = "Option::is_none")]
66 output_dimensionality: Option<usize>,
67}
68
69#[derive(Serialize)]
70struct Content {
71 parts: Vec<Part>,
72}
73
74#[derive(Serialize)]
75struct Part {
76 text: String,
77}
78
79#[derive(Deserialize)]
81struct EmbeddingResponse {
82 embedding: EmbeddingData,
83}
84
85#[derive(Deserialize)]
86struct EmbeddingData {
87 values: Vec<f32>,
88}
89
90#[derive(Serialize)]
92struct BatchEmbeddingRequest {
93 requests: Vec<EmbeddingRequest>,
94}
95
96#[derive(Deserialize)]
98struct BatchEmbeddingResponse {
99 embeddings: Vec<EmbeddingData>,
100}
101
102#[derive(Deserialize)]
104struct GeminiError {
105 error: GeminiErrorDetail,
106}
107
108#[derive(Deserialize)]
109struct GeminiErrorDetail {
110 message: String,
111 #[allow(dead_code)]
112 status: Option<String>,
113}
114
115fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
117 match status_code {
118 401 => GeminiErrorKind::Authentication,
119 429 => {
120 if message.contains("insufficient_quota") || message.contains("quota") {
122 GeminiErrorKind::QuotaExceeded
123 } else {
124 GeminiErrorKind::RateLimit
125 }
126 }
127 500..=599 => GeminiErrorKind::ServerError,
128 _ => {
129 if message.contains("API key") || message.contains("Unauthorized") {
131 GeminiErrorKind::Authentication
132 } else if message.contains("rate") {
133 GeminiErrorKind::RateLimit
134 } else if message.contains("quota") {
135 GeminiErrorKind::QuotaExceeded
136 } else {
137 GeminiErrorKind::Unknown
138 }
139 }
140 }
141}
142
143fn map_send_error(e: reqwest::Error) -> AppError {
145 if e.is_timeout() {
146 AppError::Timeout(30)
147 } else if e.is_connect() {
148 AppError::GeminiError(GeminiErrorDetails::new(
149 GeminiErrorKind::NetworkError,
150 format!("Connection failed: {}", e),
151 0,
152 ))
153 } else {
154 AppError::ClientError(e.to_string())
155 }
156}
157
158async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, AppError> {
160 let status = response.status();
161 if !status.is_success() {
162 let status_code = status.as_u16();
163 let error_text = response.text().await.unwrap_or_default();
164
165 let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text) {
166 gemini_error.error.message
167 } else {
168 format!("HTTP {}: {}", status_code, error_text)
169 };
170
171 let kind = classify_gemini_error(status_code, &message);
172
173 return Err(AppError::GeminiError(GeminiErrorDetails::new(
174 kind,
175 message,
176 status_code,
177 )));
178 }
179 Ok(response)
180}
181
182impl GeminiClient {
183 pub fn new(api_key: &str) -> Result<Self, AppError> {
185 let http_config = HttpConfig::default();
186 let client = Client::builder()
187 .timeout(http_config.timeout)
188 .build()
189 .map_err(|e| AppError::ClientError(e.to_string()))?;
190
191 Ok(Self {
192 client,
193 api_key: api_key.to_string(),
194 })
195 }
196
197 pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
215 let sanitized_text = text.replace('\n', " ");
217
218 let url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:embedContent";
219
220 let request_body = EmbeddingRequest {
221 model: "models/gemini-embedding-001".to_string(),
222 content: Content {
223 parts: vec![Part {
224 text: sanitized_text,
225 }],
226 },
227 output_dimensionality: Some(768),
228 };
229
230 let response = self
231 .client
232 .post(url)
233 .header("x-goog-api-key", self.api_key.as_str())
234 .json(&request_body)
235 .send()
236 .await
237 .map_err(map_send_error)?;
238
239 let response = check_response(response).await?;
240
241 let embedding_response: EmbeddingResponse = response
242 .json()
243 .await
244 .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
245
246 Ok(embedding_response.embedding.values)
247 }
248
249 pub async fn get_embeddings_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, AppError> {
261 if texts.is_empty() {
262 return Ok(Vec::new());
263 }
264
265 let url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents";
266
267 let requests: Vec<EmbeddingRequest> = texts
268 .iter()
269 .map(|text| EmbeddingRequest {
270 model: "models/gemini-embedding-001".to_string(),
271 content: Content {
272 parts: vec![Part {
273 text: text.replace('\n', " "),
274 }],
275 },
276 output_dimensionality: Some(768),
277 })
278 .collect();
279
280 let request_body = BatchEmbeddingRequest { requests };
281
282 let response = self
283 .client
284 .post(url)
285 .header("x-goog-api-key", self.api_key.as_str())
286 .json(&request_body)
287 .send()
288 .await
289 .map_err(map_send_error)?;
290
291 let response = check_response(response).await?;
292
293 let batch_response: BatchEmbeddingResponse = response
294 .json()
295 .await
296 .map_err(|e| AppError::ClientError(format!("Failed to parse batch response: {}", e)))?;
297
298 if batch_response.embeddings.len() != texts.len() {
299 return Err(AppError::ClientError(format!(
300 "Batch embedding count mismatch: expected {}, got {}",
301 texts.len(),
302 batch_response.embeddings.len()
303 )));
304 }
305
306 Ok(batch_response
307 .embeddings
308 .into_iter()
309 .map(|e| e.values)
310 .collect())
311 }
312}
313
314impl ceres_core::traits::EmbeddingProvider for GeminiClient {
319 fn name(&self) -> &'static str {
320 "gemini"
321 }
322
323 fn dimension(&self) -> usize {
324 768
326 }
327
328 fn max_batch_size(&self) -> usize {
329 100 }
331
332 async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
333 self.get_embeddings(text).await
334 }
335
336 async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
337 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
338 self.get_embeddings_batch(&text_refs).await
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_new_client() {
348 let client = GeminiClient::new("test-api-key");
349 assert!(client.is_ok());
350 }
351
352 #[test]
353 fn test_text_sanitization() {
354 let text_with_newlines = "Line 1\nLine 2\nLine 3";
355 let sanitized = text_with_newlines.replace('\n', " ");
356 assert_eq!(sanitized, "Line 1 Line 2 Line 3");
357 }
358
359 #[test]
360 fn test_request_serialization() {
361 let request = EmbeddingRequest {
362 model: "models/gemini-embedding-001".to_string(),
363 content: Content {
364 parts: vec![Part {
365 text: "Hello world".to_string(),
366 }],
367 },
368 output_dimensionality: Some(768),
369 };
370
371 let json = serde_json::to_string(&request).unwrap();
372 assert!(json.contains("gemini-embedding-001"));
373 assert!(json.contains("Hello world"));
374 assert!(json.contains("output_dimensionality"));
375 }
376
377 #[test]
378 fn test_classify_gemini_error_auth() {
379 let kind = classify_gemini_error(401, "Invalid API key");
380 assert_eq!(kind, GeminiErrorKind::Authentication);
381 }
382
383 #[test]
384 fn test_classify_gemini_error_auth_from_message() {
385 let kind = classify_gemini_error(400, "API key not valid");
386 assert_eq!(kind, GeminiErrorKind::Authentication);
387 }
388
389 #[test]
390 fn test_classify_gemini_error_rate_limit() {
391 let kind = classify_gemini_error(429, "Rate limit exceeded");
392 assert_eq!(kind, GeminiErrorKind::RateLimit);
393 }
394
395 #[test]
396 fn test_classify_gemini_error_quota() {
397 let kind = classify_gemini_error(429, "insufficient_quota");
398 assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
399 }
400
401 #[test]
402 fn test_classify_gemini_error_server() {
403 let kind = classify_gemini_error(500, "Internal server error");
404 assert_eq!(kind, GeminiErrorKind::ServerError);
405 }
406
407 #[test]
408 fn test_classify_gemini_error_server_503() {
409 let kind = classify_gemini_error(503, "Service unavailable");
410 assert_eq!(kind, GeminiErrorKind::ServerError);
411 }
412
413 #[test]
414 fn test_classify_gemini_error_unknown() {
415 let kind = classify_gemini_error(400, "Bad request");
416 assert_eq!(kind, GeminiErrorKind::Unknown);
417 }
418
419 #[test]
420 fn test_batch_request_serialization() {
421 let request = BatchEmbeddingRequest {
422 requests: vec![
423 EmbeddingRequest {
424 model: "models/gemini-embedding-001".to_string(),
425 content: Content {
426 parts: vec![Part {
427 text: "First text".to_string(),
428 }],
429 },
430 output_dimensionality: Some(768),
431 },
432 EmbeddingRequest {
433 model: "models/gemini-embedding-001".to_string(),
434 content: Content {
435 parts: vec![Part {
436 text: "Second text".to_string(),
437 }],
438 },
439 output_dimensionality: Some(768),
440 },
441 ],
442 };
443
444 let json = serde_json::to_string(&request).unwrap();
445 assert!(json.contains("requests"));
446 assert!(json.contains("First text"));
447 assert!(json.contains("Second text"));
448
449 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
451 let requests = parsed["requests"].as_array().unwrap();
452 assert_eq!(requests.len(), 2);
453 assert_eq!(requests[0]["model"], "models/gemini-embedding-001");
454 assert_eq!(requests[0]["output_dimensionality"], 768);
455 }
456
457 #[test]
458 fn test_batch_response_deserialization() {
459 let json = r#"{
460 "embeddings": [
461 { "values": [0.1, 0.2, 0.3] },
462 { "values": [0.4, 0.5, 0.6] }
463 ]
464 }"#;
465
466 let response: BatchEmbeddingResponse = serde_json::from_str(json).unwrap();
467 assert_eq!(response.embeddings.len(), 2);
468 assert_eq!(response.embeddings[0].values, vec![0.1, 0.2, 0.3]);
469 assert_eq!(response.embeddings[1].values, vec![0.4, 0.5, 0.6]);
470 }
471}