Skip to main content

omnigraph/
embedding.rs

1use std::future::Future;
2use std::time::Duration;
3
4use reqwest::Client;
5use serde::Deserialize;
6use serde_json::{Value, json};
7use tokio::time::sleep;
8
9use crate::error::{OmniError, Result};
10
11const DEFAULT_OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
12const DEFAULT_OPENROUTER_MODEL: &str = "openai/text-embedding-3-large";
13const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
14const DEFAULT_OPENAI_MODEL: &str = "text-embedding-3-large";
15const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
16const DEFAULT_GEMINI_MODEL: &str = "gemini-embedding-2";
17const DEFAULT_TIMEOUT_MS: u64 = 30_000;
18const DEFAULT_RETRY_ATTEMPTS: usize = 4;
19const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
20const DEFAULT_DEADLINE_MS: u64 = 60_000;
21const GEMINI_QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
22const GEMINI_DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
23
24/// Which embedding API a client speaks. Each variant owns its request shape,
25/// auth, and response parsing; everything else (retry, deadline, normalization,
26/// tracing) is provider-independent.
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub enum Provider {
29    /// OpenAI-compatible (`POST {base}/embeddings`, bearer auth,
30    /// `{model, input, dimensions}`). Covers OpenRouter (the default gateway),
31    /// OpenAI direct, and self-hosted endpoints (vLLM/Ollama/LM Studio).
32    OpenAiCompatible,
33    /// Google Gemini `generativelanguage` (`POST {base}/models/{model}:embedContent`,
34    /// `x-goog-api-key`), with `RETRIEVAL_QUERY` / `RETRIEVAL_DOCUMENT` task types.
35    Gemini,
36    /// Deterministic, offline. No network, no key.
37    Mock,
38}
39
40/// Whether the text being embedded is a search query or a stored document.
41/// Only Gemini distinguishes these (`RETRIEVAL_QUERY` vs `RETRIEVAL_DOCUMENT`);
42/// OpenAI-compatible providers and Mock produce the identical request for both,
43/// which is also the same-space property a query relies on.
44#[derive(Clone, Copy, Debug, PartialEq, Eq)]
45enum EmbedRole {
46    Query,
47    Document,
48}
49
50/// The single source of truth for how embedding text becomes a vector:
51/// provider + model + endpoint + key. Resolved once (from env for direct
52/// engine/CLI callers, or from an applied cluster `providers.embedding` profile
53/// at server boot) and shared by the query path and the offline CLI so stored
54/// and query vectors stay same-space by construction.
55#[derive(Clone, Debug)]
56pub struct EmbeddingConfig {
57    pub provider: Provider,
58    pub model: String,
59    pub base_url: String,
60    pub api_key: String,
61}
62
63impl EmbeddingConfig {
64    /// Resolve from the environment. Precedence:
65    /// 1. `OMNIGRAPH_EMBEDDINGS_MOCK` → Mock.
66    /// 2. `OMNIGRAPH_EMBED_PROVIDER` (`openai-compatible`|`openai`|`gemini`|`mock`);
67    ///    unset defaults to `openai-compatible` (OpenRouter).
68    /// 3. `OMNIGRAPH_EMBED_BASE_URL` else the provider default.
69    /// 4. `OMNIGRAPH_EMBED_MODEL` else the provider default.
70    /// 5. provider api-key env (`OPENROUTER_API_KEY`/`OPENAI_API_KEY`, or `GEMINI_API_KEY`).
71    pub fn from_env() -> Result<Self> {
72        if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
73            return Ok(Self::mock());
74        }
75
76        let alias = env_string("OMNIGRAPH_EMBED_PROVIDER");
77        if alias.as_deref() == Some("mock") {
78            return Ok(Self::mock());
79        }
80
81        let (provider, default_base, default_model, key_envs) = provider_profile(alias.as_deref())?;
82        let base_url = env_string("OMNIGRAPH_EMBED_BASE_URL")
83            .unwrap_or_else(|| default_base.to_string())
84            .trim_end_matches('/')
85            .to_string();
86        let model =
87            env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_else(|| default_model.to_string());
88
89        let api_key = key_envs.iter().copied().find_map(env_string).ok_or_else(|| {
90            OmniError::manifest_internal(format!(
91                "{} is required for the {} embedding provider",
92                key_envs.join(" or "),
93                alias.as_deref().unwrap_or("openai-compatible")
94            ))
95        })?;
96
97        Ok(Self {
98            provider,
99            model,
100            base_url,
101            api_key,
102        })
103    }
104
105    /// Build a config from explicit parts — the cluster `providers.embedding` profile path
106    /// (RFC-012 Phase 5). `provider`/`base_url`/`model` default exactly as
107    /// `from_env` does (shared `provider_profile`); `api_key` is already resolved
108    /// (the cluster path resolves a `${NAME}` ref before calling this).
109    pub fn from_parts(
110        provider: Option<&str>,
111        base_url: Option<String>,
112        model: Option<String>,
113        api_key: String,
114    ) -> Result<Self> {
115        if provider == Some("mock") {
116            // An explicit `model` (e.g. a cluster `providers.embedding` profile) is
117            // authoritative — it is what the same-space check compares against —
118            // so honor it; fall back to `mock()`'s env-based model only when the
119            // caller supplied none. Without this, a profile's `model` is silently
120            // dropped and the same-space check resolves to OMNIGRAPH_EMBED_MODEL.
121            let mut config = Self::mock();
122            if let Some(model) = model {
123                config.model = model;
124            }
125            return Ok(config);
126        }
127        let (provider, default_base, default_model, _key_envs) = provider_profile(provider)?;
128        let base_url = base_url
129            .unwrap_or_else(|| default_base.to_string())
130            .trim_end_matches('/')
131            .to_string();
132        let model = model.unwrap_or_else(|| default_model.to_string());
133        Ok(Self {
134            provider,
135            model,
136            base_url,
137            api_key,
138        })
139    }
140
141    fn mock() -> Self {
142        Self {
143            provider: Provider::Mock,
144            // Honor OMNIGRAPH_EMBED_MODEL so the same-space check is exercisable
145            // under mock; the mock vectors themselves don't depend on the model.
146            model: env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_default(),
147            base_url: String::new(),
148            api_key: String::new(),
149        }
150    }
151}
152
153#[derive(Clone, Debug)]
154pub struct EmbeddingClient {
155    config: EmbeddingConfig,
156    http: Client,
157    retry_attempts: usize,
158    retry_backoff_ms: u64,
159    /// Total wall-clock budget for one embed call, across all retries
160    /// (`OMNIGRAPH_EMBED_DEADLINE_MS`). `0` = unbounded.
161    deadline_ms: u64,
162}
163
164struct EmbedCallError {
165    message: String,
166    retryable: bool,
167}
168
169#[derive(Debug, Deserialize)]
170struct GeminiEmbedResponse {
171    embedding: GeminiContentEmbedding,
172}
173
174#[derive(Debug, Deserialize)]
175struct GeminiContentEmbedding {
176    values: Vec<f32>,
177}
178
179#[derive(Debug, Deserialize)]
180struct GoogleErrorEnvelope {
181    error: GoogleErrorBody,
182}
183
184#[derive(Debug, Deserialize)]
185struct GoogleErrorBody {
186    message: String,
187}
188
189#[derive(Debug, Deserialize)]
190struct OpenAiEmbeddingResponse {
191    data: Vec<OpenAiEmbeddingDatum>,
192}
193
194#[derive(Debug, Deserialize)]
195struct OpenAiEmbeddingDatum {
196    index: usize,
197    embedding: Vec<f32>,
198}
199
200#[derive(Debug, Deserialize)]
201struct OpenAiErrorEnvelope {
202    error: OpenAiErrorBody,
203}
204
205#[derive(Debug, Deserialize)]
206struct OpenAiErrorBody {
207    message: String,
208}
209
210impl EmbeddingClient {
211    pub fn from_env() -> Result<Self> {
212        Self::new(EmbeddingConfig::from_env()?)
213    }
214
215    pub fn new(config: EmbeddingConfig) -> Result<Self> {
216        let retry_attempts =
217            parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
218        let retry_backoff_ms =
219            parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
220        let deadline_ms =
221            parse_env_u64_allow_zero("OMNIGRAPH_EMBED_DEADLINE_MS", DEFAULT_DEADLINE_MS);
222        let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
223        let http = Client::builder()
224            .timeout(Duration::from_millis(timeout_ms))
225            .build()
226            .map_err(|e| {
227                OmniError::manifest_internal(format!("failed to initialize HTTP client: {}", e))
228            })?;
229
230        Ok(Self {
231            config,
232            http,
233            retry_attempts,
234            retry_backoff_ms,
235            deadline_ms,
236        })
237    }
238
239    pub fn config(&self) -> &EmbeddingConfig {
240        &self.config
241    }
242
243    #[cfg(test)]
244    fn mock_for_tests() -> Self {
245        Self::new(EmbeddingConfig::mock()).expect("mock client builds")
246    }
247
248    pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
249        self.embed_text(input, expected_dim, EmbedRole::Query).await
250    }
251
252    pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
253        self.embed_text(input, expected_dim, EmbedRole::Document).await
254    }
255
256    async fn embed_text(
257        &self,
258        input: &str,
259        expected_dim: usize,
260        role: EmbedRole,
261    ) -> Result<Vec<f32>> {
262        if expected_dim == 0 {
263            return Err(OmniError::manifest_internal(
264                "embedding dimension must be greater than zero",
265            ));
266        }
267
268        let started = std::time::Instant::now();
269        let result = self
270            .run_with_deadline(self.embed_text_inner(input, expected_dim, role))
271            .await;
272        let elapsed_ms = started.elapsed().as_millis() as u64;
273
274        match &result {
275            Ok(_) => tracing::info!(
276                target: "omnigraph::embedding",
277                provider = ?self.config.provider,
278                model = %self.config.model,
279                dim = expected_dim,
280                elapsed_ms,
281                outcome = "ok",
282                "embedding succeeded"
283            ),
284            Err(err) => tracing::warn!(
285                target: "omnigraph::embedding",
286                provider = ?self.config.provider,
287                model = %self.config.model,
288                dim = expected_dim,
289                elapsed_ms,
290                outcome = "error",
291                error = %err,
292                "embedding failed"
293            ),
294        }
295        result
296    }
297
298    /// Bound the whole embed operation (all retries + backoff) by `deadline_ms`,
299    /// so a degraded provider can never hang the caller for the full retry
300    /// envelope. Applies to every embed call (query and document). `0` =
301    /// unbounded. Embedding has no Lance/manifest side effects, so cancelling the
302    /// in-flight request future on elapse is safe.
303    async fn run_with_deadline<F>(&self, fut: F) -> Result<Vec<f32>>
304    where
305        F: Future<Output = Result<Vec<f32>>>,
306    {
307        if self.deadline_ms == 0 {
308            return fut.await;
309        }
310        match tokio::time::timeout(Duration::from_millis(self.deadline_ms), fut).await {
311            Ok(res) => res,
312            Err(_elapsed) => Err(OmniError::manifest_internal(format!(
313                "embedding deadline exceeded after {} ms (provider={:?}, model={})",
314                self.deadline_ms, self.config.provider, self.config.model
315            ))),
316        }
317    }
318
319    async fn embed_text_inner(
320        &self,
321        input: &str,
322        expected_dim: usize,
323        role: EmbedRole,
324    ) -> Result<Vec<f32>> {
325        match self.config.provider {
326            Provider::Mock => Ok(mock_embedding(input, expected_dim)),
327            Provider::Gemini => {
328                self.with_retry(|| self.embed_gemini_once(input, expected_dim, role))
329                    .await
330            }
331            Provider::OpenAiCompatible => {
332                self.with_retry(|| self.embed_openai_once(input, expected_dim))
333                    .await
334            }
335        }
336    }
337
338    async fn with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T>
339    where
340        F: FnMut() -> Fut,
341        Fut: Future<Output = std::result::Result<T, EmbedCallError>>,
342    {
343        let max_attempt = self.retry_attempts.max(1);
344        let mut attempt = 0usize;
345        loop {
346            attempt += 1;
347            match operation().await {
348                Ok(value) => return Ok(value),
349                Err(err) => {
350                    if !err.retryable || attempt >= max_attempt {
351                        return Err(OmniError::manifest_internal(err.message));
352                    }
353                    tracing::warn!(
354                        target: "omnigraph::embedding",
355                        provider = ?self.config.provider,
356                        model = %self.config.model,
357                        attempt,
358                        error = %err.message,
359                        "embedding attempt failed, retrying"
360                    );
361                    let shift = (attempt - 1).min(10) as u32;
362                    let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
363                    sleep(Duration::from_millis(delay)).await;
364                }
365            }
366        }
367    }
368
369    async fn embed_gemini_once(
370        &self,
371        input: &str,
372        expected_dim: usize,
373        role: EmbedRole,
374    ) -> std::result::Result<Vec<f32>, EmbedCallError> {
375        let task_type = match role {
376            EmbedRole::Query => GEMINI_QUERY_TASK_TYPE,
377            EmbedRole::Document => GEMINI_DOCUMENT_TASK_TYPE,
378        };
379
380        let response = self
381            .http
382            .post(gemini_endpoint(&self.config.base_url, &self.config.model))
383            .header("x-goog-api-key", &self.config.api_key)
384            .json(&build_gemini_request(
385                &self.config.model,
386                input,
387                expected_dim,
388                task_type,
389            ))
390            .send()
391            .await;
392        let response = match response {
393            Ok(response) => response,
394            Err(err) => {
395                let retryable = err.is_timeout() || err.is_connect() || err.is_request();
396                return Err(EmbedCallError {
397                    message: format!("embedding request failed: {}", err),
398                    retryable,
399                });
400            }
401        };
402
403        let status = response.status();
404        let body = match response.text().await {
405            Ok(body) => body,
406            Err(err) => {
407                return Err(EmbedCallError {
408                    message: format!("embedding response read failed (status {}): {}", status, err),
409                    retryable: status.is_server_error() || status.as_u16() == 429,
410                });
411            }
412        };
413
414        if !status.is_success() {
415            let message = parse_google_error_message(&body).unwrap_or(body);
416            return Err(EmbedCallError {
417                message: format!("embedding request failed with status {}: {}", status, message),
418                retryable: status.is_server_error() || status.as_u16() == 429,
419            });
420        }
421
422        let parsed: GeminiEmbedResponse =
423            serde_json::from_str(&body).map_err(|err| EmbedCallError {
424                message: format!("embedding response decode failed: {}", err),
425                retryable: false,
426            })?;
427
428        validate_and_normalize_embedding(parsed.embedding.values, expected_dim).map_err(|message| {
429            EmbedCallError {
430                message,
431                retryable: false,
432            }
433        })
434    }
435
436    async fn embed_openai_once(
437        &self,
438        input: &str,
439        expected_dim: usize,
440    ) -> std::result::Result<Vec<f32>, EmbedCallError> {
441        let response = self
442            .http
443            .post(format!("{}/embeddings", self.config.base_url))
444            .bearer_auth(&self.config.api_key)
445            .json(&build_openai_request(&self.config.model, input, expected_dim))
446            .send()
447            .await;
448        let response = match response {
449            Ok(response) => response,
450            Err(err) => {
451                let retryable = err.is_timeout() || err.is_connect() || err.is_request();
452                return Err(EmbedCallError {
453                    message: format!("embedding request failed: {}", err),
454                    retryable,
455                });
456            }
457        };
458
459        let status = response.status();
460        let body = match response.text().await {
461            Ok(body) => body,
462            Err(err) => {
463                return Err(EmbedCallError {
464                    message: format!("embedding response read failed (status {}): {}", status, err),
465                    retryable: status.is_server_error() || status.as_u16() == 429,
466                });
467            }
468        };
469
470        if !status.is_success() {
471            let message = parse_openai_error_message(&body).unwrap_or(body);
472            return Err(EmbedCallError {
473                message: format!("embedding request failed with status {}: {}", status, message),
474                retryable: status.is_server_error() || status.as_u16() == 429,
475            });
476        }
477
478        let parsed: OpenAiEmbeddingResponse =
479            serde_json::from_str(&body).map_err(|err| EmbedCallError {
480                message: format!("embedding response decode failed: {}", err),
481                retryable: false,
482            })?;
483
484        // The query path embeds exactly one string, so expect one datum at index 0.
485        let datum = parsed
486            .data
487            .into_iter()
488            .find(|d| d.index == 0)
489            .ok_or_else(|| EmbedCallError {
490                message: "embedding response missing data[0]".to_string(),
491                retryable: false,
492            })?;
493
494        validate_and_normalize_embedding(datum.embedding, expected_dim).map_err(|message| {
495            EmbedCallError {
496                message,
497                retryable: false,
498            }
499        })
500    }
501}
502
503fn gemini_endpoint(base_url: &str, model: &str) -> String {
504    format!(
505        "{}/models/{}:embedContent",
506        base_url.trim_end_matches('/'),
507        model
508    )
509}
510
511fn build_gemini_request(model: &str, input: &str, expected_dim: usize, task_type: &str) -> Value {
512    json!({
513        "model": format!("models/{}", model),
514        "content": {
515            "parts": [
516                {
517                    "text": input
518                }
519            ]
520        },
521        "taskType": task_type,
522        "outputDimensionality": expected_dim,
523    })
524}
525
526fn build_openai_request(model: &str, input: &str, expected_dim: usize) -> Value {
527    json!({
528        "model": model,
529        "input": [input],
530        "dimensions": expected_dim,
531    })
532}
533
534fn validate_and_normalize_embedding(
535    values: Vec<f32>,
536    expected_dim: usize,
537) -> std::result::Result<Vec<f32>, String> {
538    if values.len() != expected_dim {
539        return Err(format!(
540            "embedding dimension mismatch: expected {}, got {}",
541            expected_dim,
542            values.len()
543        ));
544    }
545    Ok(normalize_vector(values))
546}
547
548fn normalize_vector(mut values: Vec<f32>) -> Vec<f32> {
549    let norm = values
550        .iter()
551        .map(|v| (*v as f64) * (*v as f64))
552        .sum::<f64>()
553        .sqrt() as f32;
554    if norm > f32::EPSILON {
555        for value in &mut values {
556            *value /= norm;
557        }
558    }
559    values
560}
561
562fn parse_google_error_message(body: &str) -> Option<String> {
563    serde_json::from_str::<GoogleErrorEnvelope>(body)
564        .ok()
565        .map(|e| e.error.message)
566        .filter(|msg| !msg.trim().is_empty())
567}
568
569fn parse_openai_error_message(body: &str) -> Option<String> {
570    serde_json::from_str::<OpenAiErrorEnvelope>(body)
571        .ok()
572        .map(|e| e.error.message)
573        .filter(|msg| !msg.trim().is_empty())
574}
575
576/// Map a provider alias to `(provider, default base URL, default model, ordered
577/// api-key envs)`. Shared by `from_env` and `from_parts` so both apply identical
578/// defaults: `openai-compatible`/unset → the OpenRouter gateway, `openai` →
579/// OpenAI's own host. `mock` is handled by callers before this is reached. The
580/// `Provider` enum alone would collapse the two openai aliases, so the alias
581/// (not the enum) determines the key-env order here.
582fn provider_profile(
583    alias: Option<&str>,
584) -> Result<(Provider, &'static str, &'static str, &'static [&'static str])> {
585    Ok(match alias {
586        None | Some("openai-compatible") => (
587            Provider::OpenAiCompatible,
588            DEFAULT_OPENROUTER_BASE_URL,
589            DEFAULT_OPENROUTER_MODEL,
590            &["OPENROUTER_API_KEY", "OPENAI_API_KEY"],
591        ),
592        Some("openai") => (
593            Provider::OpenAiCompatible,
594            DEFAULT_OPENAI_BASE_URL,
595            DEFAULT_OPENAI_MODEL,
596            &["OPENAI_API_KEY"],
597        ),
598        Some("gemini") => (
599            Provider::Gemini,
600            DEFAULT_GEMINI_BASE_URL,
601            DEFAULT_GEMINI_MODEL,
602            &["GEMINI_API_KEY"],
603        ),
604        Some(other) => {
605            return Err(OmniError::manifest_internal(format!(
606                "unknown embedding provider '{}' (expected openai-compatible|openai|gemini|mock)",
607                other
608            )));
609        }
610    })
611}
612
613fn env_string(name: &str) -> Option<String> {
614    std::env::var(name)
615        .ok()
616        .map(|v| v.trim().to_string())
617        .filter(|v| !v.is_empty())
618}
619
620fn parse_env_usize(name: &str, default: usize) -> usize {
621    std::env::var(name)
622        .ok()
623        .and_then(|v| v.parse::<usize>().ok())
624        .filter(|v| *v > 0)
625        .unwrap_or(default)
626}
627
628fn parse_env_u64(name: &str, default: u64) -> u64 {
629    std::env::var(name)
630        .ok()
631        .and_then(|v| v.parse::<u64>().ok())
632        .filter(|v| *v > 0)
633        .unwrap_or(default)
634}
635
636/// Like [`parse_env_u64`] but accepts `0` as a meaningful value (the deadline
637/// uses `0` for "unbounded").
638fn parse_env_u64_allow_zero(name: &str, default: u64) -> u64 {
639    std::env::var(name)
640        .ok()
641        .and_then(|v| v.trim().parse::<u64>().ok())
642        .unwrap_or(default)
643}
644
645fn env_flag(name: &str) -> bool {
646    std::env::var(name)
647        .ok()
648        .map(|v| {
649            let s = v.trim().to_ascii_lowercase();
650            s == "1" || s == "true" || s == "yes" || s == "on"
651        })
652        .unwrap_or(false)
653}
654
655fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
656    let mut seed = fnv1a64(input.as_bytes());
657    let mut out = Vec::with_capacity(dim);
658    for _ in 0..dim {
659        seed = xorshift64(seed);
660        let ratio = (seed as f64 / u64::MAX as f64) as f32;
661        out.push((ratio * 2.0) - 1.0);
662    }
663    normalize_vector(out)
664}
665
666fn fnv1a64(bytes: &[u8]) -> u64 {
667    let mut hash = 14695981039346656037u64;
668    for byte in bytes {
669        hash ^= *byte as u64;
670        hash = hash.wrapping_mul(1099511628211u64);
671    }
672    hash
673}
674
675fn xorshift64(mut x: u64) -> u64 {
676    x ^= x << 13;
677    x ^= x >> 7;
678    x ^= x << 17;
679    x
680}
681
682#[cfg(test)]
683mod tests {
684    use std::sync::Arc;
685    use std::sync::atomic::{AtomicUsize, Ordering};
686
687    use serial_test::serial;
688
689    use super::*;
690
691    struct EnvGuard {
692        saved: Vec<(&'static str, Option<String>)>,
693    }
694
695    impl EnvGuard {
696        fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
697            let saved = vars
698                .iter()
699                .map(|(name, _)| (*name, std::env::var(name).ok()))
700                .collect::<Vec<_>>();
701            for (name, value) in vars {
702                unsafe {
703                    match value {
704                        Some(value) => std::env::set_var(name, value),
705                        None => std::env::remove_var(name),
706                    }
707                }
708            }
709            Self { saved }
710        }
711    }
712
713    impl Drop for EnvGuard {
714        fn drop(&mut self) {
715            for (name, value) in self.saved.drain(..) {
716                unsafe {
717                    match value {
718                        Some(value) => std::env::set_var(name, value),
719                        None => std::env::remove_var(name),
720                    }
721                }
722            }
723        }
724    }
725
726    // Every test that calls `EmbeddingConfig::from_env` clears the full set of
727    // embedding env vars first so the host environment can't leak in.
728    const EMBED_ENV: &[&str] = &[
729        "OMNIGRAPH_EMBEDDINGS_MOCK",
730        "OMNIGRAPH_EMBED_PROVIDER",
731        "OMNIGRAPH_EMBED_BASE_URL",
732        "OMNIGRAPH_EMBED_MODEL",
733        "OPENROUTER_API_KEY",
734        "OPENAI_API_KEY",
735        "GEMINI_API_KEY",
736    ];
737
738    fn cleared_env(extra: &[(&'static str, Option<&str>)]) -> EnvGuard {
739        let mut vars: Vec<(&'static str, Option<&str>)> =
740            EMBED_ENV.iter().map(|n| (*n, None)).collect();
741        vars.extend_from_slice(extra);
742        EnvGuard::set(&vars)
743    }
744
745    #[tokio::test]
746    async fn mock_embeddings_are_deterministic() {
747        let client = EmbeddingClient::mock_for_tests();
748        let a = client.embed_query_text("alpha", 8).await.unwrap();
749        let b = client.embed_query_text("alpha", 8).await.unwrap();
750        let c = client.embed_query_text("beta", 8).await.unwrap();
751        assert_eq!(a, b);
752        assert_ne!(a, c);
753        assert_eq!(a.len(), 8);
754    }
755
756    #[test]
757    fn gemini_request_uses_model_retrieval_query_and_dimension() {
758        let request =
759            build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_QUERY_TASK_TYPE);
760        assert_eq!(request["model"], "models/gemini-embedding-2");
761        assert_eq!(request["taskType"], GEMINI_QUERY_TASK_TYPE);
762        assert_eq!(request["outputDimensionality"], 4);
763        assert_eq!(request["content"]["parts"][0]["text"], "alpha");
764    }
765
766    #[test]
767    fn gemini_document_request_uses_retrieval_document_task_type() {
768        let request =
769            build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_DOCUMENT_TASK_TYPE);
770        assert_eq!(request["taskType"], GEMINI_DOCUMENT_TASK_TYPE);
771    }
772
773    #[test]
774    fn openai_request_uses_model_input_array_and_dimensions() {
775        let request = build_openai_request("openai/text-embedding-3-large", "alpha", 4);
776        assert_eq!(request["model"], "openai/text-embedding-3-large");
777        assert_eq!(request["input"][0], "alpha");
778        assert!(request["input"].is_array());
779        assert_eq!(request["dimensions"], 4);
780        assert!(request.get("taskType").is_none());
781    }
782
783    #[test]
784    fn validate_and_normalize_embedding_enforces_dimension() {
785        let normalized = validate_and_normalize_embedding(vec![3.0, 4.0], 2).unwrap();
786        assert!((normalized[0] - 0.6).abs() < 1e-6);
787        assert!((normalized[1] - 0.8).abs() < 1e-6);
788
789        let err = validate_and_normalize_embedding(vec![1.0, 2.0], 3).unwrap_err();
790        assert!(err.contains("expected 3, got 2"));
791    }
792
793    #[tokio::test]
794    async fn with_retry_retries_retryable_failures() {
795        let client = EmbeddingClient::mock_for_tests();
796        let attempts = Arc::new(AtomicUsize::new(0));
797        let attempts_for_call = Arc::clone(&attempts);
798
799        let value = client
800            .with_retry(|| {
801                let attempts_for_call = Arc::clone(&attempts_for_call);
802                async move {
803                    let attempt = attempts_for_call.fetch_add(1, Ordering::SeqCst);
804                    if attempt == 0 {
805                        Err(EmbedCallError {
806                            message: "retry me".to_string(),
807                            retryable: true,
808                        })
809                    } else {
810                        Ok("ok")
811                    }
812                }
813            })
814            .await
815            .unwrap();
816
817        assert_eq!(value, "ok");
818        assert_eq!(attempts.load(Ordering::SeqCst), 2);
819    }
820
821    #[tokio::test]
822    async fn with_retry_stops_on_non_retryable_failures() {
823        let client = EmbeddingClient::mock_for_tests();
824        let err = client
825            .with_retry(|| async {
826                Err::<(), _>(EmbedCallError {
827                    message: "do not retry".to_string(),
828                    retryable: false,
829                })
830            })
831            .await
832            .unwrap_err();
833
834        assert!(err.to_string().contains("do not retry"));
835    }
836
837    #[tokio::test]
838    async fn run_with_deadline_aborts_slow_future() {
839        let mut client = EmbeddingClient::mock_for_tests();
840        client.deadline_ms = 20;
841        let slow = async {
842            tokio::time::sleep(Duration::from_secs(5)).await;
843            Ok(vec![0.0_f32])
844        };
845        let err = client.run_with_deadline(slow).await.unwrap_err();
846        assert!(err.to_string().contains("deadline exceeded"));
847    }
848
849    #[tokio::test]
850    async fn run_with_deadline_passes_through_fast_future() {
851        let client = EmbeddingClient::mock_for_tests();
852        let ok = client
853            .run_with_deadline(async { Ok(vec![1.0_f32, 2.0]) })
854            .await
855            .unwrap();
856        assert_eq!(ok, vec![1.0, 2.0]);
857    }
858
859    #[tokio::test]
860    async fn run_with_deadline_zero_is_unbounded() {
861        let mut client = EmbeddingClient::mock_for_tests();
862        client.deadline_ms = 0;
863        let ok = client
864            .run_with_deadline(async { Ok(vec![3.0_f32]) })
865            .await
866            .unwrap();
867        assert_eq!(ok, vec![3.0]);
868    }
869
870    #[test]
871    #[serial]
872    fn from_env_defaults_to_openai_compatible_openrouter() {
873        let _guard = cleared_env(&[("OPENROUTER_API_KEY", Some("sk-test"))]);
874        let config = EmbeddingConfig::from_env().unwrap();
875        assert_eq!(config.provider, Provider::OpenAiCompatible);
876        assert_eq!(config.base_url, DEFAULT_OPENROUTER_BASE_URL);
877        assert_eq!(config.model, DEFAULT_OPENROUTER_MODEL);
878        assert_eq!(config.api_key, "sk-test");
879    }
880
881    #[test]
882    #[serial]
883    fn from_env_openai_alias_uses_openai_host_not_openrouter() {
884        let _guard = cleared_env(&[
885            ("OMNIGRAPH_EMBED_PROVIDER", Some("openai")),
886            ("OPENAI_API_KEY", Some("k")),
887        ]);
888        let config = EmbeddingConfig::from_env().unwrap();
889        assert_eq!(config.provider, Provider::OpenAiCompatible);
890        assert_eq!(config.base_url, DEFAULT_OPENAI_BASE_URL); // api.openai.com, not OpenRouter
891        assert_eq!(config.model, DEFAULT_OPENAI_MODEL); // text-embedding-3-large, no openai/ prefix
892        assert_eq!(config.api_key, "k");
893    }
894
895    #[test]
896    #[serial]
897    fn from_env_openai_alias_prefers_openai_key_over_openrouter() {
898        // `openai` targets api.openai.com, so an OpenRouter key must not be sent there.
899        let _guard = cleared_env(&[
900            ("OMNIGRAPH_EMBED_PROVIDER", Some("openai")),
901            ("OPENROUTER_API_KEY", Some("router")),
902            ("OPENAI_API_KEY", Some("openai")),
903        ]);
904        let config = EmbeddingConfig::from_env().unwrap();
905        assert_eq!(config.base_url, DEFAULT_OPENAI_BASE_URL);
906        assert_eq!(config.api_key, "openai");
907    }
908
909    #[test]
910    #[serial]
911    fn from_env_openai_alias_errors_when_only_openrouter_key_is_set() {
912        let _guard = cleared_env(&[
913            ("OMNIGRAPH_EMBED_PROVIDER", Some("openai")),
914            ("OPENROUTER_API_KEY", Some("router")),
915        ]);
916        let err = EmbeddingConfig::from_env().unwrap_err();
917        assert!(err.to_string().contains("OPENAI_API_KEY"), "got: {err}");
918    }
919
920    #[test]
921    fn from_parts_applies_provider_defaults_and_overrides() {
922        let openrouter = EmbeddingConfig::from_parts(None, None, None, "k".to_string()).unwrap();
923        assert_eq!(openrouter.provider, Provider::OpenAiCompatible);
924        assert_eq!(openrouter.base_url, DEFAULT_OPENROUTER_BASE_URL);
925        assert_eq!(openrouter.model, DEFAULT_OPENROUTER_MODEL);
926        assert_eq!(openrouter.api_key, "k");
927
928        let gemini =
929            EmbeddingConfig::from_parts(Some("gemini"), None, None, "g".to_string()).unwrap();
930        assert_eq!(gemini.provider, Provider::Gemini);
931        assert_eq!(gemini.base_url, DEFAULT_GEMINI_BASE_URL);
932
933        let overridden = EmbeddingConfig::from_parts(
934            Some("openai"),
935            Some("https://x/v1/".to_string()),
936            Some("custom".to_string()),
937            "k".to_string(),
938        )
939        .unwrap();
940        assert_eq!(overridden.base_url, "https://x/v1"); // trailing slash trimmed
941        assert_eq!(overridden.model, "custom");
942
943        let err =
944            EmbeddingConfig::from_parts(Some("cohere"), None, None, "k".to_string()).unwrap_err();
945        assert!(
946            err.to_string().contains("unknown embedding provider"),
947            "got: {err}"
948        );
949    }
950
951    #[test]
952    #[serial]
953    fn from_parts_mock_honors_an_explicit_model() {
954        // A cluster `providers.embedding` profile that sets `kind: mock, model: X`
955        // must resolve to model X — it is what the query-time same-space check
956        // compares against. Env cleared so the assertion isolates the arg.
957        let _guard = cleared_env(&[]);
958        let pinned =
959            EmbeddingConfig::from_parts(Some("mock"), None, Some("recorded-x".to_string()), String::new())
960                .unwrap();
961        assert_eq!(pinned.provider, Provider::Mock);
962        assert_eq!(pinned.model, "recorded-x");
963        // With no explicit model, mock falls back to its env-based default (here
964        // empty, since the env is cleared).
965        let bare = EmbeddingConfig::from_parts(Some("mock"), None, None, String::new()).unwrap();
966        assert_eq!(bare.provider, Provider::Mock);
967        assert_eq!(bare.model, "");
968    }
969
970    #[test]
971    #[serial]
972    fn from_env_openai_compatible_prefers_openrouter_key() {
973        let _guard = cleared_env(&[
974            ("OPENROUTER_API_KEY", Some("router")),
975            ("OPENAI_API_KEY", Some("openai")),
976        ]);
977        let config = EmbeddingConfig::from_env().unwrap();
978        assert_eq!(config.api_key, "router");
979    }
980
981    #[test]
982    #[serial]
983    fn from_env_explicit_gemini_provider() {
984        let _guard = cleared_env(&[
985            ("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")),
986            ("GEMINI_API_KEY", Some("g-key")),
987        ]);
988        let config = EmbeddingConfig::from_env().unwrap();
989        assert_eq!(config.provider, Provider::Gemini);
990        assert_eq!(config.base_url, DEFAULT_GEMINI_BASE_URL);
991        assert_eq!(config.model, DEFAULT_GEMINI_MODEL);
992        assert_eq!(config.api_key, "g-key");
993    }
994
995    #[test]
996    #[serial]
997    fn from_env_base_url_and_model_overrides_apply() {
998        let _guard = cleared_env(&[
999            ("OMNIGRAPH_EMBED_PROVIDER", Some("openai-compatible")),
1000            ("OMNIGRAPH_EMBED_BASE_URL", Some("https://example.test/v1/")),
1001            ("OMNIGRAPH_EMBED_MODEL", Some("custom/model")),
1002            ("OPENAI_API_KEY", Some("k")),
1003        ]);
1004        let config = EmbeddingConfig::from_env().unwrap();
1005        assert_eq!(config.base_url, "https://example.test/v1"); // trailing slash trimmed
1006        assert_eq!(config.model, "custom/model");
1007    }
1008
1009    #[test]
1010    #[serial]
1011    fn from_env_unknown_provider_errors() {
1012        let _guard = cleared_env(&[("OMNIGRAPH_EMBED_PROVIDER", Some("cohere"))]);
1013        let err = EmbeddingConfig::from_env().unwrap_err();
1014        assert!(err.to_string().contains("unknown embedding provider"));
1015    }
1016
1017    #[test]
1018    #[serial]
1019    fn from_env_errors_when_no_key_present() {
1020        let _guard = cleared_env(&[]);
1021        let err = EmbeddingConfig::from_env().unwrap_err();
1022        assert!(err.to_string().contains("OPENROUTER_API_KEY or OPENAI_API_KEY"));
1023    }
1024
1025    #[test]
1026    #[serial]
1027    fn from_env_mock_flag_wins() {
1028        let _guard = cleared_env(&[
1029            ("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")),
1030            ("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")),
1031        ]);
1032        let config = EmbeddingConfig::from_env().unwrap();
1033        assert_eq!(config.provider, Provider::Mock);
1034    }
1035}