ai_memory/reranker.rs
1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! Cross-encoder reranking for search results.
5//!
6//! A cross-encoder takes a (query, document) pair and produces a relevance
7//! score. This is more accurate than cosine similarity of independent
8//! embeddings but slower since it must run for each candidate.
9//!
10//! **Two implementations:**
11//! - `CrossEncoder::Lexical` — lightweight term-overlap scorer (default).
12//! - `CrossEncoder::Neural` — BERT-based cross-encoder loaded via candle
13//! from `cross-encoder/ms-marco-MiniLM-L-6-v2` (~80 MB, ONNX-free).
14
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::sync::mpsc::{Sender, sync_channel};
17use std::sync::{Arc, Mutex};
18use std::thread::{self, JoinHandle};
19use std::time::{Duration, Instant};
20
21use anyhow::{Context, Result};
22use candle_core::{Device, Tensor};
23use candle_nn::VarBuilder;
24use candle_transformers::models::bert::{BertModel, Config as BertConfig};
25use hf_hub::{Repo, RepoType, api::sync::Api};
26use tokenizers::Tokenizer;
27
28use crate::models::Memory;
29
30// ---------------------------------------------------------------------------
31// v0.7.0 (issue #518) — session-aware recall recency boost
32// ---------------------------------------------------------------------------
33
34/// Additive boost applied to a recall candidate that appears in the
35/// session's recently-accessed set. Sits at +0.05 — small enough that
36/// a low-relevance candidate cannot leapfrog a substantially-better
37/// match, large enough to break ties in favour of memories the agent
38/// just touched in the same session.
39pub const SESSION_RECENCY_BOOST: f64 = 0.05;
40
41/// Per-session cap on the recently-accessed ring buffer. When the
42/// buffer is at the cap, the oldest entry is evicted (FIFO) before the
43/// newest entry is appended. Keeps the substrate memory cost bounded
44/// at `O(SESSIONS * 50)` ids regardless of recall traffic.
45pub const SESSION_RECENT_CAP: usize = 50;
46
47/// v0.7.0 (issue #518) — process-global tracker mapping `session_id`
48/// to its FIFO ring buffer of recently-accessed memory ids.
49///
50/// The tracker is consulted by [`apply_session_recency_boost`] after
51/// the rerank stage of `handle_recall` (MCP) and `recall_response`
52/// (HTTP). Each call:
53///
54/// 1. Reads the per-session set BEFORE assembling the boost so the
55/// candidates already touched in this session lift in rank.
56/// 2. Appends every recall hit's id INTO the per-session ring (FIFO
57/// eviction past [`SESSION_RECENT_CAP`]) so subsequent recalls in
58/// the same session reuse the new context.
59///
60/// The tracker uses a single `Mutex` because contention is dominated
61/// by the per-recall work itself (FTS + semantic + rerank), making
62/// the lock-acquire/-release cost noise; the implementation can swap
63/// to per-shard locking if a future profile shows otherwise.
64#[derive(Debug, Default)]
65pub struct SessionRecallTracker {
66 inner: Mutex<HashMap<String, VecDeque<String>>>,
67}
68
69impl SessionRecallTracker {
70 /// Construct an empty tracker. Test code uses this directly; the
71 /// production code path goes through the process-global
72 /// [`global_session_recall_tracker`] accessor below.
73 #[must_use]
74 pub fn new() -> Self {
75 Self::default()
76 }
77
78 /// Return the set of recently-accessed memory ids for `session_id`,
79 /// or an empty set if the session is unknown. Used by the rerank
80 /// boost to decide which candidates to lift.
81 ///
82 /// v0.7.0 #1091 — kept for the public API contract (test code +
83 /// callers outside the hot path use it). The boost site
84 /// [`apply_session_recency_boost`] now uses
85 /// [`SessionRecallTracker::with_recent_ids`] to avoid the
86 /// per-recall HashSet allocation.
87 #[must_use]
88 pub fn recent_ids(&self, session_id: &str) -> HashSet<String> {
89 let Ok(guard) = self.inner.lock() else {
90 // Poisoned mutex (a panic happened while the lock was
91 // held by another thread). Surface an empty set so the
92 // recall path stays infallible — the boost just doesn't
93 // fire this call.
94 return HashSet::new();
95 };
96 guard
97 .get(session_id)
98 .map(|ring| ring.iter().cloned().collect())
99 .unwrap_or_default()
100 }
101
102 /// v0.7.0 #1091 — allocation-free per-id membership lookup against
103 /// the per-session ring. Used by [`apply_session_recency_boost`]
104 /// to apply the +0.05 boost without cloning the 50-deep ring into
105 /// a fresh `HashSet<String>` on every recall.
106 ///
107 /// The callback is invoked once with a membership predicate that
108 /// owns the inner mutex guard for its lifetime. Returns the
109 /// closure's result (typically a `Vec<(Memory, f64)>` of boosted
110 /// candidates). The membership predicate is O(N) per id over the
111 /// ring (capped at [`SESSION_RECENT_CAP`] = 50); the closure is
112 /// expected to call it K times for a K-result recall, giving
113 /// O(K*N) total — same complexity as the pre-#1091 path that
114 /// also did a HashSet build (O(N) construct) + K lookups
115 /// (O(1) each = O(K)).
116 pub fn with_recent_ids<R>(
117 &self,
118 session_id: &str,
119 f: impl FnOnce(&dyn Fn(&str) -> bool) -> R,
120 ) -> R {
121 let Ok(guard) = self.inner.lock() else {
122 // Poisoned mutex: every id misses the boost. Same
123 // posture as the empty-set fallback above.
124 return f(&|_id: &str| false);
125 };
126 match guard.get(session_id) {
127 None => f(&|_id: &str| false),
128 Some(ring) => f(&|id: &str| ring.iter().any(|existing| existing == id)),
129 }
130 }
131
132 /// Record the ids of memories returned by the just-completed
133 /// recall into the per-session ring. FIFO eviction past
134 /// [`SESSION_RECENT_CAP`] keeps the per-session set bounded.
135 ///
136 /// Duplicate ids (a memory recalled twice in the same session)
137 /// move to the front of the ring so the eviction rule keeps the
138 /// most-recently-touched ids in the set.
139 pub fn record(&self, session_id: &str, ids: impl IntoIterator<Item = String>) {
140 let Ok(mut guard) = self.inner.lock() else {
141 return;
142 };
143 let ring = guard.entry(session_id.to_string()).or_default();
144 for id in ids {
145 // De-dupe by removing any existing occurrence so the
146 // newest landing position wins.
147 ring.retain(|existing| existing != &id);
148 ring.push_back(id);
149 while ring.len() > SESSION_RECENT_CAP {
150 ring.pop_front();
151 }
152 }
153 }
154
155 /// Diagnostic: number of tracked sessions. Used by tests and the
156 /// `/metrics` surface (future).
157 #[must_use]
158 pub fn session_count(&self) -> usize {
159 self.inner.lock().map(|g| g.len()).unwrap_or(0)
160 }
161}
162
163/// Process-global [`SessionRecallTracker`] used by every recall hot
164/// path. Lazily initialised on first access; never reset within a
165/// process lifetime (per-process state by design — operator restart
166/// clears every session's recent set).
167///
168/// v0.7.x (issue #1174 follow-up #1196) — the tracker lives on
169/// [`crate::runtime_context::RuntimeContext::recall_tracker`]. The
170/// returned `&'static` reference is stable because
171/// `RuntimeContext::global()` itself is a `OnceLock`-backed
172/// process-wide singleton; the `Arc<SessionRecallTracker>` inside it
173/// is allocated once and outlives every caller.
174#[must_use]
175pub fn global_session_recall_tracker() -> &'static SessionRecallTracker {
176 &crate::runtime_context::RuntimeContext::global().recall_tracker
177}
178
179/// v0.7.0 (issue #518) — apply the per-session recently-accessed boost
180/// to a scored recall result vector AND record the post-boost hit set
181/// back into the session's ring buffer.
182///
183/// `session_id` is the caller-supplied per-session identifier. When
184/// `None` or empty, the function is a no-op (returns the input
185/// unchanged). When set:
186///
187/// 1. Every candidate whose id is in the tracker's per-session set
188/// gets `SESSION_RECENCY_BOOST` ADDED to its score.
189/// 2. The vector is re-sorted descending by the boosted score.
190/// 3. The post-boost id list is appended into the session ring (FIFO
191/// eviction past [`SESSION_RECENT_CAP`]).
192///
193/// The boost is *additive* (not multiplicative) so its effect is
194/// independent of the absolute score magnitude — the +0.05 always
195/// breaks ties at the same delta regardless of whether scores are on
196/// the 0..1 cosine band or the 0..2 blended hybrid band.
197pub fn apply_session_recency_boost(
198 results: Vec<(Memory, f64)>,
199 session_id: Option<&str>,
200 tracker: &SessionRecallTracker,
201) -> Vec<(Memory, f64)> {
202 let Some(sid) = session_id else {
203 return results;
204 };
205 if sid.is_empty() {
206 return results;
207 }
208 // v0.7.0 #1091 — drop the per-recall `HashSet<String>` allocation
209 // (50 clones at the cap) by using the membership-callback variant.
210 // The closure owns the inner mutex for the boost-apply pass; the
211 // membership predicate runs O(N) per id against the (≤ 50 entry)
212 // ring, giving the same overall complexity as the pre-#1091
213 // (HashSet-build + lookup) path without the allocation.
214 let mut boosted: Vec<(Memory, f64)> = tracker.with_recent_ids(sid, |is_recent| {
215 results
216 .into_iter()
217 .map(|(mem, score)| {
218 let bumped = if is_recent(&mem.id) {
219 score + SESSION_RECENCY_BOOST
220 } else {
221 score
222 };
223 (mem, bumped)
224 })
225 .collect()
226 });
227 // Re-sort descending — boosted candidates may move past their
228 // pre-boost neighbours.
229 boosted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
230 // v0.7.0 #1091 — record the post-boost id list into the session
231 // ring via an iterator clone-on-demand path so we don't allocate
232 // a Vec<String> just to hand to `record` (which itself iterates).
233 tracker.record(sid, boosted.iter().map(|(m, _)| m.id.clone()));
234 boosted
235}
236
237/// Blend weight applied to the original (embedding/FTS) score.
238const ORIGINAL_WEIGHT: f64 = 0.6;
239/// Blend weight applied to the cross-encoder score.
240const CROSS_ENCODER_WEIGHT: f64 = 0.4;
241
242/// #1531 M13 — clamp a non-finite blended score to the ranking floor.
243///
244/// The rerank sort uses `partial_cmp(..).unwrap_or(Equal)`; a NaN
245/// final score (poisoned caller `original_score`, or a corrupt model
246/// weights file producing a NaN logit) compares `Equal` to EVERYTHING,
247/// so the stable sort left the NaN-scored candidate wherever it sat in
248/// the input — a corrupt candidate could nondeterministically hold the
249/// top rank. Mapping non-finite scores to `f64::MIN` deterministically
250/// sinks them to the bottom of the ranking instead. Finite scores pass
251/// through untouched, so ordinary ranking is byte-identical.
252fn finite_or_floor(score: f64) -> f64 {
253 if score.is_finite() { score } else { f64::MIN }
254}
255
256/// #1597 — split a candidate pool into `(head, tail)` at
257/// [`RERANK_POOL_MAX`].
258///
259/// Pools at or under the cap come back whole (`tail` empty, input order
260/// preserved — the degenerate full-rerank case). Larger pools are
261/// sorted by the incoming blended score descending (total order via
262/// [`f64::total_cmp`] so a NaN-poisoned score cannot destabilise the
263/// sort; NaN sorts into the head, is cross-encoded, and then sinks via
264/// [`finite_or_floor`] exactly as pre-#1597) and split after the cap,
265/// so both halves come back internally sorted descending.
266fn split_rerank_pool(
267 mut candidates: Vec<(Memory, f64)>,
268) -> (Vec<(Memory, f64)>, Vec<(Memory, f64)>) {
269 let tail = if candidates.len() > RERANK_POOL_MAX {
270 candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
271 candidates.split_off(RERANK_POOL_MAX)
272 } else {
273 Vec::new()
274 };
275 (candidates, tail)
276}
277
278/// #1597 — hard cap on how many candidates receive a cross-encoder
279/// score per rerank call.
280///
281/// The Phase-3 dogfood run measured autonomous-tier recall at
282/// 2823-7737 ms/call on CPU vs 14-32 ms at the semantic tier: the
283/// pre-#1597 [`CrossEncoder::rerank`] ran one full BERT forward pass
284/// per (query, candidate) pair, sequentially, over the entire
285/// post-blend candidate pool (up to 50 rows from the recall SQL cap).
286/// Only the strongest `RERANK_POOL_MAX` candidates by incoming blended
287/// score are cross-encoded (in ONE batched forward pass); the
288/// remainder keep their blended scores and sort below the reranked
289/// head. 20 keeps the cross-encoder's precision win where it matters
290/// (the head the caller actually reads) while bounding the worst-case
291/// forward-pass cost at ~40% of the pre-fix pool.
292pub const RERANK_POOL_MAX: usize = 20;
293
294const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
295/// Bare configured-model spelling for the default reranker — shared with
296/// the `ai-memory config migrate` template (#1558 batch 6).
297pub(crate) const DEFAULT_RERANKER_MODEL: &str = "ms-marco-MiniLM-L-6-v2";
298/// Model-architecture ceiling on the cross-encoder input sequence.
299/// Per-consumer truncation (e.g. the #1604 rerank cap below) may go
300/// tighter, never looser — the resolver clamps against this value.
301pub const CROSS_ENCODER_MAX_SEQ: usize = 512;
302const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
303
304/// #1604 — compiled default for the tokenized length of **rerank**
305/// inputs, applied in [`CrossEncoder::neural_score_pairs`] (the #1597
306/// batched-forward path) instead of the architecture-ceiling
307/// [`CROSS_ENCODER_MAX_SEQ`].
308///
309/// The #1588 dogfood RE-RUN measured the residual #1597 latency:
310/// warm autonomous-tier recall was ~4,013 ms on a real (long-content)
311/// corpus vs ~533 ms on short-content rows — the [batch=20, seq=512]
312/// candle CPU forward, not pool size or batching, was the cost. BERT
313/// attention is O(n²) in sequence length, so halving the cap to 256
314/// cuts the forward ~4× while keeping the title + lead content that
315/// carries the relevance signal for memory rows. Other cross-encoder
316/// consumers (the single-pair [`CrossEncoder::score`]) keep the full
317/// [`CROSS_ENCODER_MAX_SEQ`].
318///
319/// Operator override ladder (resolved by
320/// `AppConfig::resolve_reranker()` at boot and seeded here via
321/// [`set_rerank_max_seq`]): `AI_MEMORY_RERANK_MAX_SEQ` env >
322/// `[reranker].max_seq_tokens` config > this compiled default. Values
323/// that are zero, unparseable, or above [`CROSS_ENCODER_MAX_SEQ`]
324/// fall through to the next ladder layer.
325pub const RERANK_MAX_SEQ_DEFAULT: usize = 256;
326
327/// Process-wide resolved rerank sequence cap, seeded once at boot from
328/// `AppConfig::resolve_reranker()` (the `crate::storage::set_db_mmap_size`
329/// OnceLock precedent — the scoring paths run deep in the recall
330/// pipeline where no `AppConfig` is in scope). Unseeded processes
331/// (unit tests, library embedders that bypass the CLI boot path) fall
332/// through to [`RERANK_MAX_SEQ_DEFAULT`].
333static RERANK_MAX_SEQ: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
334
335/// Seed the process-wide rerank sequence cap for every subsequent
336/// batched rerank forward. Idempotent — first writer wins; later calls
337/// are no-ops (matches `crate::storage::set_db_mmap_size`).
338pub fn set_rerank_max_seq(tokens: usize) {
339 let _ = RERANK_MAX_SEQ.set(tokens);
340}
341
342/// The effective rerank sequence cap for this process.
343fn rerank_max_seq() -> usize {
344 *RERANK_MAX_SEQ.get().unwrap_or(&RERANK_MAX_SEQ_DEFAULT)
345}
346
347/// v0.7.0 L2-8 — default multiplicative boost applied to `Reflection`-kind
348/// memories AFTER cross-encoder reranking. Reflections summarise multiple
349/// observations, so abstraction-shaped queries ("what patterns...",
350/// "what are recurring themes...") should preferentially surface them.
351/// Default value `1.2` sits in the band where a reflection with a base
352/// score equal to its source observations consistently lifts into the
353/// top-5 without dragging mediocre reflections above well-matched
354/// observations.
355pub const DEFAULT_REFLECTION_BOOST: f32 = 1.2;
356
357/// v0.7.0 L2-8 — default per-depth additional multiplier increment.
358/// `per_depth_factor = 1.0 + per_depth_increment * reflection_depth`.
359/// Deeper reflections (reflections-on-reflections) compress more
360/// observations, so a small per-depth bump is justified.
361pub const DEFAULT_REFLECTION_PER_DEPTH_INCREMENT: f32 = 0.05;
362
363/// v0.7.0 L2-8 — default depth cap mirrored from
364/// [`GovernancePolicy::effective_max_reflection_depth`]. Past this depth
365/// the per-depth multiplier stops growing; reflections deeper than the
366/// cap still receive the cap-evaluated boost (operator policy may refuse
367/// the write entirely, but the reranker side never produces an unbounded
368/// multiplier).
369pub const DEFAULT_REFLECTION_MAX_DEPTH_CAP: u32 = 3;
370
371/// v0.7.0 L2-8 — configuration for the reflection-aware reranker boost.
372///
373/// The boost is applied AFTER the cross-encoder blend (i.e. it does NOT
374/// participate in the `0.6 * original + 0.4 * cross_encoder` scoring
375/// formula). Boost shape:
376///
377/// ```text
378/// per_depth_factor = 1.0 + per_depth_increment * min(reflection_depth, max_depth_cap)
379/// final_score = base_score * (kind == Reflection ? boost * per_depth_factor : 1.0)
380/// ```
381///
382/// Default factor = `1.2` (see [`DEFAULT_REFLECTION_BOOST`]). Setting
383/// `boost = 1.0` makes the reranker reproduce its pre-L2-8 behavior
384/// exactly — a deliberate kill-switch for the recall regression suite.
385#[derive(Debug, Clone, Copy, PartialEq)]
386pub struct ReflectionBoostConfig {
387 /// Multiplicative boost applied to `Reflection`-kind memories.
388 /// Default `1.2`. `1.0` disables the boost.
389 pub boost: f32,
390 /// Per-depth additional multiplier increment. Default `0.05`.
391 pub per_depth_increment: f32,
392 /// Depth cap for the per-depth multiplier. Default `3` (mirrors
393 /// the compiled-in default of
394 /// `GovernancePolicy::effective_max_reflection_depth`). Larger
395 /// `reflection_depth` values are clamped to this cap so the
396 /// reranker never produces an unbounded multiplier.
397 pub max_depth_cap: u32,
398}
399
400impl Default for ReflectionBoostConfig {
401 fn default() -> Self {
402 Self {
403 boost: DEFAULT_REFLECTION_BOOST,
404 per_depth_increment: DEFAULT_REFLECTION_PER_DEPTH_INCREMENT,
405 max_depth_cap: DEFAULT_REFLECTION_MAX_DEPTH_CAP,
406 }
407 }
408}
409
410impl ReflectionBoostConfig {
411 /// Pin to pre-L2-8 behavior: `boost = 1.0` ⇒ multiplier is always
412 /// `1.0` regardless of memory kind or depth. Used by the regression
413 /// test that proves the new pathway is a *pure addition* over the RC
414 /// behavior.
415 #[must_use]
416 pub const fn disabled() -> Self {
417 Self {
418 boost: 1.0,
419 per_depth_increment: 0.0,
420 max_depth_cap: 0,
421 }
422 }
423
424 /// Compute the multiplicative factor for a given memory. Returns
425 /// `1.0` for non-reflections; `boost * per_depth_factor` for
426 /// reflections (with `reflection_depth` clamped to `max_depth_cap`).
427 ///
428 /// Pulled out so the same arithmetic is shared by both the per-query
429 /// `rerank` and the G9 batched `rerank_batch` codepaths — there is
430 /// exactly one place to audit the multiplier shape.
431 #[must_use]
432 pub fn factor_for(&self, mem: &Memory) -> f64 {
433 if !matches!(mem.memory_kind, crate::models::MemoryKind::Reflection) {
434 return 1.0;
435 }
436 // `reflection_depth` is stored as i32 (SQL signed) but the
437 // governance accessor returns u32; the column DEFAULT is 0 and
438 // negative values would already have been rejected by the
439 // `memory_reflect` write path. Clamp to non-negative defensively
440 // so a bad write upstream can't produce a negative multiplier.
441 let depth = u32::try_from(mem.reflection_depth.max(0)).unwrap_or(0);
442 let depth_clamped = depth.min(self.max_depth_cap);
443 let per_depth_factor =
444 f64::from(self.per_depth_increment).mul_add(f64::from(depth_clamped), 1.0);
445 f64::from(self.boost) * per_depth_factor
446 }
447}
448
449/// Cross-encoder for (query, document) relevance scoring.
450pub enum CrossEncoder {
451 /// Lightweight lexical cross-encoder using term overlap signals.
452 ///
453 /// `degraded` is `true` when this variant exists because a
454 /// configured neural cross-encoder failed to initialise (HF Hub
455 /// unreachable, model checksum mismatch, etc.) and the runtime
456 /// fell back. `false` is the originally-configured lexical tier
457 /// (operator opted in to keyword-tier or smart-tier without
458 /// cross-encoder reranking).
459 ///
460 /// v0.7.0 R3-S2 — the distinction surfaces in the recall
461 /// response's `meta.reranker_used` field as
462 /// `"degraded_lexical"` vs `"lexical"`, so an in-band signal
463 /// tells clients (MCP + HTTP) when their reranker downgraded.
464 /// The original G8 fix landed `tracing::warn!` only; G8 closure
465 /// per the playbook required an in-response field, which the
466 /// prior implementation overstated.
467 Lexical { degraded: bool },
468 /// Neural BERT-based cross-encoder (ms-marco-MiniLM-L-6-v2).
469 ///
470 /// v0.7.0 #1084 — `model` is `Arc<BertModel>` (no mutex), same
471 /// pattern as `Embedder::Local`. The pre-#1084 design held an
472 /// `Arc<Mutex<BertModel>>` and locked across the full neural
473 /// rerank forward pass, serialising every rerank-tier recall on
474 /// a single global mutex. Candle's `BertModel::forward` takes
475 /// `&self` (inference-only; weights are read-only) so the
476 /// mutex was unnecessary.
477 Neural {
478 model: Arc<BertModel>,
479 tokenizer: Arc<Tokenizer>,
480 classifier_weight: Tensor,
481 classifier_bias: Tensor,
482 device: Device,
483 },
484}
485
486impl CrossEncoder {
487 /// Create a new lexical cross-encoder (no model download required).
488 ///
489 /// This is the "originally lexical" path — the operator either
490 /// chose keyword-/semantic-tier (no cross-encoder reranking) or
491 /// explicitly opted into the lexical variant. Use
492 /// [`Self::new_neural`] to attempt the neural path with
493 /// fall-back-to-lexical semantics.
494 pub fn new() -> Self {
495 Self::Lexical { degraded: false }
496 }
497
498 /// Create a neural cross-encoder by downloading ms-marco-MiniLM-L-6-v2.
499 ///
500 /// Falls back to lexical if download or loading fails. The
501 /// fallback is marked `degraded: true` so the recall response
502 /// surfaces `reranker_used = "degraded_lexical"` per R3-S2 — an
503 /// in-band signal that v0.7.0 promises but pre-R3 only emitted
504 /// as a `tracing::warn!` (a tracing-event-only fallback is not
505 /// the same as a per-response field operators can branch on).
506 ///
507 /// v0.6.3.1 (P3, G8): when the neural path fails (e.g. HF Hub
508 /// unreachable, model checksum mismatch), emit a structured tracing
509 /// event `reranker.fallback` so operators see the silent
510 /// neural→lexical degrade. The eprintln remains for backward-compat
511 /// startup logs.
512 pub fn new_neural() -> Self {
513 match Self::load_neural() {
514 Ok(ce) => ce,
515 Err(e) => {
516 tracing::warn!(
517 target: "reranker.fallback",
518 from = "neural",
519 to = "lexical",
520 reason = %e,
521 "cross-encoder fell back to lexical: neural init failed"
522 );
523 eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
524 Self::Lexical { degraded: true }
525 }
526 }
527 }
528
529 fn load_neural() -> Result<Self> {
530 let device = Device::Cpu;
531
532 let api = Api::new().context("failed to init HuggingFace Hub API")?;
533 let repo = api.repo(Repo::new(
534 CROSS_ENCODER_MODEL_ID.to_string(),
535 RepoType::Model,
536 ));
537
538 let config_path = repo
539 .get(crate::embeddings::HF_CONFIG_FILE)
540 .context("failed to download config.json")?;
541 let tokenizer_path = repo
542 .get(crate::embeddings::HF_TOKENIZER_FILE)
543 .context("failed to download tokenizer.json")?;
544 let weights_path = repo
545 .get(crate::embeddings::HF_WEIGHTS_FILE)
546 .context("failed to download model.safetensors")?;
547
548 // Load BERT config
549 let config_data = std::fs::read_to_string(&config_path)
550 .context("failed to read cross-encoder config.json")?;
551 let config: BertConfig = serde_json::from_str(&config_data)
552 .context("failed to parse cross-encoder config.json")?;
553
554 // Load tokenizer
555 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
556 .map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
557 let truncation = tokenizers::TruncationParams {
558 max_length: CROSS_ENCODER_MAX_SEQ,
559 ..Default::default()
560 };
561 tokenizer
562 .with_truncation(Some(truncation))
563 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
564 tokenizer.with_padding(None);
565
566 // Load model weights.
567 //
568 // SAFETY (#1456): `from_mmaped_safetensors` memory-maps the
569 // weights file. The mmap is unsound only if the backing file is
570 // mutated or truncated by another process while it is mapped.
571 // `weights_path` resolves to a trusted, immutable safetensors
572 // artifact in the daemon-owned HuggingFace cache (downloaded and
573 // not subsequently written by us); it is never a caller-supplied
574 // path at request time. The mapping lives only for the duration
575 // of weight loading below.
576 let vb = unsafe {
577 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
578 .context("failed to load cross-encoder weights")?
579 };
580
581 let model = BertModel::load(vb.clone(), &config)
582 .context("failed to build cross-encoder BertModel")?;
583
584 // Load the classification head: classifier.weight [1, hidden_dim] and classifier.bias [1]
585 let classifier_weight = vb
586 .get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
587 .context("failed to load classifier.weight")?;
588 let classifier_bias = vb
589 .get(1, "classifier.bias")
590 .context("failed to load classifier.bias")?;
591
592 Ok(Self::Neural {
593 model: Arc::new(model),
594 tokenizer: Arc::new(tokenizer),
595 classifier_weight,
596 classifier_bias,
597 device,
598 })
599 }
600
601 /// Score a single (query, document) pair.
602 ///
603 /// Returns a relevance score in `0.0..=1.0`.
604 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
605 match self {
606 Self::Lexical { .. } => lexical_score(query, title, content),
607 Self::Neural {
608 model,
609 tokenizer,
610 classifier_weight,
611 classifier_bias,
612 device,
613 } => {
614 // v0.7.0 #1084 — no mutex acquisition: `Arc<BertModel>`
615 // shared across threads; `BertModel::forward(&self, ...)`
616 // is inference-only and safe to call concurrently.
617 match Self::neural_score(
618 model,
619 tokenizer,
620 classifier_weight,
621 classifier_bias,
622 device,
623 query,
624 title,
625 content,
626 ) {
627 Ok(s) => s,
628 Err(e) => {
629 tracing::warn!(
630 "neural cross-encoder score failed: {e}, using lexical fallback"
631 );
632 lexical_score(query, title, content)
633 }
634 }
635 }
636 }
637 }
638
639 #[allow(clippy::too_many_arguments)]
640 fn neural_score(
641 model: &BertModel,
642 tokenizer: &Tokenizer,
643 classifier_weight: &Tensor,
644 classifier_bias: &Tensor,
645 device: &Device,
646 query: &str,
647 title: &str,
648 content: &str,
649 ) -> Result<f32> {
650 // Cross-encoder input: "[CLS] query [SEP] title content [SEP]"
651 let document = crate::embeddings::embedding_document(title, content);
652
653 let encoding = tokenizer
654 .encode((query, document.as_str()), true)
655 .map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
656
657 let input_ids = encoding.get_ids();
658 let attention_mask = encoding.get_attention_mask();
659 let token_type_ids = encoding.get_type_ids();
660 let seq_len = input_ids.len();
661
662 let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
663 let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
664 let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
665
666 // Forward pass through BERT → [1, seq_len, 384]
667 let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
668
669 // Take [CLS] token (first token) → [1, 384]
670 let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
671
672 // Classification head: logit = cls @ weight^T + bias → [1, 1]
673 let logit = cls
674 .matmul(&classifier_weight.t()?)?
675 .broadcast_add(classifier_bias)?;
676
677 // Extract scalar logit and apply sigmoid to get [0, 1] score
678 let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
679 let score = 1.0 / (1.0 + (-logit_val).exp());
680
681 Ok(score)
682 }
683
684 /// Whether this is a neural cross-encoder.
685 pub fn is_neural(&self) -> bool {
686 matches!(self, Self::Neural { .. })
687 }
688
689 /// v0.7.0 R3-S2 — whether this cross-encoder is a *degraded*
690 /// lexical fallback (i.e., a neural variant was attempted at
691 /// startup or mid-flight and the runtime fell back). `false` for
692 /// `Neural` and for the originally-configured `Lexical` (operator
693 /// opted into keyword-/semantic-tier without cross-encoder
694 /// reranking). The recall response surfaces this distinction as
695 /// `meta.reranker_used = "degraded_lexical"` so clients can
696 /// detect the silent downgrade in-band — closing the G8 closure
697 /// claim that tracing-event-only signalling had overstated.
698 #[must_use]
699 pub fn is_degraded_lexical(&self) -> bool {
700 matches!(self, Self::Lexical { degraded: true })
701 }
702
703 /// Rerank a set of candidates by blending their original scores with
704 /// cross-encoder scores.
705 ///
706 /// **Blend formula:** `final = 0.6 * original + 0.4 * cross_encoder`
707 ///
708 /// **#1597 pool cap:** only the strongest [`RERANK_POOL_MAX`]
709 /// candidates by incoming blended score are cross-encoded; the
710 /// remainder keep their blended scores and rank below the reranked
711 /// head (head sorted by `final_score` descending, tail sorted by
712 /// blended score descending — no candidate is dropped). A pool at
713 /// or under the cap is fully reranked and returned sorted by
714 /// `final_score` descending, as before.
715 ///
716 /// **v0.7.0 L2-8 contract:** the bare `rerank` is the *pre-L2-8*
717 /// behavior — no reflection boost is applied. Daemons that want
718 /// the reflection-aware boost must call
719 /// [`Self::rerank_with_reflection_boost`] (which is what
720 /// [`BatchedReranker`] does by default with
721 /// [`ReflectionBoostConfig::default`]). Keeping the bare method
722 /// boost-free is a deliberate regression-pin discipline: the L2-8
723 /// recall test for `boost = 1.0` uses
724 /// `rerank_with_reflection_boost(.., &ReflectionBoostConfig::disabled())`
725 /// and asserts byte-identical output to `rerank(..)`.
726 pub fn rerank(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
727 // #1597 — delegate so the pool cap + batched forward pass live in
728 // exactly one place. `ReflectionBoostConfig::disabled()` yields a
729 // multiplier of exactly 1.0 for every candidate, so the output is
730 // byte-identical to the historical boost-free blend (the L2-8
731 // regression pin below asserts this equivalence directly).
732 self.rerank_with_reflection_boost(query, candidates, &ReflectionBoostConfig::disabled())
733 }
734
735 /// v0.7.0 L2-8 — rerank with a post-step reflection-aware boost.
736 ///
737 /// 1. Same blend as [`Self::rerank`] (`0.6 * original + 0.4 * ce`).
738 /// 2. **After** the blend, multiply each candidate's `final_score`
739 /// by [`ReflectionBoostConfig::factor_for`]. Observations get a
740 /// multiplier of `1.0` (unchanged); reflections get
741 /// `boost * (1.0 + per_depth_increment * clamp(depth, 0..=cap))`.
742 /// 3. Sort descending after the boost so the output ordering
743 /// reflects the post-boost ranking.
744 ///
745 /// Operationally this means: a reflection that the cross-encoder
746 /// scored at parity with its source observations *moves up*; the
747 /// movement is bounded (capped per-depth multiplier, single global
748 /// `boost` factor) so a mediocre reflection cannot leapfrog a
749 /// well-matched observation — the boost is a thumb-on-the-scale,
750 /// not a free pass.
751 /// **#1597 pool cap + batched forward pass.** Only the strongest
752 /// [`RERANK_POOL_MAX`] candidates by incoming blended score receive a
753 /// cross-encoder score (in one batched forward pass on the Neural
754 /// variant); the remainder keep their blended scores, internally
755 /// sorted descending, appended after the reranked head. No candidate
756 /// is ever dropped. A pool at or under the cap degenerates to the
757 /// historical full rerank.
758 pub fn rerank_with_reflection_boost(
759 &self,
760 query: &str,
761 candidates: Vec<(Memory, f64)>,
762 boost_config: &ReflectionBoostConfig,
763 ) -> Vec<(Memory, f64)> {
764 let (head, tail) = split_rerank_pool(candidates);
765
766 let ce_scores = self.pair_scores(query, &head);
767 let mut scored: Vec<(Memory, f64)> = head
768 .into_iter()
769 .zip(ce_scores)
770 .map(|((mem, original_score), ce_score)| {
771 let blended =
772 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * f64::from(ce_score);
773 let factor = boost_config.factor_for(&mem);
774 (mem, finite_or_floor(blended * factor))
775 })
776 .collect();
777
778 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779 // #1597 — uncapped remainder: blended scores untouched, already
780 // sorted descending by `split_rerank_pool`, ranked below the
781 // cross-encoded head.
782 scored.extend(tail);
783 scored
784 }
785
786 /// #1597 — cross-encoder scores for an (already capped) candidate
787 /// slice, one score per candidate in input order.
788 ///
789 /// Neural variant: ONE batched tokenize + forward pass via
790 /// [`Self::neural_score_pairs`] (the same machinery the G9
791 /// [`Self::rerank_batch`] path uses) instead of a sequential
792 /// per-pair forward — the second half of the #1597 fix. Falls back
793 /// to per-pair lexical scoring if the batched forward fails.
794 fn pair_scores(&self, query: &str, candidates: &[(Memory, f64)]) -> Vec<f32> {
795 let lexical_fallback = |candidates: &[(Memory, f64)]| -> Vec<f32> {
796 candidates
797 .iter()
798 .map(|(mem, _)| lexical_score(query, &mem.title, &mem.content))
799 .collect()
800 };
801 match self {
802 Self::Lexical { .. } => lexical_fallback(candidates),
803 Self::Neural {
804 model,
805 tokenizer,
806 classifier_weight,
807 classifier_bias,
808 device,
809 } => {
810 let pairs: Vec<(&str, String)> = candidates
811 .iter()
812 .map(|(mem, _)| {
813 (
814 query,
815 crate::embeddings::embedding_document(&mem.title, &mem.content),
816 )
817 })
818 .collect();
819 match Self::neural_score_pairs(
820 model,
821 tokenizer,
822 classifier_weight,
823 classifier_bias,
824 device,
825 pairs,
826 ) {
827 Ok(scores) => scores,
828 Err(e) => {
829 tracing::warn!(
830 "neural cross-encoder batch score failed: {e}, using lexical fallback"
831 );
832 lexical_fallback(candidates)
833 }
834 }
835 }
836 }
837 }
838
839 /// v0.7 G9 — batched rerank for concurrent recall.
840 ///
841 /// Process all `(query, candidates)` jobs in a single tokenize + single
842 /// forward pass on the Neural variant, holding the BERT mutex once for
843 /// the whole batch instead of once per (query, candidate) pair.
844 ///
845 /// **Throughput target**: ~3× for parallel recall vs. per-query
846 /// `rerank()` calls.
847 ///
848 /// Output ordering: `result[i]` corresponds to `queries[i]`. Each
849 /// inner vector is sorted by descending blended score, identical to
850 /// `rerank()`. Lexical variant delegates per-query (no batching win
851 /// since lexical scoring is already CPU-trivial).
852 pub fn rerank_batch(
853 &self,
854 queries: Vec<(String, Vec<(Memory, f64)>)>,
855 ) -> Vec<Vec<(Memory, f64)>> {
856 // Boost-free legacy entry point — preserves the pre-L2-8 wire
857 // shape for callers that haven't migrated to the boost-aware
858 // variant. See `rerank_batch_with_reflection_boost` for the
859 // L2-8 path; here we delegate to it with the `disabled()`
860 // config so the implementation lives in one place.
861 self.rerank_batch_with_reflection_boost(queries, &ReflectionBoostConfig::disabled())
862 }
863
864 /// v0.7.0 L2-8 — batched rerank with a post-step reflection-aware
865 /// boost applied per candidate. Same boost arithmetic as
866 /// [`Self::rerank_with_reflection_boost`], factored so the boost
867 /// shape lives in a single helper.
868 pub fn rerank_batch_with_reflection_boost(
869 &self,
870 queries: Vec<(String, Vec<(Memory, f64)>)>,
871 boost_config: &ReflectionBoostConfig,
872 ) -> Vec<Vec<(Memory, f64)>> {
873 // Single-query short-circuit: avoid any batching overhead.
874 if queries.len() == 1 {
875 let mut iter = queries.into_iter();
876 let (q, cands) = iter.next().expect("len == 1");
877 return vec![self.rerank_with_reflection_boost(&q, cands, boost_config)];
878 }
879
880 match self {
881 Self::Lexical { .. } => queries
882 .into_iter()
883 .map(|(q, cands)| self.rerank_with_reflection_boost(&q, cands, boost_config))
884 .collect(),
885 Self::Neural {
886 model,
887 tokenizer,
888 classifier_weight,
889 classifier_bias,
890 device,
891 } => {
892 // #1597 — apply the per-query pool cap BEFORE the batched
893 // forward pass so a coalesced flush pays for at most
894 // `RERANK_POOL_MAX` forwards per job; each tail is
895 // reattached below its reranked head afterwards.
896 let mut tails: Vec<Vec<(Memory, f64)>> = Vec::with_capacity(queries.len());
897 let queries: Vec<(String, Vec<(Memory, f64)>)> = queries
898 .into_iter()
899 .map(|(q, cands)| {
900 let (head, tail) = split_rerank_pool(cands);
901 tails.push(tail);
902 (q, head)
903 })
904 .collect();
905 // v0.7.0 #1084 — no mutex acquisition: `Arc<BertModel>`
906 // shared across threads; `BertModel::forward(&self, ...)`
907 // is inference-only and safe to call concurrently. The
908 // pre-#1084 poisoned-lock fallback is now unreachable
909 // (no lock to poison); a runtime error in
910 // `neural_rerank_batch` still falls through to the
911 // lexical degrade via the `Err(_)` arm below.
912 match Self::neural_rerank_batch(
913 model,
914 tokenizer,
915 classifier_weight,
916 classifier_bias,
917 device,
918 &queries,
919 ) {
920 Ok(scores) => {
921 // scores is a flat Vec<f32>, one per (query_idx,
922 // candidate_idx) in row-major order matching
923 // queries.iter().flat_map(|(_, cs)| cs).
924 let mut out = Vec::with_capacity(queries.len());
925 let mut cursor = 0usize;
926 for ((_query, cands), tail) in queries.into_iter().zip(tails) {
927 let n = cands.len();
928 let mut scored: Vec<(Memory, f64)> = cands
929 .into_iter()
930 .enumerate()
931 .map(|(i, (mem, original))| {
932 let ce = f64::from(scores[cursor + i]);
933 let blended =
934 ORIGINAL_WEIGHT * original + CROSS_ENCODER_WEIGHT * ce;
935 let factor = boost_config.factor_for(&mem);
936 (mem, finite_or_floor(blended * factor))
937 })
938 .collect();
939 cursor += n;
940 scored.sort_by(|a, b| {
941 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
942 });
943 // #1597 — uncapped remainder ranks below the
944 // cross-encoded head, blended scores untouched.
945 scored.extend(tail);
946 out.push(scored);
947 }
948 out
949 }
950 Err(e) => {
951 tracing::warn!(
952 "neural rerank_batch failed: {e}, falling back to lexical per-query"
953 );
954 queries
955 .into_iter()
956 .zip(tails)
957 .map(|((q, cands), tail)| {
958 // Runtime degrade (forward-pass failure) —
959 // mark the variant degraded so the recall
960 // response can surface `degraded_lexical`.
961 let lex = Self::Lexical { degraded: true };
962 let mut scored =
963 lex.rerank_with_reflection_boost(&q, cands, boost_config);
964 scored.extend(tail);
965 scored
966 })
967 .collect()
968 }
969 }
970 }
971 }
972 }
973
974 /// One tokenize + one forward pass over a flat batch of (query, doc)
975 /// pairs. Returns a flat `Vec<f32>` of sigmoided logits in the same
976 /// row-major order the candidates appear in `queries`.
977 fn neural_rerank_batch(
978 model: &BertModel,
979 tokenizer: &Tokenizer,
980 classifier_weight: &Tensor,
981 classifier_bias: &Tensor,
982 device: &Device,
983 queries: &[(String, Vec<(Memory, f64)>)],
984 ) -> Result<Vec<f32>> {
985 // Build the flat (query, document) pair list.
986 let mut pairs: Vec<(&str, String)> = Vec::new();
987 for (q, cands) in queries {
988 for (mem, _) in cands {
989 let document = crate::embeddings::embedding_document(&mem.title, &mem.content);
990 pairs.push((q.as_str(), document));
991 }
992 }
993 Self::neural_score_pairs(
994 model,
995 tokenizer,
996 classifier_weight,
997 classifier_bias,
998 device,
999 pairs,
1000 )
1001 }
1002
1003 /// One tokenize + one forward pass over a flat list of
1004 /// (query, document) pairs — the shared batched-inference chokepoint
1005 /// (#1597) used by BOTH the G9 multi-query [`Self::neural_rerank_batch`]
1006 /// path and the per-call [`Self::pair_scores`] path. Returns one
1007 /// sigmoided logit per pair, in input order.
1008 fn neural_score_pairs(
1009 model: &BertModel,
1010 tokenizer: &Tokenizer,
1011 classifier_weight: &Tensor,
1012 classifier_bias: &Tensor,
1013 device: &Device,
1014 pairs: Vec<(&str, String)>,
1015 ) -> Result<Vec<f32>> {
1016 if pairs.is_empty() {
1017 return Ok(Vec::new());
1018 }
1019
1020 // Variable-length pairs require padding for a single forward pass.
1021 // Clone the tokenizer so we can mutate padding settings without
1022 // racing other threads on the shared `Arc<Tokenizer>`.
1023 let mut batch_tokenizer = tokenizer.clone();
1024 let padding = tokenizers::PaddingParams {
1025 strategy: tokenizers::PaddingStrategy::BatchLongest,
1026 direction: tokenizers::PaddingDirection::Right,
1027 pad_id: 0,
1028 pad_type_id: 0,
1029 pad_token: "[PAD]".to_string(),
1030 ..Default::default()
1031 };
1032 batch_tokenizer.with_padding(Some(padding));
1033 // #1604 — rerank inputs truncate at the resolved rerank cap
1034 // (default RERANK_MAX_SEQ_DEFAULT), tighter than the
1035 // architecture-ceiling CROSS_ENCODER_MAX_SEQ the shared
1036 // tokenizer carries: long-content rows otherwise pad the whole
1037 // batch to 512 tokens and the candle CPU forward dominates
1038 // recall latency (~3.2 s/recall measured on the #1588 re-run).
1039 let truncation = tokenizers::TruncationParams {
1040 max_length: rerank_max_seq(),
1041 ..Default::default()
1042 };
1043 batch_tokenizer
1044 .with_truncation(Some(truncation))
1045 .map_err(|e| anyhow::anyhow!("failed to set rerank truncation: {e}"))?;
1046
1047 let encodings = batch_tokenizer
1048 .encode_batch(
1049 pairs
1050 .into_iter()
1051 .map(|(q, d)| tokenizers::EncodeInput::Dual(q.into(), d.into()))
1052 .collect::<Vec<_>>(),
1053 true,
1054 )
1055 .map_err(|e| anyhow::anyhow!("cross-encoder batch tokenization failed: {e}"))?;
1056
1057 let batch_size = encodings.len();
1058 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
1059
1060 let mut input_ids: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
1061 let mut attn_mask: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
1062 let mut token_types: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
1063 for enc in &encodings {
1064 input_ids.extend_from_slice(enc.get_ids());
1065 attn_mask.extend_from_slice(enc.get_attention_mask());
1066 token_types.extend_from_slice(enc.get_type_ids());
1067 }
1068
1069 let input_ids = Tensor::from_vec(input_ids, (batch_size, seq_len), device)?;
1070 let attention_mask = Tensor::from_vec(attn_mask, (batch_size, seq_len), device)?;
1071 let token_type_ids = Tensor::from_vec(token_types, (batch_size, seq_len), device)?;
1072
1073 // Forward pass → [batch, seq, 384]
1074 let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
1075
1076 // [CLS] token per row → [batch, 384]
1077 let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
1078
1079 // Classification head per row → [batch, 1]
1080 let logits = cls
1081 .matmul(&classifier_weight.t()?)?
1082 .broadcast_add(classifier_bias)?;
1083
1084 let logits_vec: Vec<f32> = logits.squeeze(1)?.to_vec1()?;
1085 Ok(logits_vec
1086 .into_iter()
1087 .map(|l| 1.0 / (1.0 + (-l).exp()))
1088 .collect())
1089 }
1090}
1091
1092impl Default for CrossEncoder {
1093 fn default() -> Self {
1094 Self::new()
1095 }
1096}
1097
1098// ---------------------------------------------------------------------------
1099// Lexical cross-encoder (original implementation)
1100// ---------------------------------------------------------------------------
1101
1102fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
1103 let query_terms = tokenize(query);
1104 if query_terms.is_empty() {
1105 return 0.0;
1106 }
1107
1108 let title_terms = tokenize(title);
1109 let content_terms = tokenize(content);
1110
1111 let doc_terms: HashSet<&str> = title_terms
1112 .iter()
1113 .chain(content_terms.iter())
1114 .copied()
1115 .collect();
1116 let query_set: HashSet<&str> = query_terms.iter().copied().collect();
1117
1118 // 1. Jaccard term overlap
1119 #[allow(clippy::cast_precision_loss)]
1120 let intersection = query_set.intersection(&doc_terms).count() as f32;
1121 #[allow(clippy::cast_precision_loss)]
1122 let union = query_set.union(&doc_terms).count() as f32;
1123 let jaccard = if union > 0.0 {
1124 intersection / union
1125 } else {
1126 0.0
1127 };
1128
1129 // 2. TF-IDF-like term weighting
1130 let doc_all: Vec<&str> = title_terms
1131 .iter()
1132 .chain(content_terms.iter())
1133 .copied()
1134 .collect();
1135 let tf_idf = tfidf_score(&query_terms, &doc_all);
1136
1137 // 3. Bigram overlap bonus
1138 let query_bigrams = bigrams(&query_terms);
1139 let doc_bigrams = bigrams(&doc_all);
1140 let bigram_overlap = if query_bigrams.is_empty() {
1141 0.0
1142 } else {
1143 let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
1144 #[allow(clippy::cast_precision_loss)]
1145 let hits = query_bigrams
1146 .iter()
1147 .filter(|b| doc_bigram_set.contains(b))
1148 .count() as f32;
1149 #[allow(clippy::cast_precision_loss)]
1150 let query_bigrams_len = query_bigrams.len() as f32;
1151 hits / query_bigrams_len
1152 };
1153
1154 // 4. Title match bonus
1155 let title_set: HashSet<&str> = title_terms.iter().copied().collect();
1156 #[allow(clippy::cast_precision_loss)]
1157 let title_hits = query_set.intersection(&title_set).count() as f32;
1158 #[allow(clippy::cast_precision_loss)]
1159 let title_bonus = if query_set.is_empty() {
1160 0.0
1161 } else {
1162 title_hits / query_set.len() as f32
1163 };
1164
1165 let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
1166 raw.clamp(0.0, 1.0)
1167}
1168
1169// ---------------------------------------------------------------------------
1170// Internal helpers
1171// ---------------------------------------------------------------------------
1172
1173fn tokenize(text: &str) -> Vec<&str> {
1174 text.split(|c: char| !c.is_alphanumeric() && c != '\'')
1175 .filter(|w| !w.is_empty())
1176 .collect()
1177}
1178
1179fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
1180 if doc_tokens.is_empty() || query_terms.is_empty() {
1181 return 0.0;
1182 }
1183
1184 let mut tf_map: HashMap<&str, usize> = HashMap::new();
1185 for &tok in doc_tokens {
1186 *tf_map.entry(tok).or_insert(0) += 1;
1187 }
1188
1189 #[allow(clippy::cast_precision_loss)]
1190 let total = doc_tokens.len() as f32;
1191 #[allow(clippy::cast_precision_loss)]
1192 let unique = tf_map.len() as f32;
1193
1194 let mut score_sum: f32 = 0.0;
1195 let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
1196
1197 for qt in &query_lower {
1198 #[allow(clippy::cast_precision_loss)]
1199 let tf = tf_map
1200 .iter()
1201 .filter(|(k, _)| k.to_lowercase() == *qt)
1202 .map(|(_, &v)| v)
1203 .sum::<usize>() as f32;
1204
1205 if tf == 0.0 {
1206 continue;
1207 }
1208
1209 let tf_norm = tf / total;
1210 #[allow(clippy::cast_precision_loss)]
1211 let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
1212 let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
1213
1214 score_sum += tf_norm * idf;
1215 }
1216
1217 #[allow(clippy::cast_precision_loss)]
1218 let max_possible = query_lower.len() as f32;
1219 (score_sum / max_possible).clamp(0.0, 1.0)
1220}
1221
1222fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
1223 tokens.windows(2).map(|w| (w[0], w[1])).collect()
1224}
1225
1226// ---------------------------------------------------------------------------
1227// v0.7 G9 — concurrent rerank coalescer
1228// ---------------------------------------------------------------------------
1229
1230/// Default upper bound on how many requests we coalesce per BERT call.
1231pub const DEFAULT_MAX_BATCH: usize = 32;
1232
1233/// Default flush latency (ms) — how long the worker waits for more requests
1234/// before processing a non-full batch. 5ms keeps single-request latency
1235/// negligible while still benefiting parallel callers.
1236pub const DEFAULT_MAX_WAIT_MS: u64 = 5;
1237
1238/// #1579 B10 — minimum number of in-flight rerank requests (including
1239/// the current one) before [`BatchedReranker::rerank`] routes through
1240/// the coalescing worker on a *neural* encoder. Below this threshold
1241/// there is nothing to coalesce WITH: the lone caller pays the worker
1242/// channel round-trip plus up to [`DEFAULT_MAX_WAIT_MS`] of flush-window
1243/// wait for zero amortisation gain.
1244///
1245/// **Criterion evidence (perf-audit P1, 2026-06, `cargo bench --bench
1246/// reranker_throughput`, lexical default):** at N=8 concurrent queries
1247/// × 10 candidates the batched path measured ~7.6 ms vs ~0.65 ms direct
1248/// — 12× SLOWER, because the per-batch flush window (5 ms) dwarfs the
1249/// sub-millisecond lexical compute. The lexical variant therefore NEVER
1250/// routes through the worker (it holds no shared-model mutex, so
1251/// coalescing has nothing to amortise at ANY N — see
1252/// [`BatchedReranker::rerank`]); the neural variant keeps the batched
1253/// path at concurrency ≥ this threshold, where the G9 measurement
1254/// showed ~3× throughput gain from holding the BERT mutex once per
1255/// batch instead of once per (query, candidate).
1256pub const BATCHED_RERANK_MIN_CONCURRENCY: usize = 2;
1257
1258/// #1579 B10 — the auto-select predicate, extracted as a free function
1259/// so the threshold arithmetic is unit-testable without standing up a
1260/// worker thread or downloading model weights. `true` ⇒ route through
1261/// the coalescing worker; `false` ⇒ direct encoder call.
1262#[must_use]
1263pub const fn use_batched_rerank_path(encoder_is_neural: bool, inflight_now: usize) -> bool {
1264 encoder_is_neural && inflight_now >= BATCHED_RERANK_MIN_CONCURRENCY
1265}
1266
1267/// Job submitted to the coalescer worker.
1268struct RerankJob {
1269 query: String,
1270 candidates: Vec<(Memory, f64)>,
1271 reply: std::sync::mpsc::SyncSender<Vec<(Memory, f64)>>,
1272}
1273
1274/// Concurrent rerank coalescer.
1275///
1276/// Wraps a `CrossEncoder` and serializes concurrent recall reranks through
1277/// a single worker thread. The worker buffers up to `max_batch` requests
1278/// or waits up to `max_wait_ms` (whichever first), then issues one
1279/// `rerank_batch` call. The Mutex around the BERT model is held for the
1280/// whole batch instead of once per (query, candidate) — the throughput
1281/// fix mandated by G9.
1282///
1283/// **Single-request latency**: the worker flushes immediately when the
1284/// queue is empty after pulling the first job, so a lone request only
1285/// pays one `recv_timeout(0)` round-trip — no artificial waiting.
1286pub struct BatchedReranker {
1287 sender: Option<Sender<RerankJob>>,
1288 /// H2 (v0.7.0 round-2) — explicit one-shot shutdown signal. The
1289 /// worker thread selects on BOTH the work channel and this
1290 /// shutdown channel; receiving on the shutdown channel makes the
1291 /// worker exit its loop deterministically, even if a holder of
1292 /// `sender` happens to outlive `Drop` (e.g. the test harness
1293 /// stashed a `Sender` clone). `Drop` triggers this BEFORE dropping
1294 /// `sender`, so a worker that is currently blocked in
1295 /// `rx.recv()` wakes up via the shutdown channel without waiting
1296 /// for the work-channel disconnect.
1297 shutdown: Option<std::sync::mpsc::Sender<()>>,
1298 worker: Option<JoinHandle<()>>,
1299 /// Direct handle to the underlying encoder, used for the single-query
1300 /// short-circuit and for callers that explicitly want non-batched
1301 /// behavior (tests, benchmarks).
1302 encoder: Arc<CrossEncoder>,
1303 /// v0.7.0 L2-8 — reflection-aware boost config the worker hands
1304 /// down to every batched `rerank` call. Defaults to
1305 /// [`ReflectionBoostConfig::default`] (boost = 1.2) so the daemon
1306 /// flow ships the boost; explicit configuration goes through
1307 /// [`Self::with_reflection_boost`] before the worker starts taking
1308 /// jobs.
1309 reflection_boost: ReflectionBoostConfig,
1310 /// v0.7.0 #1319 — opt-in noise floor applied AFTER the blend
1311 /// (`0.6 * original + 0.4 * ce_score`) and AFTER the reflection
1312 /// boost. Default is [`RerankerScoreFloor::Off`] so existing
1313 /// callers see byte-identical output to pre-#1319. Operators that
1314 /// observed the cross-encoder false-positive ordering on
1315 /// disjoint-vocab paraphrase queries (the v1 P5 probe — an Apollo
1316 /// 11 row at 0.479 surfacing above a substantively-relevant hit at
1317 /// 0.363) opt in via [`Self::with_score_floor`] to drop the
1318 /// low-confidence tail entirely.
1319 score_floor: RerankerScoreFloor,
1320 /// #1579 B10 — number of rerank requests currently inside
1321 /// [`Self::rerank`] (incremented on entry, decremented on exit).
1322 /// Drives the auto-select between the direct encoder call and the
1323 /// coalescing worker; see [`use_batched_rerank_path`].
1324 inflight: std::sync::atomic::AtomicUsize,
1325 /// #1579 B10 — observability counter: how many jobs this wrapper
1326 /// has submitted to the coalescing worker over its lifetime. The
1327 /// auto-select regression tests pin "lexical / lone-caller traffic
1328 /// never reaches the worker" on this counter.
1329 worker_submissions: std::sync::atomic::AtomicUsize,
1330}
1331
1332/// v0.7.0 #1319 — post-blend score floor applied by [`BatchedReranker`].
1333///
1334/// **Default is [`Self::Off`]** — every existing caller observes
1335/// byte-identical pre-#1319 output. Operators who hit the
1336/// paraphrase / disjoint-vocab noise band turn it on via
1337/// [`BatchedReranker::with_score_floor`] (constructor knob) or
1338/// through the resolver-side `[reranker].score_floor*` config fields
1339/// once they land.
1340///
1341/// **Why two shapes.** [`Self::Absolute`] is the literal "drop
1342/// anything below 0.5" handle the recall caller's documentation
1343/// suggests. [`Self::RelativeToTop`] keeps the top-of-list always
1344/// available — useful when the corpus is small (a 3-row recall
1345/// shouldn't return zero results just because every row scored
1346/// `0.42`) and the operator just wants a "tail cleaner".
1347///
1348/// Both variants compare the **final blended score** (after the L2-8
1349/// reflection boost), not the raw cross-encoder logit, so the floor
1350/// is comparable to the values an operator reads off `recall.memories[].score`.
1351#[derive(Debug, Clone, Copy, PartialEq)]
1352pub enum RerankerScoreFloor {
1353 /// No floor — pre-#1319 behavior. Every blended candidate is kept,
1354 /// regardless of score. Default.
1355 Off,
1356 /// Drop every candidate whose final blended score falls strictly
1357 /// below the supplied absolute value. Clamped at runtime to
1358 /// `[0.0, 1.0]`.
1359 Absolute(f64),
1360 /// Drop every candidate whose final blended score falls strictly
1361 /// below `top_score * ratio` (where `top_score` is the first row
1362 /// after sorting). Clamped at runtime to `[0.0, 1.0]`. The top
1363 /// row itself is never dropped — operators get at least one
1364 /// result even when the entire ranked set is in the noise band.
1365 RelativeToTop(f64),
1366}
1367
1368impl Default for RerankerScoreFloor {
1369 fn default() -> Self {
1370 Self::Off
1371 }
1372}
1373
1374impl RerankerScoreFloor {
1375 /// #1691/n14 — parse an operator config / env string into a score
1376 /// floor so the (previously dead) [`BatchedReranker::with_score_floor`]
1377 /// capability is reachable from `[reranker].score_floor` and
1378 /// `AI_MEMORY_RERANK_SCORE_FLOOR`.
1379 ///
1380 /// Grammar (case-insensitive, whitespace-trimmed):
1381 /// - `off` → [`RerankerScoreFloor::Off`]
1382 /// - `absolute:<f>` (alias `abs:<f>`) → [`RerankerScoreFloor::Absolute`]
1383 /// - `relative:<f>` (aliases `rel:<f>`, `relative_to_top:<f>`) →
1384 /// [`RerankerScoreFloor::RelativeToTop`]
1385 ///
1386 /// Returns `None` on any unparseable value so resolvers fall through
1387 /// to the next precedence layer. The numeric is clamped to
1388 /// `[0.0, 1.0]` at [`apply`](Self::apply) time, so an out-of-range
1389 /// value still parses (and is clamped on use) rather than erroring.
1390 #[must_use]
1391 pub fn parse(s: &str) -> Option<Self> {
1392 let s = s.trim();
1393 if s.eq_ignore_ascii_case("off") {
1394 return Some(Self::Off);
1395 }
1396 let (kind, value) = s.split_once(':')?;
1397 let v: f64 = value.trim().parse().ok()?;
1398 if !v.is_finite() {
1399 return None;
1400 }
1401 match kind.trim().to_ascii_lowercase().as_str() {
1402 "absolute" | "abs" => Some(Self::Absolute(v)),
1403 "relative" | "rel" | "relative_to_top" => Some(Self::RelativeToTop(v)),
1404 _ => None,
1405 }
1406 }
1407
1408 /// Apply the floor in-place to a pre-sorted (descending) vector
1409 /// of `(Memory, blended_score)` candidates. The implementation is
1410 /// extracted as a free helper so unit tests can pin the cutoff
1411 /// arithmetic without spinning up a [`BatchedReranker`].
1412 ///
1413 /// The top row is always preserved (so a tiny corpus never
1414 /// returns zero results) — see [`RerankerScoreFloor::RelativeToTop`]
1415 /// documentation for the rationale.
1416 fn apply(&self, scored: &mut Vec<(Memory, f64)>) {
1417 if scored.is_empty() {
1418 return;
1419 }
1420 let cutoff: f64 = match *self {
1421 Self::Off => return,
1422 Self::Absolute(v) => v.clamp(0.0, 1.0),
1423 Self::RelativeToTop(ratio) => {
1424 let top = scored.first().map(|(_, s)| *s).unwrap_or(0.0);
1425 top * ratio.clamp(0.0, 1.0)
1426 }
1427 };
1428 // Walk index-first so we can preserve the top row even when
1429 // its score sits below `cutoff` (small-corpus invariant: the
1430 // floor is a tail cleaner, not a "return nothing" knob).
1431 let mut keep = Vec::with_capacity(scored.len());
1432 for (idx, (_, score)) in scored.iter().enumerate() {
1433 if idx == 0 || *score >= cutoff {
1434 keep.push(idx);
1435 }
1436 }
1437 // `keep` is monotonically increasing; iterate in reverse and
1438 // remove dropped indices so the Vec retains the descending
1439 // sort order from the upstream rerank.
1440 let mut next_keep = keep.iter().rev().copied();
1441 let mut want = next_keep.next();
1442 let mut idx = scored.len();
1443 while idx > 0 {
1444 idx -= 1;
1445 match want {
1446 Some(k) if k == idx => {
1447 want = next_keep.next();
1448 }
1449 _ => {
1450 scored.remove(idx);
1451 }
1452 }
1453 }
1454 }
1455}
1456
1457impl BatchedReranker {
1458 /// Wrap an existing `CrossEncoder` with the default batching parameters
1459 /// (`max_batch = 32`, `max_wait_ms = 5`).
1460 pub fn new(encoder: CrossEncoder) -> Self {
1461 Self::with_params(encoder, DEFAULT_MAX_BATCH, DEFAULT_MAX_WAIT_MS)
1462 }
1463
1464 /// Wrap an existing `CrossEncoder` with custom batching parameters.
1465 pub fn with_params(encoder: CrossEncoder, max_batch: usize, max_wait_ms: u64) -> Self {
1466 Self::with_full_params(
1467 encoder,
1468 max_batch,
1469 max_wait_ms,
1470 ReflectionBoostConfig::default(),
1471 RerankerScoreFloor::Off,
1472 )
1473 }
1474
1475 /// v0.7.0 L2-8 — wrap an existing `CrossEncoder` with a custom
1476 /// reflection-boost config alongside default batching parameters.
1477 /// Used by the recall integration tests to pin specific boost shapes
1478 /// (e.g. `disabled()` for the regression test).
1479 pub fn with_reflection_boost(encoder: CrossEncoder, boost: ReflectionBoostConfig) -> Self {
1480 Self::with_full_params(
1481 encoder,
1482 DEFAULT_MAX_BATCH,
1483 DEFAULT_MAX_WAIT_MS,
1484 boost,
1485 RerankerScoreFloor::Off,
1486 )
1487 }
1488
1489 /// v0.7.0 #1319 — wrap a `CrossEncoder` with a post-blend score
1490 /// floor. The reflection-boost knob is left at the daemon default
1491 /// (`1.2`); use [`Self::with_full_params`] to set both at once.
1492 /// **Default constructors leave the floor `Off`** — flipping it on
1493 /// here is an explicit operator-opt-in.
1494 #[must_use]
1495 pub fn with_score_floor(encoder: CrossEncoder, floor: RerankerScoreFloor) -> Self {
1496 Self::with_full_params(
1497 encoder,
1498 DEFAULT_MAX_BATCH,
1499 DEFAULT_MAX_WAIT_MS,
1500 ReflectionBoostConfig::default(),
1501 floor,
1502 )
1503 }
1504
1505 /// Internal constructor — all knobs visible.
1506 fn with_full_params(
1507 encoder: CrossEncoder,
1508 max_batch: usize,
1509 max_wait_ms: u64,
1510 reflection_boost: ReflectionBoostConfig,
1511 score_floor: RerankerScoreFloor,
1512 ) -> Self {
1513 let encoder = Arc::new(encoder);
1514 let (tx, rx) = std::sync::mpsc::channel::<RerankJob>();
1515 // H2 (v0.7.0 round-2) — one-shot shutdown channel. The std
1516 // mpsc channel is used as a "oneshot": we never send more
1517 // than one value, and the worker exits on the first
1518 // `try_recv()` success OR on disconnect (Drop of the holder
1519 // closes the sender side, which also surfaces as a recv
1520 // outcome the worker can branch on).
1521 let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>();
1522 let worker_encoder = Arc::clone(&encoder);
1523 let worker_boost = reflection_boost;
1524 let max_wait = Duration::from_millis(max_wait_ms);
1525
1526 let worker = thread::Builder::new()
1527 .name("ai-memory-reranker-batcher".into())
1528 .spawn(move || {
1529 // H2 polling cadence: when waiting for the first job
1530 // of a batch, fall back to `recv_timeout` so the worker
1531 // wakes up periodically to check the shutdown signal.
1532 // 100ms keeps the test in `test_drop_terminates_worker`
1533 // comfortably inside its 500ms budget while staying
1534 // well below the 5ms intra-batch coalescing window
1535 // (no cost to the hot path).
1536 const SHUTDOWN_POLL: Duration = Duration::from_millis(100);
1537 'outer: loop {
1538 // Block until the first job arrives OR the
1539 // shutdown signal fires OR the sender drops.
1540 let first = loop {
1541 // Cheap non-blocking shutdown check first so a
1542 // signal that arrived between iterations is
1543 // observed even if the work channel had a job
1544 // queued before the signal landed.
1545 match shutdown_rx.try_recv() {
1546 Ok(()) | Err(std::sync::mpsc::TryRecvError::Disconnected) => {
1547 break 'outer;
1548 }
1549 Err(std::sync::mpsc::TryRecvError::Empty) => {}
1550 }
1551 match rx.recv_timeout(SHUTDOWN_POLL) {
1552 Ok(job) => break job,
1553 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => continue,
1554 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
1555 break 'outer;
1556 }
1557 }
1558 };
1559
1560 let mut batch: Vec<RerankJob> = Vec::with_capacity(max_batch);
1561 batch.push(first);
1562
1563 // Coalesce additional jobs that arrive within the
1564 // window, up to the batch cap.
1565 let deadline = Instant::now() + max_wait;
1566 while batch.len() < max_batch {
1567 let now = Instant::now();
1568 if now >= deadline {
1569 break;
1570 }
1571 match rx.recv_timeout(deadline - now) {
1572 Ok(j) => batch.push(j),
1573 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => break,
1574 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
1575 // Drain the current batch then exit.
1576 process_batch(&worker_encoder, batch, &worker_boost);
1577 break 'outer;
1578 }
1579 }
1580 }
1581
1582 process_batch(&worker_encoder, batch, &worker_boost);
1583 }
1584 })
1585 .expect("failed to spawn rerank batcher worker");
1586
1587 Self {
1588 sender: Some(tx),
1589 shutdown: Some(shutdown_tx),
1590 worker: Some(worker),
1591 encoder,
1592 reflection_boost,
1593 score_floor,
1594 inflight: std::sync::atomic::AtomicUsize::new(0),
1595 worker_submissions: std::sync::atomic::AtomicUsize::new(0),
1596 }
1597 }
1598
1599 /// Submit a single rerank request. Blocks until the result is
1600 /// available.
1601 ///
1602 /// #1579 B10 — **auto-select.** The wrapper keeps BOTH execution
1603 /// paths and picks per call via [`use_batched_rerank_path`]:
1604 ///
1605 /// - **Direct** (no worker round-trip) when the encoder is
1606 /// lexical / degraded-lexical (no shared-model mutex to
1607 /// amortise — criterion proved the coalescing flush window made
1608 /// the batched path 12× slower at N=8: ~7.6 ms vs ~0.65 ms), or
1609 /// when fewer than [`BATCHED_RERANK_MIN_CONCURRENCY`] requests
1610 /// are in flight (nothing to coalesce with).
1611 /// - **Coalesced** (worker thread, one `rerank_batch` per flush)
1612 /// for neural encoders under real concurrency — the G9 win
1613 /// (~3× at N=8 neural) is preserved.
1614 ///
1615 /// If the worker is unavailable for any reason (channel closed),
1616 /// falls back to a direct `rerank` call on the underlying encoder
1617 /// (with the wrapper's configured reflection boost applied).
1618 pub fn rerank(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
1619 let mut scored = self.rerank_unfloored(query, candidates);
1620 // v0.7.0 #1319 — post-blend score floor (default Off; opt-in
1621 // via `with_score_floor`). Applies to the already-sorted
1622 // descending vector returned by the encoder/worker.
1623 self.score_floor.apply(&mut scored);
1624 scored
1625 }
1626
1627 /// #1579 B10 — force the COALESCED (worker) path regardless of the
1628 /// auto-select. Kept public so the throughput bench
1629 /// (`benches/reranker_throughput.rs`) and regression tests can keep
1630 /// measuring the raw batched machinery after `rerank` started
1631 /// auto-selecting away from it at small N. Applies the same
1632 /// post-blend score floor as [`Self::rerank`].
1633 #[must_use]
1634 pub fn rerank_coalesced(
1635 &self,
1636 query: &str,
1637 candidates: Vec<(Memory, f64)>,
1638 ) -> Vec<(Memory, f64)> {
1639 let mut scored = self.rerank_coalesced_unfloored(query, candidates);
1640 self.score_floor.apply(&mut scored);
1641 scored
1642 }
1643
1644 /// Internal — same shape as [`Self::rerank`] but skips the
1645 /// post-blend score floor. Pre-#1319 callsites that explicitly
1646 /// want the raw blended output (regression tests, the byte-equal
1647 /// pin in `g9_batched_reranker_serial_calls_match_rerank`) call
1648 /// this directly.
1649 fn rerank_unfloored(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
1650 use std::sync::atomic::Ordering;
1651 // #1579 B10 — RAII in-flight guard so a panicking encoder call
1652 // can't leak the counter and wedge the auto-select high.
1653 struct InflightGuard<'a>(&'a std::sync::atomic::AtomicUsize);
1654 impl Drop for InflightGuard<'_> {
1655 fn drop(&mut self) {
1656 self.0.fetch_sub(1, Ordering::Relaxed);
1657 }
1658 }
1659 let inflight_now = self.inflight.fetch_add(1, Ordering::Relaxed) + 1;
1660 let _guard = InflightGuard(&self.inflight);
1661
1662 if use_batched_rerank_path(self.encoder.is_neural(), inflight_now) {
1663 self.rerank_coalesced_unfloored(query, candidates)
1664 } else {
1665 self.rerank_direct_unfloored(query, candidates)
1666 }
1667 }
1668
1669 /// #1579 B10 — the DIRECT path: one synchronous encoder call on the
1670 /// caller's thread, no worker round-trip, no flush-window wait.
1671 fn rerank_direct_unfloored(
1672 &self,
1673 query: &str,
1674 candidates: Vec<(Memory, f64)>,
1675 ) -> Vec<(Memory, f64)> {
1676 self.encoder
1677 .rerank_with_reflection_boost(query, candidates, &self.reflection_boost)
1678 }
1679
1680 /// The COALESCED path: submit to the worker thread and block for
1681 /// the reply. Concurrent callers are coalesced into a single
1682 /// `rerank_batch` call inside the worker. (Pre-#1579-B10 this was
1683 /// the body of `rerank_unfloored`.)
1684 fn rerank_coalesced_unfloored(
1685 &self,
1686 query: &str,
1687 candidates: Vec<(Memory, f64)>,
1688 ) -> Vec<(Memory, f64)> {
1689 let Some(sender) = self.sender.as_ref() else {
1690 return self.rerank_direct_unfloored(query, candidates);
1691 };
1692 let (reply_tx, reply_rx) = sync_channel::<Vec<(Memory, f64)>>(1);
1693 let job = RerankJob {
1694 query: query.to_string(),
1695 candidates,
1696 reply: reply_tx,
1697 };
1698 if sender.send(job).is_err() {
1699 return self.encoder.rerank_with_reflection_boost(
1700 query,
1701 Vec::new(),
1702 &self.reflection_boost,
1703 );
1704 }
1705 self.worker_submissions
1706 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1707 reply_rx.recv().unwrap_or_else(|_| {
1708 self.encoder
1709 .rerank_with_reflection_boost(query, Vec::new(), &self.reflection_boost)
1710 })
1711 }
1712
1713 /// #1579 B10 — lifetime count of jobs submitted to the coalescing
1714 /// worker. Observability hook for the auto-select regression tests
1715 /// ("lexical traffic never reaches the worker") and operator
1716 /// diagnostics.
1717 #[must_use]
1718 pub fn worker_submissions(&self) -> usize {
1719 self.worker_submissions
1720 .load(std::sync::atomic::Ordering::Relaxed)
1721 }
1722
1723 /// v0.7.0 #1319 — accessor for the configured score floor, used by
1724 /// operator-facing diagnostics. NOTE (n22): the `memory_capabilities`
1725 /// envelope does not currently surface this value; wiring the floor
1726 /// through config and exposing it in capabilities is tracked under
1727 /// #1319 / n14.
1728 #[must_use]
1729 pub fn score_floor(&self) -> RerankerScoreFloor {
1730 self.score_floor
1731 }
1732
1733 /// v0.7.0 L2-8 — expose the configured boost for the
1734 /// `memory_capabilities` reporter.
1735 #[must_use]
1736 pub fn reflection_boost(&self) -> &ReflectionBoostConfig {
1737 &self.reflection_boost
1738 }
1739
1740 /// Direct access to the wrapped encoder. Useful for callers that
1741 /// want to bypass the coalescer (tests, benchmarks).
1742 pub fn encoder(&self) -> &CrossEncoder {
1743 &self.encoder
1744 }
1745
1746 /// Convenience shortcut for `self.encoder().is_neural()`. Most
1747 /// callers in the recall pipeline only need to check the variant
1748 /// for capability reporting.
1749 pub fn is_neural(&self) -> bool {
1750 self.encoder.is_neural()
1751 }
1752
1753 /// v0.7.0 R3-S2 — shortcut for `self.encoder().is_degraded_lexical()`.
1754 /// The recall path reads this to drive the in-band `reranker_used`
1755 /// signal exposed via `RecallMeta`.
1756 #[must_use]
1757 pub fn is_degraded_lexical(&self) -> bool {
1758 self.encoder.is_degraded_lexical()
1759 }
1760}
1761
1762impl Drop for BatchedReranker {
1763 fn drop(&mut self) {
1764 // H2 (v0.7.0 round-2): two-step termination.
1765 //
1766 // 1. Fire the explicit shutdown signal FIRST so the worker
1767 // observes it even when another holder of `Sender`
1768 // (e.g. a test that cloned the work channel) would
1769 // otherwise keep the work channel alive.
1770 // 2. Then drop the work-channel sender — a worker that was
1771 // blocked in `rx.recv_timeout(...)` wakes up either via
1772 // the shutdown poll OR the disconnect, whichever
1773 // happens first.
1774 //
1775 // Joining the worker after BOTH signals fire bounds shutdown
1776 // by the SHUTDOWN_POLL cadence (100ms) in the absolute worst
1777 // case, well inside the 500ms budget exercised by
1778 // `test_drop_terminates_worker`.
1779 if let Some(shutdown) = self.shutdown.take() {
1780 let _ = shutdown.send(());
1781 }
1782 self.sender.take();
1783 if let Some(handle) = self.worker.take() {
1784 let _ = handle.join();
1785 }
1786 }
1787}
1788
1789fn process_batch(
1790 encoder: &CrossEncoder,
1791 batch: Vec<RerankJob>,
1792 boost_config: &ReflectionBoostConfig,
1793) {
1794 if batch.is_empty() {
1795 return;
1796 }
1797
1798 // Single-request fast path: bypass the batched API to avoid the
1799 // padding overhead and any latency regression on lone callers.
1800 if batch.len() == 1 {
1801 let mut iter = batch.into_iter();
1802 let job = iter.next().expect("len == 1");
1803 let result = encoder.rerank_with_reflection_boost(&job.query, job.candidates, boost_config);
1804 let _ = job.reply.send(result);
1805 return;
1806 }
1807
1808 // Build the input vector for the batched call. Use placeholder
1809 // `Memory` clones via `take` to avoid copying — we move out.
1810 let mut queries: Vec<(String, Vec<(Memory, f64)>)> = Vec::with_capacity(batch.len());
1811 let mut replies: Vec<std::sync::mpsc::SyncSender<Vec<(Memory, f64)>>> =
1812 Vec::with_capacity(batch.len());
1813 for job in batch {
1814 queries.push((job.query, job.candidates));
1815 replies.push(job.reply);
1816 }
1817
1818 let outputs = encoder.rerank_batch_with_reflection_boost(queries, boost_config);
1819 for (out, reply) in outputs.into_iter().zip(replies.into_iter()) {
1820 let _ = reply.send(out);
1821 }
1822}
1823
1824// ---------------------------------------------------------------------------
1825// Tests
1826// ---------------------------------------------------------------------------
1827
1828#[cfg(test)]
1829mod tests {
1830 use super::*;
1831 use crate::models::{Memory, Tier};
1832
1833 /// #1604 — process-wide rerank sequence-cap seeding: the first
1834 /// [`set_rerank_max_seq`] writer wins and later writes are no-ops.
1835 ///
1836 /// Order-independent by construction: other tests in this binary
1837 /// may legitimately seed the process-wide `OnceLock` first (any
1838 /// test that walks the `daemon_runtime` boot ladder does), so this
1839 /// test asserts only the post-seed immutability contract — it
1840 /// seeds (or observes the earlier seed), then proves a second
1841 /// write cannot change the value. The unseeded-default fallback
1842 /// is pinned by `resolve_reranker_1604_max_seq_ladder` (resolver
1843 /// layer, no OnceLock) instead. The pre-fix form asserted the
1844 /// unseeded default first and was order-dependent — green locally,
1845 /// red under CI's impact-aware test ordering.
1846 #[test]
1847 fn rerank_max_seq_1604_seed_once_semantics() {
1848 set_rerank_max_seq(192);
1849 let settled = rerank_max_seq();
1850 assert!(
1851 settled > 0,
1852 "settled value must be a real cap (ours or an earlier boot seed), got {settled}"
1853 );
1854 set_rerank_max_seq(64);
1855 assert_eq!(
1856 rerank_max_seq(),
1857 settled,
1858 "first writer must win — a later set_rerank_max_seq call must be a no-op"
1859 );
1860 }
1861
1862 fn make_memory(title: &str, content: &str) -> Memory {
1863 Memory {
1864 id: "test-id".to_string(),
1865 tier: Tier::Mid,
1866 namespace: "test".to_string(),
1867 title: title.to_string(),
1868 content: content.to_string(),
1869 tags: vec![],
1870 priority: 5,
1871 confidence: 1.0,
1872 source: "test".to_string(),
1873 access_count: 0,
1874 created_at: "2026-01-01T00:00:00Z".to_string(),
1875 updated_at: "2026-01-01T00:00:00Z".to_string(),
1876 last_accessed_at: None,
1877 expires_at: None,
1878 metadata: serde_json::json!({}),
1879 reflection_depth: 0,
1880 memory_kind: crate::models::MemoryKind::Observation,
1881 entity_id: None,
1882 persona_version: None,
1883 citations: Vec::new(),
1884 source_uri: None,
1885 source_span: None,
1886 confidence_source: crate::models::ConfidenceSource::CallerProvided,
1887 confidence_signals: None,
1888 confidence_decayed_at: None,
1889 version: 1,
1890 }
1891 }
1892
1893 /// #1531 M13 — a NaN original score must not nondeterministically
1894 /// hold the top rank. Pre-fix, the blended NaN compared `Equal` to
1895 /// every finite score under `partial_cmp(..).unwrap_or(Equal)`, so
1896 /// the stable sort left the poisoned candidate in its input
1897 /// position (here: first). Post-fix non-finite scores clamp to
1898 /// `f64::MIN` and sink to the bottom.
1899 #[test]
1900 fn nan_scored_candidate_sinks_to_bottom_m13() {
1901 let ce = CrossEncoder::Lexical { degraded: false };
1902 let poisoned = make_memory("poisoned", "irrelevant body");
1903 let good = make_memory("network configuration", "network configuration body");
1904 let out = ce.rerank(
1905 "network configuration",
1906 vec![(poisoned, f64::NAN), (good, 0.9)],
1907 );
1908 assert_eq!(
1909 out[0].0.title, "network configuration",
1910 "finite-scored candidate must outrank the NaN-poisoned one"
1911 );
1912 assert_eq!(out[1].0.title, "poisoned");
1913 assert_eq!(
1914 out[1].1,
1915 f64::MIN,
1916 "non-finite blended score must clamp to the ranking floor"
1917 );
1918
1919 // Boost-aware path takes the same clamp.
1920 let poisoned = make_memory("poisoned", "irrelevant body");
1921 let good = make_memory("network configuration", "network configuration body");
1922 let out = ce.rerank_with_reflection_boost(
1923 "network configuration",
1924 vec![(poisoned, f64::NAN), (good, 0.9)],
1925 &ReflectionBoostConfig::disabled(),
1926 );
1927 assert_eq!(out[0].0.title, "network configuration");
1928 assert_eq!(out[1].1, f64::MIN);
1929 }
1930
1931 #[test]
1932 fn lexical_score_returns_zero_for_empty_query() {
1933 assert_eq!(lexical_score("", "some title", "some content"), 0.0);
1934 }
1935
1936 #[test]
1937 fn lexical_score_returns_zero_for_no_overlap() {
1938 let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
1939 assert!(s < 0.05, "expected near-zero, got {s}");
1940 }
1941
1942 #[test]
1943 fn lexical_score_rewards_title_match() {
1944 let content = "This document discusses network configuration for LAN setups.";
1945 let s_title_match = lexical_score(
1946 "network configuration",
1947 "Network Configuration Guide",
1948 content,
1949 );
1950 let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
1951 assert!(
1952 s_title_match > s_no_title,
1953 "title match ({s_title_match}) should beat no title match ({s_no_title})"
1954 );
1955 }
1956
1957 #[test]
1958 fn lexical_score_is_bounded_zero_one() {
1959 let s = lexical_score(
1960 "the quick brown fox jumps over the lazy dog",
1961 "the quick brown fox",
1962 "the quick brown fox jumps over the lazy dog and more words",
1963 );
1964 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
1965 }
1966
1967 #[test]
1968 fn rerank_reorders_candidates() {
1969 let ce = CrossEncoder::new();
1970 let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
1971 let b = make_memory("Grocery list", "milk eggs bread butter cheese");
1972 let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
1973 let reranked = ce.rerank("cross-encoder reranking", candidates);
1974 assert_eq!(reranked[0].0.title, "Rust cross-encoder");
1975 }
1976
1977 #[test]
1978 fn rerank_preserves_candidate_count() {
1979 let ce = CrossEncoder::new();
1980 let candidates = vec![
1981 (make_memory("A", "alpha"), 0.5),
1982 (make_memory("B", "beta"), 0.6),
1983 (make_memory("C", "gamma"), 0.7),
1984 ];
1985 let reranked = ce.rerank("alpha", candidates);
1986 assert_eq!(reranked.len(), 3);
1987 }
1988
1989 #[test]
1990 fn bigram_overlap_boosts_phrase_match() {
1991 let s_phrase = lexical_score(
1992 "network adapter",
1993 "title",
1994 "the network adapter is connected to the LAN",
1995 );
1996 let s_scattered = lexical_score(
1997 "network adapter",
1998 "title",
1999 "the adapter handles the network traffic independently",
2000 );
2001 assert!(
2002 s_phrase > s_scattered,
2003 "phrase match ({s_phrase}) should beat scattered ({s_scattered})"
2004 );
2005 }
2006
2007 // -----------------------------------------------------------------
2008 // W11/S11b — input-count invariants for the rerank() API
2009 // -----------------------------------------------------------------
2010
2011 #[test]
2012 fn test_rerank_preserves_input_count_heuristic() {
2013 let ce = CrossEncoder::new();
2014 // Build 5 distinct candidates with varied original scores.
2015 let candidates: Vec<(Memory, f64)> = (0..5)
2016 .map(|i| {
2017 (
2018 make_memory(
2019 &format!("title {i}"),
2020 &format!("content body number {i} with some words"),
2021 ),
2022 f64::from(i) * 0.1,
2023 )
2024 })
2025 .collect();
2026 let query = "title content body";
2027 let reranked = ce.rerank(query, candidates);
2028 assert_eq!(
2029 reranked.len(),
2030 5,
2031 "heuristic rerank must preserve candidate count, got {} = {:?}",
2032 reranked.len(),
2033 reranked
2034 .iter()
2035 .map(|(m, s)| (&m.title, *s))
2036 .collect::<Vec<_>>()
2037 );
2038 // Sorted descending by final score (rerank contract).
2039 for w in reranked.windows(2) {
2040 assert!(
2041 w[0].1 >= w[1].1,
2042 "rerank output must be descending by score: {} < {}",
2043 w[0].1,
2044 w[1].1
2045 );
2046 }
2047 }
2048
2049 #[test]
2050 fn test_rerank_zero_candidates_returns_empty_heuristic() {
2051 let ce = CrossEncoder::new();
2052 let reranked = ce.rerank("query", Vec::new());
2053 assert!(reranked.is_empty());
2054 }
2055
2056 // Neural variant: gated to avoid pulling 80MB BERT weights at test time.
2057 // Run with `--features test-with-models` once the cross-encoder feature
2058 // exists upstream.
2059 #[cfg(feature = "test-with-models")]
2060 #[test]
2061 fn test_rerank_preserves_input_count_neural_if_available() {
2062 let ce = CrossEncoder::new_neural();
2063 let candidates: Vec<(Memory, f64)> = (0..5)
2064 .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
2065 .collect();
2066 let reranked = ce.rerank("body", candidates);
2067 assert_eq!(reranked.len(), 5);
2068 }
2069
2070 // -----------------------------------------------------------------
2071 // W12-E — heuristic-path branch coverage for reranker.rs
2072 //
2073 // Targets the Lexical variant only. The Neural variant requires
2074 // downloading 80+ MB of BERT weights from HuggingFace Hub and is
2075 // gated behind `feature = "test-with-models"`.
2076 // -----------------------------------------------------------------
2077
2078 #[test]
2079 fn w12e_default_is_lexical() {
2080 let ce = CrossEncoder::default();
2081 assert!(!ce.is_neural(), "Default::default() must return Lexical");
2082 }
2083
2084 #[test]
2085 fn w12e_new_returns_lexical() {
2086 let ce = CrossEncoder::new();
2087 assert!(!ce.is_neural());
2088 }
2089
2090 #[test]
2091 fn w12e_score_dispatch_lexical_matches_helper() {
2092 // The CrossEncoder::score() dispatcher must delegate to lexical_score()
2093 // for the Lexical variant. Compute both and assert exact equality.
2094 let ce = CrossEncoder::new();
2095 let q = "rust async runtime";
2096 let title = "Tokio: Rust async runtime";
2097 let content = "Tokio is an async runtime for the Rust programming language.";
2098 let via_dispatcher = ce.score(q, title, content);
2099 let direct = lexical_score(q, title, content);
2100 assert!((via_dispatcher - direct).abs() < f32::EPSILON);
2101 }
2102
2103 #[test]
2104 fn w12e_score_empty_inputs_safe() {
2105 let ce = CrossEncoder::new();
2106 // Empty query → 0.0 by short-circuit in lexical_score
2107 assert_eq!(ce.score("", "title", "content"), 0.0);
2108 // Empty title and content with non-empty query — must not panic
2109 let s = ce.score("query", "", "");
2110 assert!((0.0..=1.0).contains(&s));
2111 // Whitespace-only query treated as empty after tokenization
2112 let s_ws = ce.score(" \t\n", "title", "content");
2113 assert_eq!(s_ws, 0.0);
2114 // Punctuation-only query also yields no tokens
2115 let s_punct = ce.score("!?.,;:", "title", "content");
2116 assert_eq!(s_punct, 0.0);
2117 }
2118
2119 #[test]
2120 fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
2121 // Mixed Unicode tokens with apostrophes, accents, emoji boundaries.
2122 let s_unicode = lexical_score(
2123 "café résumé d'oeuvre",
2124 "Le Café d'Oeuvre",
2125 "résumé du café avec d'oeuvre noté",
2126 );
2127 assert!(
2128 (0.0..=1.0).contains(&s_unicode),
2129 "unicode score {s_unicode} out of bounds"
2130 );
2131
2132 // Very long content stresses the length-normalization branches.
2133 let huge = "alpha beta gamma delta ".repeat(2_500);
2134 let s_long = lexical_score("alpha gamma", "headline", &huge);
2135 assert!(
2136 (0.0..=1.0).contains(&s_long),
2137 "long score {s_long} out of bounds"
2138 );
2139 }
2140
2141 #[test]
2142 fn w12e_lexical_score_perfect_overlap_high() {
2143 // 100% query overlap with title and content should produce a high
2144 // (but bounded) score.
2145 let s = lexical_score(
2146 "alpha beta gamma",
2147 "alpha beta gamma",
2148 "alpha beta gamma alpha beta gamma",
2149 );
2150 assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
2151 assert!(s <= 1.0);
2152 }
2153
2154 #[test]
2155 fn w12e_tfidf_score_empty_doc_returns_zero() {
2156 // Branch: doc_tokens.is_empty() → 0.0 short-circuit.
2157 let q = vec!["alpha", "beta"];
2158 let doc: Vec<&str> = Vec::new();
2159 assert_eq!(tfidf_score(&q, &doc), 0.0);
2160 }
2161
2162 #[test]
2163 fn w12e_tfidf_score_empty_query_returns_zero() {
2164 // Branch: query_terms.is_empty() → 0.0 short-circuit.
2165 let q: Vec<&str> = Vec::new();
2166 let doc = vec!["alpha", "beta", "gamma"];
2167 assert_eq!(tfidf_score(&q, &doc), 0.0);
2168 }
2169
2170 #[test]
2171 fn w12e_tfidf_score_no_matching_terms() {
2172 // Query terms entirely absent from doc → tf == 0 continue branch.
2173 let q = vec!["xenon", "kryptonite"];
2174 let doc = vec!["alpha", "beta", "gamma"];
2175 let s = tfidf_score(&q, &doc);
2176 assert_eq!(s, 0.0);
2177 }
2178
2179 #[test]
2180 fn w12e_tfidf_score_partial_match_bounded() {
2181 // Mixed presence/absence; clamp branch reachable.
2182 let q = vec!["alpha", "missing"];
2183 let doc = vec!["alpha", "alpha", "beta", "gamma"];
2184 let s = tfidf_score(&q, &doc);
2185 assert!((0.0..=1.0).contains(&s));
2186 assert!(s > 0.0);
2187 }
2188
2189 #[test]
2190 fn w12e_bigrams_empty_and_single_and_multi() {
2191 // Empty input → empty bigram list.
2192 let empty: Vec<&str> = Vec::new();
2193 assert!(bigrams(&empty).is_empty());
2194
2195 // Single token → no bigrams (windows(2) yields nothing).
2196 let one = vec!["solo"];
2197 assert!(bigrams(&one).is_empty());
2198
2199 // Multi-token → N-1 bigrams.
2200 let three = vec!["a", "b", "c"];
2201 let bg = bigrams(&three);
2202 assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
2203 }
2204
2205 #[test]
2206 fn w12e_tokenize_handles_apostrophe_and_unicode() {
2207 // Apostrophes are preserved (e.g., "don't"), other punctuation splits.
2208 let toks = tokenize("don't stop, I won't!");
2209 assert!(toks.contains(&"don't"));
2210 assert!(toks.contains(&"won't"));
2211 assert!(toks.contains(&"stop"));
2212 assert!(toks.contains(&"I"));
2213
2214 // Pure-punctuation yields no tokens.
2215 let none = tokenize("!!!,,,;;;");
2216 assert!(none.is_empty());
2217
2218 // Empty string yields no tokens.
2219 let empty = tokenize("");
2220 assert!(empty.is_empty());
2221
2222 // Unicode alphanumerics survive (café = 4 alphanumeric chars).
2223 let unicode = tokenize("café résumé");
2224 assert_eq!(unicode.len(), 2);
2225 }
2226
2227 #[test]
2228 fn w12e_rerank_single_candidate_keeps_it() {
2229 let ce = CrossEncoder::new();
2230 let only = make_memory("solo title", "solo content body");
2231 let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
2232 assert_eq!(out.len(), 1);
2233 assert_eq!(out[0].0.title, "solo title");
2234 // Final score is a blend of original and CE score, both nonneg.
2235 assert!(out[0].1 >= 0.0);
2236 }
2237
2238 #[test]
2239 fn w12e_rerank_identical_originals_stable_under_score() {
2240 // When original scores are identical, ordering is determined by the
2241 // CE score. The candidate whose title/content overlaps the query
2242 // should rank first.
2243 let ce = CrossEncoder::new();
2244 let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
2245 let off_topic = make_memory("grocery", "milk eggs bread");
2246 let out = ce.rerank(
2247 "rust async",
2248 vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
2249 );
2250 assert_eq!(out.len(), 2);
2251 assert_eq!(out[0].0.title, "rust async runtime");
2252 }
2253
2254 #[test]
2255 fn w12e_rerank_descending_invariant_holds_across_shapes() {
2256 // Property-style: irrespective of input shape, output is sorted desc.
2257 let ce = CrossEncoder::new();
2258 let cands: Vec<(Memory, f64)> = vec![
2259 (make_memory("a", "alpha words"), 0.10),
2260 (make_memory("b", "beta words"), 0.95),
2261 (make_memory("c", "gamma alpha"), 0.55),
2262 (make_memory("d", ""), 0.0),
2263 (make_memory("", "empty title doc"), 0.30),
2264 ];
2265 let out = ce.rerank("alpha", cands);
2266 assert_eq!(out.len(), 5);
2267 for w in out.windows(2) {
2268 assert!(
2269 w[0].1 >= w[1].1,
2270 "non-descending pair: {} then {}",
2271 w[0].1,
2272 w[1].1
2273 );
2274 }
2275 }
2276
2277 #[test]
2278 fn w12e_lexical_score_no_title_branch_via_empty_title() {
2279 // Empty title means title_set is empty; title_bonus == 0.0.
2280 // query_set non-empty so the else branch (title_hits / |Q|) runs.
2281 let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
2282 let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
2283 assert!(s_with_title >= s_empty_title);
2284 assert!((0.0..=1.0).contains(&s_empty_title));
2285 }
2286
2287 #[test]
2288 fn w12e_lexical_score_query_terms_only_in_title() {
2289 // Title contains all query terms; content has none.
2290 let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
2291 assert!(s > 0.0);
2292 assert!(s <= 1.0);
2293 }
2294
2295 // PR-9i — buffer coverage uplift.
2296
2297 #[test]
2298 fn pr9i_new_neural_dual_outcome() {
2299 // Exercises CrossEncoder::new_neural() (lines 65-79). Behavior is
2300 // environment-dependent: with an HF cache or network the call
2301 // succeeds and returns Self::Neural; without either it falls back
2302 // to Self::Lexical via the documented eprintln + tracing warn
2303 // pathway. Both outcomes are acceptable — what matters is the
2304 // dispatch is hit. Functionally, both variants score within
2305 // [0.0, 1.0].
2306 let ce = CrossEncoder::new_neural();
2307 let s = ce.score("query", "title", "content");
2308 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
2309 }
2310
2311 // -----------------------------------------------------------------
2312 // v0.7 G9 — batched rerank parity + coalescer smoke tests
2313 // -----------------------------------------------------------------
2314
2315 #[test]
2316 fn g9_rerank_batch_matches_per_query_rerank_lexical() {
2317 // Spec: 3 queries × 5 candidates. Batched output must match
2318 // per-query rerank() output exactly for the deterministic Lexical
2319 // path. (Neural parity is gated behind `test-with-models`; the
2320 // implementation is symmetric — same blend, same sort.)
2321 let ce = CrossEncoder::new();
2322 let queries = vec!["alpha gamma", "beta words", "rust async"];
2323 let mut jobs: Vec<(String, Vec<(Memory, f64)>)> = Vec::new();
2324 let mut expected: Vec<Vec<(Memory, f64)>> = Vec::new();
2325 for q in &queries {
2326 let cands: Vec<(Memory, f64)> = (0..5)
2327 .map(|i| {
2328 (
2329 make_memory(
2330 &format!("title-{i}-{q}"),
2331 &format!("alpha beta gamma rust async body {i} {q}"),
2332 ),
2333 f64::from(i) * 0.1,
2334 )
2335 })
2336 .collect();
2337 expected.push(ce.rerank(q, cands.clone()));
2338 jobs.push(((*q).to_string(), cands));
2339 }
2340
2341 let batched = ce.rerank_batch(jobs);
2342 assert_eq!(batched.len(), expected.len());
2343 for (b, e) in batched.iter().zip(expected.iter()) {
2344 assert_eq!(b.len(), e.len());
2345 for (bi, ei) in b.iter().zip(e.iter()) {
2346 assert_eq!(bi.0.id, ei.0.id);
2347 assert_eq!(bi.0.title, ei.0.title);
2348 assert!(
2349 (bi.1 - ei.1).abs() < 1e-12,
2350 "blended score mismatch: batched={} per-query={}",
2351 bi.1,
2352 ei.1
2353 );
2354 }
2355 }
2356 }
2357
2358 #[test]
2359 fn g9_rerank_batch_single_query_short_circuits() {
2360 // Single-query batches must not regress vs rerank() — use the
2361 // single-query short-circuit path.
2362 let ce = CrossEncoder::new();
2363 let cands: Vec<(Memory, f64)> = (0..5)
2364 .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
2365 .collect();
2366 let direct = ce.rerank("body", cands.clone());
2367 let batched = ce.rerank_batch(vec![("body".to_string(), cands)]);
2368 assert_eq!(batched.len(), 1);
2369 assert_eq!(batched[0].len(), direct.len());
2370 for (a, b) in batched[0].iter().zip(direct.iter()) {
2371 assert_eq!(a.0.id, b.0.id);
2372 assert!((a.1 - b.1).abs() < 1e-12);
2373 }
2374 }
2375
2376 #[test]
2377 fn g9_rerank_batch_empty_inputs() {
2378 let ce = CrossEncoder::new();
2379 let out = ce.rerank_batch(Vec::new());
2380 assert!(out.is_empty());
2381
2382 // Multi-query but each has zero candidates.
2383 let out2 = ce.rerank_batch(vec![
2384 ("q1".to_string(), Vec::new()),
2385 ("q2".to_string(), Vec::new()),
2386 ]);
2387 assert_eq!(out2.len(), 2);
2388 assert!(out2.iter().all(std::vec::Vec::is_empty));
2389 }
2390
2391 #[test]
2392 fn g9_batched_reranker_serial_calls_match_rerank() {
2393 use super::BatchedReranker;
2394 let batched = BatchedReranker::new(CrossEncoder::new());
2395 let cands: Vec<(Memory, f64)> = (0..4)
2396 .map(|i| {
2397 (
2398 make_memory(
2399 &format!("t{i}"),
2400 &format!("alpha gamma body {i} content words"),
2401 ),
2402 f64::from(i) * 0.1,
2403 )
2404 })
2405 .collect();
2406 let direct = CrossEncoder::new().rerank("alpha", cands.clone());
2407 let via_batcher = batched.rerank("alpha", cands);
2408 assert_eq!(via_batcher.len(), direct.len());
2409 for (a, b) in via_batcher.iter().zip(direct.iter()) {
2410 assert_eq!(a.0.id, b.0.id);
2411 assert!((a.1 - b.1).abs() < 1e-12);
2412 }
2413 }
2414
2415 #[test]
2416 fn g9_batched_reranker_concurrent_calls_all_succeed() {
2417 use super::BatchedReranker;
2418 use std::sync::Arc;
2419 let batched = Arc::new(BatchedReranker::new(CrossEncoder::new()));
2420 let mut handles = Vec::new();
2421 for i in 0..8 {
2422 let b = Arc::clone(&batched);
2423 handles.push(std::thread::spawn(move || {
2424 let cands: Vec<(Memory, f64)> = (0..5)
2425 .map(|j| {
2426 (
2427 make_memory(
2428 &format!("t{i}-{j}"),
2429 &format!("body {j} alpha gamma rust"),
2430 ),
2431 0.5,
2432 )
2433 })
2434 .collect();
2435 let q = format!("alpha {i}");
2436 let out = b.rerank(&q, cands);
2437 assert_eq!(out.len(), 5);
2438 // Output is sorted descending.
2439 for w in out.windows(2) {
2440 assert!(w[0].1 >= w[1].1);
2441 }
2442 }));
2443 }
2444 for h in handles {
2445 h.join().expect("worker thread panicked");
2446 }
2447 }
2448
2449 /// #1579 B10 — the auto-select predicate: lexical NEVER batches
2450 /// (criterion: batched 7.6 ms vs direct 0.65 ms at N=8 — 12×
2451 /// inversion from the flush window); neural batches only at
2452 /// concurrency ≥ `BATCHED_RERANK_MIN_CONCURRENCY`.
2453 #[test]
2454 fn issue_1579_b10_auto_select_predicate() {
2455 use super::{BATCHED_RERANK_MIN_CONCURRENCY, use_batched_rerank_path};
2456 // Lexical: direct at every concurrency level.
2457 assert!(!use_batched_rerank_path(false, 1));
2458 assert!(!use_batched_rerank_path(false, 8));
2459 assert!(!use_batched_rerank_path(false, 1024));
2460 // Neural: lone caller goes direct (nothing to coalesce with)…
2461 assert!(!use_batched_rerank_path(true, 1));
2462 // …real concurrency keeps the G9 batched win.
2463 assert!(use_batched_rerank_path(
2464 true,
2465 BATCHED_RERANK_MIN_CONCURRENCY
2466 ));
2467 assert!(use_batched_rerank_path(true, 8));
2468 }
2469
2470 /// #1579 B10 — behavioral pin: a lexical `BatchedReranker` routes
2471 /// every call (serial AND concurrent) down the DIRECT path; the
2472 /// coalescing worker never sees a job. Pre-fix, all 8 concurrent
2473 /// lexical calls funneled through the worker and paid the 5 ms
2474 /// flush window per batch.
2475 #[test]
2476 fn issue_1579_b10_lexical_rerank_never_reaches_worker() {
2477 use super::BatchedReranker;
2478 use std::sync::Arc;
2479 let batched = Arc::new(BatchedReranker::new(CrossEncoder::new()));
2480 let mut handles = Vec::new();
2481 for i in 0..8 {
2482 let b = Arc::clone(&batched);
2483 handles.push(std::thread::spawn(move || {
2484 let cands: Vec<(Memory, f64)> = (0..5)
2485 .map(|j| {
2486 (
2487 make_memory(&format!("b10-{i}-{j}"), &format!("body {j} alpha gamma")),
2488 0.5,
2489 )
2490 })
2491 .collect();
2492 let out = b.rerank(&format!("alpha {i}"), cands);
2493 assert_eq!(out.len(), 5);
2494 }));
2495 }
2496 for h in handles {
2497 h.join().expect("worker thread panicked");
2498 }
2499 assert_eq!(
2500 batched.worker_submissions(),
2501 0,
2502 "lexical rerank must auto-select the direct path (no worker jobs)"
2503 );
2504 }
2505
2506 /// #1579 B10 — the forced coalesced path stays alive (both paths
2507 /// are kept per the remediation contract) and produces output
2508 /// byte-equal to the direct path on a lexical encoder.
2509 #[test]
2510 fn issue_1579_b10_forced_coalesced_path_matches_direct() {
2511 use super::BatchedReranker;
2512 let batched = BatchedReranker::new(CrossEncoder::new());
2513 let cands: Vec<(Memory, f64)> = (0..4)
2514 .map(|i| {
2515 (
2516 make_memory(
2517 &format!("b10-forced-{i}"),
2518 &format!("alpha gamma body {i} content words"),
2519 ),
2520 f64::from(i) * 0.1,
2521 )
2522 })
2523 .collect();
2524 let direct = batched.rerank("alpha", cands.clone());
2525 let coalesced = batched.rerank_coalesced("alpha", cands);
2526 assert_eq!(
2527 batched.worker_submissions(),
2528 1,
2529 "rerank_coalesced must route through the worker"
2530 );
2531 assert_eq!(coalesced.len(), direct.len());
2532 for (a, b) in coalesced.iter().zip(direct.iter()) {
2533 assert_eq!(a.0.id, b.0.id);
2534 assert!((a.1 - b.1).abs() < 1e-12);
2535 }
2536 }
2537
2538 #[test]
2539 fn pr9i_rerank_via_score_returns_blend() {
2540 // Even when new_neural() falls back to lexical, rerank() must
2541 // still produce a deterministic [0..1] blend. Pins the contract
2542 // for both branches of CrossEncoder::score().
2543 let ce = CrossEncoder::new_neural();
2544 let cands = vec![
2545 (
2546 Memory {
2547 id: "a".to_string(),
2548 tier: Tier::Mid,
2549 namespace: "ns".to_string(),
2550 title: "rust async runtime".to_string(),
2551 content: "tokio rust async".to_string(),
2552 tags: vec![],
2553 priority: 5,
2554 confidence: 1.0,
2555 source: "test".to_string(),
2556 access_count: 0,
2557 created_at: "2026-01-01T00:00:00Z".to_string(),
2558 updated_at: "2026-01-01T00:00:00Z".to_string(),
2559 last_accessed_at: None,
2560 expires_at: None,
2561 metadata: serde_json::json!({}),
2562 reflection_depth: 0,
2563 memory_kind: crate::models::MemoryKind::Observation,
2564 entity_id: None,
2565 persona_version: None,
2566 citations: Vec::new(),
2567 source_uri: None,
2568 source_span: None,
2569 confidence_source: crate::models::ConfidenceSource::CallerProvided,
2570 confidence_signals: None,
2571 confidence_decayed_at: None,
2572 version: 1,
2573 },
2574 0.6,
2575 ),
2576 (
2577 Memory {
2578 id: "b".to_string(),
2579 tier: Tier::Mid,
2580 namespace: "ns".to_string(),
2581 title: "grocery list".to_string(),
2582 content: "milk eggs".to_string(),
2583 tags: vec![],
2584 priority: 5,
2585 confidence: 1.0,
2586 source: "test".to_string(),
2587 access_count: 0,
2588 created_at: "2026-01-01T00:00:00Z".to_string(),
2589 updated_at: "2026-01-01T00:00:00Z".to_string(),
2590 last_accessed_at: None,
2591 expires_at: None,
2592 metadata: serde_json::json!({}),
2593 reflection_depth: 0,
2594 memory_kind: crate::models::MemoryKind::Observation,
2595 entity_id: None,
2596 persona_version: None,
2597 citations: Vec::new(),
2598 source_uri: None,
2599 source_span: None,
2600 confidence_source: crate::models::ConfidenceSource::CallerProvided,
2601 confidence_signals: None,
2602 confidence_decayed_at: None,
2603 version: 1,
2604 },
2605 0.4,
2606 ),
2607 ];
2608 let out = ce.rerank("rust async", cands);
2609 assert_eq!(out.len(), 2);
2610 for (_, score) in &out {
2611 assert!(score.is_finite());
2612 }
2613 // First entry's blended score >= second by sort contract.
2614 assert!(out[0].1 >= out[1].1);
2615 }
2616
2617 // ---------- Issue #1319 — reranker score floor (calibration) -----------
2618
2619 #[test]
2620 fn issue_1691_n14_score_floor_parse_grammar() {
2621 // #1691/n14 — the config/env parser that finally makes the
2622 // with_score_floor capability operator-reachable.
2623 assert_eq!(
2624 RerankerScoreFloor::parse("off"),
2625 Some(RerankerScoreFloor::Off)
2626 );
2627 assert_eq!(
2628 RerankerScoreFloor::parse(" OFF "),
2629 Some(RerankerScoreFloor::Off)
2630 );
2631 assert_eq!(
2632 RerankerScoreFloor::parse("absolute:0.3"),
2633 Some(RerankerScoreFloor::Absolute(0.3))
2634 );
2635 assert_eq!(
2636 RerankerScoreFloor::parse("ABS: 0.25"),
2637 Some(RerankerScoreFloor::Absolute(0.25))
2638 );
2639 assert_eq!(
2640 RerankerScoreFloor::parse("relative:0.5"),
2641 Some(RerankerScoreFloor::RelativeToTop(0.5))
2642 );
2643 assert_eq!(
2644 RerankerScoreFloor::parse("relative_to_top:0.8"),
2645 Some(RerankerScoreFloor::RelativeToTop(0.8))
2646 );
2647 // Unparseable values fall through (resolver then uses the next
2648 // precedence layer / the Off default).
2649 assert_eq!(RerankerScoreFloor::parse(""), None);
2650 assert_eq!(RerankerScoreFloor::parse("absolute"), None);
2651 assert_eq!(RerankerScoreFloor::parse("absolute:notanumber"), None);
2652 assert_eq!(RerankerScoreFloor::parse("bogus:0.5"), None);
2653 assert_eq!(RerankerScoreFloor::parse("absolute:inf"), None);
2654 }
2655
2656 /// Issue #1319 — `RerankerScoreFloor::Off` is the default and a
2657 /// no-op. Pre-#1319 callers see byte-identical output through the
2658 /// new `apply` helper.
2659 #[test]
2660 fn reranker_score_floor_default_is_off_1319() {
2661 let floor = RerankerScoreFloor::default();
2662 assert_eq!(floor, RerankerScoreFloor::Off);
2663 let mut scored = vec![
2664 (make_memory("a", "x"), 0.9_f64),
2665 (make_memory("b", "y"), 0.4_f64),
2666 (make_memory("c", "z"), 0.1_f64),
2667 ];
2668 let before = scored.clone();
2669 floor.apply(&mut scored);
2670 assert_eq!(scored.len(), before.len());
2671 for (i, (mem, s)) in scored.iter().enumerate() {
2672 assert_eq!(mem.title, before[i].0.title);
2673 assert!((s - before[i].1).abs() < f64::EPSILON);
2674 }
2675 }
2676
2677 /// Issue #1319 — absolute floor drops the tail. Top row is
2678 /// preserved even when its score happens to fall below the floor
2679 /// (small-corpus safety so a 1-row recall never returns nothing).
2680 #[test]
2681 fn reranker_score_floor_absolute_drops_tail_1319() {
2682 let floor = RerankerScoreFloor::Absolute(0.5);
2683 let mut scored = vec![
2684 (make_memory("top", "x"), 0.90_f64),
2685 (make_memory("mid", "y"), 0.60_f64),
2686 (make_memory("low", "z"), 0.30_f64),
2687 (make_memory("noise", "n"), 0.10_f64),
2688 ];
2689 floor.apply(&mut scored);
2690 // top + mid kept; low + noise dropped.
2691 let titles: Vec<&str> = scored.iter().map(|(m, _)| m.title.as_str()).collect();
2692 assert_eq!(titles, vec!["top", "mid"]);
2693 }
2694
2695 /// Issue #1319 — relative floor preserves the head and drops
2696 /// candidates below `top_score * ratio`.
2697 #[test]
2698 fn reranker_score_floor_relative_drops_tail_1319() {
2699 let floor = RerankerScoreFloor::RelativeToTop(0.5);
2700 // top_score = 0.80, cutoff = 0.40.
2701 let mut scored = vec![
2702 (make_memory("top", "x"), 0.80_f64),
2703 (make_memory("kept", "y"), 0.50_f64),
2704 (make_memory("dropped_1", "z"), 0.35_f64),
2705 (make_memory("dropped_2", "z"), 0.20_f64),
2706 ];
2707 floor.apply(&mut scored);
2708 let titles: Vec<&str> = scored.iter().map(|(m, _)| m.title.as_str()).collect();
2709 assert_eq!(titles, vec!["top", "kept"]);
2710 }
2711
2712 /// Issue #1319 — top row is preserved even when the absolute
2713 /// floor sits above every blended score. A tiny corpus that all
2714 /// scored at 0.20 must still surface its top hit, not return
2715 /// empty.
2716 #[test]
2717 fn reranker_score_floor_preserves_top_row_when_everything_below_1319() {
2718 let floor = RerankerScoreFloor::Absolute(0.5);
2719 let mut scored = vec![
2720 (make_memory("apollo", "moon landing"), 0.20_f64),
2721 (make_memory("recall", "blends fts and semantic"), 0.10_f64),
2722 ];
2723 floor.apply(&mut scored);
2724 assert_eq!(scored.len(), 1);
2725 assert_eq!(scored[0].0.title, "apollo");
2726 }
2727
2728 /// Issue #1319 — empty input is a no-op (no panic on `.first()`).
2729 #[test]
2730 fn reranker_score_floor_handles_empty_1319() {
2731 let floor = RerankerScoreFloor::Absolute(0.5);
2732 let mut scored: Vec<(Memory, f64)> = vec![];
2733 floor.apply(&mut scored);
2734 assert!(scored.is_empty());
2735 }
2736
2737 /// Issue #1319 — v1 P5 probe surfaced a paraphrase-aware corpus
2738 /// where an Apollo-11 row scored 0.479 above a
2739 /// substantively-relevant recall-mechanics row at 0.363 with
2740 /// nothing visible to the operator that would have explained the
2741 /// ordering. This regression test reconstructs the empirical
2742 /// situation (disjoint-vocab paraphrase query — query terms appear
2743 /// in neither candidate's title or content) and asserts that, with
2744 /// an operator-opt-in `RerankerScoreFloor::Absolute(0.40)`, the
2745 /// Apollo-11 false positive is dropped while the head ranking is
2746 /// preserved.
2747 ///
2748 /// **Why the floor matters here.** With the lexical CE, both
2749 /// candidates score 0.0 on the paraphrase query (disjoint vocab).
2750 /// The blend `0.6 * original + 0.4 * 0.0` reduces to `0.6 * original`,
2751 /// so the empirical ordering is set entirely by the upstream
2752 /// `original` score. The substrate cannot reorder them away from
2753 /// the noise — but it CAN expose an operator handle that drops
2754 /// the entire tail below a threshold the operator chose. That's
2755 /// what `RerankerScoreFloor` provides.
2756 #[test]
2757 fn reranker_v1_p5_paraphrase_noise_dropped_by_floor_1319() {
2758 let ce = CrossEncoder::new(); // lexical, deterministic.
2759 let apollo = make_memory(
2760 "Apollo 11 moon landing",
2761 "Neil Armstrong walked on the moon in 1969.",
2762 );
2763 let recall_b = make_memory(
2764 "Recall blends FTS and semantic scores",
2765 "The hybrid pipeline weighs cosine vs BM25 then reranks the top-k.",
2766 );
2767
2768 // Empirical pre-#1319 shape: upstream hybrid retrieval scored
2769 // Apollo above recall_b. The exact numbers mirror the v1 P5
2770 // probe (Apollo 0.479, recall_b 0.363) so the test reads as
2771 // the operator-observed evidence on the issue.
2772 let candidates = vec![(apollo.clone(), 0.479_f64), (recall_b.clone(), 0.363_f64)];
2773
2774 // Operator query: a paraphrase that lexically misses both
2775 // candidates ("what makes a recall implementation good?").
2776 // Lexical CE produces 0 for both, so the blend reduces to
2777 // `0.6 * original`.
2778 let query = "what makes a recall implementation good?";
2779
2780 // Sanity: pre-floor, Apollo still sits on top — the
2781 // substrate has no way to reorder paraphrase-disjoint
2782 // candidates without semantic input from upstream.
2783 let pre = ce.rerank(query, candidates.clone());
2784 assert_eq!(pre[0].0.title, "Apollo 11 moon landing");
2785 // Blended top score = 0.6 * 0.479 = 0.2874 (paraphrase noise band).
2786 assert!(pre[0].1 < 0.30, "top score in noise band: {}", pre[0].1);
2787
2788 // Post-#1319 with absolute floor at 0.40: the entire tail is
2789 // dropped EXCEPT the top row (preserved per the small-corpus
2790 // safety rule). The operator now sees a single result and can
2791 // judge "noise" vs "this is genuinely the best the substrate
2792 // has" without an Apollo-11 false positive sitting beneath it
2793 // at 0.218.
2794 let mut post = pre.clone();
2795 RerankerScoreFloor::Absolute(0.40).apply(&mut post);
2796 assert_eq!(
2797 post.len(),
2798 1,
2799 "floor at 0.40 must drop tail when blended scores in noise band: {post:?}"
2800 );
2801 // Top preserved.
2802 assert_eq!(post[0].0.title, "Apollo 11 moon landing");
2803 }
2804
2805 /// Issue #1319 — `BatchedReranker::with_score_floor` plumbs the
2806 /// operator-opt-in floor end-to-end through the batched worker.
2807 /// Pinned via the wrapper so future refactors of the worker
2808 /// pipeline can't silently bypass the floor.
2809 #[test]
2810 fn batched_reranker_score_floor_plumbed_end_to_end_1319() {
2811 use super::BatchedReranker;
2812 let batched = BatchedReranker::with_score_floor(
2813 CrossEncoder::new(),
2814 RerankerScoreFloor::Absolute(0.40),
2815 );
2816 assert_eq!(batched.score_floor(), RerankerScoreFloor::Absolute(0.40));
2817
2818 let apollo = make_memory("Apollo 11 moon landing", "Armstrong, 1969");
2819 let recall_b = make_memory(
2820 "Recall blends FTS and semantic scores",
2821 "hybrid pipeline weighs cosine vs BM25",
2822 );
2823 let candidates = vec![(apollo, 0.479_f64), (recall_b, 0.363_f64)];
2824 let out = batched.rerank("paraphrase miss query", candidates);
2825 // Default daemon path uses `BatchedReranker::new` (floor Off),
2826 // so existing behavior is preserved — only the opt-in
2827 // constructor plumbs the floor.
2828 assert_eq!(out.len(), 1, "score floor must drop tail: {out:?}");
2829 }
2830
2831 /// Issue #1319 — the existing `BatchedReranker::new` path leaves
2832 /// the floor at `Off`, preserving pre-#1319 byte-equality for
2833 /// every daemon that has not opted in.
2834 #[test]
2835 fn batched_reranker_default_constructor_leaves_floor_off_1319() {
2836 use super::BatchedReranker;
2837 let batched = BatchedReranker::new(CrossEncoder::new());
2838 assert_eq!(batched.score_floor(), RerankerScoreFloor::Off);
2839 }
2840}
2841
2842#[cfg(test)]
2843#[allow(
2844 clippy::unused_self,
2845 clippy::unnecessary_wraps,
2846 clippy::needless_pass_by_value,
2847 clippy::wildcard_imports
2848)]
2849pub mod test_support {
2850 use super::*;
2851
2852 /// Mock neural cross-encoder for testing. Returns deterministic scores
2853 /// based on (query, title, content) without loading BERT.
2854 pub struct MockCrossEncoder {
2855 pub use_neural: bool,
2856 }
2857
2858 impl MockCrossEncoder {
2859 /// Create a mock lexical encoder (like CrossEncoder::new()).
2860 pub fn new() -> Self {
2861 Self { use_neural: false }
2862 }
2863
2864 /// Create a mock neural encoder (like CrossEncoder::new_neural()).
2865 pub fn new_neural() -> Self {
2866 Self { use_neural: true }
2867 }
2868
2869 /// Mock score: deterministic hash-based score in [0, 1].
2870 /// Neural path uses a different formula than lexical for testing.
2871 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
2872 if self.use_neural {
2873 // Neural mock: combine query+title hash
2874 let combined = format!("{}{}", query, title);
2875 let hash = combined.bytes().fold(0u32, |acc, b| {
2876 acc.wrapping_mul(31).wrapping_add(u32::from(b))
2877 });
2878 let base = ((hash % 1000) as f32) / 1000.0;
2879 // Boost for exact title matches
2880 if title.contains(query) {
2881 (base * 0.5 + 0.5).min(1.0)
2882 } else {
2883 base
2884 }
2885 } else {
2886 // Lexical path uses the real lexical_score
2887 lexical_score(query, title, content)
2888 }
2889 }
2890
2891 /// Whether this is a neural mock.
2892 pub fn is_neural(&self) -> bool {
2893 self.use_neural
2894 }
2895
2896 /// Rerank candidates (same blending formula as real CrossEncoder).
2897 pub fn rerank(
2898 &self,
2899 query: &str,
2900 mut candidates: Vec<(Memory, f64)>,
2901 ) -> Vec<(Memory, f64)> {
2902 let mut scored: Vec<(Memory, f64)> = candidates
2903 .drain(..)
2904 .map(|(mem, original_score)| {
2905 let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
2906 let final_score =
2907 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
2908 (mem, final_score)
2909 })
2910 .collect();
2911
2912 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2913 scored
2914 }
2915 }
2916
2917 impl Default for MockCrossEncoder {
2918 fn default() -> Self {
2919 Self::new()
2920 }
2921 }
2922}
2923
2924#[cfg(test)]
2925mod mock_tests {
2926 use super::test_support::*;
2927 use super::{BatchedReranker, CrossEncoder};
2928 use crate::models::{Memory, Tier};
2929 use std::time::Duration;
2930
2931 fn make_memory(title: &str, content: &str) -> Memory {
2932 Memory {
2933 id: "test-id".to_string(),
2934 tier: Tier::Mid,
2935 namespace: "test".to_string(),
2936 title: title.to_string(),
2937 content: content.to_string(),
2938 tags: vec![],
2939 priority: 5,
2940 confidence: 1.0,
2941 source: "test".to_string(),
2942 access_count: 0,
2943 created_at: "2026-01-01T00:00:00Z".to_string(),
2944 updated_at: "2026-01-01T00:00:00Z".to_string(),
2945 last_accessed_at: None,
2946 expires_at: None,
2947 metadata: serde_json::json!({}),
2948 reflection_depth: 0,
2949 memory_kind: crate::models::MemoryKind::Observation,
2950 entity_id: None,
2951 persona_version: None,
2952 citations: Vec::new(),
2953 source_uri: None,
2954 source_span: None,
2955 confidence_source: crate::models::ConfidenceSource::CallerProvided,
2956 confidence_signals: None,
2957 confidence_decayed_at: None,
2958 version: 1,
2959 }
2960 }
2961
2962 #[test]
2963 fn mock_lexical_new() {
2964 let ce = MockCrossEncoder::new();
2965 assert!(!ce.is_neural());
2966 }
2967
2968 #[test]
2969 fn mock_neural_new() {
2970 let ce = MockCrossEncoder::new_neural();
2971 assert!(ce.is_neural());
2972 }
2973
2974 #[test]
2975 fn mock_neural_score_deterministic() {
2976 let ce = MockCrossEncoder::new_neural();
2977 let s1 = ce.score("query", "title", "content");
2978 let s2 = ce.score("query", "title", "content");
2979 assert_eq!(s1, s2);
2980 }
2981
2982 #[test]
2983 fn mock_neural_score_title_match_boost() {
2984 let ce = MockCrossEncoder::new_neural();
2985 let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
2986 let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
2987 assert!(
2988 s_title_contains > s_no_match,
2989 "title match ({s_title_contains}) should beat no match ({s_no_match})"
2990 );
2991 }
2992
2993 #[test]
2994 fn mock_neural_score_bounded() {
2995 let ce = MockCrossEncoder::new_neural();
2996 for query in &["test", "neural", "reranker", "machine learning"] {
2997 for title in &["a", "b", "the quick brown"] {
2998 let s = ce.score(query, title, "content");
2999 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
3000 }
3001 }
3002 }
3003
3004 #[test]
3005 fn mock_neural_rerank_reorders() {
3006 let ce = MockCrossEncoder::new_neural();
3007 let a = make_memory("neural network", "deep learning with transformers");
3008 let b = make_memory("grocery list", "milk eggs bread butter");
3009 let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
3010 let reranked = ce.rerank("neural network", candidates);
3011 // Neural encoder should boost the neural-network-titled memory
3012 assert_eq!(reranked[0].0.title, "neural network");
3013 }
3014
3015 #[test]
3016 fn mock_neural_rerank_preserves_count() {
3017 let ce = MockCrossEncoder::new_neural();
3018 let candidates = vec![
3019 (make_memory("A", "content a"), 0.5),
3020 (make_memory("B", "content b"), 0.4),
3021 (make_memory("C", "content c"), 0.6),
3022 ];
3023 let reranked = ce.rerank("test", candidates);
3024 assert_eq!(reranked.len(), 3);
3025 }
3026
3027 #[test]
3028 fn mock_lexical_path_via_mock() {
3029 let ce = MockCrossEncoder::new();
3030 let s = ce.score(
3031 "network adapter",
3032 "Network Configuration",
3033 "the network adapter is connected",
3034 );
3035 assert!((0.0..=1.0).contains(&s));
3036 }
3037
3038 #[test]
3039 fn mock_neural_different_from_lexical() {
3040 let lexical = MockCrossEncoder::new();
3041 let neural = MockCrossEncoder::new_neural();
3042 let s_lex = lexical.score("machine learning", "ML title", "neural networks");
3043 let s_neu = neural.score("machine learning", "ML title", "neural networks");
3044 // They should use different scoring formulas
3045 assert_ne!(s_lex, s_neu);
3046 }
3047
3048 // -----------------------------------------------------------------
3049 // H2 (v0.7.0 round-2) — worker-thread shutdown discipline.
3050 //
3051 // Contract: spawning a `BatchedReranker` and dropping it
3052 // immediately must terminate the worker thread within a bounded
3053 // wall-clock window. Without an explicit shutdown channel, a
3054 // worker that was blocked in `rx.recv()` would only exit on
3055 // sender disconnect; the explicit signal closes the worst-case
3056 // (e.g. a stashed `Sender` clone) and bounds the shutdown
3057 // latency by the worker's SHUTDOWN_POLL cadence.
3058 // -----------------------------------------------------------------
3059 #[test]
3060 fn h2_drop_terminates_worker_within_500ms() {
3061 use std::time::Instant;
3062 let reranker = BatchedReranker::new(CrossEncoder::new());
3063 // Capture the JoinHandle by exfiltrating it BEFORE drop so we
3064 // can observe thread termination from the outside. We
3065 // re-implement the Drop body inline for the assertion: fire
3066 // shutdown, drop sender, join with a wall-clock budget.
3067 let mut r = reranker;
3068 let shutdown = r.shutdown.take().expect("shutdown sender present");
3069 let worker = r.worker.take().expect("worker handle present");
3070 // Drop the work-channel sender first to mimic the same
3071 // disconnect semantics the production Drop sequence
3072 // produces.
3073 r.sender.take();
3074 let start = Instant::now();
3075 let _ = shutdown.send(());
3076 // Spawn the join on a side thread so we can apply a hard
3077 // wall-clock budget. `JoinHandle::join` does not take a
3078 // timeout, so the side-thread + park-with-deadline form is
3079 // the idiomatic Rust pattern.
3080 let (done_tx, done_rx) = std::sync::mpsc::channel::<()>();
3081 std::thread::spawn(move || {
3082 let _ = worker.join();
3083 let _ = done_tx.send(());
3084 });
3085 let observed = done_rx
3086 .recv_timeout(Duration::from_millis(500))
3087 .map(|()| Instant::now().duration_since(start));
3088 assert!(
3089 observed.is_ok(),
3090 "BatchedReranker worker did not terminate within 500ms after \
3091 explicit shutdown — observed: {observed:?}"
3092 );
3093 }
3094}
3095
3096#[test]
3097fn score_handles_empty_query_string() {
3098 let s = lexical_score("", "Document Title", "This is document content");
3099 assert_eq!(s, 0.0, "empty query must return 0.0");
3100}
3101
3102#[test]
3103fn score_handles_unicode_normalization() {
3104 // Query with accented characters, document with decomposed/composed variants
3105 let s1 = lexical_score("café", "café", "the café is open");
3106 let s2 = lexical_score("cafe", "cafe", "the cafe is open");
3107 // Both should score positively; exact equality not required due to normalization
3108 assert!(s1 > 0.0);
3109 assert!(s2 > 0.0);
3110}
3111
3112#[test]
3113fn score_handles_very_long_content_truncation() {
3114 // Query and document with extreme length (lexical tokenizer should handle it)
3115 let long_content = "word ".repeat(10000); // 50k+ chars
3116 let s = lexical_score("word", "title", &long_content);
3117 assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
3118}
3119
3120#[test]
3121fn bigram_score_with_single_token_query() {
3122 // Query with only one token — bigrams should be empty, no crash
3123 let s = lexical_score("query", "Single Token Title", "single token content");
3124 assert!((0.0..=1.0).contains(&s));
3125}
3126
3127#[cfg(test)]
3128mod issue_1597_tests {
3129 //! #1597 — rerank pool cap + batched cross-encoder forward pass.
3130 //!
3131 //! The counting-mock route is unavailable: `MockCrossEncoder` is a
3132 //! standalone test struct, not a pluggable `CrossEncoder` variant,
3133 //! so call counts cannot be observed through the production enum.
3134 //! Instead the cap is pinned via score mutation: with a query that
3135 //! shares zero tokens with every candidate, the lexical
3136 //! cross-encoder scores every scored pair `0.0`, so a cross-encoded
3137 //! candidate's final score becomes EXACTLY `ORIGINAL_WEIGHT * orig`
3138 //! while an uncapped candidate keeps `orig` bit-for-bit — making
3139 //! "exactly RERANK_POOL_MAX candidates were cross-encoded"
3140 //! observable from the output alone.
3141
3142 use super::*;
3143 use crate::models::Memory;
3144
3145 /// Query with zero token overlap against [`pool_memory`] docs —
3146 /// lexical cross-encoder score is exactly 0.0 for every pair.
3147 const NO_OVERLAP_QUERY: &str = "zzz qqq www";
3148
3149 fn pool_memory(i: i32) -> Memory {
3150 Memory {
3151 id: format!("cand-{i}"),
3152 title: format!("alpha {i}"),
3153 content: format!("beta gamma {i}"),
3154 ..Memory::default()
3155 }
3156 }
3157
3158 /// `n` candidates with distinct ascending original scores
3159 /// `0.01 * (i + 1)`, supplied in ASCENDING order so the cap's
3160 /// pre-sort is load-bearing (not a pass-through of input order).
3161 fn pool(n: i32) -> Vec<(Memory, f64)> {
3162 (0..n)
3163 .map(|i| (pool_memory(i), f64::from(i + 1) * 0.01))
3164 .collect()
3165 }
3166
3167 fn orig_score(i: i32) -> f64 {
3168 f64::from(i + 1) * 0.01
3169 }
3170
3171 /// Pool of 50 → exactly [`RERANK_POOL_MAX`] candidates get
3172 /// cross-encoder scores (their final scores move to
3173 /// `ORIGINAL_WEIGHT * orig`); the other 30 keep their blended
3174 /// scores bit-for-bit and sort below the reranked head. No
3175 /// candidate is lost.
3176 #[test]
3177 fn rerank_pool_cap_honored_1597() {
3178 let ce = CrossEncoder::Lexical { degraded: false };
3179 let n = 50;
3180 let out = ce.rerank(NO_OVERLAP_QUERY, pool(n));
3181
3182 assert_eq!(out.len(), 50, "no candidate may be lost");
3183 let ids: std::collections::HashSet<&str> = out.iter().map(|(m, _)| m.id.as_str()).collect();
3184 assert_eq!(ids.len(), 50, "no duplicate / dropped ids");
3185
3186 // Head: the top RERANK_POOL_MAX by original score (i = 30..49,
3187 // descending), each cross-encoded → ORIGINAL_WEIGHT * orig.
3188 for (rank, (mem, score)) in out.iter().take(RERANK_POOL_MAX).enumerate() {
3189 let i = 49 - i32::try_from(rank).expect("rank fits i32");
3190 assert_eq!(mem.id, format!("cand-{i}"), "head rank {rank}");
3191 assert!(
3192 (score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
3193 "head rank {rank} must carry the cross-encoded blend"
3194 );
3195 }
3196
3197 // Tail: the remaining 30 (i = 29..0, descending), blended
3198 // scores untouched (bit-for-bit the input score).
3199 for (off, (mem, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
3200 let i = 29 - i32::try_from(off).expect("offset fits i32");
3201 assert_eq!(mem.id, format!("cand-{i}"), "tail offset {off}");
3202 assert_eq!(
3203 *score,
3204 orig_score(i),
3205 "tail offset {off} must keep its blended score untouched"
3206 );
3207 }
3208 }
3209
3210 /// Order correctness: reranked head internally sorted descending,
3211 /// tail internally sorted descending, tail strictly after the head.
3212 #[test]
3213 fn rerank_pool_cap_order_correctness_1597() {
3214 let ce = CrossEncoder::Lexical { degraded: false };
3215 let out = ce.rerank(NO_OVERLAP_QUERY, pool(50));
3216 let head = &out[..RERANK_POOL_MAX];
3217 let tail = &out[RERANK_POOL_MAX..];
3218 assert!(
3219 head.windows(2).all(|w| w[0].1 >= w[1].1),
3220 "reranked head must be sorted descending"
3221 );
3222 assert!(
3223 tail.windows(2).all(|w| w[0].1 >= w[1].1),
3224 "uncapped tail must be sorted descending"
3225 );
3226 // Every tail member's ORIGINAL score is below every head
3227 // member's original score (the cap kept the strongest pool).
3228 let min_head_orig = orig_score(30);
3229 assert!(
3230 tail.iter().all(|(_, s)| *s < min_head_orig),
3231 "tail must hold only candidates the cap excluded"
3232 );
3233 }
3234
3235 /// Pool exactly at the cap → full rerank (tail empty): every
3236 /// candidate is cross-encoded.
3237 #[test]
3238 fn rerank_pool_at_cap_fully_cross_encoded_1597() {
3239 let ce = CrossEncoder::Lexical { degraded: false };
3240 let n = i32::try_from(RERANK_POOL_MAX).expect("cap fits i32");
3241 let out = ce.rerank(NO_OVERLAP_QUERY, pool(n));
3242 assert_eq!(out.len(), RERANK_POOL_MAX);
3243 for (rank, (_, score)) in out.iter().enumerate() {
3244 let i = n - 1 - i32::try_from(rank).expect("rank fits i32");
3245 assert!(
3246 (score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
3247 "at-cap pool: rank {rank} must be cross-encoded"
3248 );
3249 }
3250 }
3251
3252 /// Cap > pool size degenerates to the historical full rerank.
3253 #[test]
3254 fn rerank_cap_gt_pool_degenerates_to_full_rerank_1597() {
3255 let ce = CrossEncoder::Lexical { degraded: false };
3256 let out = ce.rerank(NO_OVERLAP_QUERY, pool(5));
3257 assert_eq!(out.len(), 5);
3258 for (rank, (_, score)) in out.iter().enumerate() {
3259 let i = 4 - i32::try_from(rank).expect("rank fits i32");
3260 assert!(
3261 (score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
3262 "small pool: rank {rank} must be cross-encoded (no tail)"
3263 );
3264 }
3265 }
3266
3267 /// The G9 multi-query batch path applies the cap per query job.
3268 #[test]
3269 fn rerank_batch_applies_pool_cap_per_query_1597() {
3270 let ce = CrossEncoder::Lexical { degraded: false };
3271 let jobs = vec![
3272 (NO_OVERLAP_QUERY.to_string(), pool(50)),
3273 (NO_OVERLAP_QUERY.to_string(), pool(50)),
3274 ];
3275 let outs = ce.rerank_batch(jobs);
3276 assert_eq!(outs.len(), 2);
3277 for out in &outs {
3278 assert_eq!(out.len(), 50, "per-job candidate count preserved");
3279 for (off, (_, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
3280 let i = 29 - i32::try_from(off).expect("offset fits i32");
3281 assert_eq!(
3282 *score,
3283 orig_score(i),
3284 "per-job tail must keep blended scores untouched"
3285 );
3286 }
3287 }
3288 }
3289
3290 /// The `BatchedReranker` production wrapper inherits the cap via
3291 /// the direct encoder path (lexical traffic never reaches the
3292 /// coalescing worker per #1579 B10).
3293 #[test]
3294 fn batched_reranker_inherits_pool_cap_1597() {
3295 let br = BatchedReranker::with_reflection_boost(
3296 CrossEncoder::Lexical { degraded: false },
3297 ReflectionBoostConfig::disabled(),
3298 );
3299 let out = br.rerank(NO_OVERLAP_QUERY, pool(50));
3300 assert_eq!(out.len(), 50);
3301 for (off, (_, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
3302 let i = 29 - i32::try_from(off).expect("offset fits i32");
3303 assert_eq!(*score, orig_score(i), "wrapper tail untouched");
3304 }
3305 }
3306
3307 /// #1597 bench evidence — manual run against the REAL neural
3308 /// cross-encoder (resolves from the local HF cache; downloads
3309 /// ~80 MB on a cold host):
3310 ///
3311 /// ```bash
3312 /// AI_MEMORY_NO_CONFIG=1 cargo test --release --lib \
3313 /// issue_1597_neural_rerank_timing_evidence -- --ignored --nocapture
3314 /// ```
3315 ///
3316 /// Prints BEFORE (sequential per-pair forward over the full
3317 /// 50-candidate pool — the pre-#1597 `rerank` shape) vs AFTER
3318 /// (capped pool + one batched forward — the shipped path).
3319 #[test]
3320 #[ignore = "#1597 manual bench evidence: loads the real neural cross-encoder"]
3321 fn issue_1597_neural_rerank_timing_evidence() {
3322 let ce = CrossEncoder::new_neural();
3323 assert!(
3324 ce.is_neural(),
3325 "neural encoder failed to load; timing evidence invalid"
3326 );
3327 let bench_pool: Vec<(Memory, f64)> = (0..50)
3328 .map(|i| {
3329 let m = Memory {
3330 id: format!("bench-{i}"),
3331 title: format!("benchmark candidate number {i} recall pipeline"),
3332 content: format!(
3333 "long-form benchmark document body number {i} with enough \
3334 material to exercise the cross-encoder, covering recall \
3335 pipeline reranking, cross encoder scoring, candidate \
3336 blending and ordering semantics for run {i}"
3337 ),
3338 ..Memory::default()
3339 };
3340 (m, f64::from(i) * 0.01)
3341 })
3342 .collect();
3343 let query = "how does the recall pipeline rerank candidates";
3344
3345 // Warm-up (first forward pays one-time allocation cost).
3346 let _ = ce.score(query, "warmup", "warmup body");
3347
3348 // BEFORE shape: one full forward per (query, candidate) pair,
3349 // sequentially, over the entire 50-candidate pool.
3350 let t0 = Instant::now();
3351 for (m, _) in &bench_pool {
3352 let _ = ce.score(query, &m.title, &m.content);
3353 }
3354 let before = t0.elapsed();
3355
3356 // AFTER: shipped path — cap at RERANK_POOL_MAX + single
3357 // batched forward.
3358 let t1 = Instant::now();
3359 let out = ce.rerank(query, bench_pool.clone());
3360 let after = t1.elapsed();
3361
3362 assert_eq!(out.len(), 50, "no candidate lost on the neural path");
3363 eprintln!(
3364 "#1597 timing (50-candidate pool, CPU): BEFORE sequential-full = {before:?}; \
3365 AFTER capped+batched = {after:?}"
3366 );
3367 }
3368}