1use std::future::Future;
2use std::time::Duration;
3
4use reqwest::Client;
5use serde::Deserialize;
6use serde_json::{Value, json};
7use tokio::time::sleep;
8
9use crate::error::{OmniError, Result};
10
11const GEMINI_EMBED_MODEL: &str = "gemini-embedding-2-preview";
12const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
13const DEFAULT_TIMEOUT_MS: u64 = 30_000;
14const DEFAULT_RETRY_ATTEMPTS: usize = 4;
15const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
16const QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
17const DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
18
19#[derive(Clone, Debug)]
20enum EmbeddingTransport {
21 Mock,
22 Gemini {
23 api_key: String,
24 base_url: String,
25 http: Client,
26 },
27}
28
29#[derive(Clone, Debug)]
30pub struct EmbeddingClient {
31 retry_attempts: usize,
32 retry_backoff_ms: u64,
33 transport: EmbeddingTransport,
34}
35
36struct EmbedCallError {
37 message: String,
38 retryable: bool,
39}
40
41#[derive(Debug, Deserialize)]
42struct GeminiEmbedResponse {
43 embedding: GeminiContentEmbedding,
44}
45
46#[derive(Debug, Deserialize)]
47struct GeminiContentEmbedding {
48 values: Vec<f32>,
49}
50
51#[derive(Debug, Deserialize)]
52struct GoogleErrorEnvelope {
53 error: GoogleErrorBody,
54}
55
56#[derive(Debug, Deserialize)]
57struct GoogleErrorBody {
58 message: String,
59}
60
61impl EmbeddingClient {
62 pub fn from_env() -> Result<Self> {
63 let retry_attempts =
64 parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
65 let retry_backoff_ms =
66 parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
67
68 if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
69 return Ok(Self {
70 retry_attempts,
71 retry_backoff_ms,
72 transport: EmbeddingTransport::Mock,
73 });
74 }
75
76 let api_key = std::env::var("GEMINI_API_KEY")
77 .ok()
78 .map(|v| v.trim().to_string())
79 .filter(|v| !v.is_empty())
80 .ok_or_else(|| {
81 OmniError::manifest_internal(
82 "GEMINI_API_KEY is required when nearest() needs a string embedding",
83 )
84 })?;
85 let base_url = std::env::var("OMNIGRAPH_GEMINI_BASE_URL")
86 .ok()
87 .map(|v| v.trim_end_matches('/').to_string())
88 .filter(|v| !v.is_empty())
89 .unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.to_string());
90 let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
91 let http = Client::builder()
92 .timeout(Duration::from_millis(timeout_ms))
93 .build()
94 .map_err(|e| {
95 OmniError::manifest_internal(format!("failed to initialize HTTP client: {}", e))
96 })?;
97
98 Ok(Self {
99 retry_attempts,
100 retry_backoff_ms,
101 transport: EmbeddingTransport::Gemini {
102 api_key,
103 base_url,
104 http,
105 },
106 })
107 }
108
109 #[cfg(test)]
110 fn mock_for_tests() -> Self {
111 Self {
112 retry_attempts: DEFAULT_RETRY_ATTEMPTS,
113 retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
114 transport: EmbeddingTransport::Mock,
115 }
116 }
117
118 pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
119 self.embed_text(input, expected_dim, QUERY_TASK_TYPE).await
120 }
121
122 pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
123 self.embed_text(input, expected_dim, DOCUMENT_TASK_TYPE)
124 .await
125 }
126
127 async fn embed_text(
128 &self,
129 input: &str,
130 expected_dim: usize,
131 task_type: &'static str,
132 ) -> Result<Vec<f32>> {
133 if expected_dim == 0 {
134 return Err(OmniError::manifest_internal(
135 "embedding dimension must be greater than zero",
136 ));
137 }
138
139 match &self.transport {
140 EmbeddingTransport::Mock => Ok(mock_embedding(input, expected_dim)),
141 EmbeddingTransport::Gemini { .. } => {
142 self.with_retry(|| self.embed_text_gemini_once(input, expected_dim, task_type))
143 .await
144 }
145 }
146 }
147
148 async fn with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T>
149 where
150 F: FnMut() -> Fut,
151 Fut: Future<Output = std::result::Result<T, EmbedCallError>>,
152 {
153 let max_attempt = self.retry_attempts.max(1);
154 let mut attempt = 0usize;
155 loop {
156 attempt += 1;
157 match operation().await {
158 Ok(value) => return Ok(value),
159 Err(err) => {
160 if !err.retryable || attempt >= max_attempt {
161 return Err(OmniError::manifest_internal(err.message));
162 }
163 let shift = (attempt - 1).min(10) as u32;
164 let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
165 sleep(Duration::from_millis(delay)).await;
166 }
167 }
168 }
169 }
170
171 async fn embed_text_gemini_once(
172 &self,
173 input: &str,
174 expected_dim: usize,
175 task_type: &'static str,
176 ) -> std::result::Result<Vec<f32>, EmbedCallError> {
177 let (api_key, base_url, http) = match &self.transport {
178 EmbeddingTransport::Gemini {
179 api_key,
180 base_url,
181 http,
182 } => (api_key, base_url, http),
183 EmbeddingTransport::Mock => unreachable!("mock transport should not call Gemini"),
184 };
185
186 let response = http
187 .post(gemini_endpoint(base_url))
188 .header("x-goog-api-key", api_key)
189 .json(&build_gemini_request(input, expected_dim, task_type))
190 .send()
191 .await;
192 let response = match response {
193 Ok(response) => response,
194 Err(err) => {
195 let retryable = err.is_timeout() || err.is_connect() || err.is_request();
196 return Err(EmbedCallError {
197 message: format!("embedding request failed: {}", err),
198 retryable,
199 });
200 }
201 };
202
203 let status = response.status();
204 let body = match response.text().await {
205 Ok(body) => body,
206 Err(err) => {
207 return Err(EmbedCallError {
208 message: format!(
209 "embedding response read failed (status {}): {}",
210 status, err
211 ),
212 retryable: status.is_server_error() || status.as_u16() == 429,
213 });
214 }
215 };
216
217 if !status.is_success() {
218 let message = parse_google_error_message(&body).unwrap_or(body);
219 return Err(EmbedCallError {
220 message: format!(
221 "embedding request failed with status {}: {}",
222 status, message
223 ),
224 retryable: status.is_server_error() || status.as_u16() == 429,
225 });
226 }
227
228 let parsed: GeminiEmbedResponse =
229 serde_json::from_str(&body).map_err(|err| EmbedCallError {
230 message: format!("embedding response decode failed: {}", err),
231 retryable: false,
232 })?;
233
234 validate_and_normalize_embedding(parsed.embedding.values, expected_dim).map_err(|message| {
235 EmbedCallError {
236 message,
237 retryable: false,
238 }
239 })
240 }
241}
242
243fn gemini_endpoint(base_url: &str) -> String {
244 format!(
245 "{}/models/{}:embedContent",
246 base_url.trim_end_matches('/'),
247 GEMINI_EMBED_MODEL
248 )
249}
250
251fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static str) -> Value {
252 json!({
253 "model": format!("models/{}", GEMINI_EMBED_MODEL),
254 "content": {
255 "parts": [
256 {
257 "text": input
258 }
259 ]
260 },
261 "taskType": task_type,
262 "outputDimensionality": expected_dim,
263 })
264}
265
266fn validate_and_normalize_embedding(
267 values: Vec<f32>,
268 expected_dim: usize,
269) -> std::result::Result<Vec<f32>, String> {
270 if values.len() != expected_dim {
271 return Err(format!(
272 "embedding dimension mismatch: expected {}, got {}",
273 expected_dim,
274 values.len()
275 ));
276 }
277 Ok(normalize_vector(values))
278}
279
280fn normalize_vector(mut values: Vec<f32>) -> Vec<f32> {
281 let norm = values
282 .iter()
283 .map(|v| (*v as f64) * (*v as f64))
284 .sum::<f64>()
285 .sqrt() as f32;
286 if norm > f32::EPSILON {
287 for value in &mut values {
288 *value /= norm;
289 }
290 }
291 values
292}
293
294fn parse_google_error_message(body: &str) -> Option<String> {
295 serde_json::from_str::<GoogleErrorEnvelope>(body)
296 .ok()
297 .map(|e| e.error.message)
298 .filter(|msg| !msg.trim().is_empty())
299}
300
301fn parse_env_usize(name: &str, default: usize) -> usize {
302 std::env::var(name)
303 .ok()
304 .and_then(|v| v.parse::<usize>().ok())
305 .filter(|v| *v > 0)
306 .unwrap_or(default)
307}
308
309fn parse_env_u64(name: &str, default: u64) -> u64 {
310 std::env::var(name)
311 .ok()
312 .and_then(|v| v.parse::<u64>().ok())
313 .filter(|v| *v > 0)
314 .unwrap_or(default)
315}
316
317fn env_flag(name: &str) -> bool {
318 std::env::var(name)
319 .ok()
320 .map(|v| {
321 let s = v.trim().to_ascii_lowercase();
322 s == "1" || s == "true" || s == "yes" || s == "on"
323 })
324 .unwrap_or(false)
325}
326
327fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
328 let mut seed = fnv1a64(input.as_bytes());
329 let mut out = Vec::with_capacity(dim);
330 for _ in 0..dim {
331 seed = xorshift64(seed);
332 let ratio = (seed as f64 / u64::MAX as f64) as f32;
333 out.push((ratio * 2.0) - 1.0);
334 }
335 normalize_vector(out)
336}
337
338fn fnv1a64(bytes: &[u8]) -> u64 {
339 let mut hash = 14695981039346656037u64;
340 for byte in bytes {
341 hash ^= *byte as u64;
342 hash = hash.wrapping_mul(1099511628211u64);
343 }
344 hash
345}
346
347fn xorshift64(mut x: u64) -> u64 {
348 x ^= x << 13;
349 x ^= x >> 7;
350 x ^= x << 17;
351 x
352}
353
354#[cfg(test)]
355mod tests {
356 use std::sync::Arc;
357 use std::sync::atomic::{AtomicUsize, Ordering};
358
359 use serial_test::serial;
360
361 use super::*;
362
363 struct EnvGuard {
364 saved: Vec<(&'static str, Option<String>)>,
365 }
366
367 impl EnvGuard {
368 fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
369 let saved = vars
370 .iter()
371 .map(|(name, _)| (*name, std::env::var(name).ok()))
372 .collect::<Vec<_>>();
373 for (name, value) in vars {
374 unsafe {
375 match value {
376 Some(value) => std::env::set_var(name, value),
377 None => std::env::remove_var(name),
378 }
379 }
380 }
381 Self { saved }
382 }
383 }
384
385 impl Drop for EnvGuard {
386 fn drop(&mut self) {
387 for (name, value) in self.saved.drain(..) {
388 unsafe {
389 match value {
390 Some(value) => std::env::set_var(name, value),
391 None => std::env::remove_var(name),
392 }
393 }
394 }
395 }
396 }
397
398 #[tokio::test]
399 async fn mock_embeddings_are_deterministic() {
400 let client = EmbeddingClient::mock_for_tests();
401 let a = client.embed_query_text("alpha", 8).await.unwrap();
402 let b = client.embed_query_text("alpha", 8).await.unwrap();
403 let c = client.embed_query_text("beta", 8).await.unwrap();
404 assert_eq!(a, b);
405 assert_ne!(a, c);
406 assert_eq!(a.len(), 8);
407 }
408
409 #[test]
410 fn gemini_request_uses_preview_model_retrieval_query_and_dimension() {
411 let request = build_gemini_request("alpha", 4, QUERY_TASK_TYPE);
412 assert_eq!(request["model"], "models/gemini-embedding-2-preview");
413 assert_eq!(request["taskType"], QUERY_TASK_TYPE);
414 assert_eq!(request["outputDimensionality"], 4);
415 assert_eq!(request["content"]["parts"][0]["text"], "alpha");
416 }
417
418 #[test]
419 fn gemini_document_request_uses_retrieval_document_task_type() {
420 let request = build_gemini_request("alpha", 4, DOCUMENT_TASK_TYPE);
421 assert_eq!(request["taskType"], DOCUMENT_TASK_TYPE);
422 }
423
424 #[test]
425 fn validate_and_normalize_embedding_enforces_dimension() {
426 let normalized = validate_and_normalize_embedding(vec![3.0, 4.0], 2).unwrap();
427 assert!((normalized[0] - 0.6).abs() < 1e-6);
428 assert!((normalized[1] - 0.8).abs() < 1e-6);
429
430 let err = validate_and_normalize_embedding(vec![1.0, 2.0], 3).unwrap_err();
431 assert!(err.contains("expected 3, got 2"));
432 }
433
434 #[tokio::test]
435 async fn with_retry_retries_retryable_failures() {
436 let client = EmbeddingClient::mock_for_tests();
437 let attempts = Arc::new(AtomicUsize::new(0));
438 let attempts_for_call = Arc::clone(&attempts);
439
440 let value = client
441 .with_retry(|| {
442 let attempts_for_call = Arc::clone(&attempts_for_call);
443 async move {
444 let attempt = attempts_for_call.fetch_add(1, Ordering::SeqCst);
445 if attempt == 0 {
446 Err(EmbedCallError {
447 message: "retry me".to_string(),
448 retryable: true,
449 })
450 } else {
451 Ok("ok")
452 }
453 }
454 })
455 .await
456 .unwrap();
457
458 assert_eq!(value, "ok");
459 assert_eq!(attempts.load(Ordering::SeqCst), 2);
460 }
461
462 #[tokio::test]
463 async fn with_retry_stops_on_non_retryable_failures() {
464 let client = EmbeddingClient::mock_for_tests();
465 let err = client
466 .with_retry(|| async {
467 Err::<(), _>(EmbedCallError {
468 message: "do not retry".to_string(),
469 retryable: false,
470 })
471 })
472 .await
473 .unwrap_err();
474
475 assert!(err.to_string().contains("do not retry"));
476 }
477
478 #[test]
479 #[serial]
480 fn from_env_requires_gemini_api_key_when_not_mocking() {
481 let _guard = EnvGuard::set(&[
482 ("OMNIGRAPH_EMBEDDINGS_MOCK", None),
483 ("GEMINI_API_KEY", None),
484 ]);
485
486 let err = EmbeddingClient::from_env().unwrap_err();
487 assert!(err.to_string().contains("GEMINI_API_KEY"));
488 }
489}