1use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, Mutex};
17
18use anyhow::{Context, Result};
19use candle_core::{Device, Tensor};
20use candle_nn::VarBuilder;
21use candle_transformers::models::bert::{BertModel, Config as BertConfig};
22use hf_hub::{Repo, RepoType, api::sync::Api};
23use tokenizers::Tokenizer;
24
25use crate::models::Memory;
26
27const ORIGINAL_WEIGHT: f64 = 0.6;
29const CROSS_ENCODER_WEIGHT: f64 = 0.4;
31
32const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
33const CROSS_ENCODER_MAX_SEQ: usize = 512;
34const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
35
36pub enum CrossEncoder {
38 Lexical,
40 Neural {
42 model: Arc<Mutex<BertModel>>,
43 tokenizer: Arc<Tokenizer>,
44 classifier_weight: Tensor,
45 classifier_bias: Tensor,
46 device: Device,
47 },
48}
49
50impl CrossEncoder {
51 pub fn new() -> Self {
53 Self::Lexical
54 }
55
56 pub fn new_neural() -> Self {
60 match Self::load_neural() {
61 Ok(ce) => ce,
62 Err(e) => {
63 eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
64 Self::Lexical
65 }
66 }
67 }
68
69 fn load_neural() -> Result<Self> {
70 let device = Device::Cpu;
71
72 let api = Api::new().context("failed to init HuggingFace Hub API")?;
73 let repo = api.repo(Repo::new(
74 CROSS_ENCODER_MODEL_ID.to_string(),
75 RepoType::Model,
76 ));
77
78 let config_path = repo
79 .get("config.json")
80 .context("failed to download config.json")?;
81 let tokenizer_path = repo
82 .get("tokenizer.json")
83 .context("failed to download tokenizer.json")?;
84 let weights_path = repo
85 .get("model.safetensors")
86 .context("failed to download model.safetensors")?;
87
88 let config_data = std::fs::read_to_string(&config_path)
90 .context("failed to read cross-encoder config.json")?;
91 let config: BertConfig = serde_json::from_str(&config_data)
92 .context("failed to parse cross-encoder config.json")?;
93
94 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
96 .map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
97 let truncation = tokenizers::TruncationParams {
98 max_length: CROSS_ENCODER_MAX_SEQ,
99 ..Default::default()
100 };
101 tokenizer
102 .with_truncation(Some(truncation))
103 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
104 tokenizer.with_padding(None);
105
106 let vb = unsafe {
108 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
109 .context("failed to load cross-encoder weights")?
110 };
111
112 let model = BertModel::load(vb.clone(), &config)
113 .context("failed to build cross-encoder BertModel")?;
114
115 let classifier_weight = vb
117 .get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
118 .context("failed to load classifier.weight")?;
119 let classifier_bias = vb
120 .get(1, "classifier.bias")
121 .context("failed to load classifier.bias")?;
122
123 Ok(Self::Neural {
124 model: Arc::new(Mutex::new(model)),
125 tokenizer: Arc::new(tokenizer),
126 classifier_weight,
127 classifier_bias,
128 device,
129 })
130 }
131
132 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
136 match self {
137 Self::Lexical => lexical_score(query, title, content),
138 Self::Neural {
139 model,
140 tokenizer,
141 classifier_weight,
142 classifier_bias,
143 device,
144 } => {
145 let model_guard = match model.lock() {
146 Ok(g) => g,
147 Err(e) => {
148 tracing::warn!("cross-encoder model lock poisoned: {e}");
149 return lexical_score(query, title, content);
150 }
151 };
152 match Self::neural_score(
153 &model_guard,
154 tokenizer,
155 classifier_weight,
156 classifier_bias,
157 device,
158 query,
159 title,
160 content,
161 ) {
162 Ok(s) => s,
163 Err(e) => {
164 tracing::warn!(
165 "neural cross-encoder score failed: {e}, using lexical fallback"
166 );
167 lexical_score(query, title, content)
168 }
169 }
170 }
171 }
172 }
173
174 #[allow(clippy::too_many_arguments)]
175 fn neural_score(
176 model: &BertModel,
177 tokenizer: &Tokenizer,
178 classifier_weight: &Tensor,
179 classifier_bias: &Tensor,
180 device: &Device,
181 query: &str,
182 title: &str,
183 content: &str,
184 ) -> Result<f32> {
185 let document = format!("{title} {content}");
187
188 let encoding = tokenizer
189 .encode((query, document.as_str()), true)
190 .map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
191
192 let input_ids = encoding.get_ids();
193 let attention_mask = encoding.get_attention_mask();
194 let token_type_ids = encoding.get_type_ids();
195 let seq_len = input_ids.len();
196
197 let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
198 let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
199 let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
200
201 let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
203
204 let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
206
207 let logit = cls
209 .matmul(&classifier_weight.t()?)?
210 .broadcast_add(classifier_bias)?;
211
212 let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
214 let score = 1.0 / (1.0 + (-logit_val).exp());
215
216 Ok(score)
217 }
218
219 pub fn is_neural(&self) -> bool {
221 matches!(self, Self::Neural { .. })
222 }
223
224 pub fn rerank(&self, query: &str, mut candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
231 let mut scored: Vec<(Memory, f64)> = candidates
232 .drain(..)
233 .map(|(mem, original_score)| {
234 let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
235 let final_score =
236 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
237 (mem, final_score)
238 })
239 .collect();
240
241 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
242 scored
243 }
244}
245
246impl Default for CrossEncoder {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
257 let query_terms = tokenize(query);
258 if query_terms.is_empty() {
259 return 0.0;
260 }
261
262 let title_terms = tokenize(title);
263 let content_terms = tokenize(content);
264
265 let doc_terms: HashSet<&str> = title_terms
266 .iter()
267 .chain(content_terms.iter())
268 .copied()
269 .collect();
270 let query_set: HashSet<&str> = query_terms.iter().copied().collect();
271
272 #[allow(clippy::cast_precision_loss)]
274 let intersection = query_set.intersection(&doc_terms).count() as f32;
275 #[allow(clippy::cast_precision_loss)]
276 let union = query_set.union(&doc_terms).count() as f32;
277 let jaccard = if union > 0.0 {
278 intersection / union
279 } else {
280 0.0
281 };
282
283 let doc_all: Vec<&str> = title_terms
285 .iter()
286 .chain(content_terms.iter())
287 .copied()
288 .collect();
289 let tf_idf = tfidf_score(&query_terms, &doc_all);
290
291 let query_bigrams = bigrams(&query_terms);
293 let doc_bigrams = bigrams(&doc_all);
294 let bigram_overlap = if query_bigrams.is_empty() {
295 0.0
296 } else {
297 let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
298 #[allow(clippy::cast_precision_loss)]
299 let hits = query_bigrams
300 .iter()
301 .filter(|b| doc_bigram_set.contains(b))
302 .count() as f32;
303 #[allow(clippy::cast_precision_loss)]
304 let query_bigrams_len = query_bigrams.len() as f32;
305 hits / query_bigrams_len
306 };
307
308 let title_set: HashSet<&str> = title_terms.iter().copied().collect();
310 #[allow(clippy::cast_precision_loss)]
311 let title_hits = query_set.intersection(&title_set).count() as f32;
312 #[allow(clippy::cast_precision_loss)]
313 let title_bonus = if query_set.is_empty() {
314 0.0
315 } else {
316 title_hits / query_set.len() as f32
317 };
318
319 let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
320 raw.clamp(0.0, 1.0)
321}
322
323fn tokenize(text: &str) -> Vec<&str> {
328 text.split(|c: char| !c.is_alphanumeric() && c != '\'')
329 .filter(|w| !w.is_empty())
330 .collect()
331}
332
333fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
334 if doc_tokens.is_empty() || query_terms.is_empty() {
335 return 0.0;
336 }
337
338 let mut tf_map: HashMap<&str, usize> = HashMap::new();
339 for &tok in doc_tokens {
340 *tf_map.entry(tok).or_insert(0) += 1;
341 }
342
343 #[allow(clippy::cast_precision_loss)]
344 let total = doc_tokens.len() as f32;
345 #[allow(clippy::cast_precision_loss)]
346 let unique = tf_map.len() as f32;
347
348 let mut score_sum: f32 = 0.0;
349 let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
350
351 for qt in &query_lower {
352 #[allow(clippy::cast_precision_loss)]
353 let tf = tf_map
354 .iter()
355 .filter(|(k, _)| k.to_lowercase() == *qt)
356 .map(|(_, &v)| v)
357 .sum::<usize>() as f32;
358
359 if tf == 0.0 {
360 continue;
361 }
362
363 let tf_norm = tf / total;
364 #[allow(clippy::cast_precision_loss)]
365 let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
366 let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
367
368 score_sum += tf_norm * idf;
369 }
370
371 #[allow(clippy::cast_precision_loss)]
372 let max_possible = query_lower.len() as f32;
373 (score_sum / max_possible).clamp(0.0, 1.0)
374}
375
376fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
377 tokens.windows(2).map(|w| (w[0], w[1])).collect()
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::models::{Memory, Tier};
388
389 fn make_memory(title: &str, content: &str) -> Memory {
390 Memory {
391 id: "test-id".to_string(),
392 tier: Tier::Mid,
393 namespace: "test".to_string(),
394 title: title.to_string(),
395 content: content.to_string(),
396 tags: vec![],
397 priority: 5,
398 confidence: 1.0,
399 source: "test".to_string(),
400 access_count: 0,
401 created_at: "2026-01-01T00:00:00Z".to_string(),
402 updated_at: "2026-01-01T00:00:00Z".to_string(),
403 last_accessed_at: None,
404 expires_at: None,
405 metadata: serde_json::json!({}),
406 }
407 }
408
409 #[test]
410 fn lexical_score_returns_zero_for_empty_query() {
411 assert_eq!(lexical_score("", "some title", "some content"), 0.0);
412 }
413
414 #[test]
415 fn lexical_score_returns_zero_for_no_overlap() {
416 let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
417 assert!(s < 0.05, "expected near-zero, got {s}");
418 }
419
420 #[test]
421 fn lexical_score_rewards_title_match() {
422 let content = "This document discusses network configuration for LAN setups.";
423 let s_title_match = lexical_score(
424 "network configuration",
425 "Network Configuration Guide",
426 content,
427 );
428 let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
429 assert!(
430 s_title_match > s_no_title,
431 "title match ({s_title_match}) should beat no title match ({s_no_title})"
432 );
433 }
434
435 #[test]
436 fn lexical_score_is_bounded_zero_one() {
437 let s = lexical_score(
438 "the quick brown fox jumps over the lazy dog",
439 "the quick brown fox",
440 "the quick brown fox jumps over the lazy dog and more words",
441 );
442 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
443 }
444
445 #[test]
446 fn rerank_reorders_candidates() {
447 let ce = CrossEncoder::new();
448 let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
449 let b = make_memory("Grocery list", "milk eggs bread butter cheese");
450 let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
451 let reranked = ce.rerank("cross-encoder reranking", candidates);
452 assert_eq!(reranked[0].0.title, "Rust cross-encoder");
453 }
454
455 #[test]
456 fn rerank_preserves_candidate_count() {
457 let ce = CrossEncoder::new();
458 let candidates = vec![
459 (make_memory("A", "alpha"), 0.5),
460 (make_memory("B", "beta"), 0.6),
461 (make_memory("C", "gamma"), 0.7),
462 ];
463 let reranked = ce.rerank("alpha", candidates);
464 assert_eq!(reranked.len(), 3);
465 }
466
467 #[test]
468 fn bigram_overlap_boosts_phrase_match() {
469 let s_phrase = lexical_score(
470 "network adapter",
471 "title",
472 "the network adapter is connected to the LAN",
473 );
474 let s_scattered = lexical_score(
475 "network adapter",
476 "title",
477 "the adapter handles the network traffic independently",
478 );
479 assert!(
480 s_phrase > s_scattered,
481 "phrase match ({s_phrase}) should beat scattered ({s_scattered})"
482 );
483 }
484
485 #[test]
490 fn test_rerank_preserves_input_count_heuristic() {
491 let ce = CrossEncoder::new();
492 let candidates: Vec<(Memory, f64)> = (0..5)
494 .map(|i| {
495 (
496 make_memory(
497 &format!("title {i}"),
498 &format!("content body number {i} with some words"),
499 ),
500 f64::from(i) * 0.1,
501 )
502 })
503 .collect();
504 let query = "title content body";
505 let reranked = ce.rerank(query, candidates);
506 assert_eq!(
507 reranked.len(),
508 5,
509 "heuristic rerank must preserve candidate count, got {} = {:?}",
510 reranked.len(),
511 reranked
512 .iter()
513 .map(|(m, s)| (&m.title, *s))
514 .collect::<Vec<_>>()
515 );
516 for w in reranked.windows(2) {
518 assert!(
519 w[0].1 >= w[1].1,
520 "rerank output must be descending by score: {} < {}",
521 w[0].1,
522 w[1].1
523 );
524 }
525 }
526
527 #[test]
528 fn test_rerank_zero_candidates_returns_empty_heuristic() {
529 let ce = CrossEncoder::new();
530 let reranked = ce.rerank("query", Vec::new());
531 assert!(reranked.is_empty());
532 }
533
534 #[cfg(feature = "test-with-models")]
538 #[test]
539 fn test_rerank_preserves_input_count_neural_if_available() {
540 let ce = CrossEncoder::new_neural();
541 let candidates: Vec<(Memory, f64)> = (0..5)
542 .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
543 .collect();
544 let reranked = ce.rerank("body", candidates);
545 assert_eq!(reranked.len(), 5);
546 }
547
548 #[test]
557 fn w12e_default_is_lexical() {
558 let ce = CrossEncoder::default();
559 assert!(!ce.is_neural(), "Default::default() must return Lexical");
560 }
561
562 #[test]
563 fn w12e_new_returns_lexical() {
564 let ce = CrossEncoder::new();
565 assert!(!ce.is_neural());
566 }
567
568 #[test]
569 fn w12e_score_dispatch_lexical_matches_helper() {
570 let ce = CrossEncoder::new();
573 let q = "rust async runtime";
574 let title = "Tokio: Rust async runtime";
575 let content = "Tokio is an async runtime for the Rust programming language.";
576 let via_dispatcher = ce.score(q, title, content);
577 let direct = lexical_score(q, title, content);
578 assert!((via_dispatcher - direct).abs() < f32::EPSILON);
579 }
580
581 #[test]
582 fn w12e_score_empty_inputs_safe() {
583 let ce = CrossEncoder::new();
584 assert_eq!(ce.score("", "title", "content"), 0.0);
586 let s = ce.score("query", "", "");
588 assert!((0.0..=1.0).contains(&s));
589 let s_ws = ce.score(" \t\n", "title", "content");
591 assert_eq!(s_ws, 0.0);
592 let s_punct = ce.score("!?.,;:", "title", "content");
594 assert_eq!(s_punct, 0.0);
595 }
596
597 #[test]
598 fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
599 let s_unicode = lexical_score(
601 "café résumé d'oeuvre",
602 "Le Café d'Oeuvre",
603 "résumé du café avec d'oeuvre noté",
604 );
605 assert!(
606 (0.0..=1.0).contains(&s_unicode),
607 "unicode score {s_unicode} out of bounds"
608 );
609
610 let huge = "alpha beta gamma delta ".repeat(2_500);
612 let s_long = lexical_score("alpha gamma", "headline", &huge);
613 assert!(
614 (0.0..=1.0).contains(&s_long),
615 "long score {s_long} out of bounds"
616 );
617 }
618
619 #[test]
620 fn w12e_lexical_score_perfect_overlap_high() {
621 let s = lexical_score(
624 "alpha beta gamma",
625 "alpha beta gamma",
626 "alpha beta gamma alpha beta gamma",
627 );
628 assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
629 assert!(s <= 1.0);
630 }
631
632 #[test]
633 fn w12e_tfidf_score_empty_doc_returns_zero() {
634 let q = vec!["alpha", "beta"];
636 let doc: Vec<&str> = Vec::new();
637 assert_eq!(tfidf_score(&q, &doc), 0.0);
638 }
639
640 #[test]
641 fn w12e_tfidf_score_empty_query_returns_zero() {
642 let q: Vec<&str> = Vec::new();
644 let doc = vec!["alpha", "beta", "gamma"];
645 assert_eq!(tfidf_score(&q, &doc), 0.0);
646 }
647
648 #[test]
649 fn w12e_tfidf_score_no_matching_terms() {
650 let q = vec!["xenon", "kryptonite"];
652 let doc = vec!["alpha", "beta", "gamma"];
653 let s = tfidf_score(&q, &doc);
654 assert_eq!(s, 0.0);
655 }
656
657 #[test]
658 fn w12e_tfidf_score_partial_match_bounded() {
659 let q = vec!["alpha", "missing"];
661 let doc = vec!["alpha", "alpha", "beta", "gamma"];
662 let s = tfidf_score(&q, &doc);
663 assert!((0.0..=1.0).contains(&s));
664 assert!(s > 0.0);
665 }
666
667 #[test]
668 fn w12e_bigrams_empty_and_single_and_multi() {
669 let empty: Vec<&str> = Vec::new();
671 assert!(bigrams(&empty).is_empty());
672
673 let one = vec!["solo"];
675 assert!(bigrams(&one).is_empty());
676
677 let three = vec!["a", "b", "c"];
679 let bg = bigrams(&three);
680 assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
681 }
682
683 #[test]
684 fn w12e_tokenize_handles_apostrophe_and_unicode() {
685 let toks = tokenize("don't stop, I won't!");
687 assert!(toks.contains(&"don't"));
688 assert!(toks.contains(&"won't"));
689 assert!(toks.contains(&"stop"));
690 assert!(toks.contains(&"I"));
691
692 let none = tokenize("!!!,,,;;;");
694 assert!(none.is_empty());
695
696 let empty = tokenize("");
698 assert!(empty.is_empty());
699
700 let unicode = tokenize("café résumé");
702 assert_eq!(unicode.len(), 2);
703 }
704
705 #[test]
706 fn w12e_rerank_single_candidate_keeps_it() {
707 let ce = CrossEncoder::new();
708 let only = make_memory("solo title", "solo content body");
709 let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
710 assert_eq!(out.len(), 1);
711 assert_eq!(out[0].0.title, "solo title");
712 assert!(out[0].1 >= 0.0);
714 }
715
716 #[test]
717 fn w12e_rerank_identical_originals_stable_under_score() {
718 let ce = CrossEncoder::new();
722 let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
723 let off_topic = make_memory("grocery", "milk eggs bread");
724 let out = ce.rerank(
725 "rust async",
726 vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
727 );
728 assert_eq!(out.len(), 2);
729 assert_eq!(out[0].0.title, "rust async runtime");
730 }
731
732 #[test]
733 fn w12e_rerank_descending_invariant_holds_across_shapes() {
734 let ce = CrossEncoder::new();
736 let cands: Vec<(Memory, f64)> = vec![
737 (make_memory("a", "alpha words"), 0.10),
738 (make_memory("b", "beta words"), 0.95),
739 (make_memory("c", "gamma alpha"), 0.55),
740 (make_memory("d", ""), 0.0),
741 (make_memory("", "empty title doc"), 0.30),
742 ];
743 let out = ce.rerank("alpha", cands);
744 assert_eq!(out.len(), 5);
745 for w in out.windows(2) {
746 assert!(
747 w[0].1 >= w[1].1,
748 "non-descending pair: {} then {}",
749 w[0].1,
750 w[1].1
751 );
752 }
753 }
754
755 #[test]
756 fn w12e_lexical_score_no_title_branch_via_empty_title() {
757 let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
760 let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
761 assert!(s_with_title >= s_empty_title);
762 assert!((0.0..=1.0).contains(&s_empty_title));
763 }
764
765 #[test]
766 fn w12e_lexical_score_query_terms_only_in_title() {
767 let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
769 assert!(s > 0.0);
770 assert!(s <= 1.0);
771 }
772}
773
774#[cfg(test)]
775#[allow(
776 clippy::unused_self,
777 clippy::unnecessary_wraps,
778 clippy::needless_pass_by_value,
779 clippy::wildcard_imports
780)]
781pub mod test_support {
782 use super::*;
783
784 pub struct MockCrossEncoder {
787 pub use_neural: bool,
788 }
789
790 impl MockCrossEncoder {
791 pub fn new() -> Self {
793 Self { use_neural: false }
794 }
795
796 pub fn new_neural() -> Self {
798 Self { use_neural: true }
799 }
800
801 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
804 if self.use_neural {
805 let combined = format!("{}{}", query, title);
807 let hash = combined.bytes().fold(0u32, |acc, b| {
808 acc.wrapping_mul(31).wrapping_add(u32::from(b))
809 });
810 let base = ((hash % 1000) as f32) / 1000.0;
811 if title.contains(query) {
813 (base * 0.5 + 0.5).min(1.0)
814 } else {
815 base
816 }
817 } else {
818 lexical_score(query, title, content)
820 }
821 }
822
823 pub fn is_neural(&self) -> bool {
825 self.use_neural
826 }
827
828 pub fn rerank(
830 &self,
831 query: &str,
832 mut candidates: Vec<(Memory, f64)>,
833 ) -> Vec<(Memory, f64)> {
834 let mut scored: Vec<(Memory, f64)> = candidates
835 .drain(..)
836 .map(|(mem, original_score)| {
837 let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
838 let final_score =
839 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
840 (mem, final_score)
841 })
842 .collect();
843
844 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
845 scored
846 }
847 }
848
849 impl Default for MockCrossEncoder {
850 fn default() -> Self {
851 Self::new()
852 }
853 }
854}
855
856#[cfg(test)]
857mod mock_tests {
858 use super::test_support::*;
859 use crate::models::{Memory, Tier};
860
861 fn make_memory(title: &str, content: &str) -> Memory {
862 Memory {
863 id: "test-id".to_string(),
864 tier: Tier::Mid,
865 namespace: "test".to_string(),
866 title: title.to_string(),
867 content: content.to_string(),
868 tags: vec![],
869 priority: 5,
870 confidence: 1.0,
871 source: "test".to_string(),
872 access_count: 0,
873 created_at: "2026-01-01T00:00:00Z".to_string(),
874 updated_at: "2026-01-01T00:00:00Z".to_string(),
875 last_accessed_at: None,
876 expires_at: None,
877 metadata: serde_json::json!({}),
878 }
879 }
880
881 #[test]
882 fn mock_lexical_new() {
883 let ce = MockCrossEncoder::new();
884 assert!(!ce.is_neural());
885 }
886
887 #[test]
888 fn mock_neural_new() {
889 let ce = MockCrossEncoder::new_neural();
890 assert!(ce.is_neural());
891 }
892
893 #[test]
894 fn mock_neural_score_deterministic() {
895 let ce = MockCrossEncoder::new_neural();
896 let s1 = ce.score("query", "title", "content");
897 let s2 = ce.score("query", "title", "content");
898 assert_eq!(s1, s2);
899 }
900
901 #[test]
902 fn mock_neural_score_title_match_boost() {
903 let ce = MockCrossEncoder::new_neural();
904 let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
905 let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
906 assert!(
907 s_title_contains > s_no_match,
908 "title match ({s_title_contains}) should beat no match ({s_no_match})"
909 );
910 }
911
912 #[test]
913 fn mock_neural_score_bounded() {
914 let ce = MockCrossEncoder::new_neural();
915 for query in &["test", "neural", "reranker", "machine learning"] {
916 for title in &["a", "b", "the quick brown"] {
917 let s = ce.score(query, title, "content");
918 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
919 }
920 }
921 }
922
923 #[test]
924 fn mock_neural_rerank_reorders() {
925 let ce = MockCrossEncoder::new_neural();
926 let a = make_memory("neural network", "deep learning with transformers");
927 let b = make_memory("grocery list", "milk eggs bread butter");
928 let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
929 let reranked = ce.rerank("neural network", candidates);
930 assert_eq!(reranked[0].0.title, "neural network");
932 }
933
934 #[test]
935 fn mock_neural_rerank_preserves_count() {
936 let ce = MockCrossEncoder::new_neural();
937 let candidates = vec![
938 (make_memory("A", "content a"), 0.5),
939 (make_memory("B", "content b"), 0.4),
940 (make_memory("C", "content c"), 0.6),
941 ];
942 let reranked = ce.rerank("test", candidates);
943 assert_eq!(reranked.len(), 3);
944 }
945
946 #[test]
947 fn mock_lexical_path_via_mock() {
948 let ce = MockCrossEncoder::new();
949 let s = ce.score(
950 "network adapter",
951 "Network Configuration",
952 "the network adapter is connected",
953 );
954 assert!((0.0..=1.0).contains(&s));
955 }
956
957 #[test]
958 fn mock_neural_different_from_lexical() {
959 let lexical = MockCrossEncoder::new();
960 let neural = MockCrossEncoder::new_neural();
961 let s_lex = lexical.score("machine learning", "ML title", "neural networks");
962 let s_neu = neural.score("machine learning", "ML title", "neural networks");
963 assert_ne!(s_lex, s_neu);
965 }
966}
967
968#[test]
969fn score_handles_empty_query_string() {
970 let s = lexical_score("", "Document Title", "This is document content");
971 assert_eq!(s, 0.0, "empty query must return 0.0");
972}
973
974#[test]
975fn score_handles_unicode_normalization() {
976 let s1 = lexical_score("café", "café", "the café is open");
978 let s2 = lexical_score("cafe", "cafe", "the cafe is open");
979 assert!(s1 > 0.0);
981 assert!(s2 > 0.0);
982}
983
984#[test]
985fn score_handles_very_long_content_truncation() {
986 let long_content = "word ".repeat(10000); let s = lexical_score("word", "title", &long_content);
989 assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
990}
991
992#[test]
993fn bigram_score_with_single_token_query() {
994 let s = lexical_score("query", "Single Token Title", "single token content");
996 assert!((0.0..=1.0).contains(&s));
997}