Skip to main content

ai_memory/
reranker.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! Cross-encoder reranking for search results.
5//!
6//! A cross-encoder takes a (query, document) pair and produces a relevance
7//! score. This is more accurate than cosine similarity of independent
8//! embeddings but slower since it must run for each candidate.
9//!
10//! **Two implementations:**
11//! - `CrossEncoder::Lexical` — lightweight term-overlap scorer (default).
12//! - `CrossEncoder::Neural` — BERT-based cross-encoder loaded via candle
13//!   from `cross-encoder/ms-marco-MiniLM-L-6-v2` (~80 MB, ONNX-free).
14
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, Mutex};
17
18use anyhow::{Context, Result};
19use candle_core::{Device, Tensor};
20use candle_nn::VarBuilder;
21use candle_transformers::models::bert::{BertModel, Config as BertConfig};
22use hf_hub::{Repo, RepoType, api::sync::Api};
23use tokenizers::Tokenizer;
24
25use crate::models::Memory;
26
27/// Blend weight applied to the original (embedding/FTS) score.
28const ORIGINAL_WEIGHT: f64 = 0.6;
29/// Blend weight applied to the cross-encoder score.
30const CROSS_ENCODER_WEIGHT: f64 = 0.4;
31
32const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
33const CROSS_ENCODER_MAX_SEQ: usize = 512;
34const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
35
36/// Cross-encoder for (query, document) relevance scoring.
37pub enum CrossEncoder {
38    /// Lightweight lexical cross-encoder using term overlap signals.
39    Lexical,
40    /// Neural BERT-based cross-encoder (ms-marco-MiniLM-L-6-v2).
41    Neural {
42        model: Arc<Mutex<BertModel>>,
43        tokenizer: Arc<Tokenizer>,
44        classifier_weight: Tensor,
45        classifier_bias: Tensor,
46        device: Device,
47    },
48}
49
50impl CrossEncoder {
51    /// Create a new lexical cross-encoder (no model download required).
52    pub fn new() -> Self {
53        Self::Lexical
54    }
55
56    /// Create a neural cross-encoder by downloading ms-marco-MiniLM-L-6-v2.
57    ///
58    /// Falls back to lexical if download or loading fails.
59    pub fn new_neural() -> Self {
60        match Self::load_neural() {
61            Ok(ce) => ce,
62            Err(e) => {
63                eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
64                Self::Lexical
65            }
66        }
67    }
68
69    fn load_neural() -> Result<Self> {
70        let device = Device::Cpu;
71
72        let api = Api::new().context("failed to init HuggingFace Hub API")?;
73        let repo = api.repo(Repo::new(
74            CROSS_ENCODER_MODEL_ID.to_string(),
75            RepoType::Model,
76        ));
77
78        let config_path = repo
79            .get("config.json")
80            .context("failed to download config.json")?;
81        let tokenizer_path = repo
82            .get("tokenizer.json")
83            .context("failed to download tokenizer.json")?;
84        let weights_path = repo
85            .get("model.safetensors")
86            .context("failed to download model.safetensors")?;
87
88        // Load BERT config
89        let config_data = std::fs::read_to_string(&config_path)
90            .context("failed to read cross-encoder config.json")?;
91        let config: BertConfig = serde_json::from_str(&config_data)
92            .context("failed to parse cross-encoder config.json")?;
93
94        // Load tokenizer
95        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
96            .map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
97        let truncation = tokenizers::TruncationParams {
98            max_length: CROSS_ENCODER_MAX_SEQ,
99            ..Default::default()
100        };
101        tokenizer
102            .with_truncation(Some(truncation))
103            .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
104        tokenizer.with_padding(None);
105
106        // Load model weights
107        let vb = unsafe {
108            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
109                .context("failed to load cross-encoder weights")?
110        };
111
112        let model = BertModel::load(vb.clone(), &config)
113            .context("failed to build cross-encoder BertModel")?;
114
115        // Load the classification head: classifier.weight [1, hidden_dim] and classifier.bias [1]
116        let classifier_weight = vb
117            .get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
118            .context("failed to load classifier.weight")?;
119        let classifier_bias = vb
120            .get(1, "classifier.bias")
121            .context("failed to load classifier.bias")?;
122
123        Ok(Self::Neural {
124            model: Arc::new(Mutex::new(model)),
125            tokenizer: Arc::new(tokenizer),
126            classifier_weight,
127            classifier_bias,
128            device,
129        })
130    }
131
132    /// Score a single (query, document) pair.
133    ///
134    /// Returns a relevance score in `0.0..=1.0`.
135    pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
136        match self {
137            Self::Lexical => lexical_score(query, title, content),
138            Self::Neural {
139                model,
140                tokenizer,
141                classifier_weight,
142                classifier_bias,
143                device,
144            } => {
145                let model_guard = match model.lock() {
146                    Ok(g) => g,
147                    Err(e) => {
148                        tracing::warn!("cross-encoder model lock poisoned: {e}");
149                        return lexical_score(query, title, content);
150                    }
151                };
152                match Self::neural_score(
153                    &model_guard,
154                    tokenizer,
155                    classifier_weight,
156                    classifier_bias,
157                    device,
158                    query,
159                    title,
160                    content,
161                ) {
162                    Ok(s) => s,
163                    Err(e) => {
164                        tracing::warn!(
165                            "neural cross-encoder score failed: {e}, using lexical fallback"
166                        );
167                        lexical_score(query, title, content)
168                    }
169                }
170            }
171        }
172    }
173
174    #[allow(clippy::too_many_arguments)]
175    fn neural_score(
176        model: &BertModel,
177        tokenizer: &Tokenizer,
178        classifier_weight: &Tensor,
179        classifier_bias: &Tensor,
180        device: &Device,
181        query: &str,
182        title: &str,
183        content: &str,
184    ) -> Result<f32> {
185        // Cross-encoder input: "[CLS] query [SEP] title content [SEP]"
186        let document = format!("{title} {content}");
187
188        let encoding = tokenizer
189            .encode((query, document.as_str()), true)
190            .map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
191
192        let input_ids = encoding.get_ids();
193        let attention_mask = encoding.get_attention_mask();
194        let token_type_ids = encoding.get_type_ids();
195        let seq_len = input_ids.len();
196
197        let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
198        let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
199        let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
200
201        // Forward pass through BERT → [1, seq_len, 384]
202        let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
203
204        // Take [CLS] token (first token) → [1, 384]
205        let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
206
207        // Classification head: logit = cls @ weight^T + bias → [1, 1]
208        let logit = cls
209            .matmul(&classifier_weight.t()?)?
210            .broadcast_add(classifier_bias)?;
211
212        // Extract scalar logit and apply sigmoid to get [0, 1] score
213        let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
214        let score = 1.0 / (1.0 + (-logit_val).exp());
215
216        Ok(score)
217    }
218
219    /// Whether this is a neural cross-encoder.
220    pub fn is_neural(&self) -> bool {
221        matches!(self, Self::Neural { .. })
222    }
223
224    /// Rerank a set of candidates by blending their original scores with
225    /// cross-encoder scores.
226    ///
227    /// **Blend formula:** `final = 0.6 * original + 0.4 * cross_encoder`
228    ///
229    /// Results are returned sorted by `final_score` descending.
230    pub fn rerank(&self, query: &str, mut candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
231        let mut scored: Vec<(Memory, f64)> = candidates
232            .drain(..)
233            .map(|(mem, original_score)| {
234                let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
235                let final_score =
236                    ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
237                (mem, final_score)
238            })
239            .collect();
240
241        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
242        scored
243    }
244}
245
246impl Default for CrossEncoder {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252// ---------------------------------------------------------------------------
253// Lexical cross-encoder (original implementation)
254// ---------------------------------------------------------------------------
255
256fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
257    let query_terms = tokenize(query);
258    if query_terms.is_empty() {
259        return 0.0;
260    }
261
262    let title_terms = tokenize(title);
263    let content_terms = tokenize(content);
264
265    let doc_terms: HashSet<&str> = title_terms
266        .iter()
267        .chain(content_terms.iter())
268        .copied()
269        .collect();
270    let query_set: HashSet<&str> = query_terms.iter().copied().collect();
271
272    // 1. Jaccard term overlap
273    #[allow(clippy::cast_precision_loss)]
274    let intersection = query_set.intersection(&doc_terms).count() as f32;
275    #[allow(clippy::cast_precision_loss)]
276    let union = query_set.union(&doc_terms).count() as f32;
277    let jaccard = if union > 0.0 {
278        intersection / union
279    } else {
280        0.0
281    };
282
283    // 2. TF-IDF-like term weighting
284    let doc_all: Vec<&str> = title_terms
285        .iter()
286        .chain(content_terms.iter())
287        .copied()
288        .collect();
289    let tf_idf = tfidf_score(&query_terms, &doc_all);
290
291    // 3. Bigram overlap bonus
292    let query_bigrams = bigrams(&query_terms);
293    let doc_bigrams = bigrams(&doc_all);
294    let bigram_overlap = if query_bigrams.is_empty() {
295        0.0
296    } else {
297        let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
298        #[allow(clippy::cast_precision_loss)]
299        let hits = query_bigrams
300            .iter()
301            .filter(|b| doc_bigram_set.contains(b))
302            .count() as f32;
303        #[allow(clippy::cast_precision_loss)]
304        let query_bigrams_len = query_bigrams.len() as f32;
305        hits / query_bigrams_len
306    };
307
308    // 4. Title match bonus
309    let title_set: HashSet<&str> = title_terms.iter().copied().collect();
310    #[allow(clippy::cast_precision_loss)]
311    let title_hits = query_set.intersection(&title_set).count() as f32;
312    #[allow(clippy::cast_precision_loss)]
313    let title_bonus = if query_set.is_empty() {
314        0.0
315    } else {
316        title_hits / query_set.len() as f32
317    };
318
319    let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
320    raw.clamp(0.0, 1.0)
321}
322
323// ---------------------------------------------------------------------------
324// Internal helpers
325// ---------------------------------------------------------------------------
326
327fn tokenize(text: &str) -> Vec<&str> {
328    text.split(|c: char| !c.is_alphanumeric() && c != '\'')
329        .filter(|w| !w.is_empty())
330        .collect()
331}
332
333fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
334    if doc_tokens.is_empty() || query_terms.is_empty() {
335        return 0.0;
336    }
337
338    let mut tf_map: HashMap<&str, usize> = HashMap::new();
339    for &tok in doc_tokens {
340        *tf_map.entry(tok).or_insert(0) += 1;
341    }
342
343    #[allow(clippy::cast_precision_loss)]
344    let total = doc_tokens.len() as f32;
345    #[allow(clippy::cast_precision_loss)]
346    let unique = tf_map.len() as f32;
347
348    let mut score_sum: f32 = 0.0;
349    let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
350
351    for qt in &query_lower {
352        #[allow(clippy::cast_precision_loss)]
353        let tf = tf_map
354            .iter()
355            .filter(|(k, _)| k.to_lowercase() == *qt)
356            .map(|(_, &v)| v)
357            .sum::<usize>() as f32;
358
359        if tf == 0.0 {
360            continue;
361        }
362
363        let tf_norm = tf / total;
364        #[allow(clippy::cast_precision_loss)]
365        let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
366        let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
367
368        score_sum += tf_norm * idf;
369    }
370
371    #[allow(clippy::cast_precision_loss)]
372    let max_possible = query_lower.len() as f32;
373    (score_sum / max_possible).clamp(0.0, 1.0)
374}
375
376fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
377    tokens.windows(2).map(|w| (w[0], w[1])).collect()
378}
379
380// ---------------------------------------------------------------------------
381// Tests
382// ---------------------------------------------------------------------------
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use crate::models::{Memory, Tier};
388
389    fn make_memory(title: &str, content: &str) -> Memory {
390        Memory {
391            id: "test-id".to_string(),
392            tier: Tier::Mid,
393            namespace: "test".to_string(),
394            title: title.to_string(),
395            content: content.to_string(),
396            tags: vec![],
397            priority: 5,
398            confidence: 1.0,
399            source: "test".to_string(),
400            access_count: 0,
401            created_at: "2026-01-01T00:00:00Z".to_string(),
402            updated_at: "2026-01-01T00:00:00Z".to_string(),
403            last_accessed_at: None,
404            expires_at: None,
405            metadata: serde_json::json!({}),
406        }
407    }
408
409    #[test]
410    fn lexical_score_returns_zero_for_empty_query() {
411        assert_eq!(lexical_score("", "some title", "some content"), 0.0);
412    }
413
414    #[test]
415    fn lexical_score_returns_zero_for_no_overlap() {
416        let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
417        assert!(s < 0.05, "expected near-zero, got {s}");
418    }
419
420    #[test]
421    fn lexical_score_rewards_title_match() {
422        let content = "This document discusses network configuration for LAN setups.";
423        let s_title_match = lexical_score(
424            "network configuration",
425            "Network Configuration Guide",
426            content,
427        );
428        let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
429        assert!(
430            s_title_match > s_no_title,
431            "title match ({s_title_match}) should beat no title match ({s_no_title})"
432        );
433    }
434
435    #[test]
436    fn lexical_score_is_bounded_zero_one() {
437        let s = lexical_score(
438            "the quick brown fox jumps over the lazy dog",
439            "the quick brown fox",
440            "the quick brown fox jumps over the lazy dog and more words",
441        );
442        assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
443    }
444
445    #[test]
446    fn rerank_reorders_candidates() {
447        let ce = CrossEncoder::new();
448        let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
449        let b = make_memory("Grocery list", "milk eggs bread butter cheese");
450        let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
451        let reranked = ce.rerank("cross-encoder reranking", candidates);
452        assert_eq!(reranked[0].0.title, "Rust cross-encoder");
453    }
454
455    #[test]
456    fn rerank_preserves_candidate_count() {
457        let ce = CrossEncoder::new();
458        let candidates = vec![
459            (make_memory("A", "alpha"), 0.5),
460            (make_memory("B", "beta"), 0.6),
461            (make_memory("C", "gamma"), 0.7),
462        ];
463        let reranked = ce.rerank("alpha", candidates);
464        assert_eq!(reranked.len(), 3);
465    }
466
467    #[test]
468    fn bigram_overlap_boosts_phrase_match() {
469        let s_phrase = lexical_score(
470            "network adapter",
471            "title",
472            "the network adapter is connected to the LAN",
473        );
474        let s_scattered = lexical_score(
475            "network adapter",
476            "title",
477            "the adapter handles the network traffic independently",
478        );
479        assert!(
480            s_phrase > s_scattered,
481            "phrase match ({s_phrase}) should beat scattered ({s_scattered})"
482        );
483    }
484
485    // -----------------------------------------------------------------
486    // W11/S11b — input-count invariants for the rerank() API
487    // -----------------------------------------------------------------
488
489    #[test]
490    fn test_rerank_preserves_input_count_heuristic() {
491        let ce = CrossEncoder::new();
492        // Build 5 distinct candidates with varied original scores.
493        let candidates: Vec<(Memory, f64)> = (0..5)
494            .map(|i| {
495                (
496                    make_memory(
497                        &format!("title {i}"),
498                        &format!("content body number {i} with some words"),
499                    ),
500                    f64::from(i) * 0.1,
501                )
502            })
503            .collect();
504        let query = "title content body";
505        let reranked = ce.rerank(query, candidates);
506        assert_eq!(
507            reranked.len(),
508            5,
509            "heuristic rerank must preserve candidate count, got {} = {:?}",
510            reranked.len(),
511            reranked
512                .iter()
513                .map(|(m, s)| (&m.title, *s))
514                .collect::<Vec<_>>()
515        );
516        // Sorted descending by final score (rerank contract).
517        for w in reranked.windows(2) {
518            assert!(
519                w[0].1 >= w[1].1,
520                "rerank output must be descending by score: {} < {}",
521                w[0].1,
522                w[1].1
523            );
524        }
525    }
526
527    #[test]
528    fn test_rerank_zero_candidates_returns_empty_heuristic() {
529        let ce = CrossEncoder::new();
530        let reranked = ce.rerank("query", Vec::new());
531        assert!(reranked.is_empty());
532    }
533
534    // Neural variant: gated to avoid pulling 80MB BERT weights at test time.
535    // Run with `--features test-with-models` once the cross-encoder feature
536    // exists upstream.
537    #[cfg(feature = "test-with-models")]
538    #[test]
539    fn test_rerank_preserves_input_count_neural_if_available() {
540        let ce = CrossEncoder::new_neural();
541        let candidates: Vec<(Memory, f64)> = (0..5)
542            .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
543            .collect();
544        let reranked = ce.rerank("body", candidates);
545        assert_eq!(reranked.len(), 5);
546    }
547
548    // -----------------------------------------------------------------
549    // W12-E — heuristic-path branch coverage for reranker.rs
550    //
551    // Targets the Lexical variant only. The Neural variant requires
552    // downloading 80+ MB of BERT weights from HuggingFace Hub and is
553    // gated behind `feature = "test-with-models"`.
554    // -----------------------------------------------------------------
555
556    #[test]
557    fn w12e_default_is_lexical() {
558        let ce = CrossEncoder::default();
559        assert!(!ce.is_neural(), "Default::default() must return Lexical");
560    }
561
562    #[test]
563    fn w12e_new_returns_lexical() {
564        let ce = CrossEncoder::new();
565        assert!(!ce.is_neural());
566    }
567
568    #[test]
569    fn w12e_score_dispatch_lexical_matches_helper() {
570        // The CrossEncoder::score() dispatcher must delegate to lexical_score()
571        // for the Lexical variant. Compute both and assert exact equality.
572        let ce = CrossEncoder::new();
573        let q = "rust async runtime";
574        let title = "Tokio: Rust async runtime";
575        let content = "Tokio is an async runtime for the Rust programming language.";
576        let via_dispatcher = ce.score(q, title, content);
577        let direct = lexical_score(q, title, content);
578        assert!((via_dispatcher - direct).abs() < f32::EPSILON);
579    }
580
581    #[test]
582    fn w12e_score_empty_inputs_safe() {
583        let ce = CrossEncoder::new();
584        // Empty query → 0.0 by short-circuit in lexical_score
585        assert_eq!(ce.score("", "title", "content"), 0.0);
586        // Empty title and content with non-empty query — must not panic
587        let s = ce.score("query", "", "");
588        assert!((0.0..=1.0).contains(&s));
589        // Whitespace-only query treated as empty after tokenization
590        let s_ws = ce.score("   \t\n", "title", "content");
591        assert_eq!(s_ws, 0.0);
592        // Punctuation-only query also yields no tokens
593        let s_punct = ce.score("!?.,;:", "title", "content");
594        assert_eq!(s_punct, 0.0);
595    }
596
597    #[test]
598    fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
599        // Mixed Unicode tokens with apostrophes, accents, emoji boundaries.
600        let s_unicode = lexical_score(
601            "café résumé d'oeuvre",
602            "Le Café d'Oeuvre",
603            "résumé du café avec d'oeuvre noté",
604        );
605        assert!(
606            (0.0..=1.0).contains(&s_unicode),
607            "unicode score {s_unicode} out of bounds"
608        );
609
610        // Very long content stresses the length-normalization branches.
611        let huge = "alpha beta gamma delta ".repeat(2_500);
612        let s_long = lexical_score("alpha gamma", "headline", &huge);
613        assert!(
614            (0.0..=1.0).contains(&s_long),
615            "long score {s_long} out of bounds"
616        );
617    }
618
619    #[test]
620    fn w12e_lexical_score_perfect_overlap_high() {
621        // 100% query overlap with title and content should produce a high
622        // (but bounded) score.
623        let s = lexical_score(
624            "alpha beta gamma",
625            "alpha beta gamma",
626            "alpha beta gamma alpha beta gamma",
627        );
628        assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
629        assert!(s <= 1.0);
630    }
631
632    #[test]
633    fn w12e_tfidf_score_empty_doc_returns_zero() {
634        // Branch: doc_tokens.is_empty() → 0.0 short-circuit.
635        let q = vec!["alpha", "beta"];
636        let doc: Vec<&str> = Vec::new();
637        assert_eq!(tfidf_score(&q, &doc), 0.0);
638    }
639
640    #[test]
641    fn w12e_tfidf_score_empty_query_returns_zero() {
642        // Branch: query_terms.is_empty() → 0.0 short-circuit.
643        let q: Vec<&str> = Vec::new();
644        let doc = vec!["alpha", "beta", "gamma"];
645        assert_eq!(tfidf_score(&q, &doc), 0.0);
646    }
647
648    #[test]
649    fn w12e_tfidf_score_no_matching_terms() {
650        // Query terms entirely absent from doc → tf == 0 continue branch.
651        let q = vec!["xenon", "kryptonite"];
652        let doc = vec!["alpha", "beta", "gamma"];
653        let s = tfidf_score(&q, &doc);
654        assert_eq!(s, 0.0);
655    }
656
657    #[test]
658    fn w12e_tfidf_score_partial_match_bounded() {
659        // Mixed presence/absence; clamp branch reachable.
660        let q = vec!["alpha", "missing"];
661        let doc = vec!["alpha", "alpha", "beta", "gamma"];
662        let s = tfidf_score(&q, &doc);
663        assert!((0.0..=1.0).contains(&s));
664        assert!(s > 0.0);
665    }
666
667    #[test]
668    fn w12e_bigrams_empty_and_single_and_multi() {
669        // Empty input → empty bigram list.
670        let empty: Vec<&str> = Vec::new();
671        assert!(bigrams(&empty).is_empty());
672
673        // Single token → no bigrams (windows(2) yields nothing).
674        let one = vec!["solo"];
675        assert!(bigrams(&one).is_empty());
676
677        // Multi-token → N-1 bigrams.
678        let three = vec!["a", "b", "c"];
679        let bg = bigrams(&three);
680        assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
681    }
682
683    #[test]
684    fn w12e_tokenize_handles_apostrophe_and_unicode() {
685        // Apostrophes are preserved (e.g., "don't"), other punctuation splits.
686        let toks = tokenize("don't stop, I won't!");
687        assert!(toks.contains(&"don't"));
688        assert!(toks.contains(&"won't"));
689        assert!(toks.contains(&"stop"));
690        assert!(toks.contains(&"I"));
691
692        // Pure-punctuation yields no tokens.
693        let none = tokenize("!!!,,,;;;");
694        assert!(none.is_empty());
695
696        // Empty string yields no tokens.
697        let empty = tokenize("");
698        assert!(empty.is_empty());
699
700        // Unicode alphanumerics survive (café = 4 alphanumeric chars).
701        let unicode = tokenize("café résumé");
702        assert_eq!(unicode.len(), 2);
703    }
704
705    #[test]
706    fn w12e_rerank_single_candidate_keeps_it() {
707        let ce = CrossEncoder::new();
708        let only = make_memory("solo title", "solo content body");
709        let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
710        assert_eq!(out.len(), 1);
711        assert_eq!(out[0].0.title, "solo title");
712        // Final score is a blend of original and CE score, both nonneg.
713        assert!(out[0].1 >= 0.0);
714    }
715
716    #[test]
717    fn w12e_rerank_identical_originals_stable_under_score() {
718        // When original scores are identical, ordering is determined by the
719        // CE score. The candidate whose title/content overlaps the query
720        // should rank first.
721        let ce = CrossEncoder::new();
722        let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
723        let off_topic = make_memory("grocery", "milk eggs bread");
724        let out = ce.rerank(
725            "rust async",
726            vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
727        );
728        assert_eq!(out.len(), 2);
729        assert_eq!(out[0].0.title, "rust async runtime");
730    }
731
732    #[test]
733    fn w12e_rerank_descending_invariant_holds_across_shapes() {
734        // Property-style: irrespective of input shape, output is sorted desc.
735        let ce = CrossEncoder::new();
736        let cands: Vec<(Memory, f64)> = vec![
737            (make_memory("a", "alpha words"), 0.10),
738            (make_memory("b", "beta words"), 0.95),
739            (make_memory("c", "gamma alpha"), 0.55),
740            (make_memory("d", ""), 0.0),
741            (make_memory("", "empty title doc"), 0.30),
742        ];
743        let out = ce.rerank("alpha", cands);
744        assert_eq!(out.len(), 5);
745        for w in out.windows(2) {
746            assert!(
747                w[0].1 >= w[1].1,
748                "non-descending pair: {} then {}",
749                w[0].1,
750                w[1].1
751            );
752        }
753    }
754
755    #[test]
756    fn w12e_lexical_score_no_title_branch_via_empty_title() {
757        // Empty title means title_set is empty; title_bonus == 0.0.
758        // query_set non-empty so the else branch (title_hits / |Q|) runs.
759        let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
760        let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
761        assert!(s_with_title >= s_empty_title);
762        assert!((0.0..=1.0).contains(&s_empty_title));
763    }
764
765    #[test]
766    fn w12e_lexical_score_query_terms_only_in_title() {
767        // Title contains all query terms; content has none.
768        let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
769        assert!(s > 0.0);
770        assert!(s <= 1.0);
771    }
772}
773
774#[cfg(test)]
775#[allow(
776    clippy::unused_self,
777    clippy::unnecessary_wraps,
778    clippy::needless_pass_by_value,
779    clippy::wildcard_imports
780)]
781pub mod test_support {
782    use super::*;
783
784    /// Mock neural cross-encoder for testing. Returns deterministic scores
785    /// based on (query, title, content) without loading BERT.
786    pub struct MockCrossEncoder {
787        pub use_neural: bool,
788    }
789
790    impl MockCrossEncoder {
791        /// Create a mock lexical encoder (like CrossEncoder::new()).
792        pub fn new() -> Self {
793            Self { use_neural: false }
794        }
795
796        /// Create a mock neural encoder (like CrossEncoder::new_neural()).
797        pub fn new_neural() -> Self {
798            Self { use_neural: true }
799        }
800
801        /// Mock score: deterministic hash-based score in [0, 1].
802        /// Neural path uses a different formula than lexical for testing.
803        pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
804            if self.use_neural {
805                // Neural mock: combine query+title hash
806                let combined = format!("{}{}", query, title);
807                let hash = combined.bytes().fold(0u32, |acc, b| {
808                    acc.wrapping_mul(31).wrapping_add(u32::from(b))
809                });
810                let base = ((hash % 1000) as f32) / 1000.0;
811                // Boost for exact title matches
812                if title.contains(query) {
813                    (base * 0.5 + 0.5).min(1.0)
814                } else {
815                    base
816                }
817            } else {
818                // Lexical path uses the real lexical_score
819                lexical_score(query, title, content)
820            }
821        }
822
823        /// Whether this is a neural mock.
824        pub fn is_neural(&self) -> bool {
825            self.use_neural
826        }
827
828        /// Rerank candidates (same blending formula as real CrossEncoder).
829        pub fn rerank(
830            &self,
831            query: &str,
832            mut candidates: Vec<(Memory, f64)>,
833        ) -> Vec<(Memory, f64)> {
834            let mut scored: Vec<(Memory, f64)> = candidates
835                .drain(..)
836                .map(|(mem, original_score)| {
837                    let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
838                    let final_score =
839                        ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
840                    (mem, final_score)
841                })
842                .collect();
843
844            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
845            scored
846        }
847    }
848
849    impl Default for MockCrossEncoder {
850        fn default() -> Self {
851            Self::new()
852        }
853    }
854}
855
856#[cfg(test)]
857mod mock_tests {
858    use super::test_support::*;
859    use crate::models::{Memory, Tier};
860
861    fn make_memory(title: &str, content: &str) -> Memory {
862        Memory {
863            id: "test-id".to_string(),
864            tier: Tier::Mid,
865            namespace: "test".to_string(),
866            title: title.to_string(),
867            content: content.to_string(),
868            tags: vec![],
869            priority: 5,
870            confidence: 1.0,
871            source: "test".to_string(),
872            access_count: 0,
873            created_at: "2026-01-01T00:00:00Z".to_string(),
874            updated_at: "2026-01-01T00:00:00Z".to_string(),
875            last_accessed_at: None,
876            expires_at: None,
877            metadata: serde_json::json!({}),
878        }
879    }
880
881    #[test]
882    fn mock_lexical_new() {
883        let ce = MockCrossEncoder::new();
884        assert!(!ce.is_neural());
885    }
886
887    #[test]
888    fn mock_neural_new() {
889        let ce = MockCrossEncoder::new_neural();
890        assert!(ce.is_neural());
891    }
892
893    #[test]
894    fn mock_neural_score_deterministic() {
895        let ce = MockCrossEncoder::new_neural();
896        let s1 = ce.score("query", "title", "content");
897        let s2 = ce.score("query", "title", "content");
898        assert_eq!(s1, s2);
899    }
900
901    #[test]
902    fn mock_neural_score_title_match_boost() {
903        let ce = MockCrossEncoder::new_neural();
904        let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
905        let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
906        assert!(
907            s_title_contains > s_no_match,
908            "title match ({s_title_contains}) should beat no match ({s_no_match})"
909        );
910    }
911
912    #[test]
913    fn mock_neural_score_bounded() {
914        let ce = MockCrossEncoder::new_neural();
915        for query in &["test", "neural", "reranker", "machine learning"] {
916            for title in &["a", "b", "the quick brown"] {
917                let s = ce.score(query, title, "content");
918                assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
919            }
920        }
921    }
922
923    #[test]
924    fn mock_neural_rerank_reorders() {
925        let ce = MockCrossEncoder::new_neural();
926        let a = make_memory("neural network", "deep learning with transformers");
927        let b = make_memory("grocery list", "milk eggs bread butter");
928        let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
929        let reranked = ce.rerank("neural network", candidates);
930        // Neural encoder should boost the neural-network-titled memory
931        assert_eq!(reranked[0].0.title, "neural network");
932    }
933
934    #[test]
935    fn mock_neural_rerank_preserves_count() {
936        let ce = MockCrossEncoder::new_neural();
937        let candidates = vec![
938            (make_memory("A", "content a"), 0.5),
939            (make_memory("B", "content b"), 0.4),
940            (make_memory("C", "content c"), 0.6),
941        ];
942        let reranked = ce.rerank("test", candidates);
943        assert_eq!(reranked.len(), 3);
944    }
945
946    #[test]
947    fn mock_lexical_path_via_mock() {
948        let ce = MockCrossEncoder::new();
949        let s = ce.score(
950            "network adapter",
951            "Network Configuration",
952            "the network adapter is connected",
953        );
954        assert!((0.0..=1.0).contains(&s));
955    }
956
957    #[test]
958    fn mock_neural_different_from_lexical() {
959        let lexical = MockCrossEncoder::new();
960        let neural = MockCrossEncoder::new_neural();
961        let s_lex = lexical.score("machine learning", "ML title", "neural networks");
962        let s_neu = neural.score("machine learning", "ML title", "neural networks");
963        // They should use different scoring formulas
964        assert_ne!(s_lex, s_neu);
965    }
966}
967
968#[test]
969fn score_handles_empty_query_string() {
970    let s = lexical_score("", "Document Title", "This is document content");
971    assert_eq!(s, 0.0, "empty query must return 0.0");
972}
973
974#[test]
975fn score_handles_unicode_normalization() {
976    // Query with accented characters, document with decomposed/composed variants
977    let s1 = lexical_score("café", "café", "the café is open");
978    let s2 = lexical_score("cafe", "cafe", "the cafe is open");
979    // Both should score positively; exact equality not required due to normalization
980    assert!(s1 > 0.0);
981    assert!(s2 > 0.0);
982}
983
984#[test]
985fn score_handles_very_long_content_truncation() {
986    // Query and document with extreme length (lexical tokenizer should handle it)
987    let long_content = "word ".repeat(10000); // 50k+ chars
988    let s = lexical_score("word", "title", &long_content);
989    assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
990}
991
992#[test]
993fn bigram_score_with_single_token_query() {
994    // Query with only one token — bigrams should be empty, no crash
995    let s = lexical_score("query", "Single Token Title", "single token content");
996    assert!((0.0..=1.0).contains(&s));
997}