Skip to main content

ceres_client/
gemini.rs

1//! Google Gemini embeddings client.
2//!
3//! # Future Extensions
4//!
5//! TODO: Implement switchable embedding providers (roadmap v0.3+)
6//! Consider creating an `EmbeddingProvider` trait:
7//! ```ignore
8//! #[async_trait]
9//! pub trait EmbeddingProvider: Send + Sync {
10//!     fn dimension(&self) -> usize;
11//!     async fn embed(&self, text: &str) -> Result<Vec<f32>, AppError>;
12//! }
13//! ```
14//!
15//! Potential providers to support:
16//! - OpenAI text-embedding-3-small/large
17//! - Cohere embed-multilingual-v3.0
18//! - E5-multilingual (local, for cross-language search)
19//! - Ollama (local embeddings)
20//!
21//! TODO(observability): Add OpenTelemetry instrumentation for cloud deployment
22//! Use `tracing-opentelemetry` crate to export spans to cloud observability platforms
23//! (AWS X-Ray, GCP Cloud Trace, Azure Monitor). Add `#[instrument]` spans on:
24//! - `get_embeddings()` - track API latency, token counts
25//! - `sync_portal()` in harvest.rs - track harvest duration breakdown
26//! -  This enables "waterfall" visualization showing time spent in each component
27//!    (e.g., 80% Gemini API wait, 20% DB insert).
28//!
29//! TODO(security): Encrypt API keys for multi-tenant deployment
30//! If supporting user-provided Gemini/CKAN API keys, store them encrypted
31//! in the database using `age` or `ring` crates instead of plaintext in .env.
32//! Consider a `api_keys` table with encrypted_key column and per-user isolation.
33
34use ceres_core::HttpConfig;
35use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
36use reqwest::Client;
37use serde::{Deserialize, Serialize};
38
39/// HTTP client for interacting with Google's Gemini Embeddings API.
40///
41/// This client provides methods to generate text embeddings using Google's
42/// gemini-embedding-001 model. Embeddings are vector representations of text
43/// that can be used for semantic search, clustering, and similarity comparisons.
44///
45/// # Security
46///
47/// The API key is securely transmitted via the `x-goog-api-key` HTTP header,
48/// not in the URL, to prevent accidental exposure in logs and proxies.
49///
50/// # Examples
51///
52/// ```no_run
53/// use ceres_client::GeminiClient;
54///
55/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
56/// let client = GeminiClient::new("your-api-key")?;
57/// let embedding = client.get_embeddings("Hello, world!").await?;
58/// println!("Embedding dimension: {}", embedding.len()); // 768
59/// # Ok(())
60/// # }
61/// ```
62#[derive(Clone)]
63pub struct GeminiClient {
64    client: Client,
65    api_key: String,
66}
67
68/// Request body for Gemini embedding API
69#[derive(Serialize)]
70struct EmbeddingRequest {
71    model: String,
72    content: Content,
73    /// Output dimensionality for Matryoshka (MRL) models like gemini-embedding-001.
74    /// Allows reducing from 3072 default to 768 for compatibility.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    output_dimensionality: Option<usize>,
77}
78
79#[derive(Serialize)]
80struct Content {
81    parts: Vec<Part>,
82}
83
84#[derive(Serialize)]
85struct Part {
86    text: String,
87}
88
89/// Response from Gemini embedding API
90#[derive(Deserialize)]
91struct EmbeddingResponse {
92    embedding: EmbeddingData,
93}
94
95#[derive(Deserialize)]
96struct EmbeddingData {
97    values: Vec<f32>,
98}
99
100/// Error response from Gemini API
101#[derive(Deserialize)]
102struct GeminiError {
103    error: GeminiErrorDetail,
104}
105
106#[derive(Deserialize)]
107struct GeminiErrorDetail {
108    message: String,
109    #[allow(dead_code)]
110    status: Option<String>,
111}
112
113/// Classify Gemini API error based on status code and message
114fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
115    match status_code {
116        401 => GeminiErrorKind::Authentication,
117        429 => {
118            // Check if it's quota exceeded or rate limit
119            if message.contains("insufficient_quota") || message.contains("quota") {
120                GeminiErrorKind::QuotaExceeded
121            } else {
122                GeminiErrorKind::RateLimit
123            }
124        }
125        500..=599 => GeminiErrorKind::ServerError,
126        _ => {
127            // Check message content for specific error types
128            if message.contains("API key") || message.contains("Unauthorized") {
129                GeminiErrorKind::Authentication
130            } else if message.contains("rate") {
131                GeminiErrorKind::RateLimit
132            } else if message.contains("quota") {
133                GeminiErrorKind::QuotaExceeded
134            } else {
135                GeminiErrorKind::Unknown
136            }
137        }
138    }
139}
140
141impl GeminiClient {
142    /// Creates a new Gemini client with the specified API key.
143    pub fn new(api_key: &str) -> Result<Self, AppError> {
144        let http_config = HttpConfig::default();
145        let client = Client::builder()
146            .timeout(http_config.timeout)
147            .build()
148            .map_err(|e| AppError::ClientError(e.to_string()))?;
149
150        Ok(Self {
151            client,
152            api_key: api_key.to_string(),
153        })
154    }
155
156    /// Generates text embeddings using Google's gemini-embedding-001 model.
157    ///
158    /// This method converts input text into a 768-dimensional vector representation
159    /// that captures semantic meaning.
160    ///
161    /// # Arguments
162    ///
163    /// * `text` - The input text to generate embeddings for
164    ///
165    /// # Returns
166    ///
167    /// A vector of 768 floating-point values representing the text embedding.
168    ///
169    /// # Errors
170    ///
171    /// Returns `AppError::ClientError` if the HTTP request fails.
172    /// Returns `AppError::Generic` if the API returns an error.
173    pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
174        // Sanitize text - replace newlines with spaces
175        let sanitized_text = text.replace('\n', " ");
176
177        let url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:embedContent";
178
179        let request_body = EmbeddingRequest {
180            model: "models/gemini-embedding-001".to_string(),
181            content: Content {
182                parts: vec![Part {
183                    text: sanitized_text,
184                }],
185            },
186            output_dimensionality: Some(768),
187        };
188
189        let response = self
190            .client
191            .post(url)
192            .header("x-goog-api-key", self.api_key.clone())
193            .json(&request_body)
194            .send()
195            .await
196            .map_err(|e| {
197                if e.is_timeout() {
198                    AppError::Timeout(30)
199                } else if e.is_connect() {
200                    AppError::GeminiError(GeminiErrorDetails::new(
201                        GeminiErrorKind::NetworkError,
202                        format!("Connection failed: {}", e),
203                        0, // No HTTP status for connection failures
204                    ))
205                } else {
206                    AppError::ClientError(e.to_string())
207                }
208            })?;
209
210        let status = response.status();
211
212        if !status.is_success() {
213            let status_code = status.as_u16();
214            let error_text = response.text().await.unwrap_or_default();
215
216            // Try to parse as structured Gemini error
217            let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text)
218            {
219                gemini_error.error.message
220            } else {
221                format!("HTTP {}: {}", status_code, error_text)
222            };
223
224            // Classify the error
225            let kind = classify_gemini_error(status_code, &message);
226
227            // Return structured error
228            return Err(AppError::GeminiError(GeminiErrorDetails::new(
229                kind,
230                message,
231                status_code,
232            )));
233        }
234
235        let embedding_response: EmbeddingResponse = response
236            .json()
237            .await
238            .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
239
240        Ok(embedding_response.embedding.values)
241    }
242}
243
244// =============================================================================
245// Trait Implementation: EmbeddingProvider
246// =============================================================================
247
248impl ceres_core::traits::EmbeddingProvider for GeminiClient {
249    fn name(&self) -> &'static str {
250        "gemini"
251    }
252
253    fn dimension(&self) -> usize {
254        // gemini-embedding-001 with output_dimensionality=768
255        768
256    }
257
258    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
259        self.get_embeddings(text).await
260    }
261
262    // Note: Gemini API supports batch embeddings via batchEmbedContents endpoint.
263    // For now, we use the default sequential implementation.
264    // TODO: Implement native batch API for improved efficiency.
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_new_client() {
273        let client = GeminiClient::new("test-api-key");
274        assert!(client.is_ok());
275    }
276
277    #[test]
278    fn test_text_sanitization() {
279        let text_with_newlines = "Line 1\nLine 2\nLine 3";
280        let sanitized = text_with_newlines.replace('\n', " ");
281        assert_eq!(sanitized, "Line 1 Line 2 Line 3");
282    }
283
284    #[test]
285    fn test_request_serialization() {
286        let request = EmbeddingRequest {
287            model: "models/gemini-embedding-001".to_string(),
288            content: Content {
289                parts: vec![Part {
290                    text: "Hello world".to_string(),
291                }],
292            },
293            output_dimensionality: Some(768),
294        };
295
296        let json = serde_json::to_string(&request).unwrap();
297        assert!(json.contains("gemini-embedding-001"));
298        assert!(json.contains("Hello world"));
299        assert!(json.contains("output_dimensionality"));
300    }
301
302    #[test]
303    fn test_classify_gemini_error_auth() {
304        let kind = classify_gemini_error(401, "Invalid API key");
305        assert_eq!(kind, GeminiErrorKind::Authentication);
306    }
307
308    #[test]
309    fn test_classify_gemini_error_auth_from_message() {
310        let kind = classify_gemini_error(400, "API key not valid");
311        assert_eq!(kind, GeminiErrorKind::Authentication);
312    }
313
314    #[test]
315    fn test_classify_gemini_error_rate_limit() {
316        let kind = classify_gemini_error(429, "Rate limit exceeded");
317        assert_eq!(kind, GeminiErrorKind::RateLimit);
318    }
319
320    #[test]
321    fn test_classify_gemini_error_quota() {
322        let kind = classify_gemini_error(429, "insufficient_quota");
323        assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
324    }
325
326    #[test]
327    fn test_classify_gemini_error_server() {
328        let kind = classify_gemini_error(500, "Internal server error");
329        assert_eq!(kind, GeminiErrorKind::ServerError);
330    }
331
332    #[test]
333    fn test_classify_gemini_error_server_503() {
334        let kind = classify_gemini_error(503, "Service unavailable");
335        assert_eq!(kind, GeminiErrorKind::ServerError);
336    }
337
338    #[test]
339    fn test_classify_gemini_error_unknown() {
340        let kind = classify_gemini_error(400, "Bad request");
341        assert_eq!(kind, GeminiErrorKind::Unknown);
342    }
343}