Skip to main content

difflore_core/context/embedding/
mod.rs

1use async_trait::async_trait;
2use sha1::{Digest, Sha1};
3use std::time::Duration;
4use unicode_segmentation::UnicodeSegmentation;
5
6use crate::crypto;
7use crate::errors::CoreError;
8
9mod cloud;
10mod openai;
11mod sha1_embedder;
12
13// Per-provider impls live in their own files; re-export the concrete
14// embedders so external paths (`difflore_core::context::embedding::*`)
15// continue to resolve exactly as before the split.
16pub use cloud::CloudEmbedder;
17pub use openai::OpenAICompatEmbedder;
18pub use sha1_embedder::Sha1Embedder;
19
20pub const EMBEDDING_DIM: usize = 128;
21
22/// Sentinel stored in `context_engine.embedding_provider_url` when the
23/// user explicitly picked the cloud-managed embedding source. The
24/// `get_embedder` chain treats this as "use CloudEmbedder if logged in,
25/// otherwise local lexical hash" — it must never be sent as a real URL to
26/// `OpenAICompatEmbedder`.
27pub const CLOUD_MANAGED_SENTINEL: &str = "cloud-managed";
28
29/// Default dimensionality for `OpenAI` `text-embedding-3-small`.
30pub const DEFAULT_OPENAI_EMBEDDING_DIM: usize = 1536;
31// Cloud may absorb transient upstream embedding timeouts with its own
32// retry window; keep the client budget longer so it does not disconnect
33// early and force the caller into SHA1 fallback. `pub(crate)` so the
34// per-query cold-start retry budget can be asserted to stay under this cap.
35pub(crate) const EMBEDDING_PROVIDER_TIMEOUT: Duration = Duration::from_secs(45);
36const EMBEDDING_RETRY_DELAYS_MS: &[u64] = &[100, 300, 700];
37pub const EMBEDDING_BATCH_SIZE: usize = 64;
38
39#[allow(clippy::panic)]
40// reason: reqwest client construction with a static timeout is unrecoverable for provider setup.
41fn embedding_http_client() -> reqwest::Client {
42    reqwest::Client::builder()
43        .timeout(EMBEDDING_PROVIDER_TIMEOUT)
44        .build()
45        .unwrap_or_else(|e| {
46            panic!("failed to build embedding HTTP client with provider timeout: {e}")
47        })
48}
49
50/// Abstract embedding provider.
51///
52/// Uses `#[async_trait]` to keep the trait object-safe
53/// (i.e. usable as `Box<dyn Embedder>`) while allowing `async fn` syntax.
54#[async_trait]
55pub trait Embedder: Send + Sync {
56    async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError>;
57
58    async fn embed_batch(
59        &self,
60        texts: &[String],
61        _rule_ids: Option<&[String]>,
62    ) -> Result<Vec<Vec<f32>>, CoreError> {
63        let mut vectors = Vec::with_capacity(texts.len());
64        for text in texts {
65            vectors.push(self.embed(text).await?);
66        }
67        Ok(vectors)
68    }
69
70    fn dim(&self) -> usize;
71
72    /// Whether this embedder produces semantically meaningful vectors.
73    /// Defaults to `true` so real embedding providers don't have to opt
74    /// in; lexical-only fallbacks override to `false` so hybrid retrieval
75    /// knows to lean harder on the FTS baseline for keyword-heavy queries.
76    fn is_semantic(&self) -> bool {
77        true
78    }
79}
80
81/// Encrypt an embedding provider API key and return the opaque storage
82/// identifier that should be persisted in settings (`embedding_provider_key`).
83///
84/// Under the hood this uses the AES-GCM master key stored in the OS keyring
85/// (see `crate::crypto`). The returned string is ciphertext hex — it is safe
86/// to store on disk. Callers must round-trip through [`load_embedding_key`]
87/// before using the key with an embedding provider.
88pub fn store_embedding_key(api_key: &str) -> Result<String, CoreError> {
89    crypto::encrypt_secret(api_key)
90        .map_err(|e| CoreError::Internal(format!("failed to encrypt embedding key: {e}")))
91}
92
93/// Decrypt an embedding provider API key from the opaque storage identifier
94/// produced by [`store_embedding_key`].
95///
96/// Returns `CoreError::Internal` on any crypto / keyring failure so callers
97/// can fall back to [`Sha1Embedder`] without panicking.
98pub fn load_embedding_key(storage_key: &str) -> Result<String, CoreError> {
99    crypto::decrypt_secret(storage_key)
100        .map_err(|e| CoreError::Internal(format!("failed to decrypt embedding key: {e}")))
101}
102
103fn retryable_embedding_status(status: reqwest::StatusCode) -> bool {
104    status == reqwest::StatusCode::REQUEST_TIMEOUT
105        || status == reqwest::StatusCode::BAD_GATEWAY
106        || status == reqwest::StatusCode::SERVICE_UNAVAILABLE
107        || status == reqwest::StatusCode::GATEWAY_TIMEOUT
108        || status.is_server_error()
109}
110
111/// Resolve the configured embedder from settings.
112///
113/// Priority chain (first match wins):
114///   1. `OpenAICompatEmbedder` — if the user explicitly configured a BYOK
115///      provider (`semantic_embedding=true` + a real `embedding_provider_url`,
116///      i.e. not the cloud-managed sentinel). This takes precedence over a
117///      stored cloud token: a user who ran `difflore embeddings setup` to
118///      bring their own key wants that provider used even while logged in.
119///   2. `CloudEmbedder` — if the user is logged in to cloud. Best-effort:
120///      we don't probe the network here, we just trust the stored token.
121///      On request failure the caller falls back to local SHA1 via
122///      `embed_text_async`.
123///   3. [`Sha1Embedder`] — deterministic offline fallback.
124///
125/// Falls back to [`Sha1Embedder`] on any settings error, so callers never
126/// have to deal with embedder construction failures. Runtime paths should
127/// treat SHA1 as a degraded fallback after cloud/BYOK retries, never as the
128/// preferred path when cloud or BYOK is available. `probe_active_embedder`
129/// mirrors this same order — keep the two in sync.
130pub async fn get_embedder() -> Box<dyn Embedder> {
131    // Step 1 — explicit BYOK provider configured via settings takes
132    // precedence over a cloud token. The cloud-managed sentinel is not a real
133    // URL, so it is excluded here and handled by the cloud branch below.
134    if let Ok(settings) = crate::settings::get().await {
135        let ce = &settings.context_engine;
136        let byok_url = ce
137            .embedding_provider_url
138            .as_ref()
139            .map(|u| u.trim().to_owned())
140            .filter(|u| !u.is_empty() && u != CLOUD_MANAGED_SENTINEL);
141        if ce.semantic_embedding
142            && let Some(url) = byok_url
143        {
144            // `embedding_provider_key` is a keyring storage identifier (a
145            // ciphertext hex blob produced by `store_embedding_key`).
146            // Decrypt it to get the real API key; on decrypt failure, warn
147            // and fall through to the cloud/SHA1 branches rather than sending
148            // empty credentials to the provider.
149            let key = match ce.embedding_provider_key.as_ref() {
150                Some(storage_key) if !storage_key.trim().is_empty() => {
151                    if let Ok(plain) = load_embedding_key(storage_key) {
152                        Some(plain)
153                    } else {
154                        eprintln!(
155                            "[embedder] failed to decrypt BYOK key; falling back to cloud/SHA1"
156                        );
157                        None
158                    }
159                }
160                // BYOK without a stored key (some local providers need none).
161                _ => Some(String::new()),
162            };
163            if let Some(key) = key {
164                let model = ce
165                    .embedding_model
166                    .clone()
167                    .unwrap_or_else(|| "text-embedding-3-small".to_owned());
168                let dim = ce.embedding_dim.unwrap_or(DEFAULT_OPENAI_EMBEDDING_DIM);
169                return Box::new(OpenAICompatEmbedder::new(url, key, model, dim));
170            }
171        }
172    }
173
174    // Step 2 — cloud-managed. Reads the same `cloud-auth.db` token used by the
175    // rest of the cloud client. No network probe at construction time: cheap
176    // and non-blocking. If the request later fails, the wrapper in
177    // `embed_text_async` falls back to local SHA1 after the provider retry
178    // path has been exhausted.
179    if let Some(token) = crate::cloud::client::CloudClient::load_token().await {
180        let base = crate::cloud::endpoints::api_base();
181        return Box::new(CloudEmbedder::new(base, token));
182    }
183
184    // Step 3 — deterministic offline SHA1 fallback.
185    Box::new(Sha1Embedder::new())
186}
187
188/// Lightweight tag for the active embedder, returned by
189/// [`probe_active_embedder`] so callers can render the right hint copy without
190/// re-implementing the priority chain.
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum ActiveEmbedderKind {
193    Cloud {
194        model: String,
195        dim: usize,
196    },
197    Byok {
198        provider_host: String,
199        model: String,
200        dim: usize,
201    },
202    Sha1,
203}
204
205impl ActiveEmbedderKind {
206    pub const fn dim(&self) -> usize {
207        match self {
208            Self::Cloud { dim, .. } | Self::Byok { dim, .. } => *dim,
209            Self::Sha1 => EMBEDDING_DIM,
210        }
211    }
212
213    pub fn profile(&self) -> String {
214        match self {
215            Self::Cloud { model, dim } => format!("cloud:{model}:{dim}"),
216            Self::Byok {
217                provider_host,
218                model,
219                dim,
220            } => format!("byok:{provider_host}:{model}:{dim}"),
221            Self::Sha1 => format!("sha1:local:{EMBEDDING_DIM}"),
222        }
223    }
224}
225
226/// Pure step shared by the async/sync probes: returns `Byok` iff settings
227/// select an explicit, usable BYOK provider (semantic_embedding on, a real
228/// non-sentinel URL, and a decryptable key if one is configured). `None` means
229/// BYOK does not apply — the caller then falls through to cloud, then SHA1,
230/// exactly like `get_embedder`.
231///
232/// Keeping this a separate, pure step lets [`probe_active_embedder`] defer the
233/// async cloud-token load until BYOK is ruled out, so a BYOK/`--no-key` user
234/// never triggers a `cloud-auth.db` decrypt / keyring access just to render
235/// status. This is the single source of truth for "is BYOK active" so
236/// diagnostics (`difflore doctor`), the TUI status bar, and the MCP hook never
237/// drift from the runtime resolver.
238fn byok_from_settings(
239    ce: Option<&crate::models::ContextEngineRecord>,
240) -> Option<ActiveEmbedderKind> {
241    let ce = ce?;
242    if !ce.semantic_embedding {
243        return None;
244    }
245    let url = ce
246        .embedding_provider_url
247        .as_ref()
248        .map(|u| u.trim())
249        .filter(|u| !u.is_empty() && *u != CLOUD_MANAGED_SENTINEL)?;
250    // A configured-but-undecryptable key means BYOK is not actually usable, so
251    // the resolver would fall through to cloud/SHA1. Reporting Byok here would
252    // mislabel the backend and let mismatched vectors persist under a BYOK
253    // embedding profile.
254    let key_usable = match ce.embedding_provider_key.as_ref() {
255        Some(storage_key) if !storage_key.trim().is_empty() => {
256            load_embedding_key(storage_key).is_ok()
257        }
258        _ => true,
259    };
260    if !key_usable {
261        return None;
262    }
263    let host = url_host(url).map_or_else(|| "byok".to_owned(), str::to_owned);
264    let model = ce
265        .embedding_model
266        .clone()
267        .unwrap_or_else(|| "text-embedding-3-small".to_owned());
268    let dim = ce.embedding_dim.unwrap_or(DEFAULT_OPENAI_EMBEDDING_DIM);
269    Some(ActiveEmbedderKind::Byok {
270        provider_host: host,
271        model,
272        dim,
273    })
274}
275
276/// Inspect the currently-resolved embedder and report its kind. Mirrors the
277/// priority chain in `get_embedder` (explicit BYOK → cloud → SHA1) without
278/// allocating the actual embedder. BYOK is checked first so a BYOK user never
279/// triggers a cloud-auth.db decrypt just to report status.
280pub async fn probe_active_embedder() -> ActiveEmbedderKind {
281    let settings = crate::settings::get().await.ok();
282    if let Some(byok) = byok_from_settings(settings.as_ref().map(|s| &s.context_engine)) {
283        return byok;
284    }
285    // `load_token_quiet`: this probe is a read-only status check (the TUI polls
286    // it on a 500ms cache), so a corrupt token must not spam stderr. Real
287    // recall/cloud paths use the loud `load_token`.
288    if crate::cloud::client::CloudClient::load_token_quiet()
289        .await
290        .is_some()
291    {
292        return ActiveEmbedderKind::Cloud {
293            model: "text-embedding-3-small".to_owned(),
294            dim: DEFAULT_OPENAI_EMBEDDING_DIM,
295        };
296    }
297    ActiveEmbedderKind::Sha1
298}
299
300/// Sync sibling of [`probe_active_embedder`] for non-async render paths (the
301/// TUI status bar). Runs the authoritative async probe on a short-lived
302/// scratch runtime on its own thread, so it returns the EXACT same answer as
303/// the runtime resolver — real token load (`DIFFLORE_TOKEN` env + decrypted
304/// `cloud-auth.db` row) and BYOK key validation — with no separate sync
305/// detection logic that could drift. The caller caches the result, so this
306/// only spawns on a cache miss.
307pub fn probe_active_embedder_sync() -> ActiveEmbedderKind {
308    std::thread::scope(|scope| {
309        scope
310            .spawn(|| {
311                match tokio::runtime::Builder::new_current_thread()
312                    .enable_all()
313                    .build()
314                {
315                    Ok(rt) => rt.block_on(probe_active_embedder()),
316                    Err(_) => ActiveEmbedderKind::Sha1,
317                }
318            })
319            .join()
320            .unwrap_or(ActiveEmbedderKind::Sha1)
321    })
322}
323
324pub async fn active_embedding_profile() -> String {
325    probe_active_embedder().await.profile()
326}
327
328fn url_host(s: &str) -> Option<&str> {
329    // Cheap parser — strip scheme then truncate at first `/` or `:port`.
330    let after_scheme = s.split_once("://").map_or(s, |(_, rest)| rest);
331    let host = after_scheme.split('/').next().unwrap_or(after_scheme);
332    if host.is_empty() { None } else { Some(host) }
333}
334
335/// Synchronous SHA1 embedding — retained as the explicit local lexical
336/// embedder for offline users who have not configured cloud/BYOK semantic
337/// embeddings.
338pub fn embed_text(text: &str) -> Vec<f32> {
339    let mut vec = vec![0.0f32; EMBEDDING_DIM];
340    for word in text.unicode_words() {
341        let mut hasher = Sha1::new();
342        hasher.update(word.to_lowercase().as_bytes());
343        let hash = hasher.finalize();
344        for (i, byte) in hash.iter().enumerate() {
345            let dim = i % EMBEDDING_DIM;
346            vec[dim] += if byte & 1 == 0 { 1.0 } else { -1.0 };
347        }
348    }
349    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
350    if norm > 0.0 {
351        for x in &mut vec {
352            *x /= norm;
353        }
354    }
355    vec
356}
357
358#[derive(Debug, Clone, PartialEq)]
359pub struct EmbeddedText {
360    pub vector: Vec<f32>,
361    pub semantic: bool,
362}
363
364/// Async embedding helper that tries the configured embedder first and
365/// keeps retrieval usable if the provider is unavailable.
366///
367/// Callers that want a guaranteed `Vec<f32>` (never an error) should use
368/// this function. If the user has no semantic provider or a configured
369/// semantic provider fails after retry, this returns the local SHA1 vector
370/// as a degraded fallback.
371pub async fn embed_text_async(text: &str) -> Vec<f32> {
372    embed_text_async_with_timeout(text, None).await.vector
373}
374
375/// Async embedding helper for latency-sensitive paths.
376///
377/// When `timeout` is present, provider calls that exceed the budget fall
378/// back to local SHA1 after retry. The returned `semantic` flag describes
379/// the actual vector, not merely the configured provider, so retrieval can
380/// weight FTS more heavily after any provider failure or timeout.
381pub async fn embed_text_async_with_timeout(text: &str, timeout: Option<Duration>) -> EmbeddedText {
382    let texts = vec![text.to_owned()];
383    embed_texts_async_with_timeout(&texts, None, timeout)
384        .await
385        .into_iter()
386        .next()
387        .unwrap_or_else(|| sha1_fallback_embedding(text))
388}
389
390pub async fn embed_texts_async_with_timeout(
391    texts: &[String],
392    rule_ids: Option<&[String]>,
393    timeout: Option<Duration>,
394) -> Vec<EmbeddedText> {
395    if texts.is_empty() {
396        return Vec::new();
397    }
398    let embedder = get_embedder().await;
399    embed_texts_with_embedder_and_timeout(embedder.as_ref(), texts, rule_ids, timeout).await
400}
401
402async fn embed_texts_with_embedder_and_timeout(
403    embedder: &dyn Embedder,
404    texts: &[String],
405    rule_ids: Option<&[String]>,
406    timeout: Option<Duration>,
407) -> Vec<EmbeddedText> {
408    let semantic = embedder.is_semantic();
409    let mut embedded = Vec::with_capacity(texts.len());
410    for (chunk_index, text_chunk) in texts.chunks(EMBEDDING_BATCH_SIZE).enumerate() {
411        let start = chunk_index * EMBEDDING_BATCH_SIZE;
412        let end = start + text_chunk.len();
413        let rule_id_chunk = rule_ids.and_then(|ids| ids.get(start..end));
414        let embed_fut = embedder.embed_batch(text_chunk, rule_id_chunk);
415        let result = match timeout {
416            Some(timeout) => match tokio::time::timeout(timeout, embed_fut).await {
417                Ok(result) => result,
418                Err(_) => Err(CoreError::Internal(format!(
419                    "embedding provider timed out after {}ms",
420                    timeout.as_millis()
421                ))),
422            },
423            None => embed_fut.await,
424        };
425
426        match result {
427            Ok(vectors)
428                if vectors.len() == text_chunk.len()
429                    && vectors.iter().all(|vector| !vector.is_empty()) =>
430            {
431                embedded.extend(
432                    vectors
433                        .into_iter()
434                        .map(|vector| EmbeddedText { vector, semantic }),
435                );
436            }
437            Ok(_) => {
438                warn_embedding_fallback_once("provider returned empty or mismatched vector batch");
439                if timeout.is_some() {
440                    embedded.extend(
441                        texts[start..]
442                            .iter()
443                            .map(|text| sha1_fallback_embedding(text)),
444                    );
445                    break;
446                }
447                embedded.extend(text_chunk.iter().map(|text| sha1_fallback_embedding(text)));
448            }
449            Err(e) => {
450                warn_embedding_fallback_once(&format!("provider failed ({e})"));
451                if timeout.is_some() {
452                    embedded.extend(
453                        texts[start..]
454                            .iter()
455                            .map(|text| sha1_fallback_embedding(text)),
456                    );
457                    break;
458                }
459                embedded.extend(text_chunk.iter().map(|text| sha1_fallback_embedding(text)));
460            }
461        }
462    }
463    embedded
464}
465
466fn sha1_fallback_embedding(text: &str) -> EmbeddedText {
467    EmbeddedText {
468        vector: embed_text(text),
469        semantic: false,
470    }
471}
472
473/// Record an embedding fallback, and print a calm warning at most once per
474/// process per distinct cause.
475///
476/// The activity event is recorded on EVERY call, not just the first: the
477/// freshness-skip and health diagnostics
478/// (`recent_embedding_fallback[_strict]`) read it to decide whether the remote
479/// provider is currently down. Deduping the *record* per process would let a
480/// long-lived MCP / hook server's down-signal go stale after the recency
481/// window, so the freshness skip would stop engaging and the futile corpus
482/// re-embed would resume for the rest of the process. Each failed embed records
483/// at most one event (the batch loop breaks on the first failure under a
484/// timeout), so this does not flood the capped activity log.
485///
486/// Only the console print is deduped: without it, a single `difflore recall`
487/// could emit one identical line per failed rule chunk. With it, the user sees
488/// one clear line + the recovery command per cause class per process.
489fn warn_embedding_fallback_once(reason: &str) {
490    use std::collections::HashSet;
491    use std::sync::Mutex;
492    static SEEN: Mutex<Option<HashSet<String>>> = Mutex::new(None);
493    let key = classify_reason(reason);
494    crate::activity_stream::record(crate::activity_stream::ActivityPayload::EmbeddingFallback {
495        reason: key.clone(),
496    });
497    let Ok(mut guard) = SEEN.lock() else {
498        return; // poisoned mutex — event already recorded; just skip the print
499    };
500    let set = guard.get_or_insert_with(HashSet::new);
501    if !set.insert(key.clone()) {
502        return; // already printed this class of failure this process
503    }
504    eprintln!("[embedding] {}", calm_fallback_summary(&key));
505    eprintln!("{}", actionable_fix_for(&key));
506}
507
508/// A calm, user-facing summary of an embedding fallback, classified by the same
509/// stable key as [`actionable_fix_for`]. Mirrors the `status` line ("semantic
510/// vectors paused; recall still works with file-pattern + keyword matching") so
511/// a transient provider hiccup reads as graceful degradation rather than
512/// breakage. The raw transport error (URLs, internal "after N attempts" detail)
513/// is deliberately kept off the hot path — `difflore doctor` is the place for
514/// the verbose diagnostic.
515fn calm_fallback_summary(key: &str) -> &'static str {
516    match key {
517        "scope" | "forbidden" | "unauthorized" => {
518            "semantic vectors paused (cloud sign-in needs refresh); \
519             recall continues with file-pattern + keyword matching"
520        }
521        "cap" => {
522            "semantic vectors paused (cloud embedding cap reached); \
523             recall continues with file-pattern + keyword matching"
524        }
525        "timeout" | "network" => {
526            "semantic vectors paused (cloud unreachable); \
527             recall continues with file-pattern + keyword matching"
528        }
529        "empty" => {
530            "semantic vectors paused (provider returned no vector); \
531             recall continues with file-pattern + keyword matching"
532        }
533        _ => {
534            "semantic vectors paused (cloud embedding unavailable); \
535             recall continues with file-pattern + keyword matching"
536        }
537    }
538}
539
540/// Bucket the raw error string into a short stable key so we can dedup
541/// "provider failed (Internal error: cloud embedding endpoint returned
542/// 403 Forbidden; semantic recall will fall back...)" across call sites
543/// without writing the full message into the dedup set.
544fn classify_reason(reason: &str) -> String {
545    let lower = reason.to_ascii_lowercase();
546    if lower.contains("missing required scope") {
547        return "scope".to_owned();
548    }
549    if lower.contains("embed cap")
550        || lower.contains("embedding cap reached")
551        || lower.contains("embed_cap_reached")
552    {
553        return "cap".to_owned();
554    }
555    if lower.contains("403") || lower.contains("forbidden") {
556        return "forbidden".to_owned();
557    }
558    if lower.contains("401") || lower.contains("unauthorized") {
559        return "unauthorized".to_owned();
560    }
561    if lower.contains("timeout") || lower.contains("timed out") {
562        return "timeout".to_owned();
563    }
564    if lower.contains("connect") || lower.contains("dns") {
565        return "network".to_owned();
566    }
567    if lower.contains("empty vector") {
568        return "empty".to_owned();
569    }
570    "other".to_owned()
571}
572
573/// Return an actionable next-step the user can run to recover. Tailored
574/// per failure class so the user doesn't get a generic "check your
575/// configuration" wall.
576fn actionable_fix_for(key: &str) -> &'static str {
577    match key {
578        "scope" => {
579            "[embedding] -> your cloud token is missing the embedding scope. \
580             Re-run `difflore cloud login` to refresh, \
581             or `difflore embeddings setup` to bring your own key."
582        }
583        "forbidden" => {
584            "[embedding] -> cloud rejected the embed request. \
585             Re-run `difflore cloud login` to refresh credentials."
586        }
587        "unauthorized" => "[embedding] -> cloud token expired. Run `difflore cloud login`.",
588        "cap" => {
589            "[embedding] -> cloud embedding cap reached. Recall stays usable via local SHA1 + FTS; \
590             upgrade for unlimited managed embedding, or run `difflore embeddings setup` for BYOK."
591        }
592        "timeout" | "network" => {
593            "[embedding] -> cloud unreachable. Recall stays usable via local SHA1 + FTS; \
594             retry when network recovers, or run `difflore embeddings setup` \
595             for an offline BYOK key."
596        }
597        "empty" => {
598            "[embedding] -> provider returned no vector. \
599             Run `difflore doctor` to inspect the active embedder."
600        }
601        _ => {
602            "[embedding] -> run `difflore doctor` for diagnostics, \
603             or `difflore embeddings setup` to switch to BYOK."
604        }
605    }
606}
607
608/// True cosine similarity in `[-1, 1]`, normalizing both inputs.
609///
610/// The local SHA1 embedder returns unit-norm vectors, but managed/BYOK
611/// providers may not, so a bare dot product would rank by magnitude and
612/// disagree with the ANN cosine path. Dividing by the norms keeps the
613/// linear-scan fallback consistent for any provider. Zero-norm inputs
614/// return `0.0` rather than `NaN`.
615pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
616    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
617    let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
618    let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
619    if norm_a == 0.0 || norm_b == 0.0 {
620        return 0.0;
621    }
622    dot / (norm_a * norm_b)
623}
624
625#[cfg(test)]
626#[allow(
627    clippy::expect_used,
628    clippy::unwrap_used,
629    clippy::panic,
630    clippy::float_cmp
631)] // reason: test code — explicit panic/expect/exact-cmp on known-finite vectors.
632mod tests {
633    use super::*;
634
635    #[test]
636    fn embed_text_produces_fixed_dim_vector() {
637        let vec = embed_text("hello world");
638        assert_eq!(vec.len(), EMBEDDING_DIM);
639    }
640
641    #[test]
642    fn embed_text_is_unit_normalized() {
643        let vec = embed_text("let x = 42;");
644        let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
645        // allow small rounding error
646        assert!((norm - 1.0).abs() < 1e-4, "expected unit-norm, got {norm}");
647    }
648
649    #[test]
650    fn embed_empty_text_returns_zero_vector() {
651        let vec = embed_text("");
652        assert_eq!(vec.len(), EMBEDDING_DIM);
653        assert!(vec.iter().all(|&v| v == 0.0));
654    }
655
656    #[test]
657    fn cosine_similarity_identical_vectors_is_one() {
658        let a = embed_text("fn main() {}");
659        let sim = cosine_similarity(&a, &a);
660        assert!((sim - 1.0).abs() < 1e-4);
661    }
662
663    #[test]
664    fn cosine_similarity_orthogonal_zero_vectors_is_zero() {
665        let a = vec![0.0; EMBEDDING_DIM];
666        let b = vec![0.0; EMBEDDING_DIM];
667        assert_eq!(cosine_similarity(&a, &b), 0.0);
668    }
669
670    #[test]
671    fn cosine_similarity_is_scale_invariant() {
672        // Same direction, different magnitudes: true cosine is 1.0. A bare
673        // dot product would return 50.0 here, mis-ranking non-unit-norm
674        // (BYOK) embeddings in the linear-scan fallback.
675        let a = [3.0_f32, 4.0];
676        let b = [6.0_f32, 8.0];
677        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
678        // Orthogonal directions: cosine 0.0 regardless of magnitude.
679        let c = [0.0_f32, 5.0];
680        let d = [7.0_f32, 0.0];
681        assert!(cosine_similarity(&c, &d).abs() < 1e-6);
682    }
683
684    #[test]
685    fn provider_failure_fallback_uses_sha1_after_retry() {
686        let fallback = sha1_fallback_embedding("hello world");
687        assert_eq!(
688            fallback.vector,
689            embed_text("hello world"),
690            "provider failures should fall back to local SHA1 only after retry"
691        );
692        assert!(
693            !fallback.semantic,
694            "provider failure fallback is local lexical hash, not semantic"
695        );
696    }
697
698    #[test]
699    fn provider_failure_warning_marks_sha1_as_fallback() {
700        let message = actionable_fix_for("network");
701        assert!(
702            message.contains("local SHA1 + FTS"),
703            "network fallback should name the degraded local path: {message}"
704        );
705        assert!(
706            message.contains("retry when network recovers"),
707            "provider failure guidance should prefer cloud recovery: {message}"
708        );
709    }
710
711    #[tokio::test]
712    async fn sha1_embedder_matches_embed_text() {
713        let embedder = Sha1Embedder::new();
714        assert_eq!(embedder.dim(), EMBEDDING_DIM);
715        let out = embedder.embed("hello world").await.expect("sha1 embed");
716        let expected = embed_text("hello world");
717        assert_eq!(out.len(), EMBEDDING_DIM);
718        assert_eq!(out, expected);
719    }
720
721    #[tokio::test]
722    async fn sha1_embedder_is_deterministic_128d() {
723        let embedder = Sha1Embedder::new();
724        let a = embedder.embed("fn main() {}").await.unwrap();
725        let b = embedder.embed("fn main() {}").await.unwrap();
726        assert_eq!(a.len(), 128);
727        assert_eq!(a, b);
728    }
729
730    struct SlowBatchEmbedder {
731        calls: std::sync::atomic::AtomicUsize,
732    }
733
734    #[async_trait::async_trait]
735    impl Embedder for SlowBatchEmbedder {
736        async fn embed(&self, _text: &str) -> Result<Vec<f32>, CoreError> {
737            unreachable!("test calls embed_batch directly")
738        }
739
740        async fn embed_batch(
741            &self,
742            texts: &[String],
743            _rule_ids: Option<&[String]>,
744        ) -> Result<Vec<Vec<f32>>, CoreError> {
745            self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
746            tokio::time::sleep(Duration::from_millis(50)).await;
747            Ok(texts.iter().map(|_| vec![1.0]).collect())
748        }
749
750        fn dim(&self) -> usize {
751            1
752        }
753    }
754
755    #[tokio::test]
756    async fn timed_batch_embedding_falls_back_for_remaining_batches_after_first_timeout() {
757        let embedder = SlowBatchEmbedder {
758            calls: std::sync::atomic::AtomicUsize::new(0),
759        };
760        let texts = (0..=(EMBEDDING_BATCH_SIZE * 3))
761            .map(|i| format!("rule body {i}"))
762            .collect::<Vec<_>>();
763
764        let embedded = embed_texts_with_embedder_and_timeout(
765            &embedder,
766            &texts,
767            None,
768            Some(Duration::from_millis(5)),
769        )
770        .await;
771
772        assert_eq!(embedded.len(), texts.len());
773        assert_eq!(
774            embedder.calls.load(std::sync::atomic::Ordering::SeqCst),
775            1,
776            "latency-sensitive batch calls should not wait once per provider batch"
777        );
778        for (embedded, text) in embedded.iter().zip(&texts) {
779            assert!(!embedded.semantic);
780            assert_eq!(embedded.vector, embed_text(text));
781        }
782    }
783
784    #[test]
785    fn openai_embedder_endpoint_handles_url_variants() {
786        let cases = &[
787            (
788                "https://api.openai.com/v1",
789                "https://api.openai.com/v1/embeddings",
790            ),
791            (
792                "https://api.example.com/v1/",
793                "https://api.example.com/v1/embeddings",
794            ),
795            (
796                "https://api.example.com/v1/embeddings",
797                "https://api.example.com/v1/embeddings",
798            ),
799        ];
800        for (base, expected) in cases {
801            let e = OpenAICompatEmbedder::new((*base).into(), "k".into(), "m".into(), 128);
802            assert_eq!(e.endpoint(), *expected, "base: {base}");
803        }
804    }
805
806    fn openai_embedding_response(values: &[f32]) -> &'static str {
807        let body = serde_json::json!({ "data": [{ "embedding": values }] }).to_string();
808        let response = format!(
809            "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
810            body.len(),
811            body
812        );
813        Box::leak(response.into_boxed_str())
814    }
815
816    #[tokio::test]
817    async fn openai_embedder_accepts_matching_dimension_without_sending_dimensions() {
818        let (url, handle) = spawn_mock(openai_embedding_response(&[0.1, 0.2, 0.3]));
819        let embedder =
820            OpenAICompatEmbedder::new(url, "k".into(), "text-embedding-3-small".into(), 3);
821        let v = embedder
822            .embed("hello")
823            .await
824            .expect("matching dim should succeed");
825        assert_eq!(v.len(), 3);
826        // Must NOT send a `dimensions` field: models like ada-002 and strict
827        // local providers reject it. Length is validated from the response.
828        let req = String::from_utf8(handle.join().unwrap()).unwrap();
829        assert!(
830            !req.contains("\"dimensions\""),
831            "request must not send a dimensions field: {req}"
832        );
833    }
834
835    #[tokio::test]
836    async fn openai_embedder_rejects_dimension_mismatch() {
837        // Provider returns 2 dims while 3 are configured — must error rather than
838        // store mismatched-length vectors under the configured profile.
839        let (url, handle) = spawn_mock(openai_embedding_response(&[0.1, 0.2]));
840        let embedder =
841            OpenAICompatEmbedder::new(url, "k".into(), "text-embedding-3-small".into(), 3);
842        let err = embedder
843            .embed("hello")
844            .await
845            .expect_err("dimension mismatch should error");
846        match err {
847            CoreError::Internal(msg) => {
848                assert!(msg.contains("dimensions"), "msg: {msg}");
849                assert!(msg.contains("difflore embeddings setup"), "msg: {msg}");
850            }
851            other => panic!("unexpected err: {other:?}"),
852        }
853        let _ = handle.join();
854    }
855
856    fn openai_batch_response(items: &[(u64, &[f32])]) -> &'static str {
857        let data: Vec<serde_json::Value> = items
858            .iter()
859            .map(|(index, vec)| serde_json::json!({ "index": index, "embedding": vec }))
860            .collect();
861        let body = serde_json::json!({ "data": data }).to_string();
862        let response = format!(
863            "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
864            body.len(),
865            body
866        );
867        Box::leak(response.into_boxed_str())
868    }
869
870    #[tokio::test]
871    async fn openai_embedder_batches_into_single_request() {
872        // spawn_mock accepts exactly one TCP connection, so this also proves the
873        // batch is sent as ONE request rather than one-per-text.
874        let resp = openai_batch_response(&[(0, &[0.1, 0.2, 0.3]), (1, &[0.4, 0.5, 0.6])]);
875        let (url, handle) = spawn_mock(resp);
876        let embedder = OpenAICompatEmbedder::new(url, "k".into(), "m".into(), 3);
877        let texts = vec!["a".to_owned(), "b".to_owned()];
878        let vectors = embedder
879            .embed_batch(&texts, None)
880            .await
881            .expect("batch embed should succeed");
882        assert_eq!(vectors.len(), 2);
883        assert_eq!(vectors[0], vec![0.1f32, 0.2, 0.3]);
884        assert_eq!(vectors[1], vec![0.4f32, 0.5, 0.6]);
885        let req = String::from_utf8(handle.join().unwrap()).unwrap();
886        assert!(
887            req.contains("\"input\""),
888            "request should batch input: {req}"
889        );
890    }
891
892    #[tokio::test]
893    async fn openai_embedder_batch_orders_by_response_index() {
894        // Items returned out of order must be sorted back to input order.
895        let resp = openai_batch_response(&[(1, &[0.4, 0.5]), (0, &[0.1, 0.2])]);
896        let (url, handle) = spawn_mock(resp);
897        let embedder = OpenAICompatEmbedder::new(url, "k".into(), "m".into(), 2);
898        let texts = vec!["first".to_owned(), "second".to_owned()];
899        let vectors = embedder
900            .embed_batch(&texts, None)
901            .await
902            .expect("batch embed should succeed");
903        assert_eq!(vectors[0], vec![0.1f32, 0.2]);
904        assert_eq!(vectors[1], vec![0.4f32, 0.5]);
905        let _ = handle.join();
906    }
907
908    #[test]
909    fn probe_active_embedder_sync_runs_without_panicking() {
910        // The TUI status bar calls this from a sync render path; it must drive
911        // the async probe on a scratch runtime without panicking or deadlocking.
912        // The exact kind depends on the test environment; we only assert the
913        // sync→async bridge works and returns a usable kind.
914        let kind = probe_active_embedder_sync();
915        assert!(kind.dim() > 0);
916    }
917
918    #[tokio::test]
919    async fn openai_embedder_omits_auth_header_when_keyless() {
920        let (url, handle) = spawn_mock(openai_batch_response(&[(0, &[0.1, 0.2])]));
921        // Empty key = keyless local provider (`--no-key`).
922        let embedder = OpenAICompatEmbedder::new(url, String::new(), "m".into(), 2);
923        embedder
924            .embed_batch(&["x".to_owned()], None)
925            .await
926            .expect("keyless embed should succeed");
927        let req = String::from_utf8(handle.join().unwrap())
928            .unwrap()
929            .to_ascii_lowercase();
930        assert!(
931            !req.contains("authorization:"),
932            "keyless request must not send an auth header: {req}"
933        );
934    }
935
936    #[tokio::test]
937    async fn openai_embedder_sends_auth_header_when_keyed() {
938        let (url, handle) = spawn_mock(openai_batch_response(&[(0, &[0.1, 0.2])]));
939        let embedder = OpenAICompatEmbedder::new(url, "sk-x".into(), "m".into(), 2);
940        embedder
941            .embed_batch(&["x".to_owned()], None)
942            .await
943            .expect("keyed embed should succeed");
944        let req = String::from_utf8(handle.join().unwrap())
945            .unwrap()
946            .to_ascii_lowercase();
947        assert!(
948            req.contains("authorization: bearer sk-x"),
949            "keyed request must send bearer auth: {req}"
950        );
951    }
952
953    // ── keyring-encrypted embedding key round-trip ──
954    //
955    // These tests require the OS keyring to be available. On CI / headless
956    // environments (no Secret Service, no Windows Credential Manager) the
957    // keyring will fall back to the path-derived master key, which is
958    // still deterministic — so the round-trip should work on both Windows
959    // and Linux dev boxes. They are marked `#[ignore]` nonetheless so they
960    // never block headless test runs that lack access to any credential
961    // backend (e.g. sandboxed CI containers where `dirs::home_dir` is
962    // unavailable). Run locally with:
963    //   cargo test -p difflore-core embedding_key -- --ignored
964    #[test]
965    #[ignore = "requires OS keyring or stable home dir; run with --ignored"]
966    fn store_and_load_embedding_key_round_trip() {
967        let plaintext = "sk-test-abcdef123456";
968        let storage_key = store_embedding_key(plaintext).expect("store should succeed");
969        assert_ne!(
970            storage_key, plaintext,
971            "stored value must not equal plaintext"
972        );
973        assert!(
974            !storage_key.is_empty(),
975            "storage key should be non-empty hex"
976        );
977        let recovered = load_embedding_key(&storage_key).expect("load should succeed");
978        assert_eq!(recovered, plaintext);
979    }
980
981    // ── CloudEmbedder ───────────────────────────────────────────
982    //
983    // Tests use a tiny TcpListener-backed HTTP/1.1 mock — adding a real
984    // mock-server crate (wiremock / mockito) just for these would bloat
985    // the dev-dep tree more than the test gains. The mock parses the
986    // first request, sends back a fixed response, and shuts down.
987
988    use std::io::{Read, Write};
989    use std::net::TcpListener;
990    use std::thread;
991
992    fn spawn_mock(response: &'static str) -> (String, thread::JoinHandle<Vec<u8>>) {
993        let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
994        let addr = listener.local_addr().unwrap();
995        let url = format!("http://{addr}");
996        let handle = thread::spawn(move || {
997            let (mut sock, _) = listener.accept().expect("accept");
998            // Read up to the headers + body. Quick-and-dirty: read once
999            // — for our small JSON requests it fits in a single recv.
1000            let mut buf = [0u8; 4096];
1001            let n = sock.read(&mut buf).unwrap_or(0);
1002            sock.write_all(response.as_bytes()).ok();
1003            sock.flush().ok();
1004            buf[..n].to_vec()
1005        });
1006        (url, handle)
1007    }
1008
1009    fn spawn_mock_sequence(
1010        responses: Vec<&'static str>,
1011    ) -> (String, thread::JoinHandle<Vec<Vec<u8>>>) {
1012        let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
1013        let addr = listener.local_addr().unwrap();
1014        let url = format!("http://{addr}");
1015        let handle = thread::spawn(move || {
1016            let mut requests = Vec::new();
1017            for response in responses {
1018                let (mut sock, _) = listener.accept().expect("accept");
1019                let mut buf = [0u8; 4096];
1020                let n = sock.read(&mut buf).unwrap_or(0);
1021                sock.write_all(response.as_bytes()).ok();
1022                sock.flush().ok();
1023                requests.push(buf[..n].to_vec());
1024            }
1025            requests
1026        });
1027        (url, handle)
1028    }
1029
1030    #[tokio::test]
1031    async fn cloud_embedder_returns_first_vector_on_success() {
1032        let body = serde_json::json!({
1033            "vectors": [[0.1, 0.2, 0.3]],
1034            "model": "text-embedding-3-small",
1035            "dim": 1536,
1036        })
1037        .to_string();
1038        let response = format!(
1039            "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
1040            body.len(),
1041            body
1042        );
1043        // Leak so the closure's 'static bound is satisfied.
1044        let response_static: &'static str = Box::leak(response.into_boxed_str());
1045        let (url, handle) = spawn_mock(response_static);
1046        let embedder = CloudEmbedder::with_model(url, "tok".into(), "m".into(), 1536);
1047        let v = embedder.embed("hello").await.expect("embed");
1048        assert_eq!(v.len(), 3);
1049        assert!((v[0] - 0.1).abs() < 1e-4);
1050        let req = handle.join().unwrap();
1051        let req_str = String::from_utf8_lossy(&req);
1052        // HTTP/1.1 headers are case-insensitive; reqwest may emit
1053        // "authorization:" lower-cased depending on the version. Compare
1054        // case-insensitively.
1055        let req_lower = req_str.to_ascii_lowercase();
1056        assert!(
1057            req_lower.contains("authorization: bearer tok"),
1058            "auth header missing in: {req_str}"
1059        );
1060        assert!(req_str.contains("\"texts\""));
1061        assert!(req_str.contains("hello"));
1062    }
1063
1064    #[tokio::test]
1065    async fn cloud_embedder_maps_5xx_to_core_error() {
1066        let response =
1067            "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 4\r\nConnection: close\r\n\r\nfail";
1068        let (url, handle) = spawn_mock_sequence(vec![response, response, response, response]);
1069        let embedder = CloudEmbedder::with_model(url, "t".into(), "m".into(), 1536);
1070        let err = embedder.embed("x").await.expect_err("should fail");
1071        match err {
1072            CoreError::Internal(msg) => assert!(msg.contains("502"), "msg: {msg}"),
1073            other => panic!("unexpected err: {other:?}"),
1074        }
1075        assert_eq!(handle.join().unwrap().len(), 4);
1076    }
1077
1078    #[tokio::test]
1079    async fn cloud_embedder_retries_transient_5xx_once() {
1080        let ok_body = serde_json::json!({
1081            "vectors": [[0.4, 0.5]],
1082            "model": "text-embedding-3-small",
1083            "dim": 1536,
1084        })
1085        .to_string();
1086        let fail = "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 4\r\nConnection: close\r\n\r\nfail";
1087        let ok = format!(
1088            "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
1089            ok_body.len(),
1090            ok_body
1091        );
1092        let ok_static: &'static str = Box::leak(ok.into_boxed_str());
1093        let (url, handle) = spawn_mock_sequence(vec![fail, ok_static]);
1094        let embedder = CloudEmbedder::with_model(url, "tok".into(), "m".into(), 1536);
1095        let v = embedder.embed("hello").await.expect("embed after retry");
1096        assert_eq!(v, vec![0.4, 0.5]);
1097        assert_eq!(handle.join().unwrap().len(), 2);
1098    }
1099
1100    #[tokio::test]
1101    async fn cloud_embedder_dim_is_reported() {
1102        let embedder = CloudEmbedder::new("http://example.invalid".into(), "t".into());
1103        assert_eq!(embedder.dim(), DEFAULT_OPENAI_EMBEDDING_DIM);
1104        assert!(embedder.is_semantic());
1105    }
1106
1107    #[test]
1108    fn cloud_embedder_endpoint_handles_trailing_slash() {
1109        let a = CloudEmbedder::new("http://h/api".into(), "t".into());
1110        let b = CloudEmbedder::new("http://h/api/".into(), "t".into());
1111        assert_eq!(a.endpoint(), "http://h/api/embeddings");
1112        assert_eq!(b.endpoint(), "http://h/api/embeddings");
1113    }
1114
1115    #[test]
1116    fn url_host_strips_scheme_and_path() {
1117        assert_eq!(
1118            url_host("https://api.openai.com/v1"),
1119            Some("api.openai.com")
1120        );
1121        assert_eq!(url_host("http://localhost:8080/x"), Some("localhost:8080"));
1122        assert_eq!(url_host("noscheme/path"), Some("noscheme"));
1123        assert_eq!(url_host(""), None);
1124    }
1125
1126    #[test]
1127    fn load_embedding_key_rejects_invalid_storage_key() {
1128        // Invalid hex / too-short ciphertext must produce an error, never
1129        // panic. This path does NOT touch the keyring (the validation fires
1130        // first inside `from_hex` / length check), so it's safe to run in
1131        // headless environments.
1132        let err = load_embedding_key("not-valid-hex-$$").unwrap_err();
1133        match err {
1134            CoreError::Internal(msg) => assert!(msg.contains("failed to decrypt")),
1135            other => panic!("unexpected error variant: {other:?}"),
1136        }
1137
1138        let err2 = load_embedding_key("abcd").unwrap_err();
1139        assert!(matches!(err2, CoreError::Internal(_)));
1140    }
1141}