1use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
23
24const DEFAULT_MODEL: &str = "text-embedding-004";
26
27const MAX_BATCH_SIZE: usize = 100;
29
30const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
35
36fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
38 if text.len() <= MAX_EMBEDDING_TEXT_LEN {
39 std::borrow::Cow::Borrowed(text)
40 } else {
41 let end = text[..MAX_EMBEDDING_TEXT_LEN]
42 .char_indices()
43 .rev()
44 .next()
45 .map(|(i, c)| i + c.len_utf8())
46 .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
47 warn!(
48 "Truncating embedding text from {} to {} chars to avoid token limit",
49 text.len(),
50 end
51 );
52 std::borrow::Cow::Owned(text[..end].to_string())
53 }
54}
55
56#[derive(Debug, Serialize)]
58struct GeminiEmbedRequest {
59 content: GeminiContent,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 task_type: Option<String>,
62}
63
64#[derive(Debug, Serialize)]
66struct GeminiBatchEmbedRequest {
67 requests: Vec<GeminiEmbedRequestItem>,
68}
69
70#[derive(Debug, Serialize)]
71struct GeminiEmbedRequestItem {
72 model: String,
73 content: GeminiContent,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 task_type: Option<String>,
76}
77
78#[derive(Debug, Serialize)]
79struct GeminiContent {
80 parts: Vec<GeminiPart>,
81}
82
83#[derive(Debug, Serialize)]
84struct GeminiPart {
85 text: String,
86}
87
88#[derive(Debug, Deserialize)]
90struct GeminiEmbedResponse {
91 embedding: GeminiEmbedding,
92}
93
94#[derive(Debug, Deserialize)]
95struct GeminiEmbedding {
96 values: Vec<f32>,
97}
98
99#[derive(Debug, Deserialize)]
101struct GeminiBatchEmbedResponse {
102 embeddings: Vec<GeminiEmbedding>,
103}
104
105#[derive(Debug, Deserialize)]
107struct GeminiErrorResponse {
108 error: GeminiError,
109}
110
111#[derive(Debug, Deserialize)]
112struct GeminiError {
113 message: String,
114 code: i32,
115}
116
117#[derive(Clone)]
119pub struct GeminiEmbeddingProvider {
120 api_key: String,
121 model: String,
122 client: Client,
123 dimension: usize,
124}
125
126impl std::fmt::Debug for GeminiEmbeddingProvider {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("GeminiEmbeddingProvider")
129 .field("model", &self.model)
130 .field("dimension", &self.dimension)
131 .finish()
132 }
133}
134
135impl GeminiEmbeddingProvider {
136 pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
138 if api_key.is_empty() {
139 bail!("Gemini API key cannot be empty");
140 }
141
142 let client = crate::http::blocking_client(REQUEST_TIMEOUT)
143 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
144
145 let model = model.unwrap_or(DEFAULT_MODEL).to_string();
146
147 let dimension = if model.contains("gemini-embedding") {
151 3072
152 } else {
153 768 };
155
156 Ok(Self {
157 api_key,
158 model,
159 client,
160 dimension,
161 })
162 }
163
164 pub fn from_env() -> Result<Self> {
166 let api_key = std::env::var("GOOGLE_API_KEY")
167 .or_else(|_| std::env::var("GEMINI_API_KEY"))
168 .map_err(|_| anyhow!("GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set"))?;
169
170 let model = std::env::var("GEMINI_EMBEDDING_MODEL").ok();
171 Self::new(api_key, model.as_deref())
172 }
173
174 pub fn model(&self) -> &str {
176 &self.model
177 }
178
179 pub fn kind(&self) -> &'static str {
181 "gemini"
182 }
183
184 pub fn dimension(&self) -> usize {
186 self.dimension
187 }
188
189 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
191 let text = truncate_for_embedding(text);
192 self.embed_with_retry(&text, 3)
193 }
194
195 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
197 if texts.is_empty() {
198 return Ok(Vec::new());
199 }
200
201 let truncated: Vec<std::borrow::Cow<'_, str>> =
203 texts.iter().map(|t| truncate_for_embedding(t)).collect();
204
205 let mut all_embeddings = Vec::with_capacity(texts.len());
206
207 for chunk in truncated.chunks(MAX_BATCH_SIZE) {
209 let embeddings = self.embed_batch_with_retry(chunk, 3)?;
210 all_embeddings.extend(embeddings);
211 }
212
213 Ok(all_embeddings)
214 }
215
216 fn embed_with_retry(&self, text: &str, max_retries: usize) -> Result<Vec<f32>> {
218 let url = format!(
219 "{}/{}:embedContent?key={}",
220 GEMINI_API_BASE, self.model, self.api_key
221 );
222
223 let request = GeminiEmbedRequest {
224 content: GeminiContent {
225 parts: vec![GeminiPart {
226 text: text.to_string(),
227 }],
228 },
229 task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
230 };
231
232 let mut last_error = None;
233
234 for attempt in 0..max_retries {
235 let response = self
236 .client
237 .post(&url)
238 .header("Content-Type", "application/json")
239 .json(&request)
240 .send();
241
242 match response {
243 Ok(resp) => {
244 let status = resp.status();
245 let body = resp.text().unwrap_or_default();
246
247 if status.is_success() {
248 let embed_response: GeminiEmbedResponse = serde_json::from_str(&body)
249 .map_err(|e| anyhow!("Failed to parse Gemini response: {}", e))?;
250
251 debug!(
252 "Gemini embedding: {} values, model={}",
253 embed_response.embedding.values.len(),
254 self.model
255 );
256
257 return Ok(embed_response.embedding.values);
258 }
259
260 if status.as_u16() == 429 {
262 let backoff = Duration::from_millis(500 * (1 << attempt));
263 warn!(
264 "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
265 backoff,
266 attempt + 1,
267 max_retries
268 );
269 std::thread::sleep(backoff);
270 last_error = Some(anyhow!("Rate limited"));
271 continue;
272 }
273
274 if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
276 return Err(anyhow!(
277 "Gemini API error ({}): {}",
278 error_response.error.code,
279 error_response.error.message
280 ));
281 }
282
283 return Err(anyhow!(
284 "Gemini API request failed with status {}: {}",
285 status,
286 body
287 ));
288 }
289 Err(e) => {
290 if attempt < max_retries - 1 {
291 let backoff = Duration::from_millis(500 * (1 << attempt));
292 warn!(
293 "Gemini request failed, retrying in {:?} (attempt {}/{}): {}",
294 backoff,
295 attempt + 1,
296 max_retries,
297 e
298 );
299 std::thread::sleep(backoff);
300 last_error = Some(anyhow!("Request failed: {}", e));
301 continue;
302 }
303 return Err(anyhow!("Gemini API request failed: {}", e));
304 }
305 }
306 }
307
308 Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
309 }
310
311 fn embed_batch_with_retry(
313 &self,
314 texts: &[std::borrow::Cow<'_, str>],
315 max_retries: usize,
316 ) -> Result<Vec<Vec<f32>>> {
317 let url = format!(
318 "{}/{}:batchEmbedContents?key={}",
319 GEMINI_API_BASE, self.model, self.api_key
320 );
321
322 let requests: Vec<GeminiEmbedRequestItem> = texts
323 .iter()
324 .map(|text| GeminiEmbedRequestItem {
325 model: format!("models/{}", self.model),
326 content: GeminiContent {
327 parts: vec![GeminiPart {
328 text: text.to_string(),
329 }],
330 },
331 task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
332 })
333 .collect();
334
335 let batch_request = GeminiBatchEmbedRequest { requests };
336
337 let mut last_error = None;
338
339 for attempt in 0..max_retries {
340 let response = self
341 .client
342 .post(&url)
343 .header("Content-Type", "application/json")
344 .json(&batch_request)
345 .send();
346
347 match response {
348 Ok(resp) => {
349 let status = resp.status();
350 let body = resp.text().unwrap_or_default();
351
352 if status.is_success() {
353 let batch_response: GeminiBatchEmbedResponse = serde_json::from_str(&body)
354 .map_err(|e| anyhow!("Failed to parse Gemini batch response: {}", e))?;
355
356 debug!(
357 "Gemini batch embeddings: {} texts, model={}",
358 batch_response.embeddings.len(),
359 self.model
360 );
361
362 return Ok(batch_response
363 .embeddings
364 .into_iter()
365 .map(|e| e.values)
366 .collect());
367 }
368
369 if status.as_u16() == 429 {
371 let backoff = Duration::from_millis(500 * (1 << attempt));
372 warn!(
373 "Rate limited by Gemini, retrying in {:?} (attempt {}/{})",
374 backoff,
375 attempt + 1,
376 max_retries
377 );
378 std::thread::sleep(backoff);
379 last_error = Some(anyhow!("Rate limited"));
380 continue;
381 }
382
383 if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&body) {
385 return Err(anyhow!(
386 "Gemini API error ({}): {}",
387 error_response.error.code,
388 error_response.error.message
389 ));
390 }
391
392 return Err(anyhow!(
393 "Gemini API request failed with status {}: {}",
394 status,
395 body
396 ));
397 }
398 Err(e) => {
399 if attempt < max_retries - 1 {
400 let backoff = Duration::from_millis(500 * (1 << attempt));
401 warn!(
402 "Gemini batch request failed, retrying in {:?} (attempt {}/{}): {}",
403 backoff,
404 attempt + 1,
405 max_retries,
406 e
407 );
408 std::thread::sleep(backoff);
409 last_error = Some(anyhow!("Request failed: {}", e));
410 continue;
411 }
412 return Err(anyhow!("Gemini API batch request failed: {}", e));
413 }
414 }
415 }
416
417 Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed batch after {} retries", max_retries)))
418 }
419}
420
421pub fn try_gemini_provider() -> Option<GeminiEmbeddingProvider> {
423 match GeminiEmbeddingProvider::from_env() {
424 Ok(provider) => {
425 info!("Gemini embedding provider available");
426 Some(provider)
427 }
428 Err(e) => {
429 debug!("Gemini provider not available: {}", e);
430 None
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_empty_api_key() {
441 let result = GeminiEmbeddingProvider::new(String::new(), None);
442 assert!(result.is_err());
443 }
444
445 #[test]
446 fn test_model_dimensions() {
447 let provider = GeminiEmbeddingProvider::new("test-key".to_string(), None).unwrap();
448 assert_eq!(provider.dimension(), 768);
449
450 let provider = GeminiEmbeddingProvider::new("test-key".to_string(), Some("gemini-embedding-001")).unwrap();
451 assert_eq!(provider.dimension(), 3072);
452 }
453
454 #[test]
455 #[ignore] fn test_real_embedding() {
457 let provider = GeminiEmbeddingProvider::from_env().expect("GOOGLE_API_KEY must be set");
458 let embedding = provider.embed_text("Hello, world!").expect("embed");
459 assert!(!embedding.is_empty());
460 }
461}