Skip to main content

codemem_embeddings/
gemini.rs

1//! Google Gemini embedding provider for Codemem.
2//!
3//! Uses the Generative Language API (`generativelanguage.googleapis.com`).
4//! Default model: `text-embedding-004` (768 dimensions).
5//!
6//! ```bash
7//! export CODEMEM_EMBED_PROVIDER=gemini
8//! export CODEMEM_EMBED_API_KEY=AIza...
9//! # Optional:
10//! export CODEMEM_EMBED_MODEL=text-embedding-004
11//! export CODEMEM_EMBED_DIMENSIONS=768
12//! ```
13
14use codemem_core::CodememError;
15
16/// Default Gemini API base URL.
17pub const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
18
19/// Default embedding model.
20pub const DEFAULT_MODEL: &str = "text-embedding-004";
21
22/// Gemini embedding provider.
23pub struct GeminiProvider {
24    api_key: String,
25    model: String,
26    dimensions: usize,
27    pub(crate) base_url: String,
28    client: reqwest::blocking::Client,
29}
30
31impl GeminiProvider {
32    /// Create a new Gemini provider.
33    pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
34        Self {
35            api_key: api_key.to_string(),
36            model: model.to_string(),
37            dimensions,
38            base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
39            client: reqwest::blocking::Client::new(),
40        }
41    }
42}
43
44impl super::EmbeddingProvider for GeminiProvider {
45    fn dimensions(&self) -> usize {
46        self.dimensions
47    }
48
49    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
50        let url = format!("{}/models/{}:embedContent", self.base_url, self.model);
51
52        let mut body = serde_json::json!({
53            "model": format!("models/{}", self.model),
54            "content": {
55                "parts": [{"text": text}]
56            },
57            "taskType": "RETRIEVAL_DOCUMENT",
58        });
59
60        if self.dimensions > 0 {
61            body["outputDimensionality"] = serde_json::json!(self.dimensions);
62        }
63
64        let response = self
65            .client
66            .post(&url)
67            .header("Content-Type", "application/json")
68            .header("x-goog-api-key", &self.api_key)
69            .json(&body)
70            .send()
71            .map_err(|e| CodememError::Embedding(format!("Gemini request failed: {e}")))?;
72
73        if !response.status().is_success() {
74            let status = response.status();
75            let body = response.text().unwrap_or_default();
76            return Err(CodememError::Embedding(format!(
77                "Gemini returned status {status}: {body}",
78            )));
79        }
80
81        let json: serde_json::Value = response
82            .json()
83            .map_err(|e| CodememError::Embedding(format!("Gemini response parse error: {e}")))?;
84
85        let embedding: Vec<f32> = json
86            .get("embedding")
87            .and_then(|v| v.get("values"))
88            .and_then(|v| v.as_array())
89            .ok_or_else(|| {
90                CodememError::Embedding("Missing embedding.values in Gemini response".into())
91            })?
92            .iter()
93            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
94            .collect();
95
96        if self.dimensions > 0 && embedding.len() != self.dimensions {
97            return Err(CodememError::Embedding(format!(
98                "Gemini returned {} dimensions, expected {}",
99                embedding.len(),
100                self.dimensions
101            )));
102        }
103
104        Ok(embedding)
105    }
106
107    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
108        if texts.is_empty() {
109            return Ok(vec![]);
110        }
111
112        let url = format!("{}/models/{}:batchEmbedContents", self.base_url, self.model);
113
114        // Gemini batchEmbedContents accepts max 100 requests per call.
115        const MAX_BATCH: usize = 100;
116        let mut all_embeddings = Vec::with_capacity(texts.len());
117
118        for chunk in texts.chunks(MAX_BATCH) {
119            let requests: Vec<serde_json::Value> = chunk
120                .iter()
121                .map(|text| {
122                    let mut req = serde_json::json!({
123                        "model": format!("models/{}", self.model),
124                        "content": {
125                            "parts": [{"text": text}]
126                        },
127                        "taskType": "RETRIEVAL_DOCUMENT",
128                    });
129                    if self.dimensions > 0 {
130                        req["outputDimensionality"] = serde_json::json!(self.dimensions);
131                    }
132                    req
133                })
134                .collect();
135
136            let body = serde_json::json!({ "requests": requests });
137
138            let response = self
139                .client
140                .post(&url)
141                .header("Content-Type", "application/json")
142                .header("x-goog-api-key", &self.api_key)
143                .json(&body)
144                .send()
145                .map_err(|e| {
146                    CodememError::Embedding(format!("Gemini batch request failed: {e}"))
147                })?;
148
149            if !response.status().is_success() {
150                let status = response.status();
151                let body = response.text().unwrap_or_default();
152                return Err(CodememError::Embedding(format!(
153                    "Gemini returned status {status}: {body}",
154                )));
155            }
156
157            let json: serde_json::Value = response.json().map_err(|e| {
158                CodememError::Embedding(format!("Gemini response parse error: {e}"))
159            })?;
160
161            let embeddings = json
162                .get("embeddings")
163                .and_then(|v| v.as_array())
164                .ok_or_else(|| {
165                    CodememError::Embedding("Missing 'embeddings' array in Gemini response".into())
166                })?;
167
168            if embeddings.len() != chunk.len() {
169                return Err(CodememError::Embedding(format!(
170                    "Gemini returned {} embeddings, expected {}",
171                    embeddings.len(),
172                    chunk.len()
173                )));
174            }
175
176            for (i, item) in embeddings.iter().enumerate() {
177                let embedding: Vec<f32> = item
178                    .get("values")
179                    .and_then(|v| v.as_array())
180                    .ok_or_else(|| {
181                        CodememError::Embedding(format!(
182                            "Missing values in Gemini embedding at index {i}"
183                        ))
184                    })?
185                    .iter()
186                    .map(|v| v.as_f64().unwrap_or(0.0) as f32)
187                    .collect();
188
189                if self.dimensions > 0 && embedding.len() != self.dimensions {
190                    return Err(CodememError::Embedding(format!(
191                        "Gemini returned {} dimensions at index {i}, expected {}",
192                        embedding.len(),
193                        self.dimensions
194                    )));
195                }
196                all_embeddings.push(embedding);
197            }
198        }
199
200        Ok(all_embeddings)
201    }
202
203    fn name(&self) -> &str {
204        "gemini"
205    }
206}
207
208#[cfg(test)]
209#[path = "tests/gemini_tests.rs"]
210mod tests;