ceres_client/
gemini.rs

1//! Google Gemini embeddings client.
2//!
3//! # Future Extensions
4//!
5//! TODO: Implement switchable embedding providers (roadmap v0.3+)
6//! Consider creating an `EmbeddingProvider` trait:
7//! ```ignore
8//! #[async_trait]
9//! pub trait EmbeddingProvider: Send + Sync {
10//!     fn dimension(&self) -> usize;
11//!     async fn embed(&self, text: &str) -> Result<Vec<f32>, AppError>;
12//! }
13//! ```
14//!
15//! Potential providers to support:
16//! - OpenAI text-embedding-3-small/large
17//! - Cohere embed-multilingual-v3.0
18//! - E5-multilingual (local, for cross-language search)
19//! - Ollama (local embeddings)
20
21use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
22use ceres_core::HttpConfig;
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25
26/// HTTP client for interacting with Google's Gemini Embeddings API.
27///
28/// This client provides methods to generate text embeddings using Google's
29/// text-embedding-004 model. Embeddings are vector representations of text
30/// that can be used for semantic search, clustering, and similarity comparisons.
31///
32/// # Security
33///
34/// The API key is securely transmitted via the `x-goog-api-key` HTTP header,
35/// not in the URL, to prevent accidental exposure in logs and proxies.
36///
37/// # Examples
38///
39/// ```no_run
40/// use ceres_client::GeminiClient;
41///
42/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
43/// let client = GeminiClient::new("your-api-key")?;
44/// let embedding = client.get_embeddings("Hello, world!").await?;
45/// println!("Embedding dimension: {}", embedding.len()); // 768
46/// # Ok(())
47/// # }
48/// ```
49#[derive(Clone)]
50pub struct GeminiClient {
51    client: Client,
52    api_key: String,
53}
54
55/// Request body for Gemini embedding API
56#[derive(Serialize)]
57struct EmbeddingRequest {
58    model: String,
59    content: Content,
60}
61
62#[derive(Serialize)]
63struct Content {
64    parts: Vec<Part>,
65}
66
67#[derive(Serialize)]
68struct Part {
69    text: String,
70}
71
72/// Response from Gemini embedding API
73#[derive(Deserialize)]
74struct EmbeddingResponse {
75    embedding: EmbeddingData,
76}
77
78#[derive(Deserialize)]
79struct EmbeddingData {
80    values: Vec<f32>,
81}
82
83/// Error response from Gemini API
84#[derive(Deserialize)]
85struct GeminiError {
86    error: GeminiErrorDetail,
87}
88
89#[derive(Deserialize)]
90struct GeminiErrorDetail {
91    message: String,
92    #[allow(dead_code)]
93    status: Option<String>,
94}
95
96/// Classify Gemini API error based on status code and message
97fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
98    match status_code {
99        401 => GeminiErrorKind::Authentication,
100        429 => {
101            // Check if it's quota exceeded or rate limit
102            if message.contains("insufficient_quota") || message.contains("quota") {
103                GeminiErrorKind::QuotaExceeded
104            } else {
105                GeminiErrorKind::RateLimit
106            }
107        }
108        500..=599 => GeminiErrorKind::ServerError,
109        _ => {
110            // Check message content for specific error types
111            if message.contains("API key") || message.contains("Unauthorized") {
112                GeminiErrorKind::Authentication
113            } else if message.contains("rate") {
114                GeminiErrorKind::RateLimit
115            } else if message.contains("quota") {
116                GeminiErrorKind::QuotaExceeded
117            } else {
118                GeminiErrorKind::Unknown
119            }
120        }
121    }
122}
123
124impl GeminiClient {
125    /// Creates a new Gemini client with the specified API key.
126    pub fn new(api_key: &str) -> Result<Self, AppError> {
127        let http_config = HttpConfig::default();
128        let client = Client::builder()
129            .timeout(http_config.timeout)
130            .build()
131            .map_err(|e| AppError::ClientError(e.to_string()))?;
132
133        Ok(Self {
134            client,
135            api_key: api_key.to_string(),
136        })
137    }
138
139    /// Generates text embeddings using Google's text-embedding-004 model.
140    ///
141    /// This method converts input text into a 768-dimensional vector representation
142    /// that captures semantic meaning.
143    ///
144    /// # Arguments
145    ///
146    /// * `text` - The input text to generate embeddings for
147    ///
148    /// # Returns
149    ///
150    /// A vector of 768 floating-point values representing the text embedding.
151    ///
152    /// # Errors
153    ///
154    /// Returns `AppError::ClientError` if the HTTP request fails.
155    /// Returns `AppError::Generic` if the API returns an error.
156    pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
157        // Sanitize text - replace newlines with spaces
158        let sanitized_text = text.replace('\n', " ");
159
160        // TODO(config): Make API endpoint configurable via GEMINI_API_ENDPOINT env var
161        // Useful for: (1) Proxy servers, (2) Self-hosted alternatives, (3) Testing
162        let url = "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent";
163
164        // TODO(config): Make embedding model configurable via GEMINI_EMBEDDING_MODEL env var
165        // Different models offer different cost/quality tradeoffs:
166        // - text-embedding-004 (current): 768 dimensions
167        // - Future models may have different dimensions - handle dynamically
168        let request_body = EmbeddingRequest {
169            model: "models/text-embedding-004".to_string(),
170            content: Content {
171                parts: vec![Part {
172                    text: sanitized_text,
173                }],
174            },
175        };
176
177        let response = self
178            .client
179            .post(url)
180            .header("x-goog-api-key", self.api_key.clone())
181            .json(&request_body)
182            .send()
183            .await
184            .map_err(|e| {
185                if e.is_timeout() {
186                    AppError::Timeout(30)
187                } else if e.is_connect() {
188                    AppError::GeminiError(GeminiErrorDetails::new(
189                        GeminiErrorKind::NetworkError,
190                        format!("Connection failed: {}", e),
191                        0, // No HTTP status for connection failures
192                    ))
193                } else {
194                    AppError::ClientError(e.to_string())
195                }
196            })?;
197
198        let status = response.status();
199
200        if !status.is_success() {
201            let status_code = status.as_u16();
202            let error_text = response.text().await.unwrap_or_default();
203
204            // Try to parse as structured Gemini error
205            let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text)
206            {
207                gemini_error.error.message
208            } else {
209                format!("HTTP {}: {}", status_code, error_text)
210            };
211
212            // Classify the error
213            let kind = classify_gemini_error(status_code, &message);
214
215            // Return structured error
216            return Err(AppError::GeminiError(GeminiErrorDetails::new(
217                kind,
218                message,
219                status_code,
220            )));
221        }
222
223        let embedding_response: EmbeddingResponse = response
224            .json()
225            .await
226            .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
227
228        Ok(embedding_response.embedding.values)
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_new_client() {
238        let client = GeminiClient::new("test-api-key");
239        assert!(client.is_ok());
240    }
241
242    #[test]
243    fn test_text_sanitization() {
244        let text_with_newlines = "Line 1\nLine 2\nLine 3";
245        let sanitized = text_with_newlines.replace('\n', " ");
246        assert_eq!(sanitized, "Line 1 Line 2 Line 3");
247    }
248
249    #[test]
250    fn test_request_serialization() {
251        let request = EmbeddingRequest {
252            model: "models/text-embedding-004".to_string(),
253            content: Content {
254                parts: vec![Part {
255                    text: "Hello world".to_string(),
256                }],
257            },
258        };
259
260        let json = serde_json::to_string(&request).unwrap();
261        assert!(json.contains("text-embedding-004"));
262        assert!(json.contains("Hello world"));
263    }
264
265    #[test]
266    fn test_classify_gemini_error_auth() {
267        let kind = classify_gemini_error(401, "Invalid API key");
268        assert_eq!(kind, GeminiErrorKind::Authentication);
269    }
270
271    #[test]
272    fn test_classify_gemini_error_auth_from_message() {
273        let kind = classify_gemini_error(400, "API key not valid");
274        assert_eq!(kind, GeminiErrorKind::Authentication);
275    }
276
277    #[test]
278    fn test_classify_gemini_error_rate_limit() {
279        let kind = classify_gemini_error(429, "Rate limit exceeded");
280        assert_eq!(kind, GeminiErrorKind::RateLimit);
281    }
282
283    #[test]
284    fn test_classify_gemini_error_quota() {
285        let kind = classify_gemini_error(429, "insufficient_quota");
286        assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
287    }
288
289    #[test]
290    fn test_classify_gemini_error_server() {
291        let kind = classify_gemini_error(500, "Internal server error");
292        assert_eq!(kind, GeminiErrorKind::ServerError);
293    }
294
295    #[test]
296    fn test_classify_gemini_error_server_503() {
297        let kind = classify_gemini_error(503, "Service unavailable");
298        assert_eq!(kind, GeminiErrorKind::ServerError);
299    }
300
301    #[test]
302    fn test_classify_gemini_error_unknown() {
303        let kind = classify_gemini_error(400, "Bad request");
304        assert_eq!(kind, GeminiErrorKind::Unknown);
305    }
306}