1use std::future::Future;
2use std::time::Duration;
3
4use reqwest::Client;
5use serde::Deserialize;
6use serde_json::{Value, json};
7use tokio::time::sleep;
8
9use crate::error::{OmniError, Result};
10
11const DEFAULT_OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
12const DEFAULT_OPENROUTER_MODEL: &str = "openai/text-embedding-3-large";
13const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
14const DEFAULT_OPENAI_MODEL: &str = "text-embedding-3-large";
15const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
16const DEFAULT_GEMINI_MODEL: &str = "gemini-embedding-2";
17const DEFAULT_TIMEOUT_MS: u64 = 30_000;
18const DEFAULT_RETRY_ATTEMPTS: usize = 4;
19const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
20const DEFAULT_DEADLINE_MS: u64 = 60_000;
21const GEMINI_QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
22const GEMINI_DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub enum Provider {
29 OpenAiCompatible,
33 Gemini,
36 Mock,
38}
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
45enum EmbedRole {
46 Query,
47 Document,
48}
49
50#[derive(Clone, Debug)]
56pub struct EmbeddingConfig {
57 pub provider: Provider,
58 pub model: String,
59 pub base_url: String,
60 pub api_key: String,
61}
62
63impl EmbeddingConfig {
64 pub fn from_env() -> Result<Self> {
72 if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
73 return Ok(Self::mock());
74 }
75
76 let alias = env_string("OMNIGRAPH_EMBED_PROVIDER");
77 if alias.as_deref() == Some("mock") {
78 return Ok(Self::mock());
79 }
80
81 let (provider, default_base, default_model, key_envs) = provider_profile(alias.as_deref())?;
82 let base_url = env_string("OMNIGRAPH_EMBED_BASE_URL")
83 .unwrap_or_else(|| default_base.to_string())
84 .trim_end_matches('/')
85 .to_string();
86 let model =
87 env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_else(|| default_model.to_string());
88
89 let api_key = key_envs.iter().copied().find_map(env_string).ok_or_else(|| {
90 OmniError::manifest_internal(format!(
91 "{} is required for the {} embedding provider",
92 key_envs.join(" or "),
93 alias.as_deref().unwrap_or("openai-compatible")
94 ))
95 })?;
96
97 Ok(Self {
98 provider,
99 model,
100 base_url,
101 api_key,
102 })
103 }
104
105 pub fn from_parts(
110 provider: Option<&str>,
111 base_url: Option<String>,
112 model: Option<String>,
113 api_key: String,
114 ) -> Result<Self> {
115 if provider == Some("mock") {
116 let mut config = Self::mock();
122 if let Some(model) = model {
123 config.model = model;
124 }
125 return Ok(config);
126 }
127 let (provider, default_base, default_model, _key_envs) = provider_profile(provider)?;
128 let base_url = base_url
129 .unwrap_or_else(|| default_base.to_string())
130 .trim_end_matches('/')
131 .to_string();
132 let model = model.unwrap_or_else(|| default_model.to_string());
133 Ok(Self {
134 provider,
135 model,
136 base_url,
137 api_key,
138 })
139 }
140
141 fn mock() -> Self {
142 Self {
143 provider: Provider::Mock,
144 model: env_string("OMNIGRAPH_EMBED_MODEL").unwrap_or_default(),
147 base_url: String::new(),
148 api_key: String::new(),
149 }
150 }
151}
152
153#[derive(Clone, Debug)]
154pub struct EmbeddingClient {
155 config: EmbeddingConfig,
156 http: Client,
157 retry_attempts: usize,
158 retry_backoff_ms: u64,
159 deadline_ms: u64,
162}
163
164struct EmbedCallError {
165 message: String,
166 retryable: bool,
167}
168
169#[derive(Debug, Deserialize)]
170struct GeminiEmbedResponse {
171 embedding: GeminiContentEmbedding,
172}
173
174#[derive(Debug, Deserialize)]
175struct GeminiContentEmbedding {
176 values: Vec<f32>,
177}
178
179#[derive(Debug, Deserialize)]
180struct GoogleErrorEnvelope {
181 error: GoogleErrorBody,
182}
183
184#[derive(Debug, Deserialize)]
185struct GoogleErrorBody {
186 message: String,
187}
188
189#[derive(Debug, Deserialize)]
190struct OpenAiEmbeddingResponse {
191 data: Vec<OpenAiEmbeddingDatum>,
192}
193
194#[derive(Debug, Deserialize)]
195struct OpenAiEmbeddingDatum {
196 index: usize,
197 embedding: Vec<f32>,
198}
199
200#[derive(Debug, Deserialize)]
201struct OpenAiErrorEnvelope {
202 error: OpenAiErrorBody,
203}
204
205#[derive(Debug, Deserialize)]
206struct OpenAiErrorBody {
207 message: String,
208}
209
210impl EmbeddingClient {
211 pub fn from_env() -> Result<Self> {
212 Self::new(EmbeddingConfig::from_env()?)
213 }
214
215 pub fn new(config: EmbeddingConfig) -> Result<Self> {
216 let retry_attempts =
217 parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
218 let retry_backoff_ms =
219 parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
220 let deadline_ms =
221 parse_env_u64_allow_zero("OMNIGRAPH_EMBED_DEADLINE_MS", DEFAULT_DEADLINE_MS);
222 let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
223 let http = Client::builder()
224 .timeout(Duration::from_millis(timeout_ms))
225 .build()
226 .map_err(|e| {
227 OmniError::manifest_internal(format!("failed to initialize HTTP client: {}", e))
228 })?;
229
230 Ok(Self {
231 config,
232 http,
233 retry_attempts,
234 retry_backoff_ms,
235 deadline_ms,
236 })
237 }
238
239 pub fn config(&self) -> &EmbeddingConfig {
240 &self.config
241 }
242
243 #[cfg(test)]
244 fn mock_for_tests() -> Self {
245 Self::new(EmbeddingConfig::mock()).expect("mock client builds")
246 }
247
248 pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
249 self.embed_text(input, expected_dim, EmbedRole::Query).await
250 }
251
252 pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
253 self.embed_text(input, expected_dim, EmbedRole::Document).await
254 }
255
256 async fn embed_text(
257 &self,
258 input: &str,
259 expected_dim: usize,
260 role: EmbedRole,
261 ) -> Result<Vec<f32>> {
262 if expected_dim == 0 {
263 return Err(OmniError::manifest_internal(
264 "embedding dimension must be greater than zero",
265 ));
266 }
267
268 let started = std::time::Instant::now();
269 let result = self
270 .run_with_deadline(self.embed_text_inner(input, expected_dim, role))
271 .await;
272 let elapsed_ms = started.elapsed().as_millis() as u64;
273
274 match &result {
275 Ok(_) => tracing::info!(
276 target: "omnigraph::embedding",
277 provider = ?self.config.provider,
278 model = %self.config.model,
279 dim = expected_dim,
280 elapsed_ms,
281 outcome = "ok",
282 "embedding succeeded"
283 ),
284 Err(err) => tracing::warn!(
285 target: "omnigraph::embedding",
286 provider = ?self.config.provider,
287 model = %self.config.model,
288 dim = expected_dim,
289 elapsed_ms,
290 outcome = "error",
291 error = %err,
292 "embedding failed"
293 ),
294 }
295 result
296 }
297
298 async fn run_with_deadline<F>(&self, fut: F) -> Result<Vec<f32>>
304 where
305 F: Future<Output = Result<Vec<f32>>>,
306 {
307 if self.deadline_ms == 0 {
308 return fut.await;
309 }
310 match tokio::time::timeout(Duration::from_millis(self.deadline_ms), fut).await {
311 Ok(res) => res,
312 Err(_elapsed) => Err(OmniError::manifest_internal(format!(
313 "embedding deadline exceeded after {} ms (provider={:?}, model={})",
314 self.deadline_ms, self.config.provider, self.config.model
315 ))),
316 }
317 }
318
319 async fn embed_text_inner(
320 &self,
321 input: &str,
322 expected_dim: usize,
323 role: EmbedRole,
324 ) -> Result<Vec<f32>> {
325 match self.config.provider {
326 Provider::Mock => Ok(mock_embedding(input, expected_dim)),
327 Provider::Gemini => {
328 self.with_retry(|| self.embed_gemini_once(input, expected_dim, role))
329 .await
330 }
331 Provider::OpenAiCompatible => {
332 self.with_retry(|| self.embed_openai_once(input, expected_dim))
333 .await
334 }
335 }
336 }
337
338 async fn with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T>
339 where
340 F: FnMut() -> Fut,
341 Fut: Future<Output = std::result::Result<T, EmbedCallError>>,
342 {
343 let max_attempt = self.retry_attempts.max(1);
344 let mut attempt = 0usize;
345 loop {
346 attempt += 1;
347 match operation().await {
348 Ok(value) => return Ok(value),
349 Err(err) => {
350 if !err.retryable || attempt >= max_attempt {
351 return Err(OmniError::manifest_internal(err.message));
352 }
353 tracing::warn!(
354 target: "omnigraph::embedding",
355 provider = ?self.config.provider,
356 model = %self.config.model,
357 attempt,
358 error = %err.message,
359 "embedding attempt failed, retrying"
360 );
361 let shift = (attempt - 1).min(10) as u32;
362 let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
363 sleep(Duration::from_millis(delay)).await;
364 }
365 }
366 }
367 }
368
369 async fn embed_gemini_once(
370 &self,
371 input: &str,
372 expected_dim: usize,
373 role: EmbedRole,
374 ) -> std::result::Result<Vec<f32>, EmbedCallError> {
375 let task_type = match role {
376 EmbedRole::Query => GEMINI_QUERY_TASK_TYPE,
377 EmbedRole::Document => GEMINI_DOCUMENT_TASK_TYPE,
378 };
379
380 let response = self
381 .http
382 .post(gemini_endpoint(&self.config.base_url, &self.config.model))
383 .header("x-goog-api-key", &self.config.api_key)
384 .json(&build_gemini_request(
385 &self.config.model,
386 input,
387 expected_dim,
388 task_type,
389 ))
390 .send()
391 .await;
392 let response = match response {
393 Ok(response) => response,
394 Err(err) => {
395 let retryable = err.is_timeout() || err.is_connect() || err.is_request();
396 return Err(EmbedCallError {
397 message: format!("embedding request failed: {}", err),
398 retryable,
399 });
400 }
401 };
402
403 let status = response.status();
404 let body = match response.text().await {
405 Ok(body) => body,
406 Err(err) => {
407 return Err(EmbedCallError {
408 message: format!("embedding response read failed (status {}): {}", status, err),
409 retryable: status.is_server_error() || status.as_u16() == 429,
410 });
411 }
412 };
413
414 if !status.is_success() {
415 let message = parse_google_error_message(&body).unwrap_or(body);
416 return Err(EmbedCallError {
417 message: format!("embedding request failed with status {}: {}", status, message),
418 retryable: status.is_server_error() || status.as_u16() == 429,
419 });
420 }
421
422 let parsed: GeminiEmbedResponse =
423 serde_json::from_str(&body).map_err(|err| EmbedCallError {
424 message: format!("embedding response decode failed: {}", err),
425 retryable: false,
426 })?;
427
428 validate_and_normalize_embedding(parsed.embedding.values, expected_dim).map_err(|message| {
429 EmbedCallError {
430 message,
431 retryable: false,
432 }
433 })
434 }
435
436 async fn embed_openai_once(
437 &self,
438 input: &str,
439 expected_dim: usize,
440 ) -> std::result::Result<Vec<f32>, EmbedCallError> {
441 let response = self
442 .http
443 .post(format!("{}/embeddings", self.config.base_url))
444 .bearer_auth(&self.config.api_key)
445 .json(&build_openai_request(&self.config.model, input, expected_dim))
446 .send()
447 .await;
448 let response = match response {
449 Ok(response) => response,
450 Err(err) => {
451 let retryable = err.is_timeout() || err.is_connect() || err.is_request();
452 return Err(EmbedCallError {
453 message: format!("embedding request failed: {}", err),
454 retryable,
455 });
456 }
457 };
458
459 let status = response.status();
460 let body = match response.text().await {
461 Ok(body) => body,
462 Err(err) => {
463 return Err(EmbedCallError {
464 message: format!("embedding response read failed (status {}): {}", status, err),
465 retryable: status.is_server_error() || status.as_u16() == 429,
466 });
467 }
468 };
469
470 if !status.is_success() {
471 let message = parse_openai_error_message(&body).unwrap_or(body);
472 return Err(EmbedCallError {
473 message: format!("embedding request failed with status {}: {}", status, message),
474 retryable: status.is_server_error() || status.as_u16() == 429,
475 });
476 }
477
478 let parsed: OpenAiEmbeddingResponse =
479 serde_json::from_str(&body).map_err(|err| EmbedCallError {
480 message: format!("embedding response decode failed: {}", err),
481 retryable: false,
482 })?;
483
484 let datum = parsed
486 .data
487 .into_iter()
488 .find(|d| d.index == 0)
489 .ok_or_else(|| EmbedCallError {
490 message: "embedding response missing data[0]".to_string(),
491 retryable: false,
492 })?;
493
494 validate_and_normalize_embedding(datum.embedding, expected_dim).map_err(|message| {
495 EmbedCallError {
496 message,
497 retryable: false,
498 }
499 })
500 }
501}
502
503fn gemini_endpoint(base_url: &str, model: &str) -> String {
504 format!(
505 "{}/models/{}:embedContent",
506 base_url.trim_end_matches('/'),
507 model
508 )
509}
510
511fn build_gemini_request(model: &str, input: &str, expected_dim: usize, task_type: &str) -> Value {
512 json!({
513 "model": format!("models/{}", model),
514 "content": {
515 "parts": [
516 {
517 "text": input
518 }
519 ]
520 },
521 "taskType": task_type,
522 "outputDimensionality": expected_dim,
523 })
524}
525
526fn build_openai_request(model: &str, input: &str, expected_dim: usize) -> Value {
527 json!({
528 "model": model,
529 "input": [input],
530 "dimensions": expected_dim,
531 })
532}
533
534fn validate_and_normalize_embedding(
535 values: Vec<f32>,
536 expected_dim: usize,
537) -> std::result::Result<Vec<f32>, String> {
538 if values.len() != expected_dim {
539 return Err(format!(
540 "embedding dimension mismatch: expected {}, got {}",
541 expected_dim,
542 values.len()
543 ));
544 }
545 Ok(normalize_vector(values))
546}
547
548fn normalize_vector(mut values: Vec<f32>) -> Vec<f32> {
549 let norm = values
550 .iter()
551 .map(|v| (*v as f64) * (*v as f64))
552 .sum::<f64>()
553 .sqrt() as f32;
554 if norm > f32::EPSILON {
555 for value in &mut values {
556 *value /= norm;
557 }
558 }
559 values
560}
561
562fn parse_google_error_message(body: &str) -> Option<String> {
563 serde_json::from_str::<GoogleErrorEnvelope>(body)
564 .ok()
565 .map(|e| e.error.message)
566 .filter(|msg| !msg.trim().is_empty())
567}
568
569fn parse_openai_error_message(body: &str) -> Option<String> {
570 serde_json::from_str::<OpenAiErrorEnvelope>(body)
571 .ok()
572 .map(|e| e.error.message)
573 .filter(|msg| !msg.trim().is_empty())
574}
575
576fn provider_profile(
583 alias: Option<&str>,
584) -> Result<(Provider, &'static str, &'static str, &'static [&'static str])> {
585 Ok(match alias {
586 None | Some("openai-compatible") => (
587 Provider::OpenAiCompatible,
588 DEFAULT_OPENROUTER_BASE_URL,
589 DEFAULT_OPENROUTER_MODEL,
590 &["OPENROUTER_API_KEY", "OPENAI_API_KEY"],
591 ),
592 Some("openai") => (
593 Provider::OpenAiCompatible,
594 DEFAULT_OPENAI_BASE_URL,
595 DEFAULT_OPENAI_MODEL,
596 &["OPENAI_API_KEY"],
597 ),
598 Some("gemini") => (
599 Provider::Gemini,
600 DEFAULT_GEMINI_BASE_URL,
601 DEFAULT_GEMINI_MODEL,
602 &["GEMINI_API_KEY"],
603 ),
604 Some(other) => {
605 return Err(OmniError::manifest_internal(format!(
606 "unknown embedding provider '{}' (expected openai-compatible|openai|gemini|mock)",
607 other
608 )));
609 }
610 })
611}
612
613fn env_string(name: &str) -> Option<String> {
614 std::env::var(name)
615 .ok()
616 .map(|v| v.trim().to_string())
617 .filter(|v| !v.is_empty())
618}
619
620fn parse_env_usize(name: &str, default: usize) -> usize {
621 std::env::var(name)
622 .ok()
623 .and_then(|v| v.parse::<usize>().ok())
624 .filter(|v| *v > 0)
625 .unwrap_or(default)
626}
627
628fn parse_env_u64(name: &str, default: u64) -> u64 {
629 std::env::var(name)
630 .ok()
631 .and_then(|v| v.parse::<u64>().ok())
632 .filter(|v| *v > 0)
633 .unwrap_or(default)
634}
635
636fn parse_env_u64_allow_zero(name: &str, default: u64) -> u64 {
639 std::env::var(name)
640 .ok()
641 .and_then(|v| v.trim().parse::<u64>().ok())
642 .unwrap_or(default)
643}
644
645fn env_flag(name: &str) -> bool {
646 std::env::var(name)
647 .ok()
648 .map(|v| {
649 let s = v.trim().to_ascii_lowercase();
650 s == "1" || s == "true" || s == "yes" || s == "on"
651 })
652 .unwrap_or(false)
653}
654
655fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
656 let mut seed = fnv1a64(input.as_bytes());
657 let mut out = Vec::with_capacity(dim);
658 for _ in 0..dim {
659 seed = xorshift64(seed);
660 let ratio = (seed as f64 / u64::MAX as f64) as f32;
661 out.push((ratio * 2.0) - 1.0);
662 }
663 normalize_vector(out)
664}
665
666fn fnv1a64(bytes: &[u8]) -> u64 {
667 let mut hash = 14695981039346656037u64;
668 for byte in bytes {
669 hash ^= *byte as u64;
670 hash = hash.wrapping_mul(1099511628211u64);
671 }
672 hash
673}
674
675fn xorshift64(mut x: u64) -> u64 {
676 x ^= x << 13;
677 x ^= x >> 7;
678 x ^= x << 17;
679 x
680}
681
682#[cfg(test)]
683mod tests {
684 use std::sync::Arc;
685 use std::sync::atomic::{AtomicUsize, Ordering};
686
687 use serial_test::serial;
688
689 use super::*;
690
691 struct EnvGuard {
692 saved: Vec<(&'static str, Option<String>)>,
693 }
694
695 impl EnvGuard {
696 fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
697 let saved = vars
698 .iter()
699 .map(|(name, _)| (*name, std::env::var(name).ok()))
700 .collect::<Vec<_>>();
701 for (name, value) in vars {
702 unsafe {
703 match value {
704 Some(value) => std::env::set_var(name, value),
705 None => std::env::remove_var(name),
706 }
707 }
708 }
709 Self { saved }
710 }
711 }
712
713 impl Drop for EnvGuard {
714 fn drop(&mut self) {
715 for (name, value) in self.saved.drain(..) {
716 unsafe {
717 match value {
718 Some(value) => std::env::set_var(name, value),
719 None => std::env::remove_var(name),
720 }
721 }
722 }
723 }
724 }
725
726 const EMBED_ENV: &[&str] = &[
729 "OMNIGRAPH_EMBEDDINGS_MOCK",
730 "OMNIGRAPH_EMBED_PROVIDER",
731 "OMNIGRAPH_EMBED_BASE_URL",
732 "OMNIGRAPH_EMBED_MODEL",
733 "OPENROUTER_API_KEY",
734 "OPENAI_API_KEY",
735 "GEMINI_API_KEY",
736 ];
737
738 fn cleared_env(extra: &[(&'static str, Option<&str>)]) -> EnvGuard {
739 let mut vars: Vec<(&'static str, Option<&str>)> =
740 EMBED_ENV.iter().map(|n| (*n, None)).collect();
741 vars.extend_from_slice(extra);
742 EnvGuard::set(&vars)
743 }
744
745 #[tokio::test]
746 async fn mock_embeddings_are_deterministic() {
747 let client = EmbeddingClient::mock_for_tests();
748 let a = client.embed_query_text("alpha", 8).await.unwrap();
749 let b = client.embed_query_text("alpha", 8).await.unwrap();
750 let c = client.embed_query_text("beta", 8).await.unwrap();
751 assert_eq!(a, b);
752 assert_ne!(a, c);
753 assert_eq!(a.len(), 8);
754 }
755
756 #[test]
757 fn gemini_request_uses_model_retrieval_query_and_dimension() {
758 let request =
759 build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_QUERY_TASK_TYPE);
760 assert_eq!(request["model"], "models/gemini-embedding-2");
761 assert_eq!(request["taskType"], GEMINI_QUERY_TASK_TYPE);
762 assert_eq!(request["outputDimensionality"], 4);
763 assert_eq!(request["content"]["parts"][0]["text"], "alpha");
764 }
765
766 #[test]
767 fn gemini_document_request_uses_retrieval_document_task_type() {
768 let request =
769 build_gemini_request("gemini-embedding-2", "alpha", 4, GEMINI_DOCUMENT_TASK_TYPE);
770 assert_eq!(request["taskType"], GEMINI_DOCUMENT_TASK_TYPE);
771 }
772
773 #[test]
774 fn openai_request_uses_model_input_array_and_dimensions() {
775 let request = build_openai_request("openai/text-embedding-3-large", "alpha", 4);
776 assert_eq!(request["model"], "openai/text-embedding-3-large");
777 assert_eq!(request["input"][0], "alpha");
778 assert!(request["input"].is_array());
779 assert_eq!(request["dimensions"], 4);
780 assert!(request.get("taskType").is_none());
781 }
782
783 #[test]
784 fn validate_and_normalize_embedding_enforces_dimension() {
785 let normalized = validate_and_normalize_embedding(vec![3.0, 4.0], 2).unwrap();
786 assert!((normalized[0] - 0.6).abs() < 1e-6);
787 assert!((normalized[1] - 0.8).abs() < 1e-6);
788
789 let err = validate_and_normalize_embedding(vec![1.0, 2.0], 3).unwrap_err();
790 assert!(err.contains("expected 3, got 2"));
791 }
792
793 #[tokio::test]
794 async fn with_retry_retries_retryable_failures() {
795 let client = EmbeddingClient::mock_for_tests();
796 let attempts = Arc::new(AtomicUsize::new(0));
797 let attempts_for_call = Arc::clone(&attempts);
798
799 let value = client
800 .with_retry(|| {
801 let attempts_for_call = Arc::clone(&attempts_for_call);
802 async move {
803 let attempt = attempts_for_call.fetch_add(1, Ordering::SeqCst);
804 if attempt == 0 {
805 Err(EmbedCallError {
806 message: "retry me".to_string(),
807 retryable: true,
808 })
809 } else {
810 Ok("ok")
811 }
812 }
813 })
814 .await
815 .unwrap();
816
817 assert_eq!(value, "ok");
818 assert_eq!(attempts.load(Ordering::SeqCst), 2);
819 }
820
821 #[tokio::test]
822 async fn with_retry_stops_on_non_retryable_failures() {
823 let client = EmbeddingClient::mock_for_tests();
824 let err = client
825 .with_retry(|| async {
826 Err::<(), _>(EmbedCallError {
827 message: "do not retry".to_string(),
828 retryable: false,
829 })
830 })
831 .await
832 .unwrap_err();
833
834 assert!(err.to_string().contains("do not retry"));
835 }
836
837 #[tokio::test]
838 async fn run_with_deadline_aborts_slow_future() {
839 let mut client = EmbeddingClient::mock_for_tests();
840 client.deadline_ms = 20;
841 let slow = async {
842 tokio::time::sleep(Duration::from_secs(5)).await;
843 Ok(vec![0.0_f32])
844 };
845 let err = client.run_with_deadline(slow).await.unwrap_err();
846 assert!(err.to_string().contains("deadline exceeded"));
847 }
848
849 #[tokio::test]
850 async fn run_with_deadline_passes_through_fast_future() {
851 let client = EmbeddingClient::mock_for_tests();
852 let ok = client
853 .run_with_deadline(async { Ok(vec![1.0_f32, 2.0]) })
854 .await
855 .unwrap();
856 assert_eq!(ok, vec![1.0, 2.0]);
857 }
858
859 #[tokio::test]
860 async fn run_with_deadline_zero_is_unbounded() {
861 let mut client = EmbeddingClient::mock_for_tests();
862 client.deadline_ms = 0;
863 let ok = client
864 .run_with_deadline(async { Ok(vec![3.0_f32]) })
865 .await
866 .unwrap();
867 assert_eq!(ok, vec![3.0]);
868 }
869
870 #[test]
871 #[serial]
872 fn from_env_defaults_to_openai_compatible_openrouter() {
873 let _guard = cleared_env(&[("OPENROUTER_API_KEY", Some("sk-test"))]);
874 let config = EmbeddingConfig::from_env().unwrap();
875 assert_eq!(config.provider, Provider::OpenAiCompatible);
876 assert_eq!(config.base_url, DEFAULT_OPENROUTER_BASE_URL);
877 assert_eq!(config.model, DEFAULT_OPENROUTER_MODEL);
878 assert_eq!(config.api_key, "sk-test");
879 }
880
881 #[test]
882 #[serial]
883 fn from_env_openai_alias_uses_openai_host_not_openrouter() {
884 let _guard = cleared_env(&[
885 ("OMNIGRAPH_EMBED_PROVIDER", Some("openai")),
886 ("OPENAI_API_KEY", Some("k")),
887 ]);
888 let config = EmbeddingConfig::from_env().unwrap();
889 assert_eq!(config.provider, Provider::OpenAiCompatible);
890 assert_eq!(config.base_url, DEFAULT_OPENAI_BASE_URL); assert_eq!(config.model, DEFAULT_OPENAI_MODEL); assert_eq!(config.api_key, "k");
893 }
894
895 #[test]
896 #[serial]
897 fn from_env_openai_alias_prefers_openai_key_over_openrouter() {
898 let _guard = cleared_env(&[
900 ("OMNIGRAPH_EMBED_PROVIDER", Some("openai")),
901 ("OPENROUTER_API_KEY", Some("router")),
902 ("OPENAI_API_KEY", Some("openai")),
903 ]);
904 let config = EmbeddingConfig::from_env().unwrap();
905 assert_eq!(config.base_url, DEFAULT_OPENAI_BASE_URL);
906 assert_eq!(config.api_key, "openai");
907 }
908
909 #[test]
910 #[serial]
911 fn from_env_openai_alias_errors_when_only_openrouter_key_is_set() {
912 let _guard = cleared_env(&[
913 ("OMNIGRAPH_EMBED_PROVIDER", Some("openai")),
914 ("OPENROUTER_API_KEY", Some("router")),
915 ]);
916 let err = EmbeddingConfig::from_env().unwrap_err();
917 assert!(err.to_string().contains("OPENAI_API_KEY"), "got: {err}");
918 }
919
920 #[test]
921 fn from_parts_applies_provider_defaults_and_overrides() {
922 let openrouter = EmbeddingConfig::from_parts(None, None, None, "k".to_string()).unwrap();
923 assert_eq!(openrouter.provider, Provider::OpenAiCompatible);
924 assert_eq!(openrouter.base_url, DEFAULT_OPENROUTER_BASE_URL);
925 assert_eq!(openrouter.model, DEFAULT_OPENROUTER_MODEL);
926 assert_eq!(openrouter.api_key, "k");
927
928 let gemini =
929 EmbeddingConfig::from_parts(Some("gemini"), None, None, "g".to_string()).unwrap();
930 assert_eq!(gemini.provider, Provider::Gemini);
931 assert_eq!(gemini.base_url, DEFAULT_GEMINI_BASE_URL);
932
933 let overridden = EmbeddingConfig::from_parts(
934 Some("openai"),
935 Some("https://x/v1/".to_string()),
936 Some("custom".to_string()),
937 "k".to_string(),
938 )
939 .unwrap();
940 assert_eq!(overridden.base_url, "https://x/v1"); assert_eq!(overridden.model, "custom");
942
943 let err =
944 EmbeddingConfig::from_parts(Some("cohere"), None, None, "k".to_string()).unwrap_err();
945 assert!(
946 err.to_string().contains("unknown embedding provider"),
947 "got: {err}"
948 );
949 }
950
951 #[test]
952 #[serial]
953 fn from_parts_mock_honors_an_explicit_model() {
954 let _guard = cleared_env(&[]);
958 let pinned =
959 EmbeddingConfig::from_parts(Some("mock"), None, Some("recorded-x".to_string()), String::new())
960 .unwrap();
961 assert_eq!(pinned.provider, Provider::Mock);
962 assert_eq!(pinned.model, "recorded-x");
963 let bare = EmbeddingConfig::from_parts(Some("mock"), None, None, String::new()).unwrap();
966 assert_eq!(bare.provider, Provider::Mock);
967 assert_eq!(bare.model, "");
968 }
969
970 #[test]
971 #[serial]
972 fn from_env_openai_compatible_prefers_openrouter_key() {
973 let _guard = cleared_env(&[
974 ("OPENROUTER_API_KEY", Some("router")),
975 ("OPENAI_API_KEY", Some("openai")),
976 ]);
977 let config = EmbeddingConfig::from_env().unwrap();
978 assert_eq!(config.api_key, "router");
979 }
980
981 #[test]
982 #[serial]
983 fn from_env_explicit_gemini_provider() {
984 let _guard = cleared_env(&[
985 ("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")),
986 ("GEMINI_API_KEY", Some("g-key")),
987 ]);
988 let config = EmbeddingConfig::from_env().unwrap();
989 assert_eq!(config.provider, Provider::Gemini);
990 assert_eq!(config.base_url, DEFAULT_GEMINI_BASE_URL);
991 assert_eq!(config.model, DEFAULT_GEMINI_MODEL);
992 assert_eq!(config.api_key, "g-key");
993 }
994
995 #[test]
996 #[serial]
997 fn from_env_base_url_and_model_overrides_apply() {
998 let _guard = cleared_env(&[
999 ("OMNIGRAPH_EMBED_PROVIDER", Some("openai-compatible")),
1000 ("OMNIGRAPH_EMBED_BASE_URL", Some("https://example.test/v1/")),
1001 ("OMNIGRAPH_EMBED_MODEL", Some("custom/model")),
1002 ("OPENAI_API_KEY", Some("k")),
1003 ]);
1004 let config = EmbeddingConfig::from_env().unwrap();
1005 assert_eq!(config.base_url, "https://example.test/v1"); assert_eq!(config.model, "custom/model");
1007 }
1008
1009 #[test]
1010 #[serial]
1011 fn from_env_unknown_provider_errors() {
1012 let _guard = cleared_env(&[("OMNIGRAPH_EMBED_PROVIDER", Some("cohere"))]);
1013 let err = EmbeddingConfig::from_env().unwrap_err();
1014 assert!(err.to_string().contains("unknown embedding provider"));
1015 }
1016
1017 #[test]
1018 #[serial]
1019 fn from_env_errors_when_no_key_present() {
1020 let _guard = cleared_env(&[]);
1021 let err = EmbeddingConfig::from_env().unwrap_err();
1022 assert!(err.to_string().contains("OPENROUTER_API_KEY or OPENAI_API_KEY"));
1023 }
1024
1025 #[test]
1026 #[serial]
1027 fn from_env_mock_flag_wins() {
1028 let _guard = cleared_env(&[
1029 ("OMNIGRAPH_EMBEDDINGS_MOCK", Some("1")),
1030 ("OMNIGRAPH_EMBED_PROVIDER", Some("gemini")),
1031 ]);
1032 let config = EmbeddingConfig::from_env().unwrap();
1033 assert_eq!(config.provider, Provider::Mock);
1034 }
1035}