1use async_trait::async_trait;
2use sha1::{Digest, Sha1};
3use std::time::Duration;
4use unicode_segmentation::UnicodeSegmentation;
5
6use crate::crypto;
7use crate::errors::CoreError;
8
9mod cloud;
10mod openai;
11mod sha1_embedder;
12
13pub use cloud::CloudEmbedder;
17pub use openai::OpenAICompatEmbedder;
18pub use sha1_embedder::Sha1Embedder;
19
20pub const EMBEDDING_DIM: usize = 128;
21
22pub const CLOUD_MANAGED_SENTINEL: &str = "cloud-managed";
28
29pub const DEFAULT_OPENAI_EMBEDDING_DIM: usize = 1536;
31pub(crate) const EMBEDDING_PROVIDER_TIMEOUT: Duration = Duration::from_secs(45);
36const EMBEDDING_RETRY_DELAYS_MS: &[u64] = &[100, 300, 700];
37pub const EMBEDDING_BATCH_SIZE: usize = 64;
38
39#[allow(clippy::panic)]
40fn embedding_http_client() -> reqwest::Client {
42 reqwest::Client::builder()
43 .timeout(EMBEDDING_PROVIDER_TIMEOUT)
44 .build()
45 .unwrap_or_else(|e| {
46 panic!("failed to build embedding HTTP client with provider timeout: {e}")
47 })
48}
49
50#[async_trait]
55pub trait Embedder: Send + Sync {
56 async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError>;
57
58 async fn embed_batch(
59 &self,
60 texts: &[String],
61 _rule_ids: Option<&[String]>,
62 ) -> Result<Vec<Vec<f32>>, CoreError> {
63 let mut vectors = Vec::with_capacity(texts.len());
64 for text in texts {
65 vectors.push(self.embed(text).await?);
66 }
67 Ok(vectors)
68 }
69
70 fn dim(&self) -> usize;
71
72 fn is_semantic(&self) -> bool {
77 true
78 }
79}
80
81pub fn store_embedding_key(api_key: &str) -> Result<String, CoreError> {
89 crypto::encrypt_secret(api_key)
90 .map_err(|e| CoreError::Internal(format!("failed to encrypt embedding key: {e}")))
91}
92
93pub fn load_embedding_key(storage_key: &str) -> Result<String, CoreError> {
99 crypto::decrypt_secret(storage_key)
100 .map_err(|e| CoreError::Internal(format!("failed to decrypt embedding key: {e}")))
101}
102
103fn retryable_embedding_status(status: reqwest::StatusCode) -> bool {
104 status == reqwest::StatusCode::REQUEST_TIMEOUT
105 || status == reqwest::StatusCode::BAD_GATEWAY
106 || status == reqwest::StatusCode::SERVICE_UNAVAILABLE
107 || status == reqwest::StatusCode::GATEWAY_TIMEOUT
108 || status.is_server_error()
109}
110
111pub async fn get_embedder() -> Box<dyn Embedder> {
131 if let Ok(settings) = crate::settings::get().await {
135 let ce = &settings.context_engine;
136 let byok_url = ce
137 .embedding_provider_url
138 .as_ref()
139 .map(|u| u.trim().to_owned())
140 .filter(|u| !u.is_empty() && u != CLOUD_MANAGED_SENTINEL);
141 if ce.semantic_embedding
142 && let Some(url) = byok_url
143 {
144 let key = match ce.embedding_provider_key.as_ref() {
150 Some(storage_key) if !storage_key.trim().is_empty() => {
151 if let Ok(plain) = load_embedding_key(storage_key) {
152 Some(plain)
153 } else {
154 eprintln!(
155 "[embedder] failed to decrypt BYOK key; falling back to cloud/SHA1"
156 );
157 None
158 }
159 }
160 _ => Some(String::new()),
162 };
163 if let Some(key) = key {
164 let model = ce
165 .embedding_model
166 .clone()
167 .unwrap_or_else(|| "text-embedding-3-small".to_owned());
168 let dim = ce.embedding_dim.unwrap_or(DEFAULT_OPENAI_EMBEDDING_DIM);
169 return Box::new(OpenAICompatEmbedder::new(url, key, model, dim));
170 }
171 }
172 }
173
174 if let Some(token) = crate::cloud::client::CloudClient::load_token().await {
180 let base = crate::cloud::endpoints::api_base();
181 return Box::new(CloudEmbedder::new(base, token));
182 }
183
184 Box::new(Sha1Embedder::new())
186}
187
188#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum ActiveEmbedderKind {
193 Cloud {
194 model: String,
195 dim: usize,
196 },
197 Byok {
198 provider_host: String,
199 model: String,
200 dim: usize,
201 },
202 Sha1,
203}
204
205impl ActiveEmbedderKind {
206 pub const fn dim(&self) -> usize {
207 match self {
208 Self::Cloud { dim, .. } | Self::Byok { dim, .. } => *dim,
209 Self::Sha1 => EMBEDDING_DIM,
210 }
211 }
212
213 pub fn profile(&self) -> String {
214 match self {
215 Self::Cloud { model, dim } => format!("cloud:{model}:{dim}"),
216 Self::Byok {
217 provider_host,
218 model,
219 dim,
220 } => format!("byok:{provider_host}:{model}:{dim}"),
221 Self::Sha1 => format!("sha1:local:{EMBEDDING_DIM}"),
222 }
223 }
224}
225
226fn byok_from_settings(
239 ce: Option<&crate::models::ContextEngineRecord>,
240) -> Option<ActiveEmbedderKind> {
241 let ce = ce?;
242 if !ce.semantic_embedding {
243 return None;
244 }
245 let url = ce
246 .embedding_provider_url
247 .as_ref()
248 .map(|u| u.trim())
249 .filter(|u| !u.is_empty() && *u != CLOUD_MANAGED_SENTINEL)?;
250 let key_usable = match ce.embedding_provider_key.as_ref() {
255 Some(storage_key) if !storage_key.trim().is_empty() => {
256 load_embedding_key(storage_key).is_ok()
257 }
258 _ => true,
259 };
260 if !key_usable {
261 return None;
262 }
263 let host = url_host(url).map_or_else(|| "byok".to_owned(), str::to_owned);
264 let model = ce
265 .embedding_model
266 .clone()
267 .unwrap_or_else(|| "text-embedding-3-small".to_owned());
268 let dim = ce.embedding_dim.unwrap_or(DEFAULT_OPENAI_EMBEDDING_DIM);
269 Some(ActiveEmbedderKind::Byok {
270 provider_host: host,
271 model,
272 dim,
273 })
274}
275
276pub async fn probe_active_embedder() -> ActiveEmbedderKind {
281 let settings = crate::settings::get().await.ok();
282 if let Some(byok) = byok_from_settings(settings.as_ref().map(|s| &s.context_engine)) {
283 return byok;
284 }
285 if crate::cloud::client::CloudClient::load_token_quiet()
289 .await
290 .is_some()
291 {
292 return ActiveEmbedderKind::Cloud {
293 model: "text-embedding-3-small".to_owned(),
294 dim: DEFAULT_OPENAI_EMBEDDING_DIM,
295 };
296 }
297 ActiveEmbedderKind::Sha1
298}
299
300pub fn probe_active_embedder_sync() -> ActiveEmbedderKind {
308 std::thread::scope(|scope| {
309 scope
310 .spawn(|| {
311 match tokio::runtime::Builder::new_current_thread()
312 .enable_all()
313 .build()
314 {
315 Ok(rt) => rt.block_on(probe_active_embedder()),
316 Err(_) => ActiveEmbedderKind::Sha1,
317 }
318 })
319 .join()
320 .unwrap_or(ActiveEmbedderKind::Sha1)
321 })
322}
323
324pub async fn active_embedding_profile() -> String {
325 probe_active_embedder().await.profile()
326}
327
328fn url_host(s: &str) -> Option<&str> {
329 let after_scheme = s.split_once("://").map_or(s, |(_, rest)| rest);
331 let host = after_scheme.split('/').next().unwrap_or(after_scheme);
332 if host.is_empty() { None } else { Some(host) }
333}
334
335pub fn embed_text(text: &str) -> Vec<f32> {
339 let mut vec = vec![0.0f32; EMBEDDING_DIM];
340 for word in text.unicode_words() {
341 let mut hasher = Sha1::new();
342 hasher.update(word.to_lowercase().as_bytes());
343 let hash = hasher.finalize();
344 for (i, byte) in hash.iter().enumerate() {
345 let dim = i % EMBEDDING_DIM;
346 vec[dim] += if byte & 1 == 0 { 1.0 } else { -1.0 };
347 }
348 }
349 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
350 if norm > 0.0 {
351 for x in &mut vec {
352 *x /= norm;
353 }
354 }
355 vec
356}
357
358#[derive(Debug, Clone, PartialEq)]
359pub struct EmbeddedText {
360 pub vector: Vec<f32>,
361 pub semantic: bool,
362}
363
364pub async fn embed_text_async(text: &str) -> Vec<f32> {
372 embed_text_async_with_timeout(text, None).await.vector
373}
374
375pub async fn embed_text_async_with_timeout(text: &str, timeout: Option<Duration>) -> EmbeddedText {
382 let texts = vec![text.to_owned()];
383 embed_texts_async_with_timeout(&texts, None, timeout)
384 .await
385 .into_iter()
386 .next()
387 .unwrap_or_else(|| sha1_fallback_embedding(text))
388}
389
390pub async fn embed_texts_async_with_timeout(
391 texts: &[String],
392 rule_ids: Option<&[String]>,
393 timeout: Option<Duration>,
394) -> Vec<EmbeddedText> {
395 if texts.is_empty() {
396 return Vec::new();
397 }
398 let embedder = get_embedder().await;
399 embed_texts_with_embedder_and_timeout(embedder.as_ref(), texts, rule_ids, timeout).await
400}
401
402async fn embed_texts_with_embedder_and_timeout(
403 embedder: &dyn Embedder,
404 texts: &[String],
405 rule_ids: Option<&[String]>,
406 timeout: Option<Duration>,
407) -> Vec<EmbeddedText> {
408 let semantic = embedder.is_semantic();
409 let mut embedded = Vec::with_capacity(texts.len());
410 for (chunk_index, text_chunk) in texts.chunks(EMBEDDING_BATCH_SIZE).enumerate() {
411 let start = chunk_index * EMBEDDING_BATCH_SIZE;
412 let end = start + text_chunk.len();
413 let rule_id_chunk = rule_ids.and_then(|ids| ids.get(start..end));
414 let embed_fut = embedder.embed_batch(text_chunk, rule_id_chunk);
415 let result = match timeout {
416 Some(timeout) => match tokio::time::timeout(timeout, embed_fut).await {
417 Ok(result) => result,
418 Err(_) => Err(CoreError::Internal(format!(
419 "embedding provider timed out after {}ms",
420 timeout.as_millis()
421 ))),
422 },
423 None => embed_fut.await,
424 };
425
426 match result {
427 Ok(vectors)
428 if vectors.len() == text_chunk.len()
429 && vectors.iter().all(|vector| !vector.is_empty()) =>
430 {
431 embedded.extend(
432 vectors
433 .into_iter()
434 .map(|vector| EmbeddedText { vector, semantic }),
435 );
436 }
437 Ok(_) => {
438 warn_embedding_fallback_once("provider returned empty or mismatched vector batch");
439 if timeout.is_some() {
440 embedded.extend(
441 texts[start..]
442 .iter()
443 .map(|text| sha1_fallback_embedding(text)),
444 );
445 break;
446 }
447 embedded.extend(text_chunk.iter().map(|text| sha1_fallback_embedding(text)));
448 }
449 Err(e) => {
450 warn_embedding_fallback_once(&format!("provider failed ({e})"));
451 if timeout.is_some() {
452 embedded.extend(
453 texts[start..]
454 .iter()
455 .map(|text| sha1_fallback_embedding(text)),
456 );
457 break;
458 }
459 embedded.extend(text_chunk.iter().map(|text| sha1_fallback_embedding(text)));
460 }
461 }
462 }
463 embedded
464}
465
466fn sha1_fallback_embedding(text: &str) -> EmbeddedText {
467 EmbeddedText {
468 vector: embed_text(text),
469 semantic: false,
470 }
471}
472
473fn warn_embedding_fallback_once(reason: &str) {
490 use std::collections::HashSet;
491 use std::sync::Mutex;
492 static SEEN: Mutex<Option<HashSet<String>>> = Mutex::new(None);
493 let key = classify_reason(reason);
494 crate::activity_stream::record(crate::activity_stream::ActivityPayload::EmbeddingFallback {
495 reason: key.clone(),
496 });
497 let Ok(mut guard) = SEEN.lock() else {
498 return; };
500 let set = guard.get_or_insert_with(HashSet::new);
501 if !set.insert(key.clone()) {
502 return; }
504 eprintln!("[embedding] {}", calm_fallback_summary(&key));
505 eprintln!("{}", actionable_fix_for(&key));
506}
507
508fn calm_fallback_summary(key: &str) -> &'static str {
516 match key {
517 "scope" | "forbidden" | "unauthorized" => {
518 "semantic vectors paused (cloud sign-in needs refresh); \
519 recall continues with file-pattern + keyword matching"
520 }
521 "cap" => {
522 "semantic vectors paused (cloud embedding cap reached); \
523 recall continues with file-pattern + keyword matching"
524 }
525 "timeout" | "network" => {
526 "semantic vectors paused (cloud unreachable); \
527 recall continues with file-pattern + keyword matching"
528 }
529 "empty" => {
530 "semantic vectors paused (provider returned no vector); \
531 recall continues with file-pattern + keyword matching"
532 }
533 _ => {
534 "semantic vectors paused (cloud embedding unavailable); \
535 recall continues with file-pattern + keyword matching"
536 }
537 }
538}
539
540fn classify_reason(reason: &str) -> String {
545 let lower = reason.to_ascii_lowercase();
546 if lower.contains("missing required scope") {
547 return "scope".to_owned();
548 }
549 if lower.contains("embed cap")
550 || lower.contains("embedding cap reached")
551 || lower.contains("embed_cap_reached")
552 {
553 return "cap".to_owned();
554 }
555 if lower.contains("403") || lower.contains("forbidden") {
556 return "forbidden".to_owned();
557 }
558 if lower.contains("401") || lower.contains("unauthorized") {
559 return "unauthorized".to_owned();
560 }
561 if lower.contains("timeout") || lower.contains("timed out") {
562 return "timeout".to_owned();
563 }
564 if lower.contains("connect") || lower.contains("dns") {
565 return "network".to_owned();
566 }
567 if lower.contains("empty vector") {
568 return "empty".to_owned();
569 }
570 "other".to_owned()
571}
572
573fn actionable_fix_for(key: &str) -> &'static str {
577 match key {
578 "scope" => {
579 "[embedding] -> your cloud token is missing the embedding scope. \
580 Re-run `difflore cloud login` to refresh, \
581 or `difflore embeddings setup` to bring your own key."
582 }
583 "forbidden" => {
584 "[embedding] -> cloud rejected the embed request. \
585 Re-run `difflore cloud login` to refresh credentials."
586 }
587 "unauthorized" => "[embedding] -> cloud token expired. Run `difflore cloud login`.",
588 "cap" => {
589 "[embedding] -> cloud embedding cap reached. Recall stays usable via local SHA1 + FTS; \
590 upgrade for unlimited managed embedding, or run `difflore embeddings setup` for BYOK."
591 }
592 "timeout" | "network" => {
593 "[embedding] -> cloud unreachable. Recall stays usable via local SHA1 + FTS; \
594 retry when network recovers, or run `difflore embeddings setup` \
595 for an offline BYOK key."
596 }
597 "empty" => {
598 "[embedding] -> provider returned no vector. \
599 Run `difflore doctor` to inspect the active embedder."
600 }
601 _ => {
602 "[embedding] -> run `difflore doctor` for diagnostics, \
603 or `difflore embeddings setup` to switch to BYOK."
604 }
605 }
606}
607
608pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
616 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
617 let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
618 let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
619 if norm_a == 0.0 || norm_b == 0.0 {
620 return 0.0;
621 }
622 dot / (norm_a * norm_b)
623}
624
625#[cfg(test)]
626#[allow(
627 clippy::expect_used,
628 clippy::unwrap_used,
629 clippy::panic,
630 clippy::float_cmp
631)] mod tests {
633 use super::*;
634
635 #[test]
636 fn embed_text_produces_fixed_dim_vector() {
637 let vec = embed_text("hello world");
638 assert_eq!(vec.len(), EMBEDDING_DIM);
639 }
640
641 #[test]
642 fn embed_text_is_unit_normalized() {
643 let vec = embed_text("let x = 42;");
644 let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
645 assert!((norm - 1.0).abs() < 1e-4, "expected unit-norm, got {norm}");
647 }
648
649 #[test]
650 fn embed_empty_text_returns_zero_vector() {
651 let vec = embed_text("");
652 assert_eq!(vec.len(), EMBEDDING_DIM);
653 assert!(vec.iter().all(|&v| v == 0.0));
654 }
655
656 #[test]
657 fn cosine_similarity_identical_vectors_is_one() {
658 let a = embed_text("fn main() {}");
659 let sim = cosine_similarity(&a, &a);
660 assert!((sim - 1.0).abs() < 1e-4);
661 }
662
663 #[test]
664 fn cosine_similarity_orthogonal_zero_vectors_is_zero() {
665 let a = vec![0.0; EMBEDDING_DIM];
666 let b = vec![0.0; EMBEDDING_DIM];
667 assert_eq!(cosine_similarity(&a, &b), 0.0);
668 }
669
670 #[test]
671 fn cosine_similarity_is_scale_invariant() {
672 let a = [3.0_f32, 4.0];
676 let b = [6.0_f32, 8.0];
677 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
678 let c = [0.0_f32, 5.0];
680 let d = [7.0_f32, 0.0];
681 assert!(cosine_similarity(&c, &d).abs() < 1e-6);
682 }
683
684 #[test]
685 fn provider_failure_fallback_uses_sha1_after_retry() {
686 let fallback = sha1_fallback_embedding("hello world");
687 assert_eq!(
688 fallback.vector,
689 embed_text("hello world"),
690 "provider failures should fall back to local SHA1 only after retry"
691 );
692 assert!(
693 !fallback.semantic,
694 "provider failure fallback is local lexical hash, not semantic"
695 );
696 }
697
698 #[test]
699 fn provider_failure_warning_marks_sha1_as_fallback() {
700 let message = actionable_fix_for("network");
701 assert!(
702 message.contains("local SHA1 + FTS"),
703 "network fallback should name the degraded local path: {message}"
704 );
705 assert!(
706 message.contains("retry when network recovers"),
707 "provider failure guidance should prefer cloud recovery: {message}"
708 );
709 }
710
711 #[tokio::test]
712 async fn sha1_embedder_matches_embed_text() {
713 let embedder = Sha1Embedder::new();
714 assert_eq!(embedder.dim(), EMBEDDING_DIM);
715 let out = embedder.embed("hello world").await.expect("sha1 embed");
716 let expected = embed_text("hello world");
717 assert_eq!(out.len(), EMBEDDING_DIM);
718 assert_eq!(out, expected);
719 }
720
721 #[tokio::test]
722 async fn sha1_embedder_is_deterministic_128d() {
723 let embedder = Sha1Embedder::new();
724 let a = embedder.embed("fn main() {}").await.unwrap();
725 let b = embedder.embed("fn main() {}").await.unwrap();
726 assert_eq!(a.len(), 128);
727 assert_eq!(a, b);
728 }
729
730 struct SlowBatchEmbedder {
731 calls: std::sync::atomic::AtomicUsize,
732 }
733
734 #[async_trait::async_trait]
735 impl Embedder for SlowBatchEmbedder {
736 async fn embed(&self, _text: &str) -> Result<Vec<f32>, CoreError> {
737 unreachable!("test calls embed_batch directly")
738 }
739
740 async fn embed_batch(
741 &self,
742 texts: &[String],
743 _rule_ids: Option<&[String]>,
744 ) -> Result<Vec<Vec<f32>>, CoreError> {
745 self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
746 tokio::time::sleep(Duration::from_millis(50)).await;
747 Ok(texts.iter().map(|_| vec![1.0]).collect())
748 }
749
750 fn dim(&self) -> usize {
751 1
752 }
753 }
754
755 #[tokio::test]
756 async fn timed_batch_embedding_falls_back_for_remaining_batches_after_first_timeout() {
757 let embedder = SlowBatchEmbedder {
758 calls: std::sync::atomic::AtomicUsize::new(0),
759 };
760 let texts = (0..=(EMBEDDING_BATCH_SIZE * 3))
761 .map(|i| format!("rule body {i}"))
762 .collect::<Vec<_>>();
763
764 let embedded = embed_texts_with_embedder_and_timeout(
765 &embedder,
766 &texts,
767 None,
768 Some(Duration::from_millis(5)),
769 )
770 .await;
771
772 assert_eq!(embedded.len(), texts.len());
773 assert_eq!(
774 embedder.calls.load(std::sync::atomic::Ordering::SeqCst),
775 1,
776 "latency-sensitive batch calls should not wait once per provider batch"
777 );
778 for (embedded, text) in embedded.iter().zip(&texts) {
779 assert!(!embedded.semantic);
780 assert_eq!(embedded.vector, embed_text(text));
781 }
782 }
783
784 #[test]
785 fn openai_embedder_endpoint_handles_url_variants() {
786 let cases = &[
787 (
788 "https://api.openai.com/v1",
789 "https://api.openai.com/v1/embeddings",
790 ),
791 (
792 "https://api.example.com/v1/",
793 "https://api.example.com/v1/embeddings",
794 ),
795 (
796 "https://api.example.com/v1/embeddings",
797 "https://api.example.com/v1/embeddings",
798 ),
799 ];
800 for (base, expected) in cases {
801 let e = OpenAICompatEmbedder::new((*base).into(), "k".into(), "m".into(), 128);
802 assert_eq!(e.endpoint(), *expected, "base: {base}");
803 }
804 }
805
806 fn openai_embedding_response(values: &[f32]) -> &'static str {
807 let body = serde_json::json!({ "data": [{ "embedding": values }] }).to_string();
808 let response = format!(
809 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
810 body.len(),
811 body
812 );
813 Box::leak(response.into_boxed_str())
814 }
815
816 #[tokio::test]
817 async fn openai_embedder_accepts_matching_dimension_without_sending_dimensions() {
818 let (url, handle) = spawn_mock(openai_embedding_response(&[0.1, 0.2, 0.3]));
819 let embedder =
820 OpenAICompatEmbedder::new(url, "k".into(), "text-embedding-3-small".into(), 3);
821 let v = embedder
822 .embed("hello")
823 .await
824 .expect("matching dim should succeed");
825 assert_eq!(v.len(), 3);
826 let req = String::from_utf8(handle.join().unwrap()).unwrap();
829 assert!(
830 !req.contains("\"dimensions\""),
831 "request must not send a dimensions field: {req}"
832 );
833 }
834
835 #[tokio::test]
836 async fn openai_embedder_rejects_dimension_mismatch() {
837 let (url, handle) = spawn_mock(openai_embedding_response(&[0.1, 0.2]));
840 let embedder =
841 OpenAICompatEmbedder::new(url, "k".into(), "text-embedding-3-small".into(), 3);
842 let err = embedder
843 .embed("hello")
844 .await
845 .expect_err("dimension mismatch should error");
846 match err {
847 CoreError::Internal(msg) => {
848 assert!(msg.contains("dimensions"), "msg: {msg}");
849 assert!(msg.contains("difflore embeddings setup"), "msg: {msg}");
850 }
851 other => panic!("unexpected err: {other:?}"),
852 }
853 let _ = handle.join();
854 }
855
856 fn openai_batch_response(items: &[(u64, &[f32])]) -> &'static str {
857 let data: Vec<serde_json::Value> = items
858 .iter()
859 .map(|(index, vec)| serde_json::json!({ "index": index, "embedding": vec }))
860 .collect();
861 let body = serde_json::json!({ "data": data }).to_string();
862 let response = format!(
863 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
864 body.len(),
865 body
866 );
867 Box::leak(response.into_boxed_str())
868 }
869
870 #[tokio::test]
871 async fn openai_embedder_batches_into_single_request() {
872 let resp = openai_batch_response(&[(0, &[0.1, 0.2, 0.3]), (1, &[0.4, 0.5, 0.6])]);
875 let (url, handle) = spawn_mock(resp);
876 let embedder = OpenAICompatEmbedder::new(url, "k".into(), "m".into(), 3);
877 let texts = vec!["a".to_owned(), "b".to_owned()];
878 let vectors = embedder
879 .embed_batch(&texts, None)
880 .await
881 .expect("batch embed should succeed");
882 assert_eq!(vectors.len(), 2);
883 assert_eq!(vectors[0], vec![0.1f32, 0.2, 0.3]);
884 assert_eq!(vectors[1], vec![0.4f32, 0.5, 0.6]);
885 let req = String::from_utf8(handle.join().unwrap()).unwrap();
886 assert!(
887 req.contains("\"input\""),
888 "request should batch input: {req}"
889 );
890 }
891
892 #[tokio::test]
893 async fn openai_embedder_batch_orders_by_response_index() {
894 let resp = openai_batch_response(&[(1, &[0.4, 0.5]), (0, &[0.1, 0.2])]);
896 let (url, handle) = spawn_mock(resp);
897 let embedder = OpenAICompatEmbedder::new(url, "k".into(), "m".into(), 2);
898 let texts = vec!["first".to_owned(), "second".to_owned()];
899 let vectors = embedder
900 .embed_batch(&texts, None)
901 .await
902 .expect("batch embed should succeed");
903 assert_eq!(vectors[0], vec![0.1f32, 0.2]);
904 assert_eq!(vectors[1], vec![0.4f32, 0.5]);
905 let _ = handle.join();
906 }
907
908 #[test]
909 fn probe_active_embedder_sync_runs_without_panicking() {
910 let kind = probe_active_embedder_sync();
915 assert!(kind.dim() > 0);
916 }
917
918 #[tokio::test]
919 async fn openai_embedder_omits_auth_header_when_keyless() {
920 let (url, handle) = spawn_mock(openai_batch_response(&[(0, &[0.1, 0.2])]));
921 let embedder = OpenAICompatEmbedder::new(url, String::new(), "m".into(), 2);
923 embedder
924 .embed_batch(&["x".to_owned()], None)
925 .await
926 .expect("keyless embed should succeed");
927 let req = String::from_utf8(handle.join().unwrap())
928 .unwrap()
929 .to_ascii_lowercase();
930 assert!(
931 !req.contains("authorization:"),
932 "keyless request must not send an auth header: {req}"
933 );
934 }
935
936 #[tokio::test]
937 async fn openai_embedder_sends_auth_header_when_keyed() {
938 let (url, handle) = spawn_mock(openai_batch_response(&[(0, &[0.1, 0.2])]));
939 let embedder = OpenAICompatEmbedder::new(url, "sk-x".into(), "m".into(), 2);
940 embedder
941 .embed_batch(&["x".to_owned()], None)
942 .await
943 .expect("keyed embed should succeed");
944 let req = String::from_utf8(handle.join().unwrap())
945 .unwrap()
946 .to_ascii_lowercase();
947 assert!(
948 req.contains("authorization: bearer sk-x"),
949 "keyed request must send bearer auth: {req}"
950 );
951 }
952
953 #[test]
965 #[ignore = "requires OS keyring or stable home dir; run with --ignored"]
966 fn store_and_load_embedding_key_round_trip() {
967 let plaintext = "sk-test-abcdef123456";
968 let storage_key = store_embedding_key(plaintext).expect("store should succeed");
969 assert_ne!(
970 storage_key, plaintext,
971 "stored value must not equal plaintext"
972 );
973 assert!(
974 !storage_key.is_empty(),
975 "storage key should be non-empty hex"
976 );
977 let recovered = load_embedding_key(&storage_key).expect("load should succeed");
978 assert_eq!(recovered, plaintext);
979 }
980
981 use std::io::{Read, Write};
989 use std::net::TcpListener;
990 use std::thread;
991
992 fn spawn_mock(response: &'static str) -> (String, thread::JoinHandle<Vec<u8>>) {
993 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
994 let addr = listener.local_addr().unwrap();
995 let url = format!("http://{addr}");
996 let handle = thread::spawn(move || {
997 let (mut sock, _) = listener.accept().expect("accept");
998 let mut buf = [0u8; 4096];
1001 let n = sock.read(&mut buf).unwrap_or(0);
1002 sock.write_all(response.as_bytes()).ok();
1003 sock.flush().ok();
1004 buf[..n].to_vec()
1005 });
1006 (url, handle)
1007 }
1008
1009 fn spawn_mock_sequence(
1010 responses: Vec<&'static str>,
1011 ) -> (String, thread::JoinHandle<Vec<Vec<u8>>>) {
1012 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
1013 let addr = listener.local_addr().unwrap();
1014 let url = format!("http://{addr}");
1015 let handle = thread::spawn(move || {
1016 let mut requests = Vec::new();
1017 for response in responses {
1018 let (mut sock, _) = listener.accept().expect("accept");
1019 let mut buf = [0u8; 4096];
1020 let n = sock.read(&mut buf).unwrap_or(0);
1021 sock.write_all(response.as_bytes()).ok();
1022 sock.flush().ok();
1023 requests.push(buf[..n].to_vec());
1024 }
1025 requests
1026 });
1027 (url, handle)
1028 }
1029
1030 #[tokio::test]
1031 async fn cloud_embedder_returns_first_vector_on_success() {
1032 let body = serde_json::json!({
1033 "vectors": [[0.1, 0.2, 0.3]],
1034 "model": "text-embedding-3-small",
1035 "dim": 1536,
1036 })
1037 .to_string();
1038 let response = format!(
1039 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
1040 body.len(),
1041 body
1042 );
1043 let response_static: &'static str = Box::leak(response.into_boxed_str());
1045 let (url, handle) = spawn_mock(response_static);
1046 let embedder = CloudEmbedder::with_model(url, "tok".into(), "m".into(), 1536);
1047 let v = embedder.embed("hello").await.expect("embed");
1048 assert_eq!(v.len(), 3);
1049 assert!((v[0] - 0.1).abs() < 1e-4);
1050 let req = handle.join().unwrap();
1051 let req_str = String::from_utf8_lossy(&req);
1052 let req_lower = req_str.to_ascii_lowercase();
1056 assert!(
1057 req_lower.contains("authorization: bearer tok"),
1058 "auth header missing in: {req_str}"
1059 );
1060 assert!(req_str.contains("\"texts\""));
1061 assert!(req_str.contains("hello"));
1062 }
1063
1064 #[tokio::test]
1065 async fn cloud_embedder_maps_5xx_to_core_error() {
1066 let response =
1067 "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 4\r\nConnection: close\r\n\r\nfail";
1068 let (url, handle) = spawn_mock_sequence(vec![response, response, response, response]);
1069 let embedder = CloudEmbedder::with_model(url, "t".into(), "m".into(), 1536);
1070 let err = embedder.embed("x").await.expect_err("should fail");
1071 match err {
1072 CoreError::Internal(msg) => assert!(msg.contains("502"), "msg: {msg}"),
1073 other => panic!("unexpected err: {other:?}"),
1074 }
1075 assert_eq!(handle.join().unwrap().len(), 4);
1076 }
1077
1078 #[tokio::test]
1079 async fn cloud_embedder_retries_transient_5xx_once() {
1080 let ok_body = serde_json::json!({
1081 "vectors": [[0.4, 0.5]],
1082 "model": "text-embedding-3-small",
1083 "dim": 1536,
1084 })
1085 .to_string();
1086 let fail = "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 4\r\nConnection: close\r\n\r\nfail";
1087 let ok = format!(
1088 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
1089 ok_body.len(),
1090 ok_body
1091 );
1092 let ok_static: &'static str = Box::leak(ok.into_boxed_str());
1093 let (url, handle) = spawn_mock_sequence(vec![fail, ok_static]);
1094 let embedder = CloudEmbedder::with_model(url, "tok".into(), "m".into(), 1536);
1095 let v = embedder.embed("hello").await.expect("embed after retry");
1096 assert_eq!(v, vec![0.4, 0.5]);
1097 assert_eq!(handle.join().unwrap().len(), 2);
1098 }
1099
1100 #[tokio::test]
1101 async fn cloud_embedder_dim_is_reported() {
1102 let embedder = CloudEmbedder::new("http://example.invalid".into(), "t".into());
1103 assert_eq!(embedder.dim(), DEFAULT_OPENAI_EMBEDDING_DIM);
1104 assert!(embedder.is_semantic());
1105 }
1106
1107 #[test]
1108 fn cloud_embedder_endpoint_handles_trailing_slash() {
1109 let a = CloudEmbedder::new("http://h/api".into(), "t".into());
1110 let b = CloudEmbedder::new("http://h/api/".into(), "t".into());
1111 assert_eq!(a.endpoint(), "http://h/api/embeddings");
1112 assert_eq!(b.endpoint(), "http://h/api/embeddings");
1113 }
1114
1115 #[test]
1116 fn url_host_strips_scheme_and_path() {
1117 assert_eq!(
1118 url_host("https://api.openai.com/v1"),
1119 Some("api.openai.com")
1120 );
1121 assert_eq!(url_host("http://localhost:8080/x"), Some("localhost:8080"));
1122 assert_eq!(url_host("noscheme/path"), Some("noscheme"));
1123 assert_eq!(url_host(""), None);
1124 }
1125
1126 #[test]
1127 fn load_embedding_key_rejects_invalid_storage_key() {
1128 let err = load_embedding_key("not-valid-hex-$$").unwrap_err();
1133 match err {
1134 CoreError::Internal(msg) => assert!(msg.contains("failed to decrypt")),
1135 other => panic!("unexpected error variant: {other:?}"),
1136 }
1137
1138 let err2 = load_embedding_key("abcd").unwrap_err();
1139 assert!(matches!(err2, CoreError::Internal(_)));
1140 }
1141}