use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::mpsc::{Sender, sync_channel};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use hf_hub::{Repo, RepoType, api::sync::Api};
use tokenizers::Tokenizer;
use crate::models::Memory;
pub const SESSION_RECENCY_BOOST: f64 = 0.05;
pub const SESSION_RECENT_CAP: usize = 50;
#[derive(Debug, Default)]
pub struct SessionRecallTracker {
inner: Mutex<HashMap<String, VecDeque<String>>>,
}
impl SessionRecallTracker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn recent_ids(&self, session_id: &str) -> HashSet<String> {
let Ok(guard) = self.inner.lock() else {
return HashSet::new();
};
guard
.get(session_id)
.map(|ring| ring.iter().cloned().collect())
.unwrap_or_default()
}
pub fn with_recent_ids<R>(
&self,
session_id: &str,
f: impl FnOnce(&dyn Fn(&str) -> bool) -> R,
) -> R {
let Ok(guard) = self.inner.lock() else {
return f(&|_id: &str| false);
};
match guard.get(session_id) {
None => f(&|_id: &str| false),
Some(ring) => f(&|id: &str| ring.iter().any(|existing| existing == id)),
}
}
pub fn record(&self, session_id: &str, ids: impl IntoIterator<Item = String>) {
let Ok(mut guard) = self.inner.lock() else {
return;
};
let ring = guard.entry(session_id.to_string()).or_default();
for id in ids {
ring.retain(|existing| existing != &id);
ring.push_back(id);
while ring.len() > SESSION_RECENT_CAP {
ring.pop_front();
}
}
}
#[must_use]
pub fn session_count(&self) -> usize {
self.inner.lock().map(|g| g.len()).unwrap_or(0)
}
}
#[must_use]
pub fn global_session_recall_tracker() -> &'static SessionRecallTracker {
&crate::runtime_context::RuntimeContext::global().recall_tracker
}
pub fn apply_session_recency_boost(
results: Vec<(Memory, f64)>,
session_id: Option<&str>,
tracker: &SessionRecallTracker,
) -> Vec<(Memory, f64)> {
let Some(sid) = session_id else {
return results;
};
if sid.is_empty() {
return results;
}
let mut boosted: Vec<(Memory, f64)> = tracker.with_recent_ids(sid, |is_recent| {
results
.into_iter()
.map(|(mem, score)| {
let bumped = if is_recent(&mem.id) {
score + SESSION_RECENCY_BOOST
} else {
score
};
(mem, bumped)
})
.collect()
});
boosted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
tracker.record(sid, boosted.iter().map(|(m, _)| m.id.clone()));
boosted
}
const ORIGINAL_WEIGHT: f64 = 0.6;
const CROSS_ENCODER_WEIGHT: f64 = 0.4;
fn finite_or_floor(score: f64) -> f64 {
if score.is_finite() { score } else { f64::MIN }
}
fn split_rerank_pool(
mut candidates: Vec<(Memory, f64)>,
) -> (Vec<(Memory, f64)>, Vec<(Memory, f64)>) {
let tail = if candidates.len() > RERANK_POOL_MAX {
candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
candidates.split_off(RERANK_POOL_MAX)
} else {
Vec::new()
};
(candidates, tail)
}
pub const RERANK_POOL_MAX: usize = 20;
const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
pub(crate) const DEFAULT_RERANKER_MODEL: &str = "ms-marco-MiniLM-L-6-v2";
pub const CROSS_ENCODER_MAX_SEQ: usize = 512;
const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
pub const RERANK_MAX_SEQ_DEFAULT: usize = 256;
static RERANK_MAX_SEQ: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
pub fn set_rerank_max_seq(tokens: usize) {
let _ = RERANK_MAX_SEQ.set(tokens);
}
fn rerank_max_seq() -> usize {
*RERANK_MAX_SEQ.get().unwrap_or(&RERANK_MAX_SEQ_DEFAULT)
}
pub const DEFAULT_REFLECTION_BOOST: f32 = 1.2;
pub const DEFAULT_REFLECTION_PER_DEPTH_INCREMENT: f32 = 0.05;
pub const DEFAULT_REFLECTION_MAX_DEPTH_CAP: u32 = 3;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ReflectionBoostConfig {
pub boost: f32,
pub per_depth_increment: f32,
pub max_depth_cap: u32,
}
impl Default for ReflectionBoostConfig {
fn default() -> Self {
Self {
boost: DEFAULT_REFLECTION_BOOST,
per_depth_increment: DEFAULT_REFLECTION_PER_DEPTH_INCREMENT,
max_depth_cap: DEFAULT_REFLECTION_MAX_DEPTH_CAP,
}
}
}
impl ReflectionBoostConfig {
#[must_use]
pub const fn disabled() -> Self {
Self {
boost: 1.0,
per_depth_increment: 0.0,
max_depth_cap: 0,
}
}
#[must_use]
pub fn factor_for(&self, mem: &Memory) -> f64 {
if !matches!(mem.memory_kind, crate::models::MemoryKind::Reflection) {
return 1.0;
}
let depth = u32::try_from(mem.reflection_depth.max(0)).unwrap_or(0);
let depth_clamped = depth.min(self.max_depth_cap);
let per_depth_factor =
f64::from(self.per_depth_increment).mul_add(f64::from(depth_clamped), 1.0);
f64::from(self.boost) * per_depth_factor
}
}
pub enum CrossEncoder {
Lexical { degraded: bool },
Neural {
model: Arc<BertModel>,
tokenizer: Arc<Tokenizer>,
classifier_weight: Tensor,
classifier_bias: Tensor,
device: Device,
},
}
impl CrossEncoder {
pub fn new() -> Self {
Self::Lexical { degraded: false }
}
pub fn new_neural() -> Self {
match Self::load_neural() {
Ok(ce) => ce,
Err(e) => {
tracing::warn!(
target: "reranker.fallback",
from = "neural",
to = "lexical",
reason = %e,
"cross-encoder fell back to lexical: neural init failed"
);
eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
Self::Lexical { degraded: true }
}
}
}
fn load_neural() -> Result<Self> {
let device = Device::Cpu;
let api = Api::new().context("failed to init HuggingFace Hub API")?;
let repo = api.repo(Repo::new(
CROSS_ENCODER_MODEL_ID.to_string(),
RepoType::Model,
));
let config_path = repo
.get(crate::embeddings::HF_CONFIG_FILE)
.context("failed to download config.json")?;
let tokenizer_path = repo
.get(crate::embeddings::HF_TOKENIZER_FILE)
.context("failed to download tokenizer.json")?;
let weights_path = repo
.get(crate::embeddings::HF_WEIGHTS_FILE)
.context("failed to download model.safetensors")?;
let config_data = std::fs::read_to_string(&config_path)
.context("failed to read cross-encoder config.json")?;
let config: BertConfig = serde_json::from_str(&config_data)
.context("failed to parse cross-encoder config.json")?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
let truncation = tokenizers::TruncationParams {
max_length: CROSS_ENCODER_MAX_SEQ,
..Default::default()
};
tokenizer
.with_truncation(Some(truncation))
.map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
tokenizer.with_padding(None);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
.context("failed to load cross-encoder weights")?
};
let model = BertModel::load(vb.clone(), &config)
.context("failed to build cross-encoder BertModel")?;
let classifier_weight = vb
.get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
.context("failed to load classifier.weight")?;
let classifier_bias = vb
.get(1, "classifier.bias")
.context("failed to load classifier.bias")?;
Ok(Self::Neural {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
classifier_weight,
classifier_bias,
device,
})
}
pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
match self {
Self::Lexical { .. } => lexical_score(query, title, content),
Self::Neural {
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
} => {
match Self::neural_score(
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
query,
title,
content,
) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
"neural cross-encoder score failed: {e}, using lexical fallback"
);
lexical_score(query, title, content)
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn neural_score(
model: &BertModel,
tokenizer: &Tokenizer,
classifier_weight: &Tensor,
classifier_bias: &Tensor,
device: &Device,
query: &str,
title: &str,
content: &str,
) -> Result<f32> {
let document = crate::embeddings::embedding_document(title, content);
let encoding = tokenizer
.encode((query, document.as_str()), true)
.map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_type_ids = encoding.get_type_ids();
let seq_len = input_ids.len();
let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
let logit = cls
.matmul(&classifier_weight.t()?)?
.broadcast_add(classifier_bias)?;
let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
let score = 1.0 / (1.0 + (-logit_val).exp());
Ok(score)
}
pub fn is_neural(&self) -> bool {
matches!(self, Self::Neural { .. })
}
#[must_use]
pub fn is_degraded_lexical(&self) -> bool {
matches!(self, Self::Lexical { degraded: true })
}
pub fn rerank(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
self.rerank_with_reflection_boost(query, candidates, &ReflectionBoostConfig::disabled())
}
pub fn rerank_with_reflection_boost(
&self,
query: &str,
candidates: Vec<(Memory, f64)>,
boost_config: &ReflectionBoostConfig,
) -> Vec<(Memory, f64)> {
let (head, tail) = split_rerank_pool(candidates);
let ce_scores = self.pair_scores(query, &head);
let mut scored: Vec<(Memory, f64)> = head
.into_iter()
.zip(ce_scores)
.map(|((mem, original_score), ce_score)| {
let blended =
ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * f64::from(ce_score);
let factor = boost_config.factor_for(&mem);
(mem, finite_or_floor(blended * factor))
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.extend(tail);
scored
}
fn pair_scores(&self, query: &str, candidates: &[(Memory, f64)]) -> Vec<f32> {
let lexical_fallback = |candidates: &[(Memory, f64)]| -> Vec<f32> {
candidates
.iter()
.map(|(mem, _)| lexical_score(query, &mem.title, &mem.content))
.collect()
};
match self {
Self::Lexical { .. } => lexical_fallback(candidates),
Self::Neural {
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
} => {
let pairs: Vec<(&str, String)> = candidates
.iter()
.map(|(mem, _)| {
(
query,
crate::embeddings::embedding_document(&mem.title, &mem.content),
)
})
.collect();
match Self::neural_score_pairs(
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
pairs,
) {
Ok(scores) => scores,
Err(e) => {
tracing::warn!(
"neural cross-encoder batch score failed: {e}, using lexical fallback"
);
lexical_fallback(candidates)
}
}
}
}
}
pub fn rerank_batch(
&self,
queries: Vec<(String, Vec<(Memory, f64)>)>,
) -> Vec<Vec<(Memory, f64)>> {
self.rerank_batch_with_reflection_boost(queries, &ReflectionBoostConfig::disabled())
}
pub fn rerank_batch_with_reflection_boost(
&self,
queries: Vec<(String, Vec<(Memory, f64)>)>,
boost_config: &ReflectionBoostConfig,
) -> Vec<Vec<(Memory, f64)>> {
if queries.len() == 1 {
let mut iter = queries.into_iter();
let (q, cands) = iter.next().expect("len == 1");
return vec![self.rerank_with_reflection_boost(&q, cands, boost_config)];
}
match self {
Self::Lexical { .. } => queries
.into_iter()
.map(|(q, cands)| self.rerank_with_reflection_boost(&q, cands, boost_config))
.collect(),
Self::Neural {
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
} => {
let mut tails: Vec<Vec<(Memory, f64)>> = Vec::with_capacity(queries.len());
let queries: Vec<(String, Vec<(Memory, f64)>)> = queries
.into_iter()
.map(|(q, cands)| {
let (head, tail) = split_rerank_pool(cands);
tails.push(tail);
(q, head)
})
.collect();
match Self::neural_rerank_batch(
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
&queries,
) {
Ok(scores) => {
let mut out = Vec::with_capacity(queries.len());
let mut cursor = 0usize;
for ((_query, cands), tail) in queries.into_iter().zip(tails) {
let n = cands.len();
let mut scored: Vec<(Memory, f64)> = cands
.into_iter()
.enumerate()
.map(|(i, (mem, original))| {
let ce = f64::from(scores[cursor + i]);
let blended =
ORIGINAL_WEIGHT * original + CROSS_ENCODER_WEIGHT * ce;
let factor = boost_config.factor_for(&mem);
(mem, finite_or_floor(blended * factor))
})
.collect();
cursor += n;
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
scored.extend(tail);
out.push(scored);
}
out
}
Err(e) => {
tracing::warn!(
"neural rerank_batch failed: {e}, falling back to lexical per-query"
);
queries
.into_iter()
.zip(tails)
.map(|((q, cands), tail)| {
let lex = Self::Lexical { degraded: true };
let mut scored =
lex.rerank_with_reflection_boost(&q, cands, boost_config);
scored.extend(tail);
scored
})
.collect()
}
}
}
}
}
fn neural_rerank_batch(
model: &BertModel,
tokenizer: &Tokenizer,
classifier_weight: &Tensor,
classifier_bias: &Tensor,
device: &Device,
queries: &[(String, Vec<(Memory, f64)>)],
) -> Result<Vec<f32>> {
let mut pairs: Vec<(&str, String)> = Vec::new();
for (q, cands) in queries {
for (mem, _) in cands {
let document = crate::embeddings::embedding_document(&mem.title, &mem.content);
pairs.push((q.as_str(), document));
}
}
Self::neural_score_pairs(
model,
tokenizer,
classifier_weight,
classifier_bias,
device,
pairs,
)
}
fn neural_score_pairs(
model: &BertModel,
tokenizer: &Tokenizer,
classifier_weight: &Tensor,
classifier_bias: &Tensor,
device: &Device,
pairs: Vec<(&str, String)>,
) -> Result<Vec<f32>> {
if pairs.is_empty() {
return Ok(Vec::new());
}
let mut batch_tokenizer = tokenizer.clone();
let padding = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
direction: tokenizers::PaddingDirection::Right,
pad_id: 0,
pad_type_id: 0,
pad_token: "[PAD]".to_string(),
..Default::default()
};
batch_tokenizer.with_padding(Some(padding));
let truncation = tokenizers::TruncationParams {
max_length: rerank_max_seq(),
..Default::default()
};
batch_tokenizer
.with_truncation(Some(truncation))
.map_err(|e| anyhow::anyhow!("failed to set rerank truncation: {e}"))?;
let encodings = batch_tokenizer
.encode_batch(
pairs
.into_iter()
.map(|(q, d)| tokenizers::EncodeInput::Dual(q.into(), d.into()))
.collect::<Vec<_>>(),
true,
)
.map_err(|e| anyhow::anyhow!("cross-encoder batch tokenization failed: {e}"))?;
let batch_size = encodings.len();
let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
let mut input_ids: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
let mut attn_mask: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
let mut token_types: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
for enc in &encodings {
input_ids.extend_from_slice(enc.get_ids());
attn_mask.extend_from_slice(enc.get_attention_mask());
token_types.extend_from_slice(enc.get_type_ids());
}
let input_ids = Tensor::from_vec(input_ids, (batch_size, seq_len), device)?;
let attention_mask = Tensor::from_vec(attn_mask, (batch_size, seq_len), device)?;
let token_type_ids = Tensor::from_vec(token_types, (batch_size, seq_len), device)?;
let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
let logits = cls
.matmul(&classifier_weight.t()?)?
.broadcast_add(classifier_bias)?;
let logits_vec: Vec<f32> = logits.squeeze(1)?.to_vec1()?;
Ok(logits_vec
.into_iter()
.map(|l| 1.0 / (1.0 + (-l).exp()))
.collect())
}
}
impl Default for CrossEncoder {
fn default() -> Self {
Self::new()
}
}
fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
let query_terms = tokenize(query);
if query_terms.is_empty() {
return 0.0;
}
let title_terms = tokenize(title);
let content_terms = tokenize(content);
let doc_terms: HashSet<&str> = title_terms
.iter()
.chain(content_terms.iter())
.copied()
.collect();
let query_set: HashSet<&str> = query_terms.iter().copied().collect();
#[allow(clippy::cast_precision_loss)]
let intersection = query_set.intersection(&doc_terms).count() as f32;
#[allow(clippy::cast_precision_loss)]
let union = query_set.union(&doc_terms).count() as f32;
let jaccard = if union > 0.0 {
intersection / union
} else {
0.0
};
let doc_all: Vec<&str> = title_terms
.iter()
.chain(content_terms.iter())
.copied()
.collect();
let tf_idf = tfidf_score(&query_terms, &doc_all);
let query_bigrams = bigrams(&query_terms);
let doc_bigrams = bigrams(&doc_all);
let bigram_overlap = if query_bigrams.is_empty() {
0.0
} else {
let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
#[allow(clippy::cast_precision_loss)]
let hits = query_bigrams
.iter()
.filter(|b| doc_bigram_set.contains(b))
.count() as f32;
#[allow(clippy::cast_precision_loss)]
let query_bigrams_len = query_bigrams.len() as f32;
hits / query_bigrams_len
};
let title_set: HashSet<&str> = title_terms.iter().copied().collect();
#[allow(clippy::cast_precision_loss)]
let title_hits = query_set.intersection(&title_set).count() as f32;
#[allow(clippy::cast_precision_loss)]
let title_bonus = if query_set.is_empty() {
0.0
} else {
title_hits / query_set.len() as f32
};
let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
raw.clamp(0.0, 1.0)
}
fn tokenize(text: &str) -> Vec<&str> {
text.split(|c: char| !c.is_alphanumeric() && c != '\'')
.filter(|w| !w.is_empty())
.collect()
}
fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
if doc_tokens.is_empty() || query_terms.is_empty() {
return 0.0;
}
let mut tf_map: HashMap<&str, usize> = HashMap::new();
for &tok in doc_tokens {
*tf_map.entry(tok).or_insert(0) += 1;
}
#[allow(clippy::cast_precision_loss)]
let total = doc_tokens.len() as f32;
#[allow(clippy::cast_precision_loss)]
let unique = tf_map.len() as f32;
let mut score_sum: f32 = 0.0;
let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
for qt in &query_lower {
#[allow(clippy::cast_precision_loss)]
let tf = tf_map
.iter()
.filter(|(k, _)| k.to_lowercase() == *qt)
.map(|(_, &v)| v)
.sum::<usize>() as f32;
if tf == 0.0 {
continue;
}
let tf_norm = tf / total;
#[allow(clippy::cast_precision_loss)]
let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
score_sum += tf_norm * idf;
}
#[allow(clippy::cast_precision_loss)]
let max_possible = query_lower.len() as f32;
(score_sum / max_possible).clamp(0.0, 1.0)
}
fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
tokens.windows(2).map(|w| (w[0], w[1])).collect()
}
pub const DEFAULT_MAX_BATCH: usize = 32;
pub const DEFAULT_MAX_WAIT_MS: u64 = 5;
pub const BATCHED_RERANK_MIN_CONCURRENCY: usize = 2;
#[must_use]
pub const fn use_batched_rerank_path(encoder_is_neural: bool, inflight_now: usize) -> bool {
encoder_is_neural && inflight_now >= BATCHED_RERANK_MIN_CONCURRENCY
}
struct RerankJob {
query: String,
candidates: Vec<(Memory, f64)>,
reply: std::sync::mpsc::SyncSender<Vec<(Memory, f64)>>,
}
pub struct BatchedReranker {
sender: Option<Sender<RerankJob>>,
shutdown: Option<std::sync::mpsc::Sender<()>>,
worker: Option<JoinHandle<()>>,
encoder: Arc<CrossEncoder>,
reflection_boost: ReflectionBoostConfig,
score_floor: RerankerScoreFloor,
inflight: std::sync::atomic::AtomicUsize,
worker_submissions: std::sync::atomic::AtomicUsize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RerankerScoreFloor {
Off,
Absolute(f64),
RelativeToTop(f64),
}
impl Default for RerankerScoreFloor {
fn default() -> Self {
Self::Off
}
}
impl RerankerScoreFloor {
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
let s = s.trim();
if s.eq_ignore_ascii_case("off") {
return Some(Self::Off);
}
let (kind, value) = s.split_once(':')?;
let v: f64 = value.trim().parse().ok()?;
if !v.is_finite() {
return None;
}
match kind.trim().to_ascii_lowercase().as_str() {
"absolute" | "abs" => Some(Self::Absolute(v)),
"relative" | "rel" | "relative_to_top" => Some(Self::RelativeToTop(v)),
_ => None,
}
}
fn apply(&self, scored: &mut Vec<(Memory, f64)>) {
if scored.is_empty() {
return;
}
let cutoff: f64 = match *self {
Self::Off => return,
Self::Absolute(v) => v.clamp(0.0, 1.0),
Self::RelativeToTop(ratio) => {
let top = scored.first().map(|(_, s)| *s).unwrap_or(0.0);
top * ratio.clamp(0.0, 1.0)
}
};
let mut keep = Vec::with_capacity(scored.len());
for (idx, (_, score)) in scored.iter().enumerate() {
if idx == 0 || *score >= cutoff {
keep.push(idx);
}
}
let mut next_keep = keep.iter().rev().copied();
let mut want = next_keep.next();
let mut idx = scored.len();
while idx > 0 {
idx -= 1;
match want {
Some(k) if k == idx => {
want = next_keep.next();
}
_ => {
scored.remove(idx);
}
}
}
}
}
impl BatchedReranker {
pub fn new(encoder: CrossEncoder) -> Self {
Self::with_params(encoder, DEFAULT_MAX_BATCH, DEFAULT_MAX_WAIT_MS)
}
pub fn with_params(encoder: CrossEncoder, max_batch: usize, max_wait_ms: u64) -> Self {
Self::with_full_params(
encoder,
max_batch,
max_wait_ms,
ReflectionBoostConfig::default(),
RerankerScoreFloor::Off,
)
}
pub fn with_reflection_boost(encoder: CrossEncoder, boost: ReflectionBoostConfig) -> Self {
Self::with_full_params(
encoder,
DEFAULT_MAX_BATCH,
DEFAULT_MAX_WAIT_MS,
boost,
RerankerScoreFloor::Off,
)
}
#[must_use]
pub fn with_score_floor(encoder: CrossEncoder, floor: RerankerScoreFloor) -> Self {
Self::with_full_params(
encoder,
DEFAULT_MAX_BATCH,
DEFAULT_MAX_WAIT_MS,
ReflectionBoostConfig::default(),
floor,
)
}
fn with_full_params(
encoder: CrossEncoder,
max_batch: usize,
max_wait_ms: u64,
reflection_boost: ReflectionBoostConfig,
score_floor: RerankerScoreFloor,
) -> Self {
let encoder = Arc::new(encoder);
let (tx, rx) = std::sync::mpsc::channel::<RerankJob>();
let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>();
let worker_encoder = Arc::clone(&encoder);
let worker_boost = reflection_boost;
let max_wait = Duration::from_millis(max_wait_ms);
let worker = thread::Builder::new()
.name("ai-memory-reranker-batcher".into())
.spawn(move || {
const SHUTDOWN_POLL: Duration = Duration::from_millis(100);
'outer: loop {
let first = loop {
match shutdown_rx.try_recv() {
Ok(()) | Err(std::sync::mpsc::TryRecvError::Disconnected) => {
break 'outer;
}
Err(std::sync::mpsc::TryRecvError::Empty) => {}
}
match rx.recv_timeout(SHUTDOWN_POLL) {
Ok(job) => break job,
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => continue,
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
break 'outer;
}
}
};
let mut batch: Vec<RerankJob> = Vec::with_capacity(max_batch);
batch.push(first);
let deadline = Instant::now() + max_wait;
while batch.len() < max_batch {
let now = Instant::now();
if now >= deadline {
break;
}
match rx.recv_timeout(deadline - now) {
Ok(j) => batch.push(j),
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => break,
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
process_batch(&worker_encoder, batch, &worker_boost);
break 'outer;
}
}
}
process_batch(&worker_encoder, batch, &worker_boost);
}
})
.expect("failed to spawn rerank batcher worker");
Self {
sender: Some(tx),
shutdown: Some(shutdown_tx),
worker: Some(worker),
encoder,
reflection_boost,
score_floor,
inflight: std::sync::atomic::AtomicUsize::new(0),
worker_submissions: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn rerank(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
let mut scored = self.rerank_unfloored(query, candidates);
self.score_floor.apply(&mut scored);
scored
}
#[must_use]
pub fn rerank_coalesced(
&self,
query: &str,
candidates: Vec<(Memory, f64)>,
) -> Vec<(Memory, f64)> {
let mut scored = self.rerank_coalesced_unfloored(query, candidates);
self.score_floor.apply(&mut scored);
scored
}
fn rerank_unfloored(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
use std::sync::atomic::Ordering;
struct InflightGuard<'a>(&'a std::sync::atomic::AtomicUsize);
impl Drop for InflightGuard<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
let inflight_now = self.inflight.fetch_add(1, Ordering::Relaxed) + 1;
let _guard = InflightGuard(&self.inflight);
if use_batched_rerank_path(self.encoder.is_neural(), inflight_now) {
self.rerank_coalesced_unfloored(query, candidates)
} else {
self.rerank_direct_unfloored(query, candidates)
}
}
fn rerank_direct_unfloored(
&self,
query: &str,
candidates: Vec<(Memory, f64)>,
) -> Vec<(Memory, f64)> {
self.encoder
.rerank_with_reflection_boost(query, candidates, &self.reflection_boost)
}
fn rerank_coalesced_unfloored(
&self,
query: &str,
candidates: Vec<(Memory, f64)>,
) -> Vec<(Memory, f64)> {
let Some(sender) = self.sender.as_ref() else {
return self.rerank_direct_unfloored(query, candidates);
};
let (reply_tx, reply_rx) = sync_channel::<Vec<(Memory, f64)>>(1);
let job = RerankJob {
query: query.to_string(),
candidates,
reply: reply_tx,
};
if sender.send(job).is_err() {
return self.encoder.rerank_with_reflection_boost(
query,
Vec::new(),
&self.reflection_boost,
);
}
self.worker_submissions
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
reply_rx.recv().unwrap_or_else(|_| {
self.encoder
.rerank_with_reflection_boost(query, Vec::new(), &self.reflection_boost)
})
}
#[must_use]
pub fn worker_submissions(&self) -> usize {
self.worker_submissions
.load(std::sync::atomic::Ordering::Relaxed)
}
#[must_use]
pub fn score_floor(&self) -> RerankerScoreFloor {
self.score_floor
}
#[must_use]
pub fn reflection_boost(&self) -> &ReflectionBoostConfig {
&self.reflection_boost
}
pub fn encoder(&self) -> &CrossEncoder {
&self.encoder
}
pub fn is_neural(&self) -> bool {
self.encoder.is_neural()
}
#[must_use]
pub fn is_degraded_lexical(&self) -> bool {
self.encoder.is_degraded_lexical()
}
}
impl Drop for BatchedReranker {
fn drop(&mut self) {
if let Some(shutdown) = self.shutdown.take() {
let _ = shutdown.send(());
}
self.sender.take();
if let Some(handle) = self.worker.take() {
let _ = handle.join();
}
}
}
fn process_batch(
encoder: &CrossEncoder,
batch: Vec<RerankJob>,
boost_config: &ReflectionBoostConfig,
) {
if batch.is_empty() {
return;
}
if batch.len() == 1 {
let mut iter = batch.into_iter();
let job = iter.next().expect("len == 1");
let result = encoder.rerank_with_reflection_boost(&job.query, job.candidates, boost_config);
let _ = job.reply.send(result);
return;
}
let mut queries: Vec<(String, Vec<(Memory, f64)>)> = Vec::with_capacity(batch.len());
let mut replies: Vec<std::sync::mpsc::SyncSender<Vec<(Memory, f64)>>> =
Vec::with_capacity(batch.len());
for job in batch {
queries.push((job.query, job.candidates));
replies.push(job.reply);
}
let outputs = encoder.rerank_batch_with_reflection_boost(queries, boost_config);
for (out, reply) in outputs.into_iter().zip(replies.into_iter()) {
let _ = reply.send(out);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{Memory, Tier};
#[test]
fn rerank_max_seq_1604_seed_once_semantics() {
set_rerank_max_seq(192);
let settled = rerank_max_seq();
assert!(
settled > 0,
"settled value must be a real cap (ours or an earlier boot seed), got {settled}"
);
set_rerank_max_seq(64);
assert_eq!(
rerank_max_seq(),
settled,
"first writer must win — a later set_rerank_max_seq call must be a no-op"
);
}
fn make_memory(title: &str, content: &str) -> Memory {
Memory {
id: "test-id".to_string(),
tier: Tier::Mid,
namespace: "test".to_string(),
title: title.to_string(),
content: content.to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: "2026-01-01T00:00:00Z".to_string(),
updated_at: "2026-01-01T00:00:00Z".to_string(),
last_accessed_at: None,
expires_at: None,
metadata: serde_json::json!({}),
reflection_depth: 0,
memory_kind: crate::models::MemoryKind::Observation,
entity_id: None,
persona_version: None,
citations: Vec::new(),
source_uri: None,
source_span: None,
confidence_source: crate::models::ConfidenceSource::CallerProvided,
confidence_signals: None,
confidence_decayed_at: None,
version: 1,
}
}
#[test]
fn nan_scored_candidate_sinks_to_bottom_m13() {
let ce = CrossEncoder::Lexical { degraded: false };
let poisoned = make_memory("poisoned", "irrelevant body");
let good = make_memory("network configuration", "network configuration body");
let out = ce.rerank(
"network configuration",
vec![(poisoned, f64::NAN), (good, 0.9)],
);
assert_eq!(
out[0].0.title, "network configuration",
"finite-scored candidate must outrank the NaN-poisoned one"
);
assert_eq!(out[1].0.title, "poisoned");
assert_eq!(
out[1].1,
f64::MIN,
"non-finite blended score must clamp to the ranking floor"
);
let poisoned = make_memory("poisoned", "irrelevant body");
let good = make_memory("network configuration", "network configuration body");
let out = ce.rerank_with_reflection_boost(
"network configuration",
vec![(poisoned, f64::NAN), (good, 0.9)],
&ReflectionBoostConfig::disabled(),
);
assert_eq!(out[0].0.title, "network configuration");
assert_eq!(out[1].1, f64::MIN);
}
#[test]
fn lexical_score_returns_zero_for_empty_query() {
assert_eq!(lexical_score("", "some title", "some content"), 0.0);
}
#[test]
fn lexical_score_returns_zero_for_no_overlap() {
let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
assert!(s < 0.05, "expected near-zero, got {s}");
}
#[test]
fn lexical_score_rewards_title_match() {
let content = "This document discusses network configuration for LAN setups.";
let s_title_match = lexical_score(
"network configuration",
"Network Configuration Guide",
content,
);
let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
assert!(
s_title_match > s_no_title,
"title match ({s_title_match}) should beat no title match ({s_no_title})"
);
}
#[test]
fn lexical_score_is_bounded_zero_one() {
let s = lexical_score(
"the quick brown fox jumps over the lazy dog",
"the quick brown fox",
"the quick brown fox jumps over the lazy dog and more words",
);
assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
}
#[test]
fn rerank_reorders_candidates() {
let ce = CrossEncoder::new();
let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
let b = make_memory("Grocery list", "milk eggs bread butter cheese");
let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
let reranked = ce.rerank("cross-encoder reranking", candidates);
assert_eq!(reranked[0].0.title, "Rust cross-encoder");
}
#[test]
fn rerank_preserves_candidate_count() {
let ce = CrossEncoder::new();
let candidates = vec![
(make_memory("A", "alpha"), 0.5),
(make_memory("B", "beta"), 0.6),
(make_memory("C", "gamma"), 0.7),
];
let reranked = ce.rerank("alpha", candidates);
assert_eq!(reranked.len(), 3);
}
#[test]
fn bigram_overlap_boosts_phrase_match() {
let s_phrase = lexical_score(
"network adapter",
"title",
"the network adapter is connected to the LAN",
);
let s_scattered = lexical_score(
"network adapter",
"title",
"the adapter handles the network traffic independently",
);
assert!(
s_phrase > s_scattered,
"phrase match ({s_phrase}) should beat scattered ({s_scattered})"
);
}
#[test]
fn test_rerank_preserves_input_count_heuristic() {
let ce = CrossEncoder::new();
let candidates: Vec<(Memory, f64)> = (0..5)
.map(|i| {
(
make_memory(
&format!("title {i}"),
&format!("content body number {i} with some words"),
),
f64::from(i) * 0.1,
)
})
.collect();
let query = "title content body";
let reranked = ce.rerank(query, candidates);
assert_eq!(
reranked.len(),
5,
"heuristic rerank must preserve candidate count, got {} = {:?}",
reranked.len(),
reranked
.iter()
.map(|(m, s)| (&m.title, *s))
.collect::<Vec<_>>()
);
for w in reranked.windows(2) {
assert!(
w[0].1 >= w[1].1,
"rerank output must be descending by score: {} < {}",
w[0].1,
w[1].1
);
}
}
#[test]
fn test_rerank_zero_candidates_returns_empty_heuristic() {
let ce = CrossEncoder::new();
let reranked = ce.rerank("query", Vec::new());
assert!(reranked.is_empty());
}
#[cfg(feature = "test-with-models")]
#[test]
fn test_rerank_preserves_input_count_neural_if_available() {
let ce = CrossEncoder::new_neural();
let candidates: Vec<(Memory, f64)> = (0..5)
.map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
.collect();
let reranked = ce.rerank("body", candidates);
assert_eq!(reranked.len(), 5);
}
#[test]
fn w12e_default_is_lexical() {
let ce = CrossEncoder::default();
assert!(!ce.is_neural(), "Default::default() must return Lexical");
}
#[test]
fn w12e_new_returns_lexical() {
let ce = CrossEncoder::new();
assert!(!ce.is_neural());
}
#[test]
fn w12e_score_dispatch_lexical_matches_helper() {
let ce = CrossEncoder::new();
let q = "rust async runtime";
let title = "Tokio: Rust async runtime";
let content = "Tokio is an async runtime for the Rust programming language.";
let via_dispatcher = ce.score(q, title, content);
let direct = lexical_score(q, title, content);
assert!((via_dispatcher - direct).abs() < f32::EPSILON);
}
#[test]
fn w12e_score_empty_inputs_safe() {
let ce = CrossEncoder::new();
assert_eq!(ce.score("", "title", "content"), 0.0);
let s = ce.score("query", "", "");
assert!((0.0..=1.0).contains(&s));
let s_ws = ce.score(" \t\n", "title", "content");
assert_eq!(s_ws, 0.0);
let s_punct = ce.score("!?.,;:", "title", "content");
assert_eq!(s_punct, 0.0);
}
#[test]
fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
let s_unicode = lexical_score(
"café résumé d'oeuvre",
"Le Café d'Oeuvre",
"résumé du café avec d'oeuvre noté",
);
assert!(
(0.0..=1.0).contains(&s_unicode),
"unicode score {s_unicode} out of bounds"
);
let huge = "alpha beta gamma delta ".repeat(2_500);
let s_long = lexical_score("alpha gamma", "headline", &huge);
assert!(
(0.0..=1.0).contains(&s_long),
"long score {s_long} out of bounds"
);
}
#[test]
fn w12e_lexical_score_perfect_overlap_high() {
let s = lexical_score(
"alpha beta gamma",
"alpha beta gamma",
"alpha beta gamma alpha beta gamma",
);
assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
assert!(s <= 1.0);
}
#[test]
fn w12e_tfidf_score_empty_doc_returns_zero() {
let q = vec!["alpha", "beta"];
let doc: Vec<&str> = Vec::new();
assert_eq!(tfidf_score(&q, &doc), 0.0);
}
#[test]
fn w12e_tfidf_score_empty_query_returns_zero() {
let q: Vec<&str> = Vec::new();
let doc = vec!["alpha", "beta", "gamma"];
assert_eq!(tfidf_score(&q, &doc), 0.0);
}
#[test]
fn w12e_tfidf_score_no_matching_terms() {
let q = vec!["xenon", "kryptonite"];
let doc = vec!["alpha", "beta", "gamma"];
let s = tfidf_score(&q, &doc);
assert_eq!(s, 0.0);
}
#[test]
fn w12e_tfidf_score_partial_match_bounded() {
let q = vec!["alpha", "missing"];
let doc = vec!["alpha", "alpha", "beta", "gamma"];
let s = tfidf_score(&q, &doc);
assert!((0.0..=1.0).contains(&s));
assert!(s > 0.0);
}
#[test]
fn w12e_bigrams_empty_and_single_and_multi() {
let empty: Vec<&str> = Vec::new();
assert!(bigrams(&empty).is_empty());
let one = vec!["solo"];
assert!(bigrams(&one).is_empty());
let three = vec!["a", "b", "c"];
let bg = bigrams(&three);
assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
}
#[test]
fn w12e_tokenize_handles_apostrophe_and_unicode() {
let toks = tokenize("don't stop, I won't!");
assert!(toks.contains(&"don't"));
assert!(toks.contains(&"won't"));
assert!(toks.contains(&"stop"));
assert!(toks.contains(&"I"));
let none = tokenize("!!!,,,;;;");
assert!(none.is_empty());
let empty = tokenize("");
assert!(empty.is_empty());
let unicode = tokenize("café résumé");
assert_eq!(unicode.len(), 2);
}
#[test]
fn w12e_rerank_single_candidate_keeps_it() {
let ce = CrossEncoder::new();
let only = make_memory("solo title", "solo content body");
let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
assert_eq!(out.len(), 1);
assert_eq!(out[0].0.title, "solo title");
assert!(out[0].1 >= 0.0);
}
#[test]
fn w12e_rerank_identical_originals_stable_under_score() {
let ce = CrossEncoder::new();
let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
let off_topic = make_memory("grocery", "milk eggs bread");
let out = ce.rerank(
"rust async",
vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
);
assert_eq!(out.len(), 2);
assert_eq!(out[0].0.title, "rust async runtime");
}
#[test]
fn w12e_rerank_descending_invariant_holds_across_shapes() {
let ce = CrossEncoder::new();
let cands: Vec<(Memory, f64)> = vec![
(make_memory("a", "alpha words"), 0.10),
(make_memory("b", "beta words"), 0.95),
(make_memory("c", "gamma alpha"), 0.55),
(make_memory("d", ""), 0.0),
(make_memory("", "empty title doc"), 0.30),
];
let out = ce.rerank("alpha", cands);
assert_eq!(out.len(), 5);
for w in out.windows(2) {
assert!(
w[0].1 >= w[1].1,
"non-descending pair: {} then {}",
w[0].1,
w[1].1
);
}
}
#[test]
fn w12e_lexical_score_no_title_branch_via_empty_title() {
let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
assert!(s_with_title >= s_empty_title);
assert!((0.0..=1.0).contains(&s_empty_title));
}
#[test]
fn w12e_lexical_score_query_terms_only_in_title() {
let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
assert!(s > 0.0);
assert!(s <= 1.0);
}
#[test]
fn pr9i_new_neural_dual_outcome() {
let ce = CrossEncoder::new_neural();
let s = ce.score("query", "title", "content");
assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
}
#[test]
fn g9_rerank_batch_matches_per_query_rerank_lexical() {
let ce = CrossEncoder::new();
let queries = vec!["alpha gamma", "beta words", "rust async"];
let mut jobs: Vec<(String, Vec<(Memory, f64)>)> = Vec::new();
let mut expected: Vec<Vec<(Memory, f64)>> = Vec::new();
for q in &queries {
let cands: Vec<(Memory, f64)> = (0..5)
.map(|i| {
(
make_memory(
&format!("title-{i}-{q}"),
&format!("alpha beta gamma rust async body {i} {q}"),
),
f64::from(i) * 0.1,
)
})
.collect();
expected.push(ce.rerank(q, cands.clone()));
jobs.push(((*q).to_string(), cands));
}
let batched = ce.rerank_batch(jobs);
assert_eq!(batched.len(), expected.len());
for (b, e) in batched.iter().zip(expected.iter()) {
assert_eq!(b.len(), e.len());
for (bi, ei) in b.iter().zip(e.iter()) {
assert_eq!(bi.0.id, ei.0.id);
assert_eq!(bi.0.title, ei.0.title);
assert!(
(bi.1 - ei.1).abs() < 1e-12,
"blended score mismatch: batched={} per-query={}",
bi.1,
ei.1
);
}
}
}
#[test]
fn g9_rerank_batch_single_query_short_circuits() {
let ce = CrossEncoder::new();
let cands: Vec<(Memory, f64)> = (0..5)
.map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
.collect();
let direct = ce.rerank("body", cands.clone());
let batched = ce.rerank_batch(vec![("body".to_string(), cands)]);
assert_eq!(batched.len(), 1);
assert_eq!(batched[0].len(), direct.len());
for (a, b) in batched[0].iter().zip(direct.iter()) {
assert_eq!(a.0.id, b.0.id);
assert!((a.1 - b.1).abs() < 1e-12);
}
}
#[test]
fn g9_rerank_batch_empty_inputs() {
let ce = CrossEncoder::new();
let out = ce.rerank_batch(Vec::new());
assert!(out.is_empty());
let out2 = ce.rerank_batch(vec![
("q1".to_string(), Vec::new()),
("q2".to_string(), Vec::new()),
]);
assert_eq!(out2.len(), 2);
assert!(out2.iter().all(std::vec::Vec::is_empty));
}
#[test]
fn g9_batched_reranker_serial_calls_match_rerank() {
use super::BatchedReranker;
let batched = BatchedReranker::new(CrossEncoder::new());
let cands: Vec<(Memory, f64)> = (0..4)
.map(|i| {
(
make_memory(
&format!("t{i}"),
&format!("alpha gamma body {i} content words"),
),
f64::from(i) * 0.1,
)
})
.collect();
let direct = CrossEncoder::new().rerank("alpha", cands.clone());
let via_batcher = batched.rerank("alpha", cands);
assert_eq!(via_batcher.len(), direct.len());
for (a, b) in via_batcher.iter().zip(direct.iter()) {
assert_eq!(a.0.id, b.0.id);
assert!((a.1 - b.1).abs() < 1e-12);
}
}
#[test]
fn g9_batched_reranker_concurrent_calls_all_succeed() {
use super::BatchedReranker;
use std::sync::Arc;
let batched = Arc::new(BatchedReranker::new(CrossEncoder::new()));
let mut handles = Vec::new();
for i in 0..8 {
let b = Arc::clone(&batched);
handles.push(std::thread::spawn(move || {
let cands: Vec<(Memory, f64)> = (0..5)
.map(|j| {
(
make_memory(
&format!("t{i}-{j}"),
&format!("body {j} alpha gamma rust"),
),
0.5,
)
})
.collect();
let q = format!("alpha {i}");
let out = b.rerank(&q, cands);
assert_eq!(out.len(), 5);
for w in out.windows(2) {
assert!(w[0].1 >= w[1].1);
}
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
}
#[test]
fn issue_1579_b10_auto_select_predicate() {
use super::{BATCHED_RERANK_MIN_CONCURRENCY, use_batched_rerank_path};
assert!(!use_batched_rerank_path(false, 1));
assert!(!use_batched_rerank_path(false, 8));
assert!(!use_batched_rerank_path(false, 1024));
assert!(!use_batched_rerank_path(true, 1));
assert!(use_batched_rerank_path(
true,
BATCHED_RERANK_MIN_CONCURRENCY
));
assert!(use_batched_rerank_path(true, 8));
}
#[test]
fn issue_1579_b10_lexical_rerank_never_reaches_worker() {
use super::BatchedReranker;
use std::sync::Arc;
let batched = Arc::new(BatchedReranker::new(CrossEncoder::new()));
let mut handles = Vec::new();
for i in 0..8 {
let b = Arc::clone(&batched);
handles.push(std::thread::spawn(move || {
let cands: Vec<(Memory, f64)> = (0..5)
.map(|j| {
(
make_memory(&format!("b10-{i}-{j}"), &format!("body {j} alpha gamma")),
0.5,
)
})
.collect();
let out = b.rerank(&format!("alpha {i}"), cands);
assert_eq!(out.len(), 5);
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
assert_eq!(
batched.worker_submissions(),
0,
"lexical rerank must auto-select the direct path (no worker jobs)"
);
}
#[test]
fn issue_1579_b10_forced_coalesced_path_matches_direct() {
use super::BatchedReranker;
let batched = BatchedReranker::new(CrossEncoder::new());
let cands: Vec<(Memory, f64)> = (0..4)
.map(|i| {
(
make_memory(
&format!("b10-forced-{i}"),
&format!("alpha gamma body {i} content words"),
),
f64::from(i) * 0.1,
)
})
.collect();
let direct = batched.rerank("alpha", cands.clone());
let coalesced = batched.rerank_coalesced("alpha", cands);
assert_eq!(
batched.worker_submissions(),
1,
"rerank_coalesced must route through the worker"
);
assert_eq!(coalesced.len(), direct.len());
for (a, b) in coalesced.iter().zip(direct.iter()) {
assert_eq!(a.0.id, b.0.id);
assert!((a.1 - b.1).abs() < 1e-12);
}
}
#[test]
fn pr9i_rerank_via_score_returns_blend() {
let ce = CrossEncoder::new_neural();
let cands = vec![
(
Memory {
id: "a".to_string(),
tier: Tier::Mid,
namespace: "ns".to_string(),
title: "rust async runtime".to_string(),
content: "tokio rust async".to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: "2026-01-01T00:00:00Z".to_string(),
updated_at: "2026-01-01T00:00:00Z".to_string(),
last_accessed_at: None,
expires_at: None,
metadata: serde_json::json!({}),
reflection_depth: 0,
memory_kind: crate::models::MemoryKind::Observation,
entity_id: None,
persona_version: None,
citations: Vec::new(),
source_uri: None,
source_span: None,
confidence_source: crate::models::ConfidenceSource::CallerProvided,
confidence_signals: None,
confidence_decayed_at: None,
version: 1,
},
0.6,
),
(
Memory {
id: "b".to_string(),
tier: Tier::Mid,
namespace: "ns".to_string(),
title: "grocery list".to_string(),
content: "milk eggs".to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: "2026-01-01T00:00:00Z".to_string(),
updated_at: "2026-01-01T00:00:00Z".to_string(),
last_accessed_at: None,
expires_at: None,
metadata: serde_json::json!({}),
reflection_depth: 0,
memory_kind: crate::models::MemoryKind::Observation,
entity_id: None,
persona_version: None,
citations: Vec::new(),
source_uri: None,
source_span: None,
confidence_source: crate::models::ConfidenceSource::CallerProvided,
confidence_signals: None,
confidence_decayed_at: None,
version: 1,
},
0.4,
),
];
let out = ce.rerank("rust async", cands);
assert_eq!(out.len(), 2);
for (_, score) in &out {
assert!(score.is_finite());
}
assert!(out[0].1 >= out[1].1);
}
#[test]
fn issue_1691_n14_score_floor_parse_grammar() {
assert_eq!(
RerankerScoreFloor::parse("off"),
Some(RerankerScoreFloor::Off)
);
assert_eq!(
RerankerScoreFloor::parse(" OFF "),
Some(RerankerScoreFloor::Off)
);
assert_eq!(
RerankerScoreFloor::parse("absolute:0.3"),
Some(RerankerScoreFloor::Absolute(0.3))
);
assert_eq!(
RerankerScoreFloor::parse("ABS: 0.25"),
Some(RerankerScoreFloor::Absolute(0.25))
);
assert_eq!(
RerankerScoreFloor::parse("relative:0.5"),
Some(RerankerScoreFloor::RelativeToTop(0.5))
);
assert_eq!(
RerankerScoreFloor::parse("relative_to_top:0.8"),
Some(RerankerScoreFloor::RelativeToTop(0.8))
);
assert_eq!(RerankerScoreFloor::parse(""), None);
assert_eq!(RerankerScoreFloor::parse("absolute"), None);
assert_eq!(RerankerScoreFloor::parse("absolute:notanumber"), None);
assert_eq!(RerankerScoreFloor::parse("bogus:0.5"), None);
assert_eq!(RerankerScoreFloor::parse("absolute:inf"), None);
}
#[test]
fn reranker_score_floor_default_is_off_1319() {
let floor = RerankerScoreFloor::default();
assert_eq!(floor, RerankerScoreFloor::Off);
let mut scored = vec![
(make_memory("a", "x"), 0.9_f64),
(make_memory("b", "y"), 0.4_f64),
(make_memory("c", "z"), 0.1_f64),
];
let before = scored.clone();
floor.apply(&mut scored);
assert_eq!(scored.len(), before.len());
for (i, (mem, s)) in scored.iter().enumerate() {
assert_eq!(mem.title, before[i].0.title);
assert!((s - before[i].1).abs() < f64::EPSILON);
}
}
#[test]
fn reranker_score_floor_absolute_drops_tail_1319() {
let floor = RerankerScoreFloor::Absolute(0.5);
let mut scored = vec![
(make_memory("top", "x"), 0.90_f64),
(make_memory("mid", "y"), 0.60_f64),
(make_memory("low", "z"), 0.30_f64),
(make_memory("noise", "n"), 0.10_f64),
];
floor.apply(&mut scored);
let titles: Vec<&str> = scored.iter().map(|(m, _)| m.title.as_str()).collect();
assert_eq!(titles, vec!["top", "mid"]);
}
#[test]
fn reranker_score_floor_relative_drops_tail_1319() {
let floor = RerankerScoreFloor::RelativeToTop(0.5);
let mut scored = vec![
(make_memory("top", "x"), 0.80_f64),
(make_memory("kept", "y"), 0.50_f64),
(make_memory("dropped_1", "z"), 0.35_f64),
(make_memory("dropped_2", "z"), 0.20_f64),
];
floor.apply(&mut scored);
let titles: Vec<&str> = scored.iter().map(|(m, _)| m.title.as_str()).collect();
assert_eq!(titles, vec!["top", "kept"]);
}
#[test]
fn reranker_score_floor_preserves_top_row_when_everything_below_1319() {
let floor = RerankerScoreFloor::Absolute(0.5);
let mut scored = vec![
(make_memory("apollo", "moon landing"), 0.20_f64),
(make_memory("recall", "blends fts and semantic"), 0.10_f64),
];
floor.apply(&mut scored);
assert_eq!(scored.len(), 1);
assert_eq!(scored[0].0.title, "apollo");
}
#[test]
fn reranker_score_floor_handles_empty_1319() {
let floor = RerankerScoreFloor::Absolute(0.5);
let mut scored: Vec<(Memory, f64)> = vec![];
floor.apply(&mut scored);
assert!(scored.is_empty());
}
#[test]
fn reranker_v1_p5_paraphrase_noise_dropped_by_floor_1319() {
let ce = CrossEncoder::new(); let apollo = make_memory(
"Apollo 11 moon landing",
"Neil Armstrong walked on the moon in 1969.",
);
let recall_b = make_memory(
"Recall blends FTS and semantic scores",
"The hybrid pipeline weighs cosine vs BM25 then reranks the top-k.",
);
let candidates = vec![(apollo.clone(), 0.479_f64), (recall_b.clone(), 0.363_f64)];
let query = "what makes a recall implementation good?";
let pre = ce.rerank(query, candidates.clone());
assert_eq!(pre[0].0.title, "Apollo 11 moon landing");
assert!(pre[0].1 < 0.30, "top score in noise band: {}", pre[0].1);
let mut post = pre.clone();
RerankerScoreFloor::Absolute(0.40).apply(&mut post);
assert_eq!(
post.len(),
1,
"floor at 0.40 must drop tail when blended scores in noise band: {post:?}"
);
assert_eq!(post[0].0.title, "Apollo 11 moon landing");
}
#[test]
fn batched_reranker_score_floor_plumbed_end_to_end_1319() {
use super::BatchedReranker;
let batched = BatchedReranker::with_score_floor(
CrossEncoder::new(),
RerankerScoreFloor::Absolute(0.40),
);
assert_eq!(batched.score_floor(), RerankerScoreFloor::Absolute(0.40));
let apollo = make_memory("Apollo 11 moon landing", "Armstrong, 1969");
let recall_b = make_memory(
"Recall blends FTS and semantic scores",
"hybrid pipeline weighs cosine vs BM25",
);
let candidates = vec![(apollo, 0.479_f64), (recall_b, 0.363_f64)];
let out = batched.rerank("paraphrase miss query", candidates);
assert_eq!(out.len(), 1, "score floor must drop tail: {out:?}");
}
#[test]
fn batched_reranker_default_constructor_leaves_floor_off_1319() {
use super::BatchedReranker;
let batched = BatchedReranker::new(CrossEncoder::new());
assert_eq!(batched.score_floor(), RerankerScoreFloor::Off);
}
}
#[cfg(test)]
#[allow(
clippy::unused_self,
clippy::unnecessary_wraps,
clippy::needless_pass_by_value,
clippy::wildcard_imports
)]
pub mod test_support {
use super::*;
pub struct MockCrossEncoder {
pub use_neural: bool,
}
impl MockCrossEncoder {
pub fn new() -> Self {
Self { use_neural: false }
}
pub fn new_neural() -> Self {
Self { use_neural: true }
}
pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
if self.use_neural {
let combined = format!("{}{}", query, title);
let hash = combined.bytes().fold(0u32, |acc, b| {
acc.wrapping_mul(31).wrapping_add(u32::from(b))
});
let base = ((hash % 1000) as f32) / 1000.0;
if title.contains(query) {
(base * 0.5 + 0.5).min(1.0)
} else {
base
}
} else {
lexical_score(query, title, content)
}
}
pub fn is_neural(&self) -> bool {
self.use_neural
}
pub fn rerank(
&self,
query: &str,
mut candidates: Vec<(Memory, f64)>,
) -> Vec<(Memory, f64)> {
let mut scored: Vec<(Memory, f64)> = candidates
.drain(..)
.map(|(mem, original_score)| {
let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
let final_score =
ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
(mem, final_score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
}
}
impl Default for MockCrossEncoder {
fn default() -> Self {
Self::new()
}
}
}
#[cfg(test)]
mod mock_tests {
use super::test_support::*;
use super::{BatchedReranker, CrossEncoder};
use crate::models::{Memory, Tier};
use std::time::Duration;
fn make_memory(title: &str, content: &str) -> Memory {
Memory {
id: "test-id".to_string(),
tier: Tier::Mid,
namespace: "test".to_string(),
title: title.to_string(),
content: content.to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: "2026-01-01T00:00:00Z".to_string(),
updated_at: "2026-01-01T00:00:00Z".to_string(),
last_accessed_at: None,
expires_at: None,
metadata: serde_json::json!({}),
reflection_depth: 0,
memory_kind: crate::models::MemoryKind::Observation,
entity_id: None,
persona_version: None,
citations: Vec::new(),
source_uri: None,
source_span: None,
confidence_source: crate::models::ConfidenceSource::CallerProvided,
confidence_signals: None,
confidence_decayed_at: None,
version: 1,
}
}
#[test]
fn mock_lexical_new() {
let ce = MockCrossEncoder::new();
assert!(!ce.is_neural());
}
#[test]
fn mock_neural_new() {
let ce = MockCrossEncoder::new_neural();
assert!(ce.is_neural());
}
#[test]
fn mock_neural_score_deterministic() {
let ce = MockCrossEncoder::new_neural();
let s1 = ce.score("query", "title", "content");
let s2 = ce.score("query", "title", "content");
assert_eq!(s1, s2);
}
#[test]
fn mock_neural_score_title_match_boost() {
let ce = MockCrossEncoder::new_neural();
let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
assert!(
s_title_contains > s_no_match,
"title match ({s_title_contains}) should beat no match ({s_no_match})"
);
}
#[test]
fn mock_neural_score_bounded() {
let ce = MockCrossEncoder::new_neural();
for query in &["test", "neural", "reranker", "machine learning"] {
for title in &["a", "b", "the quick brown"] {
let s = ce.score(query, title, "content");
assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
}
}
}
#[test]
fn mock_neural_rerank_reorders() {
let ce = MockCrossEncoder::new_neural();
let a = make_memory("neural network", "deep learning with transformers");
let b = make_memory("grocery list", "milk eggs bread butter");
let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
let reranked = ce.rerank("neural network", candidates);
assert_eq!(reranked[0].0.title, "neural network");
}
#[test]
fn mock_neural_rerank_preserves_count() {
let ce = MockCrossEncoder::new_neural();
let candidates = vec![
(make_memory("A", "content a"), 0.5),
(make_memory("B", "content b"), 0.4),
(make_memory("C", "content c"), 0.6),
];
let reranked = ce.rerank("test", candidates);
assert_eq!(reranked.len(), 3);
}
#[test]
fn mock_lexical_path_via_mock() {
let ce = MockCrossEncoder::new();
let s = ce.score(
"network adapter",
"Network Configuration",
"the network adapter is connected",
);
assert!((0.0..=1.0).contains(&s));
}
#[test]
fn mock_neural_different_from_lexical() {
let lexical = MockCrossEncoder::new();
let neural = MockCrossEncoder::new_neural();
let s_lex = lexical.score("machine learning", "ML title", "neural networks");
let s_neu = neural.score("machine learning", "ML title", "neural networks");
assert_ne!(s_lex, s_neu);
}
#[test]
fn h2_drop_terminates_worker_within_500ms() {
use std::time::Instant;
let reranker = BatchedReranker::new(CrossEncoder::new());
let mut r = reranker;
let shutdown = r.shutdown.take().expect("shutdown sender present");
let worker = r.worker.take().expect("worker handle present");
r.sender.take();
let start = Instant::now();
let _ = shutdown.send(());
let (done_tx, done_rx) = std::sync::mpsc::channel::<()>();
std::thread::spawn(move || {
let _ = worker.join();
let _ = done_tx.send(());
});
let observed = done_rx
.recv_timeout(Duration::from_millis(500))
.map(|()| Instant::now().duration_since(start));
assert!(
observed.is_ok(),
"BatchedReranker worker did not terminate within 500ms after \
explicit shutdown — observed: {observed:?}"
);
}
}
#[test]
fn score_handles_empty_query_string() {
let s = lexical_score("", "Document Title", "This is document content");
assert_eq!(s, 0.0, "empty query must return 0.0");
}
#[test]
fn score_handles_unicode_normalization() {
let s1 = lexical_score("café", "café", "the café is open");
let s2 = lexical_score("cafe", "cafe", "the cafe is open");
assert!(s1 > 0.0);
assert!(s2 > 0.0);
}
#[test]
fn score_handles_very_long_content_truncation() {
let long_content = "word ".repeat(10000); let s = lexical_score("word", "title", &long_content);
assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
}
#[test]
fn bigram_score_with_single_token_query() {
let s = lexical_score("query", "Single Token Title", "single token content");
assert!((0.0..=1.0).contains(&s));
}
#[cfg(test)]
mod issue_1597_tests {
use super::*;
use crate::models::Memory;
const NO_OVERLAP_QUERY: &str = "zzz qqq www";
fn pool_memory(i: i32) -> Memory {
Memory {
id: format!("cand-{i}"),
title: format!("alpha {i}"),
content: format!("beta gamma {i}"),
..Memory::default()
}
}
fn pool(n: i32) -> Vec<(Memory, f64)> {
(0..n)
.map(|i| (pool_memory(i), f64::from(i + 1) * 0.01))
.collect()
}
fn orig_score(i: i32) -> f64 {
f64::from(i + 1) * 0.01
}
#[test]
fn rerank_pool_cap_honored_1597() {
let ce = CrossEncoder::Lexical { degraded: false };
let n = 50;
let out = ce.rerank(NO_OVERLAP_QUERY, pool(n));
assert_eq!(out.len(), 50, "no candidate may be lost");
let ids: std::collections::HashSet<&str> = out.iter().map(|(m, _)| m.id.as_str()).collect();
assert_eq!(ids.len(), 50, "no duplicate / dropped ids");
for (rank, (mem, score)) in out.iter().take(RERANK_POOL_MAX).enumerate() {
let i = 49 - i32::try_from(rank).expect("rank fits i32");
assert_eq!(mem.id, format!("cand-{i}"), "head rank {rank}");
assert!(
(score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
"head rank {rank} must carry the cross-encoded blend"
);
}
for (off, (mem, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
let i = 29 - i32::try_from(off).expect("offset fits i32");
assert_eq!(mem.id, format!("cand-{i}"), "tail offset {off}");
assert_eq!(
*score,
orig_score(i),
"tail offset {off} must keep its blended score untouched"
);
}
}
#[test]
fn rerank_pool_cap_order_correctness_1597() {
let ce = CrossEncoder::Lexical { degraded: false };
let out = ce.rerank(NO_OVERLAP_QUERY, pool(50));
let head = &out[..RERANK_POOL_MAX];
let tail = &out[RERANK_POOL_MAX..];
assert!(
head.windows(2).all(|w| w[0].1 >= w[1].1),
"reranked head must be sorted descending"
);
assert!(
tail.windows(2).all(|w| w[0].1 >= w[1].1),
"uncapped tail must be sorted descending"
);
let min_head_orig = orig_score(30);
assert!(
tail.iter().all(|(_, s)| *s < min_head_orig),
"tail must hold only candidates the cap excluded"
);
}
#[test]
fn rerank_pool_at_cap_fully_cross_encoded_1597() {
let ce = CrossEncoder::Lexical { degraded: false };
let n = i32::try_from(RERANK_POOL_MAX).expect("cap fits i32");
let out = ce.rerank(NO_OVERLAP_QUERY, pool(n));
assert_eq!(out.len(), RERANK_POOL_MAX);
for (rank, (_, score)) in out.iter().enumerate() {
let i = n - 1 - i32::try_from(rank).expect("rank fits i32");
assert!(
(score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
"at-cap pool: rank {rank} must be cross-encoded"
);
}
}
#[test]
fn rerank_cap_gt_pool_degenerates_to_full_rerank_1597() {
let ce = CrossEncoder::Lexical { degraded: false };
let out = ce.rerank(NO_OVERLAP_QUERY, pool(5));
assert_eq!(out.len(), 5);
for (rank, (_, score)) in out.iter().enumerate() {
let i = 4 - i32::try_from(rank).expect("rank fits i32");
assert!(
(score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
"small pool: rank {rank} must be cross-encoded (no tail)"
);
}
}
#[test]
fn rerank_batch_applies_pool_cap_per_query_1597() {
let ce = CrossEncoder::Lexical { degraded: false };
let jobs = vec![
(NO_OVERLAP_QUERY.to_string(), pool(50)),
(NO_OVERLAP_QUERY.to_string(), pool(50)),
];
let outs = ce.rerank_batch(jobs);
assert_eq!(outs.len(), 2);
for out in &outs {
assert_eq!(out.len(), 50, "per-job candidate count preserved");
for (off, (_, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
let i = 29 - i32::try_from(off).expect("offset fits i32");
assert_eq!(
*score,
orig_score(i),
"per-job tail must keep blended scores untouched"
);
}
}
}
#[test]
fn batched_reranker_inherits_pool_cap_1597() {
let br = BatchedReranker::with_reflection_boost(
CrossEncoder::Lexical { degraded: false },
ReflectionBoostConfig::disabled(),
);
let out = br.rerank(NO_OVERLAP_QUERY, pool(50));
assert_eq!(out.len(), 50);
for (off, (_, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
let i = 29 - i32::try_from(off).expect("offset fits i32");
assert_eq!(*score, orig_score(i), "wrapper tail untouched");
}
}
#[test]
#[ignore = "#1597 manual bench evidence: loads the real neural cross-encoder"]
fn issue_1597_neural_rerank_timing_evidence() {
let ce = CrossEncoder::new_neural();
assert!(
ce.is_neural(),
"neural encoder failed to load; timing evidence invalid"
);
let bench_pool: Vec<(Memory, f64)> = (0..50)
.map(|i| {
let m = Memory {
id: format!("bench-{i}"),
title: format!("benchmark candidate number {i} recall pipeline"),
content: format!(
"long-form benchmark document body number {i} with enough \
material to exercise the cross-encoder, covering recall \
pipeline reranking, cross encoder scoring, candidate \
blending and ordering semantics for run {i}"
),
..Memory::default()
};
(m, f64::from(i) * 0.01)
})
.collect();
let query = "how does the recall pipeline rerank candidates";
let _ = ce.score(query, "warmup", "warmup body");
let t0 = Instant::now();
for (m, _) in &bench_pool {
let _ = ce.score(query, &m.title, &m.content);
}
let before = t0.elapsed();
let t1 = Instant::now();
let out = ce.rerank(query, bench_pool.clone());
let after = t1.elapsed();
assert_eq!(out.len(), 50, "no candidate lost on the neural path");
eprintln!(
"#1597 timing (50-candidate pool, CPU): BEFORE sequential-full = {before:?}; \
AFTER capped+batched = {after:?}"
);
}
}