memvid_cli/
gemini_embeddings.rs

1//! Gemini (Google AI) Embeddings Provider
2//!
3//! This module provides an `EmbeddingProvider` implementation that uses
4//! Google's Gemini API for generating high-quality embeddings.
5//!
6//! ## Environment Variables
7//! - `GOOGLE_API_KEY` or `GEMINI_API_KEY`: Required API key for Google AI
8//! - `GEMINI_EMBEDDING_MODEL`: Optional model override (default: text-embedding-004)
9//!
10//! ## Features
11//! - Supports all Gemini embedding models
12//! - Efficient batch processing
13//! - Thread-safe for concurrent use
14
15use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21/// Gemini embeddings API base URL
22const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
23
24/// Default embedding model
25const DEFAULT_MODEL: &str = "text-embedding-004";
26
27/// Maximum texts per batch
28const MAX_BATCH_SIZE: usize = 100;
29
30/// Request timeout
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33/// Maximum characters for embedding text to avoid exceeding token limits.
34const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
35
36/// Truncate text to MAX_EMBEDDING_TEXT_LEN to avoid token limit errors.
37fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
38    if text.len() <= MAX_EMBEDDING_TEXT_LEN {
39        std::borrow::Cow::Borrowed(text)
40    } else {
41        let end = text[..MAX_EMBEDDING_TEXT_LEN]
42            .char_indices()
43            .rev()
44            .next()
45            .map(|(i, c)| i + c.len_utf8())
46            .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
47        warn!(
48            "Truncating embedding text from {} to {} chars to avoid token limit",
49            text.len(),
50            end
51        );
52        std::borrow::Cow::Owned(text[..end].to_string())
53    }
54}
55
56/// Gemini embed content request
57#[derive(Debug, Serialize)]
58struct GeminiEmbedRequest {
59    content: GeminiContent,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    task_type: Option<String>,
62}
63
64/// Gemini batch embed request
65#[derive(Debug, Serialize)]
66struct GeminiBatchEmbedRequest {
67    requests: Vec<GeminiEmbedRequestItem>,
68}
69
70#[derive(Debug, Serialize)]
71struct GeminiEmbedRequestItem {
72    model: String,
73    content: GeminiContent,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    task_type: Option<String>,
76}
77
78#[derive(Debug, Serialize)]
79struct GeminiContent {
80    parts: Vec<GeminiPart>,
81}
82
83#[derive(Debug, Serialize)]
84struct GeminiPart {
85    text: String,
86}
87
88/// Gemini embed response
89#[derive(Debug, Deserialize)]
90struct GeminiEmbedResponse {
91    embedding: GeminiEmbedding,
92}
93
94#[derive(Debug, Deserialize)]
95struct GeminiEmbedding {
96    values: Vec<f32>,
97}
98
99/// Gemini batch embed response
100#[derive(Debug, Deserialize)]
101struct GeminiBatchEmbedResponse {
102    embeddings: Vec<GeminiEmbedding>,
103}
104
105/// Gemini error response
106#[derive(Debug, Deserialize)]
107struct GeminiErrorResponse {
108    error: GeminiError,
109}
110
111#[derive(Debug, Deserialize)]
112struct GeminiError {
113    message: String,
114    code: i32,
115}
116
117/// Gemini Embedding Provider
118#[derive(Clone)]
119pub struct GeminiEmbeddingProvider {
120    api_key: String,
121    model: String,
122    client: Client,
123    dimension: usize,
124}
125
126impl std::fmt::Debug for GeminiEmbeddingProvider {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("GeminiEmbeddingProvider")
129            .field("model", &self.model)
130            .field("dimension", &self.dimension)
131            .finish()
132    }
133}
134
135impl GeminiEmbeddingProvider {
136    /// Create a new Gemini embedding provider
137    pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
138        if api_key.is_empty() {
139            bail!("Gemini API key cannot be empty");
140        }
141
142        let client = crate::http::blocking_client(REQUEST_TIMEOUT)
143            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
144
145        let model = model.unwrap_or(DEFAULT_MODEL).to_string();
146
147        // Dimension depends on model:
148        // text-embedding-004: 768 (default, can be reduced)
149        // gemini-embedding-001: 3072 (can be truncated)
150        let dimension = if model.contains("gemini-embedding") {
151            3072
152        } else {
153            768  // text-embedding-004 default
154        };
155
156        Ok(Self {
157            api_key,
158            model,
159            client,
160            dimension,
161        })
162    }
163
164    /// Create provider from environment variables
165    pub fn from_env() -> Result<Self> {
166        let api_key = std::env::var("GOOGLE_API_KEY")
167            .or_else(|_| std::env::var("GEMINI_API_KEY"))
168            .map_err(|_| anyhow!("GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set"))?;
169
170        let model = std::env::var("GEMINI_EMBEDDING_MODEL").ok();
171        Self::new(api_key, model.as_deref())
172    }
173
174    /// Get model name
175    pub fn model(&self) -> &str {
176        &self.model
177    }
178
179    /// Get provider kind
180    pub fn kind(&self) -> &'static str {
181        "gemini"
182    }
183
184    /// Get embedding dimension
185    pub fn dimension(&self) -> usize {
186        self.dimension
187    }
188
189    /// Embed a single text
190    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
191        let text = truncate_for_embedding(text);
192        self.embed_with_retry(&text, 3)
193    }
194
195    /// Embed multiple texts in batch
196    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
197        if texts.is_empty() {
198            return Ok(Vec::new());
199        }
200
201        // Truncate all texts first
202        let truncated: Vec<std::borrow::Cow<'_, str>> =
203            texts.iter().map(|t| truncate_for_embedding(t)).collect();
204
205        let mut all_embeddings = Vec::with_capacity(texts.len());
206
207        // Process in batches
208        for chunk in truncated.chunks(MAX_BATCH_SIZE) {
209            let embeddings = self.embed_batch_with_retry(chunk, 3)?;
210            all_embeddings.extend(embeddings);
211        }
212
213        Ok(all_embeddings)
214    }
215
216    /// Embed single text with retry logic
217    fn embed_with_retry(&self, text: &str, max_retries: usize) -> Result<Vec<f32>> {
218        let url = format!(
219            "{}/{}:embedContent?key={}",
220            GEMINI_API_BASE, self.model, self.api_key
221        );
222
223        let request = GeminiEmbedRequest {
224            content: GeminiContent {
225                parts: vec![GeminiPart {
226                    text: text.to_string(),
227                }],
228            },
229            task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
230        };
231
232        let mut last_error = None;
233
234        for attempt in 0..max_retries {
235            let response = self
236                .client
237                .post(&url)
238                .header("Content-Type", "application/json")
239                .json(&request)
240                .send();
241
242            match response {
243                Ok(resp) => {
244                    let status = resp.status();
245                    let body = resp.text().unwrap_or_default();
246
247                    if status.is_success() {
248                        let embed_response: GeminiEmbedResponse = serde_json::from_str(&body)
249                            .map_err(|e| anyhow!("Failed to parse Gemini response: {}", e))?;
250
251                        debug!(
252                            "Gemini embedding: {} values, model={}",
253                            embed_response.embedding.values.len(),
254                            self.model
255                        );
256
257                        return Ok(embed_response.embedding.values);
258                    }
259
260                    // Check for rate limiting
261                    if status.as_u16() == 429 {
262                        let backoff = Duration::from_millis(500 * (1 << attempt));
263                        warn!(
264                            "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
265                            backoff,
266                            attempt + 1,
267                            max_retries
268                        );
269                        std::thread::sleep(backoff);
270                        last_error = Some(anyhow!("Rate limited"));
271                        continue;
272                    }
273
274                    // Try to parse error response
275                    if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
276                        return Err(anyhow!(
277                            "Gemini API error ({}): {}",
278                            error_response.error.code,
279                            error_response.error.message
280                        ));
281                    }
282
283                    return Err(anyhow!(
284                        "Gemini API request failed with status {}: {}",
285                        status,
286                        body
287                    ));
288                }
289                Err(e) => {
290                    if attempt < max_retries - 1 {
291                        let backoff = Duration::from_millis(500 * (1 << attempt));
292                        warn!(
293                            "Gemini request failed, retrying in {:?} (attempt {}/{}): {}",
294                            backoff,
295                            attempt + 1,
296                            max_retries,
297                            e
298                        );
299                        std::thread::sleep(backoff);
300                        last_error = Some(anyhow!("Request failed: {}", e));
301                        continue;
302                    }
303                    return Err(anyhow!("Gemini API request failed: {}", e));
304                }
305            }
306        }
307
308        Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
309    }
310
311    /// Embed batch with retry logic
312    fn embed_batch_with_retry(
313        &self,
314        texts: &[std::borrow::Cow<'_, str>],
315        max_retries: usize,
316    ) -> Result<Vec<Vec<f32>>> {
317        let url = format!(
318            "{}/{}:batchEmbedContents?key={}",
319            GEMINI_API_BASE, self.model, self.api_key
320        );
321
322        let requests: Vec<GeminiEmbedRequestItem> = texts
323            .iter()
324            .map(|text| GeminiEmbedRequestItem {
325                model: format!("models/{}", self.model),
326                content: GeminiContent {
327                    parts: vec![GeminiPart {
328                        text: text.to_string(),
329                    }],
330                },
331                task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
332            })
333            .collect();
334
335        let batch_request = GeminiBatchEmbedRequest { requests };
336
337        let mut last_error = None;
338
339        for attempt in 0..max_retries {
340            let response = self
341                .client
342                .post(&url)
343                .header("Content-Type", "application/json")
344                .json(&batch_request)
345                .send();
346
347            match response {
348                Ok(resp) => {
349                    let status = resp.status();
350                    let body = resp.text().unwrap_or_default();
351
352                    if status.is_success() {
353                        let batch_response: GeminiBatchEmbedResponse = serde_json::from_str(&body)
354                            .map_err(|e| anyhow!("Failed to parse Gemini batch response: {}", e))?;
355
356                        debug!(
357                            "Gemini batch embeddings: {} texts, model={}",
358                            batch_response.embeddings.len(),
359                            self.model
360                        );
361
362                        return Ok(batch_response
363                            .embeddings
364                            .into_iter()
365                            .map(|e| e.values)
366                            .collect());
367                    }
368
369                    // Check for rate limiting
370                    if status.as_u16() == 429 {
371                        let backoff = Duration::from_millis(500 * (1 << attempt));
372                        warn!(
373                            "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
374                            backoff,
375                            attempt + 1,
376                            max_retries
377                        );
378                        std::thread::sleep(backoff);
379                        last_error = Some(anyhow!("Rate limited"));
380                        continue;
381                    }
382
383                    // Try to parse error response
384                    if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
385                        return Err(anyhow!(
386                            "Gemini API error ({}): {}",
387                            error_response.error.code,
388                            error_response.error.message
389                        ));
390                    }
391
392                    return Err(anyhow!(
393                        "Gemini API request failed with status {}: {}",
394                        status,
395                        body
396                    ));
397                }
398                Err(e) => {
399                    if attempt < max_retries - 1 {
400                        let backoff = Duration::from_millis(500 * (1 << attempt));
401                        warn!(
402                            "Gemini batch request failed, retrying in {:?} (attempt {}/{}): {}",
403                            backoff,
404                            attempt + 1,
405                            max_retries,
406                            e
407                        );
408                        std::thread::sleep(backoff);
409                        last_error = Some(anyhow!("Request failed: {}", e));
410                        continue;
411                    }
412                    return Err(anyhow!("Gemini API batch request failed: {}", e));
413                }
414            }
415        }
416
417        Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed batch after {} retries", max_retries)))
418    }
419}
420
421/// Helper to create a Gemini provider or return error
422pub fn try_gemini_provider() -> Option<GeminiEmbeddingProvider> {
423    match GeminiEmbeddingProvider::from_env() {
424        Ok(provider) => {
425            info!("Gemini embedding provider available");
426            Some(provider)
427        }
428        Err(e) => {
429            debug!("Gemini provider not available: {}", e);
430            None
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_empty_api_key() {
441        let result = GeminiEmbeddingProvider::new(String::new(), None);
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_model_dimensions() {
447        let provider = GeminiEmbeddingProvider::new("test-key".to_string(), None).unwrap();
448        assert_eq!(provider.dimension(), 768);
449
450        let provider = GeminiEmbeddingProvider::new("test-key".to_string(), Some("gemini-embedding-001")).unwrap();
451        assert_eq!(provider.dimension(), 3072);
452    }
453
454    #[test]
455    #[ignore] // Requires valid API key
456    fn test_real_embedding() {
457        let provider = GeminiEmbeddingProvider::from_env().expect("GOOGLE_API_KEY must be set");
458        let embedding = provider.embed_text("Hello, world!").expect("embed");
459        assert!(!embedding.is_empty());
460    }
461}