memvid_cli/
gemini_embeddings.rs

1//! Gemini (Google AI) Embeddings Provider
2//!
3//! This module provides an `EmbeddingProvider` implementation that uses
4//! Google's Gemini API for generating high-quality embeddings.
5//!
6//! ## Environment Variables
7//! - `GOOGLE_API_KEY` or `GEMINI_API_KEY`: Required API key for Google AI
8//! - `GEMINI_EMBEDDING_MODEL`: Optional model override (default: text-embedding-004)
9//!
10//! ## Features
11//! - Supports all Gemini embedding models
12//! - Efficient batch processing
13//! - Thread-safe for concurrent use
14
15use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21/// Gemini embeddings API base URL
22const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
23
24/// Default embedding model
25const DEFAULT_MODEL: &str = "text-embedding-004";
26
27/// Maximum texts per batch
28const MAX_BATCH_SIZE: usize = 100;
29
30/// Request timeout
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33/// Maximum characters for embedding text to avoid exceeding token limits.
34const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
35
36/// Truncate text to MAX_EMBEDDING_TEXT_LEN to avoid token limit errors.
37fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
38    if text.len() <= MAX_EMBEDDING_TEXT_LEN {
39        std::borrow::Cow::Borrowed(text)
40    } else {
41        let end = text[..MAX_EMBEDDING_TEXT_LEN]
42            .char_indices()
43            .rev()
44            .next()
45            .map(|(i, c)| i + c.len_utf8())
46            .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
47        warn!(
48            "Truncating embedding text from {} to {} chars to avoid token limit",
49            text.len(),
50            end
51        );
52        std::borrow::Cow::Owned(text[..end].to_string())
53    }
54}
55
56/// Gemini embed content request
57#[derive(Debug, Serialize)]
58struct GeminiEmbedRequest {
59    content: GeminiContent,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    task_type: Option<String>,
62}
63
64/// Gemini batch embed request
65#[derive(Debug, Serialize)]
66struct GeminiBatchEmbedRequest {
67    requests: Vec<GeminiEmbedRequestItem>,
68}
69
70#[derive(Debug, Serialize)]
71struct GeminiEmbedRequestItem {
72    model: String,
73    content: GeminiContent,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    task_type: Option<String>,
76}
77
78#[derive(Debug, Serialize)]
79struct GeminiContent {
80    parts: Vec<GeminiPart>,
81}
82
83#[derive(Debug, Serialize)]
84struct GeminiPart {
85    text: String,
86}
87
88/// Gemini embed response
89#[derive(Debug, Deserialize)]
90struct GeminiEmbedResponse {
91    embedding: GeminiEmbedding,
92}
93
94#[derive(Debug, Deserialize)]
95struct GeminiEmbedding {
96    values: Vec<f32>,
97}
98
99/// Gemini batch embed response
100#[derive(Debug, Deserialize)]
101struct GeminiBatchEmbedResponse {
102    embeddings: Vec<GeminiEmbedding>,
103}
104
105/// Gemini error response
106#[derive(Debug, Deserialize)]
107struct GeminiErrorResponse {
108    error: GeminiError,
109}
110
111#[derive(Debug, Deserialize)]
112struct GeminiError {
113    message: String,
114    code: i32,
115}
116
117/// Gemini Embedding Provider
118#[derive(Clone)]
119pub struct GeminiEmbeddingProvider {
120    api_key: String,
121    model: String,
122    client: Client,
123    dimension: usize,
124}
125
126impl std::fmt::Debug for GeminiEmbeddingProvider {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("GeminiEmbeddingProvider")
129            .field("model", &self.model)
130            .field("dimension", &self.dimension)
131            .finish()
132    }
133}
134
135impl GeminiEmbeddingProvider {
136    /// Create a new Gemini embedding provider
137    pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
138        if api_key.is_empty() {
139            bail!("Gemini API key cannot be empty");
140        }
141
142        let client = crate::http::blocking_client(REQUEST_TIMEOUT)
143            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
144
145        let model = model.unwrap_or(DEFAULT_MODEL).to_string();
146
147        // Dimension depends on model:
148        // text-embedding-004: 768 (default, can be reduced)
149        // gemini-embedding-001: 3072 (can be truncated)
150        let dimension = if model.contains("gemini-embedding") {
151            3072
152        } else {
153            768 // text-embedding-004 default
154        };
155
156        Ok(Self {
157            api_key,
158            model,
159            client,
160            dimension,
161        })
162    }
163
164    /// Create provider from environment variables
165    pub fn from_env() -> Result<Self> {
166        let api_key = std::env::var("GOOGLE_API_KEY")
167            .or_else(|_| std::env::var("GEMINI_API_KEY"))
168            .map_err(|_| {
169                anyhow!("GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set")
170            })?;
171
172        let model = std::env::var("GEMINI_EMBEDDING_MODEL").ok();
173        Self::new(api_key, model.as_deref())
174    }
175
176    /// Get model name
177    pub fn model(&self) -> &str {
178        &self.model
179    }
180
181    /// Get provider kind
182    pub fn kind(&self) -> &'static str {
183        "gemini"
184    }
185
186    /// Get embedding dimension
187    pub fn dimension(&self) -> usize {
188        self.dimension
189    }
190
191    /// Embed a single text
192    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
193        let text = truncate_for_embedding(text);
194        self.embed_with_retry(&text, 3)
195    }
196
197    /// Embed multiple texts in batch
198    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
199        if texts.is_empty() {
200            return Ok(Vec::new());
201        }
202
203        // Truncate all texts first
204        let truncated: Vec<std::borrow::Cow<'_, str>> =
205            texts.iter().map(|t| truncate_for_embedding(t)).collect();
206
207        let mut all_embeddings = Vec::with_capacity(texts.len());
208
209        // Process in batches
210        for chunk in truncated.chunks(MAX_BATCH_SIZE) {
211            let embeddings = self.embed_batch_with_retry(chunk, 3)?;
212            all_embeddings.extend(embeddings);
213        }
214
215        Ok(all_embeddings)
216    }
217
218    /// Embed single text with retry logic
219    fn embed_with_retry(&self, text: &str, max_retries: usize) -> Result<Vec<f32>> {
220        let url = format!(
221            "{}/{}:embedContent?key={}",
222            GEMINI_API_BASE, self.model, self.api_key
223        );
224
225        let request = GeminiEmbedRequest {
226            content: GeminiContent {
227                parts: vec![GeminiPart {
228                    text: text.to_string(),
229                }],
230            },
231            task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
232        };
233
234        let mut last_error = None;
235
236        for attempt in 0..max_retries {
237            let response = self
238                .client
239                .post(&url)
240                .header("Content-Type", "application/json")
241                .json(&request)
242                .send();
243
244            match response {
245                Ok(resp) => {
246                    let status = resp.status();
247                    let body = resp.text().unwrap_or_default();
248
249                    if status.is_success() {
250                        let embed_response: GeminiEmbedResponse = serde_json::from_str(&body)
251                            .map_err(|e| anyhow!("Failed to parse Gemini response: {}", e))?;
252
253                        debug!(
254                            "Gemini embedding: {} values, model={}",
255                            embed_response.embedding.values.len(),
256                            self.model
257                        );
258
259                        return Ok(embed_response.embedding.values);
260                    }
261
262                    // Check for rate limiting
263                    if status.as_u16() == 429 {
264                        let backoff = Duration::from_millis(500 * (1 << attempt));
265                        warn!(
266                            "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
267                            backoff,
268                            attempt + 1,
269                            max_retries
270                        );
271                        std::thread::sleep(backoff);
272                        last_error = Some(anyhow!("Rate limited"));
273                        continue;
274                    }
275
276                    // Try to parse error response
277                    if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
278                        return Err(anyhow!(
279                            "Gemini API error ({}): {}",
280                            error_response.error.code,
281                            error_response.error.message
282                        ));
283                    }
284
285                    return Err(anyhow!(
286                        "Gemini API request failed with status {}: {}",
287                        status,
288                        body
289                    ));
290                }
291                Err(e) => {
292                    if attempt < max_retries - 1 {
293                        let backoff = Duration::from_millis(500 * (1 << attempt));
294                        warn!(
295                            "Gemini request failed, retrying in {:?} (attempt {}/{}): {}",
296                            backoff,
297                            attempt + 1,
298                            max_retries,
299                            e
300                        );
301                        std::thread::sleep(backoff);
302                        last_error = Some(anyhow!("Request failed: {}", e));
303                        continue;
304                    }
305                    return Err(anyhow!("Gemini API request failed: {}", e));
306                }
307            }
308        }
309
310        Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
311    }
312
313    /// Embed batch with retry logic
314    fn embed_batch_with_retry(
315        &self,
316        texts: &[std::borrow::Cow<'_, str>],
317        max_retries: usize,
318    ) -> Result<Vec<Vec<f32>>> {
319        let url = format!(
320            "{}/{}:batchEmbedContents?key={}",
321            GEMINI_API_BASE, self.model, self.api_key
322        );
323
324        let requests: Vec<GeminiEmbedRequestItem> = texts
325            .iter()
326            .map(|text| GeminiEmbedRequestItem {
327                model: format!("models/{}", self.model),
328                content: GeminiContent {
329                    parts: vec![GeminiPart {
330                        text: text.to_string(),
331                    }],
332                },
333                task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
334            })
335            .collect();
336
337        let batch_request = GeminiBatchEmbedRequest { requests };
338
339        let mut last_error = None;
340
341        for attempt in 0..max_retries {
342            let response = self
343                .client
344                .post(&url)
345                .header("Content-Type", "application/json")
346                .json(&batch_request)
347                .send();
348
349            match response {
350                Ok(resp) => {
351                    let status = resp.status();
352                    let body = resp.text().unwrap_or_default();
353
354                    if status.is_success() {
355                        let batch_response: GeminiBatchEmbedResponse = serde_json::from_str(&body)
356                            .map_err(|e| anyhow!("Failed to parse Gemini batch response: {}", e))?;
357
358                        debug!(
359                            "Gemini batch embeddings: {} texts, model={}",
360                            batch_response.embeddings.len(),
361                            self.model
362                        );
363
364                        return Ok(batch_response
365                            .embeddings
366                            .into_iter()
367                            .map(|e| e.values)
368                            .collect());
369                    }
370
371                    // Check for rate limiting
372                    if status.as_u16() == 429 {
373                        let backoff = Duration::from_millis(500 * (1 << attempt));
374                        warn!(
375                            "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
376                            backoff,
377                            attempt + 1,
378                            max_retries
379                        );
380                        std::thread::sleep(backoff);
381                        last_error = Some(anyhow!("Rate limited"));
382                        continue;
383                    }
384
385                    // Try to parse error response
386                    if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
387                        return Err(anyhow!(
388                            "Gemini API error ({}): {}",
389                            error_response.error.code,
390                            error_response.error.message
391                        ));
392                    }
393
394                    return Err(anyhow!(
395                        "Gemini API request failed with status {}: {}",
396                        status,
397                        body
398                    ));
399                }
400                Err(e) => {
401                    if attempt < max_retries - 1 {
402                        let backoff = Duration::from_millis(500 * (1 << attempt));
403                        warn!(
404                            "Gemini batch request failed, retrying in {:?} (attempt {}/{}): {}",
405                            backoff,
406                            attempt + 1,
407                            max_retries,
408                            e
409                        );
410                        std::thread::sleep(backoff);
411                        last_error = Some(anyhow!("Request failed: {}", e));
412                        continue;
413                    }
414                    return Err(anyhow!("Gemini API batch request failed: {}", e));
415                }
416            }
417        }
418
419        Err(last_error
420            .unwrap_or_else(|| anyhow!("Failed to embed batch after {} retries", max_retries)))
421    }
422}
423
424/// Helper to create a Gemini provider or return error
425pub fn try_gemini_provider() -> Option<GeminiEmbeddingProvider> {
426    match GeminiEmbeddingProvider::from_env() {
427        Ok(provider) => {
428            info!("Gemini embedding provider available");
429            Some(provider)
430        }
431        Err(e) => {
432            debug!("Gemini provider not available: {}", e);
433            None
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_empty_api_key() {
444        let result = GeminiEmbeddingProvider::new(String::new(), None);
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn test_model_dimensions() {
450        let provider = GeminiEmbeddingProvider::new("test-key".to_string(), None).unwrap();
451        assert_eq!(provider.dimension(), 768);
452
453        let provider =
454            GeminiEmbeddingProvider::new("test-key".to_string(), Some("gemini-embedding-001"))
455                .unwrap();
456        assert_eq!(provider.dimension(), 3072);
457    }
458
459    #[test]
460    #[ignore] // Requires valid API key
461    fn test_real_embedding() {
462        let provider = GeminiEmbeddingProvider::from_env().expect("GOOGLE_API_KEY must be set");
463        let embedding = provider.embed_text("Hello, world!").expect("embed");
464        assert!(!embedding.is_empty());
465    }
466}