1use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, Mutex};
17
18use anyhow::{Context, Result};
19use candle_core::{Device, Tensor};
20use candle_nn::VarBuilder;
21use candle_transformers::models::bert::{BertModel, Config as BertConfig};
22use hf_hub::{Repo, RepoType, api::sync::Api};
23use tokenizers::Tokenizer;
24
25use crate::models::Memory;
26
27const ORIGINAL_WEIGHT: f64 = 0.6;
29const CROSS_ENCODER_WEIGHT: f64 = 0.4;
31
32const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
33const CROSS_ENCODER_MAX_SEQ: usize = 512;
34const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
35
36pub enum CrossEncoder {
38 Lexical,
40 Neural {
42 model: Arc<Mutex<BertModel>>,
43 tokenizer: Arc<Tokenizer>,
44 classifier_weight: Tensor,
45 classifier_bias: Tensor,
46 device: Device,
47 },
48}
49
50impl CrossEncoder {
51 pub fn new() -> Self {
53 Self::Lexical
54 }
55
56 pub fn new_neural() -> Self {
66 match Self::load_neural() {
67 Ok(ce) => ce,
68 Err(e) => {
69 tracing::warn!(
70 target: "reranker.fallback",
71 from = "neural",
72 to = "lexical",
73 reason = %e,
74 "cross-encoder fell back to lexical: neural init failed"
75 );
76 eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
77 Self::Lexical
78 }
79 }
80 }
81
82 fn load_neural() -> Result<Self> {
83 let device = Device::Cpu;
84
85 let api = Api::new().context("failed to init HuggingFace Hub API")?;
86 let repo = api.repo(Repo::new(
87 CROSS_ENCODER_MODEL_ID.to_string(),
88 RepoType::Model,
89 ));
90
91 let config_path = repo
92 .get("config.json")
93 .context("failed to download config.json")?;
94 let tokenizer_path = repo
95 .get("tokenizer.json")
96 .context("failed to download tokenizer.json")?;
97 let weights_path = repo
98 .get("model.safetensors")
99 .context("failed to download model.safetensors")?;
100
101 let config_data = std::fs::read_to_string(&config_path)
103 .context("failed to read cross-encoder config.json")?;
104 let config: BertConfig = serde_json::from_str(&config_data)
105 .context("failed to parse cross-encoder config.json")?;
106
107 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
109 .map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
110 let truncation = tokenizers::TruncationParams {
111 max_length: CROSS_ENCODER_MAX_SEQ,
112 ..Default::default()
113 };
114 tokenizer
115 .with_truncation(Some(truncation))
116 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
117 tokenizer.with_padding(None);
118
119 let vb = unsafe {
121 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
122 .context("failed to load cross-encoder weights")?
123 };
124
125 let model = BertModel::load(vb.clone(), &config)
126 .context("failed to build cross-encoder BertModel")?;
127
128 let classifier_weight = vb
130 .get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
131 .context("failed to load classifier.weight")?;
132 let classifier_bias = vb
133 .get(1, "classifier.bias")
134 .context("failed to load classifier.bias")?;
135
136 Ok(Self::Neural {
137 model: Arc::new(Mutex::new(model)),
138 tokenizer: Arc::new(tokenizer),
139 classifier_weight,
140 classifier_bias,
141 device,
142 })
143 }
144
145 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
149 match self {
150 Self::Lexical => lexical_score(query, title, content),
151 Self::Neural {
152 model,
153 tokenizer,
154 classifier_weight,
155 classifier_bias,
156 device,
157 } => {
158 let model_guard = match model.lock() {
159 Ok(g) => g,
160 Err(e) => {
161 tracing::warn!("cross-encoder model lock poisoned: {e}");
162 return lexical_score(query, title, content);
163 }
164 };
165 match Self::neural_score(
166 &model_guard,
167 tokenizer,
168 classifier_weight,
169 classifier_bias,
170 device,
171 query,
172 title,
173 content,
174 ) {
175 Ok(s) => s,
176 Err(e) => {
177 tracing::warn!(
178 "neural cross-encoder score failed: {e}, using lexical fallback"
179 );
180 lexical_score(query, title, content)
181 }
182 }
183 }
184 }
185 }
186
187 #[allow(clippy::too_many_arguments)]
188 fn neural_score(
189 model: &BertModel,
190 tokenizer: &Tokenizer,
191 classifier_weight: &Tensor,
192 classifier_bias: &Tensor,
193 device: &Device,
194 query: &str,
195 title: &str,
196 content: &str,
197 ) -> Result<f32> {
198 let document = format!("{title} {content}");
200
201 let encoding = tokenizer
202 .encode((query, document.as_str()), true)
203 .map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
204
205 let input_ids = encoding.get_ids();
206 let attention_mask = encoding.get_attention_mask();
207 let token_type_ids = encoding.get_type_ids();
208 let seq_len = input_ids.len();
209
210 let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
211 let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
212 let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
213
214 let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
216
217 let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
219
220 let logit = cls
222 .matmul(&classifier_weight.t()?)?
223 .broadcast_add(classifier_bias)?;
224
225 let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
227 let score = 1.0 / (1.0 + (-logit_val).exp());
228
229 Ok(score)
230 }
231
232 pub fn is_neural(&self) -> bool {
234 matches!(self, Self::Neural { .. })
235 }
236
237 pub fn rerank(&self, query: &str, mut candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
244 let mut scored: Vec<(Memory, f64)> = candidates
245 .drain(..)
246 .map(|(mem, original_score)| {
247 let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
248 let final_score =
249 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
250 (mem, final_score)
251 })
252 .collect();
253
254 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
255 scored
256 }
257}
258
259impl Default for CrossEncoder {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
270 let query_terms = tokenize(query);
271 if query_terms.is_empty() {
272 return 0.0;
273 }
274
275 let title_terms = tokenize(title);
276 let content_terms = tokenize(content);
277
278 let doc_terms: HashSet<&str> = title_terms
279 .iter()
280 .chain(content_terms.iter())
281 .copied()
282 .collect();
283 let query_set: HashSet<&str> = query_terms.iter().copied().collect();
284
285 #[allow(clippy::cast_precision_loss)]
287 let intersection = query_set.intersection(&doc_terms).count() as f32;
288 #[allow(clippy::cast_precision_loss)]
289 let union = query_set.union(&doc_terms).count() as f32;
290 let jaccard = if union > 0.0 {
291 intersection / union
292 } else {
293 0.0
294 };
295
296 let doc_all: Vec<&str> = title_terms
298 .iter()
299 .chain(content_terms.iter())
300 .copied()
301 .collect();
302 let tf_idf = tfidf_score(&query_terms, &doc_all);
303
304 let query_bigrams = bigrams(&query_terms);
306 let doc_bigrams = bigrams(&doc_all);
307 let bigram_overlap = if query_bigrams.is_empty() {
308 0.0
309 } else {
310 let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
311 #[allow(clippy::cast_precision_loss)]
312 let hits = query_bigrams
313 .iter()
314 .filter(|b| doc_bigram_set.contains(b))
315 .count() as f32;
316 #[allow(clippy::cast_precision_loss)]
317 let query_bigrams_len = query_bigrams.len() as f32;
318 hits / query_bigrams_len
319 };
320
321 let title_set: HashSet<&str> = title_terms.iter().copied().collect();
323 #[allow(clippy::cast_precision_loss)]
324 let title_hits = query_set.intersection(&title_set).count() as f32;
325 #[allow(clippy::cast_precision_loss)]
326 let title_bonus = if query_set.is_empty() {
327 0.0
328 } else {
329 title_hits / query_set.len() as f32
330 };
331
332 let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
333 raw.clamp(0.0, 1.0)
334}
335
336fn tokenize(text: &str) -> Vec<&str> {
341 text.split(|c: char| !c.is_alphanumeric() && c != '\'')
342 .filter(|w| !w.is_empty())
343 .collect()
344}
345
346fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
347 if doc_tokens.is_empty() || query_terms.is_empty() {
348 return 0.0;
349 }
350
351 let mut tf_map: HashMap<&str, usize> = HashMap::new();
352 for &tok in doc_tokens {
353 *tf_map.entry(tok).or_insert(0) += 1;
354 }
355
356 #[allow(clippy::cast_precision_loss)]
357 let total = doc_tokens.len() as f32;
358 #[allow(clippy::cast_precision_loss)]
359 let unique = tf_map.len() as f32;
360
361 let mut score_sum: f32 = 0.0;
362 let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
363
364 for qt in &query_lower {
365 #[allow(clippy::cast_precision_loss)]
366 let tf = tf_map
367 .iter()
368 .filter(|(k, _)| k.to_lowercase() == *qt)
369 .map(|(_, &v)| v)
370 .sum::<usize>() as f32;
371
372 if tf == 0.0 {
373 continue;
374 }
375
376 let tf_norm = tf / total;
377 #[allow(clippy::cast_precision_loss)]
378 let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
379 let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
380
381 score_sum += tf_norm * idf;
382 }
383
384 #[allow(clippy::cast_precision_loss)]
385 let max_possible = query_lower.len() as f32;
386 (score_sum / max_possible).clamp(0.0, 1.0)
387}
388
389fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
390 tokens.windows(2).map(|w| (w[0], w[1])).collect()
391}
392
393#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::models::{Memory, Tier};
401
402 fn make_memory(title: &str, content: &str) -> Memory {
403 Memory {
404 id: "test-id".to_string(),
405 tier: Tier::Mid,
406 namespace: "test".to_string(),
407 title: title.to_string(),
408 content: content.to_string(),
409 tags: vec![],
410 priority: 5,
411 confidence: 1.0,
412 source: "test".to_string(),
413 access_count: 0,
414 created_at: "2026-01-01T00:00:00Z".to_string(),
415 updated_at: "2026-01-01T00:00:00Z".to_string(),
416 last_accessed_at: None,
417 expires_at: None,
418 metadata: serde_json::json!({}),
419 }
420 }
421
422 #[test]
423 fn lexical_score_returns_zero_for_empty_query() {
424 assert_eq!(lexical_score("", "some title", "some content"), 0.0);
425 }
426
427 #[test]
428 fn lexical_score_returns_zero_for_no_overlap() {
429 let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
430 assert!(s < 0.05, "expected near-zero, got {s}");
431 }
432
433 #[test]
434 fn lexical_score_rewards_title_match() {
435 let content = "This document discusses network configuration for LAN setups.";
436 let s_title_match = lexical_score(
437 "network configuration",
438 "Network Configuration Guide",
439 content,
440 );
441 let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
442 assert!(
443 s_title_match > s_no_title,
444 "title match ({s_title_match}) should beat no title match ({s_no_title})"
445 );
446 }
447
448 #[test]
449 fn lexical_score_is_bounded_zero_one() {
450 let s = lexical_score(
451 "the quick brown fox jumps over the lazy dog",
452 "the quick brown fox",
453 "the quick brown fox jumps over the lazy dog and more words",
454 );
455 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
456 }
457
458 #[test]
459 fn rerank_reorders_candidates() {
460 let ce = CrossEncoder::new();
461 let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
462 let b = make_memory("Grocery list", "milk eggs bread butter cheese");
463 let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
464 let reranked = ce.rerank("cross-encoder reranking", candidates);
465 assert_eq!(reranked[0].0.title, "Rust cross-encoder");
466 }
467
468 #[test]
469 fn rerank_preserves_candidate_count() {
470 let ce = CrossEncoder::new();
471 let candidates = vec![
472 (make_memory("A", "alpha"), 0.5),
473 (make_memory("B", "beta"), 0.6),
474 (make_memory("C", "gamma"), 0.7),
475 ];
476 let reranked = ce.rerank("alpha", candidates);
477 assert_eq!(reranked.len(), 3);
478 }
479
480 #[test]
481 fn bigram_overlap_boosts_phrase_match() {
482 let s_phrase = lexical_score(
483 "network adapter",
484 "title",
485 "the network adapter is connected to the LAN",
486 );
487 let s_scattered = lexical_score(
488 "network adapter",
489 "title",
490 "the adapter handles the network traffic independently",
491 );
492 assert!(
493 s_phrase > s_scattered,
494 "phrase match ({s_phrase}) should beat scattered ({s_scattered})"
495 );
496 }
497
498 #[test]
503 fn test_rerank_preserves_input_count_heuristic() {
504 let ce = CrossEncoder::new();
505 let candidates: Vec<(Memory, f64)> = (0..5)
507 .map(|i| {
508 (
509 make_memory(
510 &format!("title {i}"),
511 &format!("content body number {i} with some words"),
512 ),
513 f64::from(i) * 0.1,
514 )
515 })
516 .collect();
517 let query = "title content body";
518 let reranked = ce.rerank(query, candidates);
519 assert_eq!(
520 reranked.len(),
521 5,
522 "heuristic rerank must preserve candidate count, got {} = {:?}",
523 reranked.len(),
524 reranked
525 .iter()
526 .map(|(m, s)| (&m.title, *s))
527 .collect::<Vec<_>>()
528 );
529 for w in reranked.windows(2) {
531 assert!(
532 w[0].1 >= w[1].1,
533 "rerank output must be descending by score: {} < {}",
534 w[0].1,
535 w[1].1
536 );
537 }
538 }
539
540 #[test]
541 fn test_rerank_zero_candidates_returns_empty_heuristic() {
542 let ce = CrossEncoder::new();
543 let reranked = ce.rerank("query", Vec::new());
544 assert!(reranked.is_empty());
545 }
546
547 #[cfg(feature = "test-with-models")]
551 #[test]
552 fn test_rerank_preserves_input_count_neural_if_available() {
553 let ce = CrossEncoder::new_neural();
554 let candidates: Vec<(Memory, f64)> = (0..5)
555 .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
556 .collect();
557 let reranked = ce.rerank("body", candidates);
558 assert_eq!(reranked.len(), 5);
559 }
560
561 #[test]
570 fn w12e_default_is_lexical() {
571 let ce = CrossEncoder::default();
572 assert!(!ce.is_neural(), "Default::default() must return Lexical");
573 }
574
575 #[test]
576 fn w12e_new_returns_lexical() {
577 let ce = CrossEncoder::new();
578 assert!(!ce.is_neural());
579 }
580
581 #[test]
582 fn w12e_score_dispatch_lexical_matches_helper() {
583 let ce = CrossEncoder::new();
586 let q = "rust async runtime";
587 let title = "Tokio: Rust async runtime";
588 let content = "Tokio is an async runtime for the Rust programming language.";
589 let via_dispatcher = ce.score(q, title, content);
590 let direct = lexical_score(q, title, content);
591 assert!((via_dispatcher - direct).abs() < f32::EPSILON);
592 }
593
594 #[test]
595 fn w12e_score_empty_inputs_safe() {
596 let ce = CrossEncoder::new();
597 assert_eq!(ce.score("", "title", "content"), 0.0);
599 let s = ce.score("query", "", "");
601 assert!((0.0..=1.0).contains(&s));
602 let s_ws = ce.score(" \t\n", "title", "content");
604 assert_eq!(s_ws, 0.0);
605 let s_punct = ce.score("!?.,;:", "title", "content");
607 assert_eq!(s_punct, 0.0);
608 }
609
610 #[test]
611 fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
612 let s_unicode = lexical_score(
614 "café résumé d'oeuvre",
615 "Le Café d'Oeuvre",
616 "résumé du café avec d'oeuvre noté",
617 );
618 assert!(
619 (0.0..=1.0).contains(&s_unicode),
620 "unicode score {s_unicode} out of bounds"
621 );
622
623 let huge = "alpha beta gamma delta ".repeat(2_500);
625 let s_long = lexical_score("alpha gamma", "headline", &huge);
626 assert!(
627 (0.0..=1.0).contains(&s_long),
628 "long score {s_long} out of bounds"
629 );
630 }
631
632 #[test]
633 fn w12e_lexical_score_perfect_overlap_high() {
634 let s = lexical_score(
637 "alpha beta gamma",
638 "alpha beta gamma",
639 "alpha beta gamma alpha beta gamma",
640 );
641 assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
642 assert!(s <= 1.0);
643 }
644
645 #[test]
646 fn w12e_tfidf_score_empty_doc_returns_zero() {
647 let q = vec!["alpha", "beta"];
649 let doc: Vec<&str> = Vec::new();
650 assert_eq!(tfidf_score(&q, &doc), 0.0);
651 }
652
653 #[test]
654 fn w12e_tfidf_score_empty_query_returns_zero() {
655 let q: Vec<&str> = Vec::new();
657 let doc = vec!["alpha", "beta", "gamma"];
658 assert_eq!(tfidf_score(&q, &doc), 0.0);
659 }
660
661 #[test]
662 fn w12e_tfidf_score_no_matching_terms() {
663 let q = vec!["xenon", "kryptonite"];
665 let doc = vec!["alpha", "beta", "gamma"];
666 let s = tfidf_score(&q, &doc);
667 assert_eq!(s, 0.0);
668 }
669
670 #[test]
671 fn w12e_tfidf_score_partial_match_bounded() {
672 let q = vec!["alpha", "missing"];
674 let doc = vec!["alpha", "alpha", "beta", "gamma"];
675 let s = tfidf_score(&q, &doc);
676 assert!((0.0..=1.0).contains(&s));
677 assert!(s > 0.0);
678 }
679
680 #[test]
681 fn w12e_bigrams_empty_and_single_and_multi() {
682 let empty: Vec<&str> = Vec::new();
684 assert!(bigrams(&empty).is_empty());
685
686 let one = vec!["solo"];
688 assert!(bigrams(&one).is_empty());
689
690 let three = vec!["a", "b", "c"];
692 let bg = bigrams(&three);
693 assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
694 }
695
696 #[test]
697 fn w12e_tokenize_handles_apostrophe_and_unicode() {
698 let toks = tokenize("don't stop, I won't!");
700 assert!(toks.contains(&"don't"));
701 assert!(toks.contains(&"won't"));
702 assert!(toks.contains(&"stop"));
703 assert!(toks.contains(&"I"));
704
705 let none = tokenize("!!!,,,;;;");
707 assert!(none.is_empty());
708
709 let empty = tokenize("");
711 assert!(empty.is_empty());
712
713 let unicode = tokenize("café résumé");
715 assert_eq!(unicode.len(), 2);
716 }
717
718 #[test]
719 fn w12e_rerank_single_candidate_keeps_it() {
720 let ce = CrossEncoder::new();
721 let only = make_memory("solo title", "solo content body");
722 let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
723 assert_eq!(out.len(), 1);
724 assert_eq!(out[0].0.title, "solo title");
725 assert!(out[0].1 >= 0.0);
727 }
728
729 #[test]
730 fn w12e_rerank_identical_originals_stable_under_score() {
731 let ce = CrossEncoder::new();
735 let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
736 let off_topic = make_memory("grocery", "milk eggs bread");
737 let out = ce.rerank(
738 "rust async",
739 vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
740 );
741 assert_eq!(out.len(), 2);
742 assert_eq!(out[0].0.title, "rust async runtime");
743 }
744
745 #[test]
746 fn w12e_rerank_descending_invariant_holds_across_shapes() {
747 let ce = CrossEncoder::new();
749 let cands: Vec<(Memory, f64)> = vec![
750 (make_memory("a", "alpha words"), 0.10),
751 (make_memory("b", "beta words"), 0.95),
752 (make_memory("c", "gamma alpha"), 0.55),
753 (make_memory("d", ""), 0.0),
754 (make_memory("", "empty title doc"), 0.30),
755 ];
756 let out = ce.rerank("alpha", cands);
757 assert_eq!(out.len(), 5);
758 for w in out.windows(2) {
759 assert!(
760 w[0].1 >= w[1].1,
761 "non-descending pair: {} then {}",
762 w[0].1,
763 w[1].1
764 );
765 }
766 }
767
768 #[test]
769 fn w12e_lexical_score_no_title_branch_via_empty_title() {
770 let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
773 let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
774 assert!(s_with_title >= s_empty_title);
775 assert!((0.0..=1.0).contains(&s_empty_title));
776 }
777
778 #[test]
779 fn w12e_lexical_score_query_terms_only_in_title() {
780 let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
782 assert!(s > 0.0);
783 assert!(s <= 1.0);
784 }
785
786 #[test]
789 fn pr9i_new_neural_dual_outcome() {
790 let ce = CrossEncoder::new_neural();
798 let s = ce.score("query", "title", "content");
799 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
800 }
801
802 #[test]
803 fn pr9i_rerank_via_score_returns_blend() {
804 let ce = CrossEncoder::new_neural();
808 let cands = vec![
809 (
810 Memory {
811 id: "a".to_string(),
812 tier: Tier::Mid,
813 namespace: "ns".to_string(),
814 title: "rust async runtime".to_string(),
815 content: "tokio rust async".to_string(),
816 tags: vec![],
817 priority: 5,
818 confidence: 1.0,
819 source: "test".to_string(),
820 access_count: 0,
821 created_at: "2026-01-01T00:00:00Z".to_string(),
822 updated_at: "2026-01-01T00:00:00Z".to_string(),
823 last_accessed_at: None,
824 expires_at: None,
825 metadata: serde_json::json!({}),
826 },
827 0.6,
828 ),
829 (
830 Memory {
831 id: "b".to_string(),
832 tier: Tier::Mid,
833 namespace: "ns".to_string(),
834 title: "grocery list".to_string(),
835 content: "milk eggs".to_string(),
836 tags: vec![],
837 priority: 5,
838 confidence: 1.0,
839 source: "test".to_string(),
840 access_count: 0,
841 created_at: "2026-01-01T00:00:00Z".to_string(),
842 updated_at: "2026-01-01T00:00:00Z".to_string(),
843 last_accessed_at: None,
844 expires_at: None,
845 metadata: serde_json::json!({}),
846 },
847 0.4,
848 ),
849 ];
850 let out = ce.rerank("rust async", cands);
851 assert_eq!(out.len(), 2);
852 for (_, score) in &out {
853 assert!(score.is_finite());
854 }
855 assert!(out[0].1 >= out[1].1);
857 }
858}
859
860#[cfg(test)]
861#[allow(
862 clippy::unused_self,
863 clippy::unnecessary_wraps,
864 clippy::needless_pass_by_value,
865 clippy::wildcard_imports
866)]
867pub mod test_support {
868 use super::*;
869
870 pub struct MockCrossEncoder {
873 pub use_neural: bool,
874 }
875
876 impl MockCrossEncoder {
877 pub fn new() -> Self {
879 Self { use_neural: false }
880 }
881
882 pub fn new_neural() -> Self {
884 Self { use_neural: true }
885 }
886
887 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
890 if self.use_neural {
891 let combined = format!("{}{}", query, title);
893 let hash = combined.bytes().fold(0u32, |acc, b| {
894 acc.wrapping_mul(31).wrapping_add(u32::from(b))
895 });
896 let base = ((hash % 1000) as f32) / 1000.0;
897 if title.contains(query) {
899 (base * 0.5 + 0.5).min(1.0)
900 } else {
901 base
902 }
903 } else {
904 lexical_score(query, title, content)
906 }
907 }
908
909 pub fn is_neural(&self) -> bool {
911 self.use_neural
912 }
913
914 pub fn rerank(
916 &self,
917 query: &str,
918 mut candidates: Vec<(Memory, f64)>,
919 ) -> Vec<(Memory, f64)> {
920 let mut scored: Vec<(Memory, f64)> = candidates
921 .drain(..)
922 .map(|(mem, original_score)| {
923 let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
924 let final_score =
925 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
926 (mem, final_score)
927 })
928 .collect();
929
930 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
931 scored
932 }
933 }
934
935 impl Default for MockCrossEncoder {
936 fn default() -> Self {
937 Self::new()
938 }
939 }
940}
941
942#[cfg(test)]
943mod mock_tests {
944 use super::test_support::*;
945 use crate::models::{Memory, Tier};
946
947 fn make_memory(title: &str, content: &str) -> Memory {
948 Memory {
949 id: "test-id".to_string(),
950 tier: Tier::Mid,
951 namespace: "test".to_string(),
952 title: title.to_string(),
953 content: content.to_string(),
954 tags: vec![],
955 priority: 5,
956 confidence: 1.0,
957 source: "test".to_string(),
958 access_count: 0,
959 created_at: "2026-01-01T00:00:00Z".to_string(),
960 updated_at: "2026-01-01T00:00:00Z".to_string(),
961 last_accessed_at: None,
962 expires_at: None,
963 metadata: serde_json::json!({}),
964 }
965 }
966
967 #[test]
968 fn mock_lexical_new() {
969 let ce = MockCrossEncoder::new();
970 assert!(!ce.is_neural());
971 }
972
973 #[test]
974 fn mock_neural_new() {
975 let ce = MockCrossEncoder::new_neural();
976 assert!(ce.is_neural());
977 }
978
979 #[test]
980 fn mock_neural_score_deterministic() {
981 let ce = MockCrossEncoder::new_neural();
982 let s1 = ce.score("query", "title", "content");
983 let s2 = ce.score("query", "title", "content");
984 assert_eq!(s1, s2);
985 }
986
987 #[test]
988 fn mock_neural_score_title_match_boost() {
989 let ce = MockCrossEncoder::new_neural();
990 let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
991 let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
992 assert!(
993 s_title_contains > s_no_match,
994 "title match ({s_title_contains}) should beat no match ({s_no_match})"
995 );
996 }
997
998 #[test]
999 fn mock_neural_score_bounded() {
1000 let ce = MockCrossEncoder::new_neural();
1001 for query in &["test", "neural", "reranker", "machine learning"] {
1002 for title in &["a", "b", "the quick brown"] {
1003 let s = ce.score(query, title, "content");
1004 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
1005 }
1006 }
1007 }
1008
1009 #[test]
1010 fn mock_neural_rerank_reorders() {
1011 let ce = MockCrossEncoder::new_neural();
1012 let a = make_memory("neural network", "deep learning with transformers");
1013 let b = make_memory("grocery list", "milk eggs bread butter");
1014 let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
1015 let reranked = ce.rerank("neural network", candidates);
1016 assert_eq!(reranked[0].0.title, "neural network");
1018 }
1019
1020 #[test]
1021 fn mock_neural_rerank_preserves_count() {
1022 let ce = MockCrossEncoder::new_neural();
1023 let candidates = vec![
1024 (make_memory("A", "content a"), 0.5),
1025 (make_memory("B", "content b"), 0.4),
1026 (make_memory("C", "content c"), 0.6),
1027 ];
1028 let reranked = ce.rerank("test", candidates);
1029 assert_eq!(reranked.len(), 3);
1030 }
1031
1032 #[test]
1033 fn mock_lexical_path_via_mock() {
1034 let ce = MockCrossEncoder::new();
1035 let s = ce.score(
1036 "network adapter",
1037 "Network Configuration",
1038 "the network adapter is connected",
1039 );
1040 assert!((0.0..=1.0).contains(&s));
1041 }
1042
1043 #[test]
1044 fn mock_neural_different_from_lexical() {
1045 let lexical = MockCrossEncoder::new();
1046 let neural = MockCrossEncoder::new_neural();
1047 let s_lex = lexical.score("machine learning", "ML title", "neural networks");
1048 let s_neu = neural.score("machine learning", "ML title", "neural networks");
1049 assert_ne!(s_lex, s_neu);
1051 }
1052}
1053
1054#[test]
1055fn score_handles_empty_query_string() {
1056 let s = lexical_score("", "Document Title", "This is document content");
1057 assert_eq!(s, 0.0, "empty query must return 0.0");
1058}
1059
1060#[test]
1061fn score_handles_unicode_normalization() {
1062 let s1 = lexical_score("café", "café", "the café is open");
1064 let s2 = lexical_score("cafe", "cafe", "the cafe is open");
1065 assert!(s1 > 0.0);
1067 assert!(s2 > 0.0);
1068}
1069
1070#[test]
1071fn score_handles_very_long_content_truncation() {
1072 let long_content = "word ".repeat(10000); let s = lexical_score("word", "title", &long_content);
1075 assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
1076}
1077
1078#[test]
1079fn bigram_score_with_single_token_query() {
1080 let s = lexical_score("query", "Single Token Title", "single token content");
1082 assert!((0.0..=1.0).contains(&s));
1083}