Skip to main content

ceres_client/
gemini.rs

1//! Google Gemini embeddings client.
2//!
3//! # Future Extensions
4//!
5//! TODO: Implement switchable embedding providers (roadmap v0.3+)
6//! Consider creating an `EmbeddingProvider` trait:
7//! ```ignore
8//! #[async_trait]
9//! pub trait EmbeddingProvider: Send + Sync {
10//!     fn dimension(&self) -> usize;
11//!     async fn embed(&self, text: &str) -> Result<Vec<f32>, AppError>;
12//! }
13//! ```
14//!
15//! Potential providers to support:
16//! - OpenAI text-embedding-3-small/large
17//! - Cohere embed-multilingual-v3.0
18//! - E5-multilingual (local, for cross-language search)
19//! - Ollama (local embeddings)
20//!
21//! TODO(observability): Add OpenTelemetry instrumentation for cloud deployment
22//! Use `tracing-opentelemetry` crate to export spans to cloud observability platforms
23//! (AWS X-Ray, GCP Cloud Trace, Azure Monitor). Add `#[instrument]` spans on:
24//! - `get_embeddings()` - track API latency, token counts
25//! - `sync_portal()` in harvest.rs - track harvest duration breakdown
26//! -  This enables "waterfall" visualization showing time spent in each component
27//!    (e.g., 80% Gemini API wait, 20% DB insert).
28//!
29//! TODO(security): Encrypt API keys for multi-tenant deployment
30//! If supporting user-provided Gemini/CKAN API keys, store them encrypted
31//! in the database using `age` or `ring` crates instead of plaintext in .env.
32//! Consider a `api_keys` table with encrypted_key column and per-user isolation.
33
34use ceres_core::HttpConfig;
35use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
36use reqwest::Client;
37use serde::{Deserialize, Serialize};
38
39/// HTTP client for interacting with Google's Gemini Embeddings API.
40///
41/// This client provides methods to generate text embeddings using Google's
42/// text-embedding-004 model. Embeddings are vector representations of text
43/// that can be used for semantic search, clustering, and similarity comparisons.
44///
45/// # Security
46///
47/// The API key is securely transmitted via the `x-goog-api-key` HTTP header,
48/// not in the URL, to prevent accidental exposure in logs and proxies.
49///
50/// # Examples
51///
52/// ```no_run
53/// use ceres_client::GeminiClient;
54///
55/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
56/// let client = GeminiClient::new("your-api-key")?;
57/// let embedding = client.get_embeddings("Hello, world!").await?;
58/// println!("Embedding dimension: {}", embedding.len()); // 768
59/// # Ok(())
60/// # }
61/// ```
62#[derive(Clone)]
63pub struct GeminiClient {
64    client: Client,
65    api_key: String,
66}
67
68/// Request body for Gemini embedding API
69#[derive(Serialize)]
70struct EmbeddingRequest {
71    model: String,
72    content: Content,
73}
74
75#[derive(Serialize)]
76struct Content {
77    parts: Vec<Part>,
78}
79
80#[derive(Serialize)]
81struct Part {
82    text: String,
83}
84
85/// Response from Gemini embedding API
86#[derive(Deserialize)]
87struct EmbeddingResponse {
88    embedding: EmbeddingData,
89}
90
91#[derive(Deserialize)]
92struct EmbeddingData {
93    values: Vec<f32>,
94}
95
96/// Error response from Gemini API
97#[derive(Deserialize)]
98struct GeminiError {
99    error: GeminiErrorDetail,
100}
101
102#[derive(Deserialize)]
103struct GeminiErrorDetail {
104    message: String,
105    #[allow(dead_code)]
106    status: Option<String>,
107}
108
109/// Classify Gemini API error based on status code and message
110fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
111    match status_code {
112        401 => GeminiErrorKind::Authentication,
113        429 => {
114            // Check if it's quota exceeded or rate limit
115            if message.contains("insufficient_quota") || message.contains("quota") {
116                GeminiErrorKind::QuotaExceeded
117            } else {
118                GeminiErrorKind::RateLimit
119            }
120        }
121        500..=599 => GeminiErrorKind::ServerError,
122        _ => {
123            // Check message content for specific error types
124            if message.contains("API key") || message.contains("Unauthorized") {
125                GeminiErrorKind::Authentication
126            } else if message.contains("rate") {
127                GeminiErrorKind::RateLimit
128            } else if message.contains("quota") {
129                GeminiErrorKind::QuotaExceeded
130            } else {
131                GeminiErrorKind::Unknown
132            }
133        }
134    }
135}
136
137impl GeminiClient {
138    /// Creates a new Gemini client with the specified API key.
139    pub fn new(api_key: &str) -> Result<Self, AppError> {
140        let http_config = HttpConfig::default();
141        let client = Client::builder()
142            .timeout(http_config.timeout)
143            .build()
144            .map_err(|e| AppError::ClientError(e.to_string()))?;
145
146        Ok(Self {
147            client,
148            api_key: api_key.to_string(),
149        })
150    }
151
152    /// Generates text embeddings using Google's text-embedding-004 model.
153    ///
154    /// This method converts input text into a 768-dimensional vector representation
155    /// that captures semantic meaning.
156    ///
157    /// # Arguments
158    ///
159    /// * `text` - The input text to generate embeddings for
160    ///
161    /// # Returns
162    ///
163    /// A vector of 768 floating-point values representing the text embedding.
164    ///
165    /// # Errors
166    ///
167    /// Returns `AppError::ClientError` if the HTTP request fails.
168    /// Returns `AppError::Generic` if the API returns an error.
169    pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
170        // Sanitize text - replace newlines with spaces
171        let sanitized_text = text.replace('\n', " ");
172
173        // TODO(config): Make API endpoint configurable via GEMINI_API_ENDPOINT env var
174        // Useful for: (1) Proxy servers, (2) Self-hosted alternatives, (3) Testing
175        let url = "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent";
176
177        // TODO(config): Make embedding model configurable via GEMINI_EMBEDDING_MODEL env var
178        // Different models offer different cost/quality tradeoffs:
179        // - text-embedding-004 (current): 768 dimensions
180        // - Future models may have different dimensions - handle dynamically
181        let request_body = EmbeddingRequest {
182            model: "models/text-embedding-004".to_string(),
183            content: Content {
184                parts: vec![Part {
185                    text: sanitized_text,
186                }],
187            },
188        };
189
190        let response = self
191            .client
192            .post(url)
193            .header("x-goog-api-key", self.api_key.clone())
194            .json(&request_body)
195            .send()
196            .await
197            .map_err(|e| {
198                if e.is_timeout() {
199                    AppError::Timeout(30)
200                } else if e.is_connect() {
201                    AppError::GeminiError(GeminiErrorDetails::new(
202                        GeminiErrorKind::NetworkError,
203                        format!("Connection failed: {}", e),
204                        0, // No HTTP status for connection failures
205                    ))
206                } else {
207                    AppError::ClientError(e.to_string())
208                }
209            })?;
210
211        let status = response.status();
212
213        if !status.is_success() {
214            let status_code = status.as_u16();
215            let error_text = response.text().await.unwrap_or_default();
216
217            // Try to parse as structured Gemini error
218            let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text)
219            {
220                gemini_error.error.message
221            } else {
222                format!("HTTP {}: {}", status_code, error_text)
223            };
224
225            // Classify the error
226            let kind = classify_gemini_error(status_code, &message);
227
228            // Return structured error
229            return Err(AppError::GeminiError(GeminiErrorDetails::new(
230                kind,
231                message,
232                status_code,
233            )));
234        }
235
236        let embedding_response: EmbeddingResponse = response
237            .json()
238            .await
239            .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
240
241        Ok(embedding_response.embedding.values)
242    }
243}
244
245// =============================================================================
246// Trait Implementation: EmbeddingProvider
247// =============================================================================
248
249impl ceres_core::traits::EmbeddingProvider for GeminiClient {
250    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
251        self.get_embeddings(text).await
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_new_client() {
261        let client = GeminiClient::new("test-api-key");
262        assert!(client.is_ok());
263    }
264
265    #[test]
266    fn test_text_sanitization() {
267        let text_with_newlines = "Line 1\nLine 2\nLine 3";
268        let sanitized = text_with_newlines.replace('\n', " ");
269        assert_eq!(sanitized, "Line 1 Line 2 Line 3");
270    }
271
272    #[test]
273    fn test_request_serialization() {
274        let request = EmbeddingRequest {
275            model: "models/text-embedding-004".to_string(),
276            content: Content {
277                parts: vec![Part {
278                    text: "Hello world".to_string(),
279                }],
280            },
281        };
282
283        let json = serde_json::to_string(&request).unwrap();
284        assert!(json.contains("text-embedding-004"));
285        assert!(json.contains("Hello world"));
286    }
287
288    #[test]
289    fn test_classify_gemini_error_auth() {
290        let kind = classify_gemini_error(401, "Invalid API key");
291        assert_eq!(kind, GeminiErrorKind::Authentication);
292    }
293
294    #[test]
295    fn test_classify_gemini_error_auth_from_message() {
296        let kind = classify_gemini_error(400, "API key not valid");
297        assert_eq!(kind, GeminiErrorKind::Authentication);
298    }
299
300    #[test]
301    fn test_classify_gemini_error_rate_limit() {
302        let kind = classify_gemini_error(429, "Rate limit exceeded");
303        assert_eq!(kind, GeminiErrorKind::RateLimit);
304    }
305
306    #[test]
307    fn test_classify_gemini_error_quota() {
308        let kind = classify_gemini_error(429, "insufficient_quota");
309        assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
310    }
311
312    #[test]
313    fn test_classify_gemini_error_server() {
314        let kind = classify_gemini_error(500, "Internal server error");
315        assert_eq!(kind, GeminiErrorKind::ServerError);
316    }
317
318    #[test]
319    fn test_classify_gemini_error_server_503() {
320        let kind = classify_gemini_error(503, "Service unavailable");
321        assert_eq!(kind, GeminiErrorKind::ServerError);
322    }
323
324    #[test]
325    fn test_classify_gemini_error_unknown() {
326        let kind = classify_gemini_error(400, "Bad request");
327        assert_eq!(kind, GeminiErrorKind::Unknown);
328    }
329}