agcodex_core/embeddings/providers/
gemini.rs

1//! Gemini embeddings provider - completely separate from chat models
2//!
3//! Supports:
4//! - gemini-embedding-001 (768 dimensions)
5//! - Batch processing
6//! - Uses GEMINI_API_KEY environment variable
7
8use super::super::EmbeddingError;
9use super::super::EmbeddingProvider;
10use super::super::EmbeddingVector;
11use reqwest::Client;
12use serde::Deserialize;
13use serde::Serialize;
14use tokio::time::Duration;
15use tokio::time::sleep;
16
17/// Gemini embedding provider
18pub struct GeminiProvider {
19    client: Client,
20    api_key: String,
21    model: String,
22    api_endpoint: Option<String>,
23}
24
25impl GeminiProvider {
26    /// Create a new Gemini provider
27    pub fn new(api_key: String, model: String) -> Self {
28        Self {
29            client: Client::new(),
30            api_key,
31            model,
32            api_endpoint: None,
33        }
34    }
35
36    /// Create a new Gemini provider with custom endpoint
37    pub fn new_with_endpoint(api_key: String, model: String, api_endpoint: String) -> Self {
38        Self {
39            client: Client::new(),
40            api_key,
41            model,
42            api_endpoint: Some(api_endpoint),
43        }
44    }
45}
46
47#[derive(Debug, Serialize)]
48struct GeminiRequest {
49    content: GeminiContent,
50}
51
52#[derive(Debug, Serialize)]
53struct GeminiContent {
54    parts: Vec<GeminiPart>,
55}
56
57#[derive(Debug, Serialize)]
58struct GeminiPart {
59    text: String,
60}
61
62#[derive(Debug, Deserialize)]
63struct GeminiResponse {
64    embedding: GeminiEmbedding,
65}
66
67#[derive(Debug, Deserialize)]
68struct GeminiEmbedding {
69    values: Vec<f32>,
70}
71
72#[derive(Debug, Deserialize)]
73struct GeminiError {
74    error: GeminiErrorDetail,
75}
76
77#[derive(Debug, Deserialize)]
78struct GeminiErrorDetail {
79    message: String,
80    _code: Option<u32>,
81    status: Option<String>,
82}
83
84#[async_trait::async_trait]
85impl EmbeddingProvider for GeminiProvider {
86    fn model_id(&self) -> String {
87        format!("gemini:{}", self.model)
88    }
89
90    fn dimensions(&self) -> usize {
91        // Return model-specific dimensions
92        match self.model.as_str() {
93            "gemini-embedding-001" => 768,
94            "text-embedding-004" => 768,
95            "embedding-001" => 768,
96            "textembedding-gecko@001" => 768,
97            "textembedding-gecko@003" => 768,
98            _ => 768, // Default fallback
99        }
100    }
101
102    async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
103        let request = GeminiRequest {
104            content: GeminiContent {
105                parts: vec![GeminiPart {
106                    text: text.to_string(),
107                }],
108            },
109        };
110
111        let endpoint = self
112            .api_endpoint
113            .as_deref()
114            .unwrap_or("https://generativelanguage.googleapis.com");
115        let url = format!(
116            "{}/v1/models/{}:embedContent?key={}",
117            endpoint, self.model, self.api_key
118        );
119
120        let response = self
121            .client
122            .post(&url)
123            .header("Content-Type", "application/json")
124            .json(&request)
125            .send()
126            .await
127            .map_err(|e| EmbeddingError::ApiError(format!("Request failed: {}", e)))?;
128
129        let status = response.status();
130        if !status.is_success() {
131            let error_text = response
132                .text()
133                .await
134                .unwrap_or_else(|_| "Unknown error".to_string());
135
136            // Try to parse Gemini error format
137            if let Ok(error) = serde_json::from_str::<GeminiError>(&error_text) {
138                return Err(EmbeddingError::ApiError(format!(
139                    "Gemini API error ({}): {} - {}",
140                    status,
141                    error.error.status.unwrap_or_else(|| "Unknown".to_string()),
142                    error.error.message
143                )));
144            }
145
146            return Err(EmbeddingError::ApiError(format!(
147                "Gemini API error ({}): {}",
148                status, error_text
149            )));
150        }
151
152        let gemini_response: GeminiResponse = response
153            .json()
154            .await
155            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {}", e)))?;
156
157        // Validate dimensions
158        let expected_dims = self.dimensions();
159        if gemini_response.embedding.values.len() != expected_dims {
160            return Err(EmbeddingError::DimensionMismatch {
161                expected: expected_dims,
162                actual: gemini_response.embedding.values.len(),
163            });
164        }
165
166        Ok(gemini_response.embedding.values)
167    }
168
169    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
170        if texts.is_empty() {
171            return Ok(vec![]);
172        }
173
174        // Gemini doesn't support batch embedding in a single call,
175        // so we need to make multiple requests
176        let mut embeddings = Vec::with_capacity(texts.len());
177
178        for text in texts {
179            let embedding = self.embed(text).await?;
180            embeddings.push(embedding);
181
182            // Add a small delay to respect rate limits
183            sleep(Duration::from_millis(100)).await;
184        }
185
186        Ok(embeddings)
187    }
188
189    fn is_available(&self) -> bool {
190        !self.api_key.is_empty()
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_model_id() {
200        let provider = GeminiProvider::new("test-key".to_string(), "embedding-001".to_string());
201        assert_eq!(provider.model_id(), "gemini:embedding-001");
202    }
203
204    #[test]
205    fn test_dimensions() {
206        let provider =
207            GeminiProvider::new("test-key".to_string(), "gemini-embedding-001".to_string());
208        assert_eq!(provider.dimensions(), 768);
209
210        let provider_old =
211            GeminiProvider::new("test-key".to_string(), "text-embedding-004".to_string());
212        assert_eq!(provider_old.dimensions(), 768);
213
214        let provider_default =
215            GeminiProvider::new("test-key".to_string(), "unknown-model".to_string());
216        assert_eq!(provider_default.dimensions(), 768);
217    }
218
219    #[test]
220    fn test_is_available() {
221        let provider = GeminiProvider::new("test-key".to_string(), "embedding-001".to_string());
222        assert!(provider.is_available());
223
224        let provider_empty = GeminiProvider::new(String::new(), "embedding-001".to_string());
225        assert!(!provider_empty.is_available());
226    }
227}