Skip to main content

ceres_client/
gemini.rs

1//! Google Gemini embeddings client.
2//!
3//! # Embedding Provider Architecture
4//!
5//! The `EmbeddingProvider` trait (defined in `ceres_core::traits`) was introduced in PR #81,
6//! abstracting over embedding backends. Current implementations: Gemini, OpenAI.
7//! Remaining providers tracked in issue #79:
8//! - Ollama (local embeddings)
9//! - E5-multilingual (local, for cross-language search)
10//!
11//! TODO(observability): Add OpenTelemetry instrumentation for cloud deployment
12//! Use `tracing-opentelemetry` crate to export spans to cloud observability platforms
13//! (AWS X-Ray, GCP Cloud Trace, Azure Monitor). Add `#[instrument]` spans on:
14//! - `get_embeddings()` - track API latency, token counts
15//! - `sync_portal()` in harvest.rs - track harvest duration breakdown
16//! -  This enables "waterfall" visualization showing time spent in each component
17//!    (e.g., 80% Gemini API wait, 20% DB insert).
18//!
19//! TODO(security): Encrypt API keys for multi-tenant deployment
20//! If supporting user-provided Gemini/CKAN API keys, store them encrypted
21//! in the database using `age` or `ring` crates instead of plaintext in .env.
22//! Consider a `api_keys` table with encrypted_key column and per-user isolation.
23
24use ceres_core::HttpConfig;
25use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
26use reqwest::Client;
27use serde::{Deserialize, Serialize};
28
29/// HTTP client for interacting with Google's Gemini Embeddings API.
30///
31/// This client provides methods to generate text embeddings using Google's
32/// gemini-embedding-001 model. Embeddings are vector representations of text
33/// that can be used for semantic search, clustering, and similarity comparisons.
34///
35/// # Security
36///
37/// The API key is securely transmitted via the `x-goog-api-key` HTTP header,
38/// not in the URL, to prevent accidental exposure in logs and proxies.
39///
40/// # Examples
41///
42/// ```no_run
43/// use ceres_client::GeminiClient;
44///
45/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
46/// let client = GeminiClient::new("your-api-key")?;
47/// let embedding = client.get_embeddings("Hello, world!").await?;
48/// println!("Embedding dimension: {}", embedding.len()); // 768
49/// # Ok(())
50/// # }
51/// ```
52#[derive(Clone)]
53pub struct GeminiClient {
54    client: Client,
55    api_key: String,
56}
57
58/// Request body for Gemini embedding API
59#[derive(Serialize)]
60struct EmbeddingRequest {
61    model: String,
62    content: Content,
63    /// Output dimensionality for Matryoshka (MRL) models like gemini-embedding-001.
64    /// Allows reducing from 3072 default to 768 for compatibility.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    output_dimensionality: Option<usize>,
67}
68
69#[derive(Serialize)]
70struct Content {
71    parts: Vec<Part>,
72}
73
74#[derive(Serialize)]
75struct Part {
76    text: String,
77}
78
79/// Response from Gemini embedding API
80#[derive(Deserialize)]
81struct EmbeddingResponse {
82    embedding: EmbeddingData,
83}
84
85#[derive(Deserialize)]
86struct EmbeddingData {
87    values: Vec<f32>,
88}
89
90/// Request body for Gemini batch embedding API (`batchEmbedContents`)
91#[derive(Serialize)]
92struct BatchEmbeddingRequest {
93    requests: Vec<EmbeddingRequest>,
94}
95
96/// Response from Gemini batch embedding API
97#[derive(Deserialize)]
98struct BatchEmbeddingResponse {
99    embeddings: Vec<EmbeddingData>,
100}
101
102/// Error response from Gemini API
103#[derive(Deserialize)]
104struct GeminiError {
105    error: GeminiErrorDetail,
106}
107
108#[derive(Deserialize)]
109struct GeminiErrorDetail {
110    message: String,
111    #[allow(dead_code)]
112    status: Option<String>,
113}
114
115/// Classify Gemini API error based on status code and message
116fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
117    match status_code {
118        401 => GeminiErrorKind::Authentication,
119        429 => {
120            // Check if it's quota exceeded or rate limit
121            if message.contains("insufficient_quota") || message.contains("quota") {
122                GeminiErrorKind::QuotaExceeded
123            } else {
124                GeminiErrorKind::RateLimit
125            }
126        }
127        500..=599 => GeminiErrorKind::ServerError,
128        _ => {
129            // Check message content for specific error types
130            if message.contains("API key") || message.contains("Unauthorized") {
131                GeminiErrorKind::Authentication
132            } else if message.contains("rate") {
133                GeminiErrorKind::RateLimit
134            } else if message.contains("quota") {
135                GeminiErrorKind::QuotaExceeded
136            } else {
137                GeminiErrorKind::Unknown
138            }
139        }
140    }
141}
142
143/// Maps a reqwest send error to the appropriate `AppError` variant.
144fn map_send_error(e: reqwest::Error) -> AppError {
145    if e.is_timeout() {
146        AppError::Timeout(30)
147    } else if e.is_connect() {
148        AppError::GeminiError(GeminiErrorDetails::new(
149            GeminiErrorKind::NetworkError,
150            format!("Connection failed: {}", e),
151            0,
152        ))
153    } else {
154        AppError::ClientError(e.to_string())
155    }
156}
157
158/// Checks a Gemini API response status and returns a structured error on failure.
159async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, AppError> {
160    let status = response.status();
161    if !status.is_success() {
162        let status_code = status.as_u16();
163        let error_text = response.text().await.unwrap_or_default();
164
165        let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text) {
166            gemini_error.error.message
167        } else {
168            format!("HTTP {}: {}", status_code, error_text)
169        };
170
171        let kind = classify_gemini_error(status_code, &message);
172
173        return Err(AppError::GeminiError(GeminiErrorDetails::new(
174            kind,
175            message,
176            status_code,
177        )));
178    }
179    Ok(response)
180}
181
182impl GeminiClient {
183    /// Creates a new Gemini client with the specified API key.
184    pub fn new(api_key: &str) -> Result<Self, AppError> {
185        let http_config = HttpConfig::default();
186        let client = Client::builder()
187            .timeout(http_config.timeout)
188            .build()
189            .map_err(|e| AppError::ClientError(e.to_string()))?;
190
191        Ok(Self {
192            client,
193            api_key: api_key.to_string(),
194        })
195    }
196
197    /// Generates text embeddings using Google's gemini-embedding-001 model.
198    ///
199    /// This method converts input text into a 768-dimensional vector representation
200    /// that captures semantic meaning.
201    ///
202    /// # Arguments
203    ///
204    /// * `text` - The input text to generate embeddings for
205    ///
206    /// # Returns
207    ///
208    /// A vector of 768 floating-point values representing the text embedding.
209    ///
210    /// # Errors
211    ///
212    /// Returns `AppError::ClientError` if the HTTP request fails.
213    /// Returns `AppError::Generic` if the API returns an error.
214    pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
215        // Sanitize text - replace newlines with spaces
216        let sanitized_text = text.replace('\n', " ");
217
218        let url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:embedContent";
219
220        let request_body = EmbeddingRequest {
221            model: "models/gemini-embedding-001".to_string(),
222            content: Content {
223                parts: vec![Part {
224                    text: sanitized_text,
225                }],
226            },
227            output_dimensionality: Some(768),
228        };
229
230        let response = self
231            .client
232            .post(url)
233            .header("x-goog-api-key", self.api_key.as_str())
234            .json(&request_body)
235            .send()
236            .await
237            .map_err(map_send_error)?;
238
239        let response = check_response(response).await?;
240
241        let embedding_response: EmbeddingResponse = response
242            .json()
243            .await
244            .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
245
246        Ok(embedding_response.embedding.values)
247    }
248
249    /// Generates embeddings for multiple texts in a single API call.
250    ///
251    /// Uses the `batchEmbedContents` endpoint which supports up to 100 texts.
252    ///
253    /// # Arguments
254    ///
255    /// * `texts` - Slice of text references to embed
256    ///
257    /// # Returns
258    ///
259    /// A vector of embedding vectors, one per input text, in the same order.
260    pub async fn get_embeddings_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, AppError> {
261        if texts.is_empty() {
262            return Ok(Vec::new());
263        }
264
265        let url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents";
266
267        let requests: Vec<EmbeddingRequest> = texts
268            .iter()
269            .map(|text| EmbeddingRequest {
270                model: "models/gemini-embedding-001".to_string(),
271                content: Content {
272                    parts: vec![Part {
273                        text: text.replace('\n', " "),
274                    }],
275                },
276                output_dimensionality: Some(768),
277            })
278            .collect();
279
280        let request_body = BatchEmbeddingRequest { requests };
281
282        let response = self
283            .client
284            .post(url)
285            .header("x-goog-api-key", self.api_key.as_str())
286            .json(&request_body)
287            .send()
288            .await
289            .map_err(map_send_error)?;
290
291        let response = check_response(response).await?;
292
293        let batch_response: BatchEmbeddingResponse = response
294            .json()
295            .await
296            .map_err(|e| AppError::ClientError(format!("Failed to parse batch response: {}", e)))?;
297
298        if batch_response.embeddings.len() != texts.len() {
299            return Err(AppError::ClientError(format!(
300                "Batch embedding count mismatch: expected {}, got {}",
301                texts.len(),
302                batch_response.embeddings.len()
303            )));
304        }
305
306        Ok(batch_response
307            .embeddings
308            .into_iter()
309            .map(|e| e.values)
310            .collect())
311    }
312}
313
314// =============================================================================
315// Trait Implementation: EmbeddingProvider
316// =============================================================================
317
318impl ceres_core::traits::EmbeddingProvider for GeminiClient {
319    fn name(&self) -> &'static str {
320        "gemini"
321    }
322
323    fn dimension(&self) -> usize {
324        // gemini-embedding-001 with output_dimensionality=768
325        768
326    }
327
328    fn max_batch_size(&self) -> usize {
329        100 // Gemini batchEmbedContents limit
330    }
331
332    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
333        self.get_embeddings(text).await
334    }
335
336    async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
337        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
338        self.get_embeddings_batch(&text_refs).await
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_new_client() {
348        let client = GeminiClient::new("test-api-key");
349        assert!(client.is_ok());
350    }
351
352    #[test]
353    fn test_text_sanitization() {
354        let text_with_newlines = "Line 1\nLine 2\nLine 3";
355        let sanitized = text_with_newlines.replace('\n', " ");
356        assert_eq!(sanitized, "Line 1 Line 2 Line 3");
357    }
358
359    #[test]
360    fn test_request_serialization() {
361        let request = EmbeddingRequest {
362            model: "models/gemini-embedding-001".to_string(),
363            content: Content {
364                parts: vec![Part {
365                    text: "Hello world".to_string(),
366                }],
367            },
368            output_dimensionality: Some(768),
369        };
370
371        let json = serde_json::to_string(&request).unwrap();
372        assert!(json.contains("gemini-embedding-001"));
373        assert!(json.contains("Hello world"));
374        assert!(json.contains("output_dimensionality"));
375    }
376
377    #[test]
378    fn test_classify_gemini_error_auth() {
379        let kind = classify_gemini_error(401, "Invalid API key");
380        assert_eq!(kind, GeminiErrorKind::Authentication);
381    }
382
383    #[test]
384    fn test_classify_gemini_error_auth_from_message() {
385        let kind = classify_gemini_error(400, "API key not valid");
386        assert_eq!(kind, GeminiErrorKind::Authentication);
387    }
388
389    #[test]
390    fn test_classify_gemini_error_rate_limit() {
391        let kind = classify_gemini_error(429, "Rate limit exceeded");
392        assert_eq!(kind, GeminiErrorKind::RateLimit);
393    }
394
395    #[test]
396    fn test_classify_gemini_error_quota() {
397        let kind = classify_gemini_error(429, "insufficient_quota");
398        assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
399    }
400
401    #[test]
402    fn test_classify_gemini_error_server() {
403        let kind = classify_gemini_error(500, "Internal server error");
404        assert_eq!(kind, GeminiErrorKind::ServerError);
405    }
406
407    #[test]
408    fn test_classify_gemini_error_server_503() {
409        let kind = classify_gemini_error(503, "Service unavailable");
410        assert_eq!(kind, GeminiErrorKind::ServerError);
411    }
412
413    #[test]
414    fn test_classify_gemini_error_unknown() {
415        let kind = classify_gemini_error(400, "Bad request");
416        assert_eq!(kind, GeminiErrorKind::Unknown);
417    }
418
419    #[test]
420    fn test_batch_request_serialization() {
421        let request = BatchEmbeddingRequest {
422            requests: vec![
423                EmbeddingRequest {
424                    model: "models/gemini-embedding-001".to_string(),
425                    content: Content {
426                        parts: vec![Part {
427                            text: "First text".to_string(),
428                        }],
429                    },
430                    output_dimensionality: Some(768),
431                },
432                EmbeddingRequest {
433                    model: "models/gemini-embedding-001".to_string(),
434                    content: Content {
435                        parts: vec![Part {
436                            text: "Second text".to_string(),
437                        }],
438                    },
439                    output_dimensionality: Some(768),
440                },
441            ],
442        };
443
444        let json = serde_json::to_string(&request).unwrap();
445        assert!(json.contains("requests"));
446        assert!(json.contains("First text"));
447        assert!(json.contains("Second text"));
448
449        // Verify structure matches Gemini API expectations
450        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
451        let requests = parsed["requests"].as_array().unwrap();
452        assert_eq!(requests.len(), 2);
453        assert_eq!(requests[0]["model"], "models/gemini-embedding-001");
454        assert_eq!(requests[0]["output_dimensionality"], 768);
455    }
456
457    #[test]
458    fn test_batch_response_deserialization() {
459        let json = r#"{
460            "embeddings": [
461                { "values": [0.1, 0.2, 0.3] },
462                { "values": [0.4, 0.5, 0.6] }
463            ]
464        }"#;
465
466        let response: BatchEmbeddingResponse = serde_json::from_str(json).unwrap();
467        assert_eq!(response.embeddings.len(), 2);
468        assert_eq!(response.embeddings[0].values, vec![0.1, 0.2, 0.3]);
469        assert_eq!(response.embeddings[1].values, vec![0.4, 0.5, 0.6]);
470    }
471}