Skip to main content

omnigraph/
embedding.rs

1use std::future::Future;
2use std::time::Duration;
3
4use reqwest::Client;
5use serde::Deserialize;
6use serde_json::{Value, json};
7use tokio::time::sleep;
8
9use crate::error::{OmniError, Result};
10
11const GEMINI_EMBED_MODEL: &str = "gemini-embedding-2-preview";
12const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
13const DEFAULT_TIMEOUT_MS: u64 = 30_000;
14const DEFAULT_RETRY_ATTEMPTS: usize = 4;
15const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
16const QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
17const DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
18
19#[derive(Clone, Debug)]
20enum EmbeddingTransport {
21    Mock,
22    Gemini {
23        api_key: String,
24        base_url: String,
25        http: Client,
26    },
27}
28
29#[derive(Clone, Debug)]
30pub struct EmbeddingClient {
31    retry_attempts: usize,
32    retry_backoff_ms: u64,
33    transport: EmbeddingTransport,
34}
35
36struct EmbedCallError {
37    message: String,
38    retryable: bool,
39}
40
41#[derive(Debug, Deserialize)]
42struct GeminiEmbedResponse {
43    embedding: GeminiContentEmbedding,
44}
45
46#[derive(Debug, Deserialize)]
47struct GeminiContentEmbedding {
48    values: Vec<f32>,
49}
50
51#[derive(Debug, Deserialize)]
52struct GoogleErrorEnvelope {
53    error: GoogleErrorBody,
54}
55
56#[derive(Debug, Deserialize)]
57struct GoogleErrorBody {
58    message: String,
59}
60
61impl EmbeddingClient {
62    pub fn from_env() -> Result<Self> {
63        let retry_attempts =
64            parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
65        let retry_backoff_ms =
66            parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
67
68        if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
69            return Ok(Self {
70                retry_attempts,
71                retry_backoff_ms,
72                transport: EmbeddingTransport::Mock,
73            });
74        }
75
76        let api_key = std::env::var("GEMINI_API_KEY")
77            .ok()
78            .map(|v| v.trim().to_string())
79            .filter(|v| !v.is_empty())
80            .ok_or_else(|| {
81                OmniError::manifest_internal(
82                    "GEMINI_API_KEY is required when nearest() needs a string embedding",
83                )
84            })?;
85        let base_url = std::env::var("OMNIGRAPH_GEMINI_BASE_URL")
86            .ok()
87            .map(|v| v.trim_end_matches('/').to_string())
88            .filter(|v| !v.is_empty())
89            .unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.to_string());
90        let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
91        let http = Client::builder()
92            .timeout(Duration::from_millis(timeout_ms))
93            .build()
94            .map_err(|e| {
95                OmniError::manifest_internal(format!("failed to initialize HTTP client: {}", e))
96            })?;
97
98        Ok(Self {
99            retry_attempts,
100            retry_backoff_ms,
101            transport: EmbeddingTransport::Gemini {
102                api_key,
103                base_url,
104                http,
105            },
106        })
107    }
108
109    #[cfg(test)]
110    fn mock_for_tests() -> Self {
111        Self {
112            retry_attempts: DEFAULT_RETRY_ATTEMPTS,
113            retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
114            transport: EmbeddingTransport::Mock,
115        }
116    }
117
118    pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
119        self.embed_text(input, expected_dim, QUERY_TASK_TYPE).await
120    }
121
122    pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
123        self.embed_text(input, expected_dim, DOCUMENT_TASK_TYPE)
124            .await
125    }
126
127    async fn embed_text(
128        &self,
129        input: &str,
130        expected_dim: usize,
131        task_type: &'static str,
132    ) -> Result<Vec<f32>> {
133        if expected_dim == 0 {
134            return Err(OmniError::manifest_internal(
135                "embedding dimension must be greater than zero",
136            ));
137        }
138
139        match &self.transport {
140            EmbeddingTransport::Mock => Ok(mock_embedding(input, expected_dim)),
141            EmbeddingTransport::Gemini { .. } => {
142                self.with_retry(|| self.embed_text_gemini_once(input, expected_dim, task_type))
143                    .await
144            }
145        }
146    }
147
148    async fn with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T>
149    where
150        F: FnMut() -> Fut,
151        Fut: Future<Output = std::result::Result<T, EmbedCallError>>,
152    {
153        let max_attempt = self.retry_attempts.max(1);
154        let mut attempt = 0usize;
155        loop {
156            attempt += 1;
157            match operation().await {
158                Ok(value) => return Ok(value),
159                Err(err) => {
160                    if !err.retryable || attempt >= max_attempt {
161                        return Err(OmniError::manifest_internal(err.message));
162                    }
163                    let shift = (attempt - 1).min(10) as u32;
164                    let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
165                    sleep(Duration::from_millis(delay)).await;
166                }
167            }
168        }
169    }
170
171    async fn embed_text_gemini_once(
172        &self,
173        input: &str,
174        expected_dim: usize,
175        task_type: &'static str,
176    ) -> std::result::Result<Vec<f32>, EmbedCallError> {
177        let (api_key, base_url, http) = match &self.transport {
178            EmbeddingTransport::Gemini {
179                api_key,
180                base_url,
181                http,
182            } => (api_key, base_url, http),
183            EmbeddingTransport::Mock => unreachable!("mock transport should not call Gemini"),
184        };
185
186        let response = http
187            .post(gemini_endpoint(base_url))
188            .header("x-goog-api-key", api_key)
189            .json(&build_gemini_request(input, expected_dim, task_type))
190            .send()
191            .await;
192        let response = match response {
193            Ok(response) => response,
194            Err(err) => {
195                let retryable = err.is_timeout() || err.is_connect() || err.is_request();
196                return Err(EmbedCallError {
197                    message: format!("embedding request failed: {}", err),
198                    retryable,
199                });
200            }
201        };
202
203        let status = response.status();
204        let body = match response.text().await {
205            Ok(body) => body,
206            Err(err) => {
207                return Err(EmbedCallError {
208                    message: format!(
209                        "embedding response read failed (status {}): {}",
210                        status, err
211                    ),
212                    retryable: status.is_server_error() || status.as_u16() == 429,
213                });
214            }
215        };
216
217        if !status.is_success() {
218            let message = parse_google_error_message(&body).unwrap_or(body);
219            return Err(EmbedCallError {
220                message: format!(
221                    "embedding request failed with status {}: {}",
222                    status, message
223                ),
224                retryable: status.is_server_error() || status.as_u16() == 429,
225            });
226        }
227
228        let parsed: GeminiEmbedResponse =
229            serde_json::from_str(&body).map_err(|err| EmbedCallError {
230                message: format!("embedding response decode failed: {}", err),
231                retryable: false,
232            })?;
233
234        validate_and_normalize_embedding(parsed.embedding.values, expected_dim).map_err(|message| {
235            EmbedCallError {
236                message,
237                retryable: false,
238            }
239        })
240    }
241}
242
243fn gemini_endpoint(base_url: &str) -> String {
244    format!(
245        "{}/models/{}:embedContent",
246        base_url.trim_end_matches('/'),
247        GEMINI_EMBED_MODEL
248    )
249}
250
251fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static str) -> Value {
252    json!({
253        "model": format!("models/{}", GEMINI_EMBED_MODEL),
254        "content": {
255            "parts": [
256                {
257                    "text": input
258                }
259            ]
260        },
261        "taskType": task_type,
262        "outputDimensionality": expected_dim,
263    })
264}
265
266fn validate_and_normalize_embedding(
267    values: Vec<f32>,
268    expected_dim: usize,
269) -> std::result::Result<Vec<f32>, String> {
270    if values.len() != expected_dim {
271        return Err(format!(
272            "embedding dimension mismatch: expected {}, got {}",
273            expected_dim,
274            values.len()
275        ));
276    }
277    Ok(normalize_vector(values))
278}
279
280fn normalize_vector(mut values: Vec<f32>) -> Vec<f32> {
281    let norm = values
282        .iter()
283        .map(|v| (*v as f64) * (*v as f64))
284        .sum::<f64>()
285        .sqrt() as f32;
286    if norm > f32::EPSILON {
287        for value in &mut values {
288            *value /= norm;
289        }
290    }
291    values
292}
293
294fn parse_google_error_message(body: &str) -> Option<String> {
295    serde_json::from_str::<GoogleErrorEnvelope>(body)
296        .ok()
297        .map(|e| e.error.message)
298        .filter(|msg| !msg.trim().is_empty())
299}
300
301fn parse_env_usize(name: &str, default: usize) -> usize {
302    std::env::var(name)
303        .ok()
304        .and_then(|v| v.parse::<usize>().ok())
305        .filter(|v| *v > 0)
306        .unwrap_or(default)
307}
308
309fn parse_env_u64(name: &str, default: u64) -> u64 {
310    std::env::var(name)
311        .ok()
312        .and_then(|v| v.parse::<u64>().ok())
313        .filter(|v| *v > 0)
314        .unwrap_or(default)
315}
316
317fn env_flag(name: &str) -> bool {
318    std::env::var(name)
319        .ok()
320        .map(|v| {
321            let s = v.trim().to_ascii_lowercase();
322            s == "1" || s == "true" || s == "yes" || s == "on"
323        })
324        .unwrap_or(false)
325}
326
327fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
328    let mut seed = fnv1a64(input.as_bytes());
329    let mut out = Vec::with_capacity(dim);
330    for _ in 0..dim {
331        seed = xorshift64(seed);
332        let ratio = (seed as f64 / u64::MAX as f64) as f32;
333        out.push((ratio * 2.0) - 1.0);
334    }
335    normalize_vector(out)
336}
337
338fn fnv1a64(bytes: &[u8]) -> u64 {
339    let mut hash = 14695981039346656037u64;
340    for byte in bytes {
341        hash ^= *byte as u64;
342        hash = hash.wrapping_mul(1099511628211u64);
343    }
344    hash
345}
346
347fn xorshift64(mut x: u64) -> u64 {
348    x ^= x << 13;
349    x ^= x >> 7;
350    x ^= x << 17;
351    x
352}
353
354#[cfg(test)]
355mod tests {
356    use std::sync::Arc;
357    use std::sync::atomic::{AtomicUsize, Ordering};
358
359    use serial_test::serial;
360
361    use super::*;
362
363    struct EnvGuard {
364        saved: Vec<(&'static str, Option<String>)>,
365    }
366
367    impl EnvGuard {
368        fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
369            let saved = vars
370                .iter()
371                .map(|(name, _)| (*name, std::env::var(name).ok()))
372                .collect::<Vec<_>>();
373            for (name, value) in vars {
374                unsafe {
375                    match value {
376                        Some(value) => std::env::set_var(name, value),
377                        None => std::env::remove_var(name),
378                    }
379                }
380            }
381            Self { saved }
382        }
383    }
384
385    impl Drop for EnvGuard {
386        fn drop(&mut self) {
387            for (name, value) in self.saved.drain(..) {
388                unsafe {
389                    match value {
390                        Some(value) => std::env::set_var(name, value),
391                        None => std::env::remove_var(name),
392                    }
393                }
394            }
395        }
396    }
397
398    #[tokio::test]
399    async fn mock_embeddings_are_deterministic() {
400        let client = EmbeddingClient::mock_for_tests();
401        let a = client.embed_query_text("alpha", 8).await.unwrap();
402        let b = client.embed_query_text("alpha", 8).await.unwrap();
403        let c = client.embed_query_text("beta", 8).await.unwrap();
404        assert_eq!(a, b);
405        assert_ne!(a, c);
406        assert_eq!(a.len(), 8);
407    }
408
409    #[test]
410    fn gemini_request_uses_preview_model_retrieval_query_and_dimension() {
411        let request = build_gemini_request("alpha", 4, QUERY_TASK_TYPE);
412        assert_eq!(request["model"], "models/gemini-embedding-2-preview");
413        assert_eq!(request["taskType"], QUERY_TASK_TYPE);
414        assert_eq!(request["outputDimensionality"], 4);
415        assert_eq!(request["content"]["parts"][0]["text"], "alpha");
416    }
417
418    #[test]
419    fn gemini_document_request_uses_retrieval_document_task_type() {
420        let request = build_gemini_request("alpha", 4, DOCUMENT_TASK_TYPE);
421        assert_eq!(request["taskType"], DOCUMENT_TASK_TYPE);
422    }
423
424    #[test]
425    fn validate_and_normalize_embedding_enforces_dimension() {
426        let normalized = validate_and_normalize_embedding(vec![3.0, 4.0], 2).unwrap();
427        assert!((normalized[0] - 0.6).abs() < 1e-6);
428        assert!((normalized[1] - 0.8).abs() < 1e-6);
429
430        let err = validate_and_normalize_embedding(vec![1.0, 2.0], 3).unwrap_err();
431        assert!(err.contains("expected 3, got 2"));
432    }
433
434    #[tokio::test]
435    async fn with_retry_retries_retryable_failures() {
436        let client = EmbeddingClient::mock_for_tests();
437        let attempts = Arc::new(AtomicUsize::new(0));
438        let attempts_for_call = Arc::clone(&attempts);
439
440        let value = client
441            .with_retry(|| {
442                let attempts_for_call = Arc::clone(&attempts_for_call);
443                async move {
444                    let attempt = attempts_for_call.fetch_add(1, Ordering::SeqCst);
445                    if attempt == 0 {
446                        Err(EmbedCallError {
447                            message: "retry me".to_string(),
448                            retryable: true,
449                        })
450                    } else {
451                        Ok("ok")
452                    }
453                }
454            })
455            .await
456            .unwrap();
457
458        assert_eq!(value, "ok");
459        assert_eq!(attempts.load(Ordering::SeqCst), 2);
460    }
461
462    #[tokio::test]
463    async fn with_retry_stops_on_non_retryable_failures() {
464        let client = EmbeddingClient::mock_for_tests();
465        let err = client
466            .with_retry(|| async {
467                Err::<(), _>(EmbedCallError {
468                    message: "do not retry".to_string(),
469                    retryable: false,
470                })
471            })
472            .await
473            .unwrap_err();
474
475        assert!(err.to_string().contains("do not retry"));
476    }
477
478    #[test]
479    #[serial]
480    fn from_env_requires_gemini_api_key_when_not_mocking() {
481        let _guard = EnvGuard::set(&[
482            ("OMNIGRAPH_EMBEDDINGS_MOCK", None),
483            ("GEMINI_API_KEY", None),
484        ]);
485
486        let err = EmbeddingClient::from_env().unwrap_err();
487        assert!(err.to_string().contains("GEMINI_API_KEY"));
488    }
489}