Skip to main content

ai_memory/
reranker.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! Cross-encoder reranking for search results.
5//!
6//! A cross-encoder takes a (query, document) pair and produces a relevance
7//! score. This is more accurate than cosine similarity of independent
8//! embeddings but slower since it must run for each candidate.
9//!
10//! **Two implementations:**
11//! - `CrossEncoder::Lexical` — lightweight term-overlap scorer (default).
12//! - `CrossEncoder::Neural` — BERT-based cross-encoder loaded via candle
13//!   from `cross-encoder/ms-marco-MiniLM-L-6-v2` (~80 MB, ONNX-free).
14
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, Mutex};
17
18use anyhow::{Context, Result};
19use candle_core::{Device, Tensor};
20use candle_nn::VarBuilder;
21use candle_transformers::models::bert::{BertModel, Config as BertConfig};
22use hf_hub::{Repo, RepoType, api::sync::Api};
23use tokenizers::Tokenizer;
24
25use crate::models::Memory;
26
27/// Blend weight applied to the original (embedding/FTS) score.
28const ORIGINAL_WEIGHT: f64 = 0.6;
29/// Blend weight applied to the cross-encoder score.
30const CROSS_ENCODER_WEIGHT: f64 = 0.4;
31
32const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
33const CROSS_ENCODER_MAX_SEQ: usize = 512;
34const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
35
36/// Cross-encoder for (query, document) relevance scoring.
37pub enum CrossEncoder {
38    /// Lightweight lexical cross-encoder using term overlap signals.
39    Lexical,
40    /// Neural BERT-based cross-encoder (ms-marco-MiniLM-L-6-v2).
41    Neural {
42        model: Arc<Mutex<BertModel>>,
43        tokenizer: Arc<Tokenizer>,
44        classifier_weight: Tensor,
45        classifier_bias: Tensor,
46        device: Device,
47    },
48}
49
50impl CrossEncoder {
51    /// Create a new lexical cross-encoder (no model download required).
52    pub fn new() -> Self {
53        Self::Lexical
54    }
55
56    /// Create a neural cross-encoder by downloading ms-marco-MiniLM-L-6-v2.
57    ///
58    /// Falls back to lexical if download or loading fails.
59    ///
60    /// v0.6.3.1 (P3, G8): when the neural path fails (e.g. HF Hub
61    /// unreachable, model checksum mismatch), emit a structured tracing
62    /// event `reranker.fallback` so operators see the silent
63    /// neural→lexical degrade. The eprintln remains for backward-compat
64    /// startup logs.
65    pub fn new_neural() -> Self {
66        match Self::load_neural() {
67            Ok(ce) => ce,
68            Err(e) => {
69                tracing::warn!(
70                    target: "reranker.fallback",
71                    from = "neural",
72                    to = "lexical",
73                    reason = %e,
74                    "cross-encoder fell back to lexical: neural init failed"
75                );
76                eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
77                Self::Lexical
78            }
79        }
80    }
81
82    fn load_neural() -> Result<Self> {
83        let device = Device::Cpu;
84
85        let api = Api::new().context("failed to init HuggingFace Hub API")?;
86        let repo = api.repo(Repo::new(
87            CROSS_ENCODER_MODEL_ID.to_string(),
88            RepoType::Model,
89        ));
90
91        let config_path = repo
92            .get("config.json")
93            .context("failed to download config.json")?;
94        let tokenizer_path = repo
95            .get("tokenizer.json")
96            .context("failed to download tokenizer.json")?;
97        let weights_path = repo
98            .get("model.safetensors")
99            .context("failed to download model.safetensors")?;
100
101        // Load BERT config
102        let config_data = std::fs::read_to_string(&config_path)
103            .context("failed to read cross-encoder config.json")?;
104        let config: BertConfig = serde_json::from_str(&config_data)
105            .context("failed to parse cross-encoder config.json")?;
106
107        // Load tokenizer
108        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
109            .map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
110        let truncation = tokenizers::TruncationParams {
111            max_length: CROSS_ENCODER_MAX_SEQ,
112            ..Default::default()
113        };
114        tokenizer
115            .with_truncation(Some(truncation))
116            .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
117        tokenizer.with_padding(None);
118
119        // Load model weights
120        let vb = unsafe {
121            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
122                .context("failed to load cross-encoder weights")?
123        };
124
125        let model = BertModel::load(vb.clone(), &config)
126            .context("failed to build cross-encoder BertModel")?;
127
128        // Load the classification head: classifier.weight [1, hidden_dim] and classifier.bias [1]
129        let classifier_weight = vb
130            .get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
131            .context("failed to load classifier.weight")?;
132        let classifier_bias = vb
133            .get(1, "classifier.bias")
134            .context("failed to load classifier.bias")?;
135
136        Ok(Self::Neural {
137            model: Arc::new(Mutex::new(model)),
138            tokenizer: Arc::new(tokenizer),
139            classifier_weight,
140            classifier_bias,
141            device,
142        })
143    }
144
145    /// Score a single (query, document) pair.
146    ///
147    /// Returns a relevance score in `0.0..=1.0`.
148    pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
149        match self {
150            Self::Lexical => lexical_score(query, title, content),
151            Self::Neural {
152                model,
153                tokenizer,
154                classifier_weight,
155                classifier_bias,
156                device,
157            } => {
158                let model_guard = match model.lock() {
159                    Ok(g) => g,
160                    Err(e) => {
161                        tracing::warn!("cross-encoder model lock poisoned: {e}");
162                        return lexical_score(query, title, content);
163                    }
164                };
165                match Self::neural_score(
166                    &model_guard,
167                    tokenizer,
168                    classifier_weight,
169                    classifier_bias,
170                    device,
171                    query,
172                    title,
173                    content,
174                ) {
175                    Ok(s) => s,
176                    Err(e) => {
177                        tracing::warn!(
178                            "neural cross-encoder score failed: {e}, using lexical fallback"
179                        );
180                        lexical_score(query, title, content)
181                    }
182                }
183            }
184        }
185    }
186
187    #[allow(clippy::too_many_arguments)]
188    fn neural_score(
189        model: &BertModel,
190        tokenizer: &Tokenizer,
191        classifier_weight: &Tensor,
192        classifier_bias: &Tensor,
193        device: &Device,
194        query: &str,
195        title: &str,
196        content: &str,
197    ) -> Result<f32> {
198        // Cross-encoder input: "[CLS] query [SEP] title content [SEP]"
199        let document = format!("{title} {content}");
200
201        let encoding = tokenizer
202            .encode((query, document.as_str()), true)
203            .map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
204
205        let input_ids = encoding.get_ids();
206        let attention_mask = encoding.get_attention_mask();
207        let token_type_ids = encoding.get_type_ids();
208        let seq_len = input_ids.len();
209
210        let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
211        let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
212        let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
213
214        // Forward pass through BERT → [1, seq_len, 384]
215        let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
216
217        // Take [CLS] token (first token) → [1, 384]
218        let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
219
220        // Classification head: logit = cls @ weight^T + bias → [1, 1]
221        let logit = cls
222            .matmul(&classifier_weight.t()?)?
223            .broadcast_add(classifier_bias)?;
224
225        // Extract scalar logit and apply sigmoid to get [0, 1] score
226        let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
227        let score = 1.0 / (1.0 + (-logit_val).exp());
228
229        Ok(score)
230    }
231
232    /// Whether this is a neural cross-encoder.
233    pub fn is_neural(&self) -> bool {
234        matches!(self, Self::Neural { .. })
235    }
236
237    /// Rerank a set of candidates by blending their original scores with
238    /// cross-encoder scores.
239    ///
240    /// **Blend formula:** `final = 0.6 * original + 0.4 * cross_encoder`
241    ///
242    /// Results are returned sorted by `final_score` descending.
243    pub fn rerank(&self, query: &str, mut candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
244        let mut scored: Vec<(Memory, f64)> = candidates
245            .drain(..)
246            .map(|(mem, original_score)| {
247                let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
248                let final_score =
249                    ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
250                (mem, final_score)
251            })
252            .collect();
253
254        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
255        scored
256    }
257}
258
259impl Default for CrossEncoder {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Lexical cross-encoder (original implementation)
267// ---------------------------------------------------------------------------
268
269fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
270    let query_terms = tokenize(query);
271    if query_terms.is_empty() {
272        return 0.0;
273    }
274
275    let title_terms = tokenize(title);
276    let content_terms = tokenize(content);
277
278    let doc_terms: HashSet<&str> = title_terms
279        .iter()
280        .chain(content_terms.iter())
281        .copied()
282        .collect();
283    let query_set: HashSet<&str> = query_terms.iter().copied().collect();
284
285    // 1. Jaccard term overlap
286    #[allow(clippy::cast_precision_loss)]
287    let intersection = query_set.intersection(&doc_terms).count() as f32;
288    #[allow(clippy::cast_precision_loss)]
289    let union = query_set.union(&doc_terms).count() as f32;
290    let jaccard = if union > 0.0 {
291        intersection / union
292    } else {
293        0.0
294    };
295
296    // 2. TF-IDF-like term weighting
297    let doc_all: Vec<&str> = title_terms
298        .iter()
299        .chain(content_terms.iter())
300        .copied()
301        .collect();
302    let tf_idf = tfidf_score(&query_terms, &doc_all);
303
304    // 3. Bigram overlap bonus
305    let query_bigrams = bigrams(&query_terms);
306    let doc_bigrams = bigrams(&doc_all);
307    let bigram_overlap = if query_bigrams.is_empty() {
308        0.0
309    } else {
310        let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
311        #[allow(clippy::cast_precision_loss)]
312        let hits = query_bigrams
313            .iter()
314            .filter(|b| doc_bigram_set.contains(b))
315            .count() as f32;
316        #[allow(clippy::cast_precision_loss)]
317        let query_bigrams_len = query_bigrams.len() as f32;
318        hits / query_bigrams_len
319    };
320
321    // 4. Title match bonus
322    let title_set: HashSet<&str> = title_terms.iter().copied().collect();
323    #[allow(clippy::cast_precision_loss)]
324    let title_hits = query_set.intersection(&title_set).count() as f32;
325    #[allow(clippy::cast_precision_loss)]
326    let title_bonus = if query_set.is_empty() {
327        0.0
328    } else {
329        title_hits / query_set.len() as f32
330    };
331
332    let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
333    raw.clamp(0.0, 1.0)
334}
335
336// ---------------------------------------------------------------------------
337// Internal helpers
338// ---------------------------------------------------------------------------
339
340fn tokenize(text: &str) -> Vec<&str> {
341    text.split(|c: char| !c.is_alphanumeric() && c != '\'')
342        .filter(|w| !w.is_empty())
343        .collect()
344}
345
346fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
347    if doc_tokens.is_empty() || query_terms.is_empty() {
348        return 0.0;
349    }
350
351    let mut tf_map: HashMap<&str, usize> = HashMap::new();
352    for &tok in doc_tokens {
353        *tf_map.entry(tok).or_insert(0) += 1;
354    }
355
356    #[allow(clippy::cast_precision_loss)]
357    let total = doc_tokens.len() as f32;
358    #[allow(clippy::cast_precision_loss)]
359    let unique = tf_map.len() as f32;
360
361    let mut score_sum: f32 = 0.0;
362    let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
363
364    for qt in &query_lower {
365        #[allow(clippy::cast_precision_loss)]
366        let tf = tf_map
367            .iter()
368            .filter(|(k, _)| k.to_lowercase() == *qt)
369            .map(|(_, &v)| v)
370            .sum::<usize>() as f32;
371
372        if tf == 0.0 {
373            continue;
374        }
375
376        let tf_norm = tf / total;
377        #[allow(clippy::cast_precision_loss)]
378        let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
379        let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
380
381        score_sum += tf_norm * idf;
382    }
383
384    #[allow(clippy::cast_precision_loss)]
385    let max_possible = query_lower.len() as f32;
386    (score_sum / max_possible).clamp(0.0, 1.0)
387}
388
389fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
390    tokens.windows(2).map(|w| (w[0], w[1])).collect()
391}
392
393// ---------------------------------------------------------------------------
394// Tests
395// ---------------------------------------------------------------------------
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::models::{Memory, Tier};
401
402    fn make_memory(title: &str, content: &str) -> Memory {
403        Memory {
404            id: "test-id".to_string(),
405            tier: Tier::Mid,
406            namespace: "test".to_string(),
407            title: title.to_string(),
408            content: content.to_string(),
409            tags: vec![],
410            priority: 5,
411            confidence: 1.0,
412            source: "test".to_string(),
413            access_count: 0,
414            created_at: "2026-01-01T00:00:00Z".to_string(),
415            updated_at: "2026-01-01T00:00:00Z".to_string(),
416            last_accessed_at: None,
417            expires_at: None,
418            metadata: serde_json::json!({}),
419        }
420    }
421
422    #[test]
423    fn lexical_score_returns_zero_for_empty_query() {
424        assert_eq!(lexical_score("", "some title", "some content"), 0.0);
425    }
426
427    #[test]
428    fn lexical_score_returns_zero_for_no_overlap() {
429        let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
430        assert!(s < 0.05, "expected near-zero, got {s}");
431    }
432
433    #[test]
434    fn lexical_score_rewards_title_match() {
435        let content = "This document discusses network configuration for LAN setups.";
436        let s_title_match = lexical_score(
437            "network configuration",
438            "Network Configuration Guide",
439            content,
440        );
441        let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
442        assert!(
443            s_title_match > s_no_title,
444            "title match ({s_title_match}) should beat no title match ({s_no_title})"
445        );
446    }
447
448    #[test]
449    fn lexical_score_is_bounded_zero_one() {
450        let s = lexical_score(
451            "the quick brown fox jumps over the lazy dog",
452            "the quick brown fox",
453            "the quick brown fox jumps over the lazy dog and more words",
454        );
455        assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
456    }
457
458    #[test]
459    fn rerank_reorders_candidates() {
460        let ce = CrossEncoder::new();
461        let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
462        let b = make_memory("Grocery list", "milk eggs bread butter cheese");
463        let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
464        let reranked = ce.rerank("cross-encoder reranking", candidates);
465        assert_eq!(reranked[0].0.title, "Rust cross-encoder");
466    }
467
468    #[test]
469    fn rerank_preserves_candidate_count() {
470        let ce = CrossEncoder::new();
471        let candidates = vec![
472            (make_memory("A", "alpha"), 0.5),
473            (make_memory("B", "beta"), 0.6),
474            (make_memory("C", "gamma"), 0.7),
475        ];
476        let reranked = ce.rerank("alpha", candidates);
477        assert_eq!(reranked.len(), 3);
478    }
479
480    #[test]
481    fn bigram_overlap_boosts_phrase_match() {
482        let s_phrase = lexical_score(
483            "network adapter",
484            "title",
485            "the network adapter is connected to the LAN",
486        );
487        let s_scattered = lexical_score(
488            "network adapter",
489            "title",
490            "the adapter handles the network traffic independently",
491        );
492        assert!(
493            s_phrase > s_scattered,
494            "phrase match ({s_phrase}) should beat scattered ({s_scattered})"
495        );
496    }
497
498    // -----------------------------------------------------------------
499    // W11/S11b — input-count invariants for the rerank() API
500    // -----------------------------------------------------------------
501
502    #[test]
503    fn test_rerank_preserves_input_count_heuristic() {
504        let ce = CrossEncoder::new();
505        // Build 5 distinct candidates with varied original scores.
506        let candidates: Vec<(Memory, f64)> = (0..5)
507            .map(|i| {
508                (
509                    make_memory(
510                        &format!("title {i}"),
511                        &format!("content body number {i} with some words"),
512                    ),
513                    f64::from(i) * 0.1,
514                )
515            })
516            .collect();
517        let query = "title content body";
518        let reranked = ce.rerank(query, candidates);
519        assert_eq!(
520            reranked.len(),
521            5,
522            "heuristic rerank must preserve candidate count, got {} = {:?}",
523            reranked.len(),
524            reranked
525                .iter()
526                .map(|(m, s)| (&m.title, *s))
527                .collect::<Vec<_>>()
528        );
529        // Sorted descending by final score (rerank contract).
530        for w in reranked.windows(2) {
531            assert!(
532                w[0].1 >= w[1].1,
533                "rerank output must be descending by score: {} < {}",
534                w[0].1,
535                w[1].1
536            );
537        }
538    }
539
540    #[test]
541    fn test_rerank_zero_candidates_returns_empty_heuristic() {
542        let ce = CrossEncoder::new();
543        let reranked = ce.rerank("query", Vec::new());
544        assert!(reranked.is_empty());
545    }
546
547    // Neural variant: gated to avoid pulling 80MB BERT weights at test time.
548    // Run with `--features test-with-models` once the cross-encoder feature
549    // exists upstream.
550    #[cfg(feature = "test-with-models")]
551    #[test]
552    fn test_rerank_preserves_input_count_neural_if_available() {
553        let ce = CrossEncoder::new_neural();
554        let candidates: Vec<(Memory, f64)> = (0..5)
555            .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
556            .collect();
557        let reranked = ce.rerank("body", candidates);
558        assert_eq!(reranked.len(), 5);
559    }
560
561    // -----------------------------------------------------------------
562    // W12-E — heuristic-path branch coverage for reranker.rs
563    //
564    // Targets the Lexical variant only. The Neural variant requires
565    // downloading 80+ MB of BERT weights from HuggingFace Hub and is
566    // gated behind `feature = "test-with-models"`.
567    // -----------------------------------------------------------------
568
569    #[test]
570    fn w12e_default_is_lexical() {
571        let ce = CrossEncoder::default();
572        assert!(!ce.is_neural(), "Default::default() must return Lexical");
573    }
574
575    #[test]
576    fn w12e_new_returns_lexical() {
577        let ce = CrossEncoder::new();
578        assert!(!ce.is_neural());
579    }
580
581    #[test]
582    fn w12e_score_dispatch_lexical_matches_helper() {
583        // The CrossEncoder::score() dispatcher must delegate to lexical_score()
584        // for the Lexical variant. Compute both and assert exact equality.
585        let ce = CrossEncoder::new();
586        let q = "rust async runtime";
587        let title = "Tokio: Rust async runtime";
588        let content = "Tokio is an async runtime for the Rust programming language.";
589        let via_dispatcher = ce.score(q, title, content);
590        let direct = lexical_score(q, title, content);
591        assert!((via_dispatcher - direct).abs() < f32::EPSILON);
592    }
593
594    #[test]
595    fn w12e_score_empty_inputs_safe() {
596        let ce = CrossEncoder::new();
597        // Empty query → 0.0 by short-circuit in lexical_score
598        assert_eq!(ce.score("", "title", "content"), 0.0);
599        // Empty title and content with non-empty query — must not panic
600        let s = ce.score("query", "", "");
601        assert!((0.0..=1.0).contains(&s));
602        // Whitespace-only query treated as empty after tokenization
603        let s_ws = ce.score("   \t\n", "title", "content");
604        assert_eq!(s_ws, 0.0);
605        // Punctuation-only query also yields no tokens
606        let s_punct = ce.score("!?.,;:", "title", "content");
607        assert_eq!(s_punct, 0.0);
608    }
609
610    #[test]
611    fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
612        // Mixed Unicode tokens with apostrophes, accents, emoji boundaries.
613        let s_unicode = lexical_score(
614            "café résumé d'oeuvre",
615            "Le Café d'Oeuvre",
616            "résumé du café avec d'oeuvre noté",
617        );
618        assert!(
619            (0.0..=1.0).contains(&s_unicode),
620            "unicode score {s_unicode} out of bounds"
621        );
622
623        // Very long content stresses the length-normalization branches.
624        let huge = "alpha beta gamma delta ".repeat(2_500);
625        let s_long = lexical_score("alpha gamma", "headline", &huge);
626        assert!(
627            (0.0..=1.0).contains(&s_long),
628            "long score {s_long} out of bounds"
629        );
630    }
631
632    #[test]
633    fn w12e_lexical_score_perfect_overlap_high() {
634        // 100% query overlap with title and content should produce a high
635        // (but bounded) score.
636        let s = lexical_score(
637            "alpha beta gamma",
638            "alpha beta gamma",
639            "alpha beta gamma alpha beta gamma",
640        );
641        assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
642        assert!(s <= 1.0);
643    }
644
645    #[test]
646    fn w12e_tfidf_score_empty_doc_returns_zero() {
647        // Branch: doc_tokens.is_empty() → 0.0 short-circuit.
648        let q = vec!["alpha", "beta"];
649        let doc: Vec<&str> = Vec::new();
650        assert_eq!(tfidf_score(&q, &doc), 0.0);
651    }
652
653    #[test]
654    fn w12e_tfidf_score_empty_query_returns_zero() {
655        // Branch: query_terms.is_empty() → 0.0 short-circuit.
656        let q: Vec<&str> = Vec::new();
657        let doc = vec!["alpha", "beta", "gamma"];
658        assert_eq!(tfidf_score(&q, &doc), 0.0);
659    }
660
661    #[test]
662    fn w12e_tfidf_score_no_matching_terms() {
663        // Query terms entirely absent from doc → tf == 0 continue branch.
664        let q = vec!["xenon", "kryptonite"];
665        let doc = vec!["alpha", "beta", "gamma"];
666        let s = tfidf_score(&q, &doc);
667        assert_eq!(s, 0.0);
668    }
669
670    #[test]
671    fn w12e_tfidf_score_partial_match_bounded() {
672        // Mixed presence/absence; clamp branch reachable.
673        let q = vec!["alpha", "missing"];
674        let doc = vec!["alpha", "alpha", "beta", "gamma"];
675        let s = tfidf_score(&q, &doc);
676        assert!((0.0..=1.0).contains(&s));
677        assert!(s > 0.0);
678    }
679
680    #[test]
681    fn w12e_bigrams_empty_and_single_and_multi() {
682        // Empty input → empty bigram list.
683        let empty: Vec<&str> = Vec::new();
684        assert!(bigrams(&empty).is_empty());
685
686        // Single token → no bigrams (windows(2) yields nothing).
687        let one = vec!["solo"];
688        assert!(bigrams(&one).is_empty());
689
690        // Multi-token → N-1 bigrams.
691        let three = vec!["a", "b", "c"];
692        let bg = bigrams(&three);
693        assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
694    }
695
696    #[test]
697    fn w12e_tokenize_handles_apostrophe_and_unicode() {
698        // Apostrophes are preserved (e.g., "don't"), other punctuation splits.
699        let toks = tokenize("don't stop, I won't!");
700        assert!(toks.contains(&"don't"));
701        assert!(toks.contains(&"won't"));
702        assert!(toks.contains(&"stop"));
703        assert!(toks.contains(&"I"));
704
705        // Pure-punctuation yields no tokens.
706        let none = tokenize("!!!,,,;;;");
707        assert!(none.is_empty());
708
709        // Empty string yields no tokens.
710        let empty = tokenize("");
711        assert!(empty.is_empty());
712
713        // Unicode alphanumerics survive (café = 4 alphanumeric chars).
714        let unicode = tokenize("café résumé");
715        assert_eq!(unicode.len(), 2);
716    }
717
718    #[test]
719    fn w12e_rerank_single_candidate_keeps_it() {
720        let ce = CrossEncoder::new();
721        let only = make_memory("solo title", "solo content body");
722        let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
723        assert_eq!(out.len(), 1);
724        assert_eq!(out[0].0.title, "solo title");
725        // Final score is a blend of original and CE score, both nonneg.
726        assert!(out[0].1 >= 0.0);
727    }
728
729    #[test]
730    fn w12e_rerank_identical_originals_stable_under_score() {
731        // When original scores are identical, ordering is determined by the
732        // CE score. The candidate whose title/content overlaps the query
733        // should rank first.
734        let ce = CrossEncoder::new();
735        let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
736        let off_topic = make_memory("grocery", "milk eggs bread");
737        let out = ce.rerank(
738            "rust async",
739            vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
740        );
741        assert_eq!(out.len(), 2);
742        assert_eq!(out[0].0.title, "rust async runtime");
743    }
744
745    #[test]
746    fn w12e_rerank_descending_invariant_holds_across_shapes() {
747        // Property-style: irrespective of input shape, output is sorted desc.
748        let ce = CrossEncoder::new();
749        let cands: Vec<(Memory, f64)> = vec![
750            (make_memory("a", "alpha words"), 0.10),
751            (make_memory("b", "beta words"), 0.95),
752            (make_memory("c", "gamma alpha"), 0.55),
753            (make_memory("d", ""), 0.0),
754            (make_memory("", "empty title doc"), 0.30),
755        ];
756        let out = ce.rerank("alpha", cands);
757        assert_eq!(out.len(), 5);
758        for w in out.windows(2) {
759            assert!(
760                w[0].1 >= w[1].1,
761                "non-descending pair: {} then {}",
762                w[0].1,
763                w[1].1
764            );
765        }
766    }
767
768    #[test]
769    fn w12e_lexical_score_no_title_branch_via_empty_title() {
770        // Empty title means title_set is empty; title_bonus == 0.0.
771        // query_set non-empty so the else branch (title_hits / |Q|) runs.
772        let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
773        let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
774        assert!(s_with_title >= s_empty_title);
775        assert!((0.0..=1.0).contains(&s_empty_title));
776    }
777
778    #[test]
779    fn w12e_lexical_score_query_terms_only_in_title() {
780        // Title contains all query terms; content has none.
781        let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
782        assert!(s > 0.0);
783        assert!(s <= 1.0);
784    }
785
786    // PR-9i — buffer coverage uplift.
787
788    #[test]
789    fn pr9i_new_neural_dual_outcome() {
790        // Exercises CrossEncoder::new_neural() (lines 65-79). Behavior is
791        // environment-dependent: with an HF cache or network the call
792        // succeeds and returns Self::Neural; without either it falls back
793        // to Self::Lexical via the documented eprintln + tracing warn
794        // pathway. Both outcomes are acceptable — what matters is the
795        // dispatch is hit. Functionally, both variants score within
796        // [0.0, 1.0].
797        let ce = CrossEncoder::new_neural();
798        let s = ce.score("query", "title", "content");
799        assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
800    }
801
802    #[test]
803    fn pr9i_rerank_via_score_returns_blend() {
804        // Even when new_neural() falls back to lexical, rerank() must
805        // still produce a deterministic [0..1] blend. Pins the contract
806        // for both branches of CrossEncoder::score().
807        let ce = CrossEncoder::new_neural();
808        let cands = vec![
809            (
810                Memory {
811                    id: "a".to_string(),
812                    tier: Tier::Mid,
813                    namespace: "ns".to_string(),
814                    title: "rust async runtime".to_string(),
815                    content: "tokio rust async".to_string(),
816                    tags: vec![],
817                    priority: 5,
818                    confidence: 1.0,
819                    source: "test".to_string(),
820                    access_count: 0,
821                    created_at: "2026-01-01T00:00:00Z".to_string(),
822                    updated_at: "2026-01-01T00:00:00Z".to_string(),
823                    last_accessed_at: None,
824                    expires_at: None,
825                    metadata: serde_json::json!({}),
826                },
827                0.6,
828            ),
829            (
830                Memory {
831                    id: "b".to_string(),
832                    tier: Tier::Mid,
833                    namespace: "ns".to_string(),
834                    title: "grocery list".to_string(),
835                    content: "milk eggs".to_string(),
836                    tags: vec![],
837                    priority: 5,
838                    confidence: 1.0,
839                    source: "test".to_string(),
840                    access_count: 0,
841                    created_at: "2026-01-01T00:00:00Z".to_string(),
842                    updated_at: "2026-01-01T00:00:00Z".to_string(),
843                    last_accessed_at: None,
844                    expires_at: None,
845                    metadata: serde_json::json!({}),
846                },
847                0.4,
848            ),
849        ];
850        let out = ce.rerank("rust async", cands);
851        assert_eq!(out.len(), 2);
852        for (_, score) in &out {
853            assert!(score.is_finite());
854        }
855        // First entry's blended score >= second by sort contract.
856        assert!(out[0].1 >= out[1].1);
857    }
858}
859
860#[cfg(test)]
861#[allow(
862    clippy::unused_self,
863    clippy::unnecessary_wraps,
864    clippy::needless_pass_by_value,
865    clippy::wildcard_imports
866)]
867pub mod test_support {
868    use super::*;
869
870    /// Mock neural cross-encoder for testing. Returns deterministic scores
871    /// based on (query, title, content) without loading BERT.
872    pub struct MockCrossEncoder {
873        pub use_neural: bool,
874    }
875
876    impl MockCrossEncoder {
877        /// Create a mock lexical encoder (like CrossEncoder::new()).
878        pub fn new() -> Self {
879            Self { use_neural: false }
880        }
881
882        /// Create a mock neural encoder (like CrossEncoder::new_neural()).
883        pub fn new_neural() -> Self {
884            Self { use_neural: true }
885        }
886
887        /// Mock score: deterministic hash-based score in [0, 1].
888        /// Neural path uses a different formula than lexical for testing.
889        pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
890            if self.use_neural {
891                // Neural mock: combine query+title hash
892                let combined = format!("{}{}", query, title);
893                let hash = combined.bytes().fold(0u32, |acc, b| {
894                    acc.wrapping_mul(31).wrapping_add(u32::from(b))
895                });
896                let base = ((hash % 1000) as f32) / 1000.0;
897                // Boost for exact title matches
898                if title.contains(query) {
899                    (base * 0.5 + 0.5).min(1.0)
900                } else {
901                    base
902                }
903            } else {
904                // Lexical path uses the real lexical_score
905                lexical_score(query, title, content)
906            }
907        }
908
909        /// Whether this is a neural mock.
910        pub fn is_neural(&self) -> bool {
911            self.use_neural
912        }
913
914        /// Rerank candidates (same blending formula as real CrossEncoder).
915        pub fn rerank(
916            &self,
917            query: &str,
918            mut candidates: Vec<(Memory, f64)>,
919        ) -> Vec<(Memory, f64)> {
920            let mut scored: Vec<(Memory, f64)> = candidates
921                .drain(..)
922                .map(|(mem, original_score)| {
923                    let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
924                    let final_score =
925                        ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
926                    (mem, final_score)
927                })
928                .collect();
929
930            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
931            scored
932        }
933    }
934
935    impl Default for MockCrossEncoder {
936        fn default() -> Self {
937            Self::new()
938        }
939    }
940}
941
942#[cfg(test)]
943mod mock_tests {
944    use super::test_support::*;
945    use crate::models::{Memory, Tier};
946
947    fn make_memory(title: &str, content: &str) -> Memory {
948        Memory {
949            id: "test-id".to_string(),
950            tier: Tier::Mid,
951            namespace: "test".to_string(),
952            title: title.to_string(),
953            content: content.to_string(),
954            tags: vec![],
955            priority: 5,
956            confidence: 1.0,
957            source: "test".to_string(),
958            access_count: 0,
959            created_at: "2026-01-01T00:00:00Z".to_string(),
960            updated_at: "2026-01-01T00:00:00Z".to_string(),
961            last_accessed_at: None,
962            expires_at: None,
963            metadata: serde_json::json!({}),
964        }
965    }
966
967    #[test]
968    fn mock_lexical_new() {
969        let ce = MockCrossEncoder::new();
970        assert!(!ce.is_neural());
971    }
972
973    #[test]
974    fn mock_neural_new() {
975        let ce = MockCrossEncoder::new_neural();
976        assert!(ce.is_neural());
977    }
978
979    #[test]
980    fn mock_neural_score_deterministic() {
981        let ce = MockCrossEncoder::new_neural();
982        let s1 = ce.score("query", "title", "content");
983        let s2 = ce.score("query", "title", "content");
984        assert_eq!(s1, s2);
985    }
986
987    #[test]
988    fn mock_neural_score_title_match_boost() {
989        let ce = MockCrossEncoder::new_neural();
990        let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
991        let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
992        assert!(
993            s_title_contains > s_no_match,
994            "title match ({s_title_contains}) should beat no match ({s_no_match})"
995        );
996    }
997
998    #[test]
999    fn mock_neural_score_bounded() {
1000        let ce = MockCrossEncoder::new_neural();
1001        for query in &["test", "neural", "reranker", "machine learning"] {
1002            for title in &["a", "b", "the quick brown"] {
1003                let s = ce.score(query, title, "content");
1004                assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
1005            }
1006        }
1007    }
1008
1009    #[test]
1010    fn mock_neural_rerank_reorders() {
1011        let ce = MockCrossEncoder::new_neural();
1012        let a = make_memory("neural network", "deep learning with transformers");
1013        let b = make_memory("grocery list", "milk eggs bread butter");
1014        let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
1015        let reranked = ce.rerank("neural network", candidates);
1016        // Neural encoder should boost the neural-network-titled memory
1017        assert_eq!(reranked[0].0.title, "neural network");
1018    }
1019
1020    #[test]
1021    fn mock_neural_rerank_preserves_count() {
1022        let ce = MockCrossEncoder::new_neural();
1023        let candidates = vec![
1024            (make_memory("A", "content a"), 0.5),
1025            (make_memory("B", "content b"), 0.4),
1026            (make_memory("C", "content c"), 0.6),
1027        ];
1028        let reranked = ce.rerank("test", candidates);
1029        assert_eq!(reranked.len(), 3);
1030    }
1031
1032    #[test]
1033    fn mock_lexical_path_via_mock() {
1034        let ce = MockCrossEncoder::new();
1035        let s = ce.score(
1036            "network adapter",
1037            "Network Configuration",
1038            "the network adapter is connected",
1039        );
1040        assert!((0.0..=1.0).contains(&s));
1041    }
1042
1043    #[test]
1044    fn mock_neural_different_from_lexical() {
1045        let lexical = MockCrossEncoder::new();
1046        let neural = MockCrossEncoder::new_neural();
1047        let s_lex = lexical.score("machine learning", "ML title", "neural networks");
1048        let s_neu = neural.score("machine learning", "ML title", "neural networks");
1049        // They should use different scoring formulas
1050        assert_ne!(s_lex, s_neu);
1051    }
1052}
1053
1054#[test]
1055fn score_handles_empty_query_string() {
1056    let s = lexical_score("", "Document Title", "This is document content");
1057    assert_eq!(s, 0.0, "empty query must return 0.0");
1058}
1059
1060#[test]
1061fn score_handles_unicode_normalization() {
1062    // Query with accented characters, document with decomposed/composed variants
1063    let s1 = lexical_score("café", "café", "the café is open");
1064    let s2 = lexical_score("cafe", "cafe", "the cafe is open");
1065    // Both should score positively; exact equality not required due to normalization
1066    assert!(s1 > 0.0);
1067    assert!(s2 > 0.0);
1068}
1069
1070#[test]
1071fn score_handles_very_long_content_truncation() {
1072    // Query and document with extreme length (lexical tokenizer should handle it)
1073    let long_content = "word ".repeat(10000); // 50k+ chars
1074    let s = lexical_score("word", "title", &long_content);
1075    assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
1076}
1077
1078#[test]
1079fn bigram_score_with_single_token_query() {
1080    // Query with only one token — bigrams should be empty, no crash
1081    let s = lexical_score("query", "Single Token Title", "single token content");
1082    assert!((0.0..=1.0).contains(&s));
1083}