1use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
23
24const DEFAULT_MODEL: &str = "text-embedding-004";
26
27const MAX_BATCH_SIZE: usize = 100;
29
30const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
35
36fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
38 if text.len() <= MAX_EMBEDDING_TEXT_LEN {
39 std::borrow::Cow::Borrowed(text)
40 } else {
41 let end = text[..MAX_EMBEDDING_TEXT_LEN]
42 .char_indices()
43 .rev()
44 .next()
45 .map(|(i, c)| i + c.len_utf8())
46 .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
47 warn!(
48 "Truncating embedding text from {} to {} chars to avoid token limit",
49 text.len(),
50 end
51 );
52 std::borrow::Cow::Owned(text[..end].to_string())
53 }
54}
55
56#[derive(Debug, Serialize)]
58struct GeminiEmbedRequest {
59 content: GeminiContent,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 task_type: Option<String>,
62}
63
64#[derive(Debug, Serialize)]
66struct GeminiBatchEmbedRequest {
67 requests: Vec<GeminiEmbedRequestItem>,
68}
69
70#[derive(Debug, Serialize)]
71struct GeminiEmbedRequestItem {
72 model: String,
73 content: GeminiContent,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 task_type: Option<String>,
76}
77
78#[derive(Debug, Serialize)]
79struct GeminiContent {
80 parts: Vec<GeminiPart>,
81}
82
83#[derive(Debug, Serialize)]
84struct GeminiPart {
85 text: String,
86}
87
88#[derive(Debug, Deserialize)]
90struct GeminiEmbedResponse {
91 embedding: GeminiEmbedding,
92}
93
94#[derive(Debug, Deserialize)]
95struct GeminiEmbedding {
96 values: Vec<f32>,
97}
98
99#[derive(Debug, Deserialize)]
101struct GeminiBatchEmbedResponse {
102 embeddings: Vec<GeminiEmbedding>,
103}
104
105#[derive(Debug, Deserialize)]
107struct GeminiErrorResponse {
108 error: GeminiError,
109}
110
111#[derive(Debug, Deserialize)]
112struct GeminiError {
113 message: String,
114 code: i32,
115}
116
117#[derive(Clone)]
119pub struct GeminiEmbeddingProvider {
120 api_key: String,
121 model: String,
122 client: Client,
123 dimension: usize,
124}
125
126impl std::fmt::Debug for GeminiEmbeddingProvider {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("GeminiEmbeddingProvider")
129 .field("model", &self.model)
130 .field("dimension", &self.dimension)
131 .finish()
132 }
133}
134
135impl GeminiEmbeddingProvider {
136 pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
138 if api_key.is_empty() {
139 bail!("Gemini API key cannot be empty");
140 }
141
142 let client = crate::http::blocking_client(REQUEST_TIMEOUT)
143 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
144
145 let model = model.unwrap_or(DEFAULT_MODEL).to_string();
146
147 let dimension = if model.contains("gemini-embedding") {
151 3072
152 } else {
153 768 };
155
156 Ok(Self {
157 api_key,
158 model,
159 client,
160 dimension,
161 })
162 }
163
164 pub fn from_env() -> Result<Self> {
166 let api_key = std::env::var("GOOGLE_API_KEY")
167 .or_else(|_| std::env::var("GEMINI_API_KEY"))
168 .map_err(|_| {
169 anyhow!("GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set")
170 })?;
171
172 let model = std::env::var("GEMINI_EMBEDDING_MODEL").ok();
173 Self::new(api_key, model.as_deref())
174 }
175
176 pub fn model(&self) -> &str {
178 &self.model
179 }
180
181 pub fn kind(&self) -> &'static str {
183 "gemini"
184 }
185
186 pub fn dimension(&self) -> usize {
188 self.dimension
189 }
190
191 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
193 let text = truncate_for_embedding(text);
194 self.embed_with_retry(&text, 3)
195 }
196
197 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
199 if texts.is_empty() {
200 return Ok(Vec::new());
201 }
202
203 let truncated: Vec<std::borrow::Cow<'_, str>> =
205 texts.iter().map(|t| truncate_for_embedding(t)).collect();
206
207 let mut all_embeddings = Vec::with_capacity(texts.len());
208
209 for chunk in truncated.chunks(MAX_BATCH_SIZE) {
211 let embeddings = self.embed_batch_with_retry(chunk, 3)?;
212 all_embeddings.extend(embeddings);
213 }
214
215 Ok(all_embeddings)
216 }
217
218 fn embed_with_retry(&self, text: &str, max_retries: usize) -> Result<Vec<f32>> {
220 let url = format!(
221 "{}/{}:embedContent?key={}",
222 GEMINI_API_BASE, self.model, self.api_key
223 );
224
225 let request = GeminiEmbedRequest {
226 content: GeminiContent {
227 parts: vec![GeminiPart {
228 text: text.to_string(),
229 }],
230 },
231 task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
232 };
233
234 let mut last_error = None;
235
236 for attempt in 0..max_retries {
237 let response = self
238 .client
239 .post(&url)
240 .header("Content-Type", "application/json")
241 .json(&request)
242 .send();
243
244 match response {
245 Ok(resp) => {
246 let status = resp.status();
247 let body = resp.text().unwrap_or_default();
248
249 if status.is_success() {
250 let embed_response: GeminiEmbedResponse = serde_json::from_str(&body)
251 .map_err(|e| anyhow!("Failed to parse Gemini response: {}", e))?;
252
253 debug!(
254 "Gemini embedding: {} values, model={}",
255 embed_response.embedding.values.len(),
256 self.model
257 );
258
259 return Ok(embed_response.embedding.values);
260 }
261
262 if status.as_u16() == 429 {
264 let backoff = Duration::from_millis(500 * (1 << attempt));
265 warn!(
266 "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
267 backoff,
268 attempt + 1,
269 max_retries
270 );
271 std::thread::sleep(backoff);
272 last_error = Some(anyhow!("Rate limited"));
273 continue;
274 }
275
276 if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
278 return Err(anyhow!(
279 "Gemini API error ({}): {}",
280 error_response.error.code,
281 error_response.error.message
282 ));
283 }
284
285 return Err(anyhow!(
286 "Gemini API request failed with status {}: {}",
287 status,
288 body
289 ));
290 }
291 Err(e) => {
292 if attempt < max_retries - 1 {
293 let backoff = Duration::from_millis(500 * (1 << attempt));
294 warn!(
295 "Gemini request failed, retrying in {:?} (attempt {}/{}): {}",
296 backoff,
297 attempt + 1,
298 max_retries,
299 e
300 );
301 std::thread::sleep(backoff);
302 last_error = Some(anyhow!("Request failed: {}", e));
303 continue;
304 }
305 return Err(anyhow!("Gemini API request failed: {}", e));
306 }
307 }
308 }
309
310 Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
311 }
312
313 fn embed_batch_with_retry(
315 &self,
316 texts: &[std::borrow::Cow<'_, str>],
317 max_retries: usize,
318 ) -> Result<Vec<Vec<f32>>> {
319 let url = format!(
320 "{}/{}:batchEmbedContents?key={}",
321 GEMINI_API_BASE, self.model, self.api_key
322 );
323
324 let requests: Vec<GeminiEmbedRequestItem> = texts
325 .iter()
326 .map(|text| GeminiEmbedRequestItem {
327 model: format!("models/{}", self.model),
328 content: GeminiContent {
329 parts: vec![GeminiPart {
330 text: text.to_string(),
331 }],
332 },
333 task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
334 })
335 .collect();
336
337 let batch_request = GeminiBatchEmbedRequest { requests };
338
339 let mut last_error = None;
340
341 for attempt in 0..max_retries {
342 let response = self
343 .client
344 .post(&url)
345 .header("Content-Type", "application/json")
346 .json(&batch_request)
347 .send();
348
349 match response {
350 Ok(resp) => {
351 let status = resp.status();
352 let body = resp.text().unwrap_or_default();
353
354 if status.is_success() {
355 let batch_response: GeminiBatchEmbedResponse = serde_json::from_str(&body)
356 .map_err(|e| anyhow!("Failed to parse Gemini batch response: {}", e))?;
357
358 debug!(
359 "Gemini batch embeddings: {} texts, model={}",
360 batch_response.embeddings.len(),
361 self.model
362 );
363
364 return Ok(batch_response
365 .embeddings
366 .into_iter()
367 .map(|e| e.values)
368 .collect());
369 }
370
371 if status.as_u16() == 429 {
373 let backoff = Duration::from_millis(500 * (1 << attempt));
374 warn!(
375 "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
376 backoff,
377 attempt + 1,
378 max_retries
379 );
380 std::thread::sleep(backoff);
381 last_error = Some(anyhow!("Rate limited"));
382 continue;
383 }
384
385 if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
387 return Err(anyhow!(
388 "Gemini API error ({}): {}",
389 error_response.error.code,
390 error_response.error.message
391 ));
392 }
393
394 return Err(anyhow!(
395 "Gemini API request failed with status {}: {}",
396 status,
397 body
398 ));
399 }
400 Err(e) => {
401 if attempt < max_retries - 1 {
402 let backoff = Duration::from_millis(500 * (1 << attempt));
403 warn!(
404 "Gemini batch request failed, retrying in {:?} (attempt {}/{}): {}",
405 backoff,
406 attempt + 1,
407 max_retries,
408 e
409 );
410 std::thread::sleep(backoff);
411 last_error = Some(anyhow!("Request failed: {}", e));
412 continue;
413 }
414 return Err(anyhow!("Gemini API batch request failed: {}", e));
415 }
416 }
417 }
418
419 Err(last_error
420 .unwrap_or_else(|| anyhow!("Failed to embed batch after {} retries", max_retries)))
421 }
422}
423
424pub fn try_gemini_provider() -> Option<GeminiEmbeddingProvider> {
426 match GeminiEmbeddingProvider::from_env() {
427 Ok(provider) => {
428 info!("Gemini embedding provider available");
429 Some(provider)
430 }
431 Err(e) => {
432 debug!("Gemini provider not available: {}", e);
433 None
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_empty_api_key() {
444 let result = GeminiEmbeddingProvider::new(String::new(), None);
445 assert!(result.is_err());
446 }
447
448 #[test]
449 fn test_model_dimensions() {
450 let provider = GeminiEmbeddingProvider::new("test-key".to_string(), None).unwrap();
451 assert_eq!(provider.dimension(), 768);
452
453 let provider =
454 GeminiEmbeddingProvider::new("test-key".to_string(), Some("gemini-embedding-001"))
455 .unwrap();
456 assert_eq!(provider.dimension(), 3072);
457 }
458
459 #[test]
460 #[ignore] fn test_real_embedding() {
462 let provider = GeminiEmbeddingProvider::from_env().expect("GOOGLE_API_KEY must be set");
463 let embedding = provider.embed_text("Hello, world!").expect("embed");
464 assert!(!embedding.is_empty());
465 }
466}