ai_memory/reranker.rs
1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! Cross-encoder reranking for search results.
5//!
6//! A cross-encoder takes a (query, document) pair and produces a relevance
7//! score. This is more accurate than cosine similarity of independent
8//! embeddings but slower since it must run for each candidate.
9//!
10//! **Two implementations:**
11//! - `CrossEncoder::Lexical` — lightweight term-overlap scorer (default).
12//! - `CrossEncoder::Neural` — BERT-based cross-encoder loaded via candle
13//! from `cross-encoder/ms-marco-MiniLM-L-6-v2` (~80 MB, ONNX-free).
14
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::sync::mpsc::{Sender, sync_channel};
17use std::sync::{Arc, Mutex};
18use std::thread::{self, JoinHandle};
19use std::time::{Duration, Instant};
20
21use anyhow::{Context, Result};
22use candle_core::{Device, Tensor};
23use candle_nn::VarBuilder;
24use candle_transformers::models::bert::{BertModel, Config as BertConfig};
25use hf_hub::{Repo, RepoType, api::sync::Api};
26use tokenizers::Tokenizer;
27
28use crate::models::Memory;
29
30// ---------------------------------------------------------------------------
31// v0.7.0 (issue #518) — session-aware recall recency boost
32// ---------------------------------------------------------------------------
33
34/// Additive boost applied to a recall candidate that appears in the
35/// session's recently-accessed set. Sits at +0.05 — small enough that
36/// a low-relevance candidate cannot leapfrog a substantially-better
37/// match, large enough to break ties in favour of memories the agent
38/// just touched in the same session.
39pub const SESSION_RECENCY_BOOST: f64 = 0.05;
40
41/// Per-session cap on the recently-accessed ring buffer. When the
42/// buffer is at the cap, the oldest entry is evicted (FIFO) before the
43/// newest entry is appended. Keeps the substrate memory cost bounded
44/// at `O(SESSIONS * 50)` ids regardless of recall traffic.
45pub const SESSION_RECENT_CAP: usize = 50;
46
47/// v0.7.0 (issue #518) — process-global tracker mapping `session_id`
48/// to its FIFO ring buffer of recently-accessed memory ids.
49///
50/// The tracker is consulted by [`apply_session_recency_boost`] after
51/// the rerank stage of `handle_recall` (MCP) and `recall_response`
52/// (HTTP). Each call:
53///
54/// 1. Reads the per-session set BEFORE assembling the boost so the
55/// candidates already touched in this session lift in rank.
56/// 2. Appends every recall hit's id INTO the per-session ring (FIFO
57/// eviction past [`SESSION_RECENT_CAP`]) so subsequent recalls in
58/// the same session reuse the new context.
59///
60/// The tracker uses a single `Mutex` because contention is dominated
61/// by the per-recall work itself (FTS + semantic + rerank), making
62/// the lock-acquire/-release cost noise; the implementation can swap
63/// to per-shard locking if a future profile shows otherwise.
64#[derive(Debug, Default)]
65pub struct SessionRecallTracker {
66 inner: Mutex<HashMap<String, VecDeque<String>>>,
67}
68
69impl SessionRecallTracker {
70 /// Construct an empty tracker. Test code uses this directly; the
71 /// production code path goes through the process-global
72 /// [`global_session_recall_tracker`] accessor below.
73 #[must_use]
74 pub fn new() -> Self {
75 Self::default()
76 }
77
78 /// Return the set of recently-accessed memory ids for `session_id`,
79 /// or an empty set if the session is unknown. Used by the rerank
80 /// boost to decide which candidates to lift.
81 ///
82 /// v0.7.0 #1091 — kept for the public API contract (test code +
83 /// callers outside the hot path use it). The boost site
84 /// [`apply_session_recency_boost`] now uses
85 /// [`SessionRecallTracker::with_recent_ids`] to avoid the
86 /// per-recall HashSet allocation.
87 #[must_use]
88 pub fn recent_ids(&self, session_id: &str) -> HashSet<String> {
89 let Ok(guard) = self.inner.lock() else {
90 // Poisoned mutex (a panic happened while the lock was
91 // held by another thread). Surface an empty set so the
92 // recall path stays infallible — the boost just doesn't
93 // fire this call.
94 return HashSet::new();
95 };
96 guard
97 .get(session_id)
98 .map(|ring| ring.iter().cloned().collect())
99 .unwrap_or_default()
100 }
101
102 /// v0.7.0 #1091 — allocation-free per-id membership lookup against
103 /// the per-session ring. Used by [`apply_session_recency_boost`]
104 /// to apply the +0.05 boost without cloning the 50-deep ring into
105 /// a fresh `HashSet<String>` on every recall.
106 ///
107 /// The callback is invoked once with a membership predicate that
108 /// owns the inner mutex guard for its lifetime. Returns the
109 /// closure's result (typically a `Vec<(Memory, f64)>` of boosted
110 /// candidates). The membership predicate is O(N) per id over the
111 /// ring (capped at [`SESSION_RECENT_CAP`] = 50); the closure is
112 /// expected to call it K times for a K-result recall, giving
113 /// O(K*N) total — same complexity as the pre-#1091 path that
114 /// also did a HashSet build (O(N) construct) + K lookups
115 /// (O(1) each = O(K)).
116 pub fn with_recent_ids<R>(
117 &self,
118 session_id: &str,
119 f: impl FnOnce(&dyn Fn(&str) -> bool) -> R,
120 ) -> R {
121 let Ok(guard) = self.inner.lock() else {
122 // Poisoned mutex: every id misses the boost. Same
123 // posture as the empty-set fallback above.
124 return f(&|_id: &str| false);
125 };
126 match guard.get(session_id) {
127 None => f(&|_id: &str| false),
128 Some(ring) => f(&|id: &str| ring.iter().any(|existing| existing == id)),
129 }
130 }
131
132 /// Record the ids of memories returned by the just-completed
133 /// recall into the per-session ring. FIFO eviction past
134 /// [`SESSION_RECENT_CAP`] keeps the per-session set bounded.
135 ///
136 /// Duplicate ids (a memory recalled twice in the same session)
137 /// move to the front of the ring so the eviction rule keeps the
138 /// most-recently-touched ids in the set.
139 pub fn record(&self, session_id: &str, ids: impl IntoIterator<Item = String>) {
140 let Ok(mut guard) = self.inner.lock() else {
141 return;
142 };
143 let ring = guard.entry(session_id.to_string()).or_default();
144 for id in ids {
145 // De-dupe by removing any existing occurrence so the
146 // newest landing position wins.
147 ring.retain(|existing| existing != &id);
148 ring.push_back(id);
149 while ring.len() > SESSION_RECENT_CAP {
150 ring.pop_front();
151 }
152 }
153 }
154
155 /// Diagnostic: number of tracked sessions. Used by tests and the
156 /// `/metrics` surface (future).
157 #[must_use]
158 pub fn session_count(&self) -> usize {
159 self.inner.lock().map(|g| g.len()).unwrap_or(0)
160 }
161}
162
163/// Process-global [`SessionRecallTracker`] used by every recall hot
164/// path. Lazily initialised on first access; never reset within a
165/// process lifetime (per-process state by design — operator restart
166/// clears every session's recent set).
167///
168/// v0.7.x (issue #1174 follow-up #1196) — the tracker lives on
169/// [`crate::runtime_context::RuntimeContext::recall_tracker`]. The
170/// returned `&'static` reference is stable because
171/// `RuntimeContext::global()` itself is a `OnceLock`-backed
172/// process-wide singleton; the `Arc<SessionRecallTracker>` inside it
173/// is allocated once and outlives every caller.
174#[must_use]
175pub fn global_session_recall_tracker() -> &'static SessionRecallTracker {
176 &crate::runtime_context::RuntimeContext::global().recall_tracker
177}
178
179/// v0.7.0 (issue #518) — apply the per-session recently-accessed boost
180/// to a scored recall result vector AND record the post-boost hit set
181/// back into the session's ring buffer.
182///
183/// `session_id` is the caller-supplied per-session identifier. When
184/// `None` or empty, the function is a no-op (returns the input
185/// unchanged). When set:
186///
187/// 1. Every candidate whose id is in the tracker's per-session set
188/// gets `SESSION_RECENCY_BOOST` ADDED to its score.
189/// 2. The vector is re-sorted descending by the boosted score.
190/// 3. The post-boost id list is appended into the session ring (FIFO
191/// eviction past [`SESSION_RECENT_CAP`]).
192///
193/// The boost is *additive* (not multiplicative) so its effect is
194/// independent of the absolute score magnitude — the +0.05 always
195/// breaks ties at the same delta regardless of whether scores are on
196/// the 0..1 cosine band or the 0..2 blended hybrid band.
197pub fn apply_session_recency_boost(
198 results: Vec<(Memory, f64)>,
199 session_id: Option<&str>,
200 tracker: &SessionRecallTracker,
201) -> Vec<(Memory, f64)> {
202 let Some(sid) = session_id else {
203 return results;
204 };
205 if sid.is_empty() {
206 return results;
207 }
208 // v0.7.0 #1091 — drop the per-recall `HashSet<String>` allocation
209 // (50 clones at the cap) by using the membership-callback variant.
210 // The closure owns the inner mutex for the boost-apply pass; the
211 // membership predicate runs O(N) per id against the (≤ 50 entry)
212 // ring, giving the same overall complexity as the pre-#1091
213 // (HashSet-build + lookup) path without the allocation.
214 let mut boosted: Vec<(Memory, f64)> = tracker.with_recent_ids(sid, |is_recent| {
215 results
216 .into_iter()
217 .map(|(mem, score)| {
218 let bumped = if is_recent(&mem.id) {
219 score + SESSION_RECENCY_BOOST
220 } else {
221 score
222 };
223 (mem, bumped)
224 })
225 .collect()
226 });
227 // Re-sort descending — boosted candidates may move past their
228 // pre-boost neighbours.
229 boosted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
230 // v0.7.0 #1091 — record the post-boost id list into the session
231 // ring via an iterator clone-on-demand path so we don't allocate
232 // a Vec<String> just to hand to `record` (which itself iterates).
233 tracker.record(sid, boosted.iter().map(|(m, _)| m.id.clone()));
234 boosted
235}
236
237/// Blend weight applied to the original (embedding/FTS) score.
238const ORIGINAL_WEIGHT: f64 = 0.6;
239/// Blend weight applied to the cross-encoder score.
240const CROSS_ENCODER_WEIGHT: f64 = 0.4;
241
242/// #1531 M13 — clamp a non-finite blended score to the ranking floor.
243///
244/// The rerank sort uses `partial_cmp(..).unwrap_or(Equal)`; a NaN
245/// final score (poisoned caller `original_score`, or a corrupt model
246/// weights file producing a NaN logit) compares `Equal` to EVERYTHING,
247/// so the stable sort left the NaN-scored candidate wherever it sat in
248/// the input — a corrupt candidate could nondeterministically hold the
249/// top rank. Mapping non-finite scores to `f64::MIN` deterministically
250/// sinks them to the bottom of the ranking instead. Finite scores pass
251/// through untouched, so ordinary ranking is byte-identical.
252fn finite_or_floor(score: f64) -> f64 {
253 if score.is_finite() { score } else { f64::MIN }
254}
255
256/// #1597 — split a candidate pool into `(head, tail)` at
257/// [`RERANK_POOL_MAX`].
258///
259/// Pools at or under the cap come back whole (`tail` empty, input order
260/// preserved — the degenerate full-rerank case). Larger pools are
261/// sorted by the incoming blended score descending (total order via
262/// [`f64::total_cmp`] so a NaN-poisoned score cannot destabilise the
263/// sort; NaN sorts into the head, is cross-encoded, and then sinks via
264/// [`finite_or_floor`] exactly as pre-#1597) and split after the cap,
265/// so both halves come back internally sorted descending.
266fn split_rerank_pool(
267 mut candidates: Vec<(Memory, f64)>,
268) -> (Vec<(Memory, f64)>, Vec<(Memory, f64)>) {
269 let tail = if candidates.len() > RERANK_POOL_MAX {
270 candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
271 candidates.split_off(RERANK_POOL_MAX)
272 } else {
273 Vec::new()
274 };
275 (candidates, tail)
276}
277
278/// #1597 — hard cap on how many candidates receive a cross-encoder
279/// score per rerank call.
280///
281/// The Phase-3 dogfood run measured autonomous-tier recall at
282/// 2823-7737 ms/call on CPU vs 14-32 ms at the semantic tier: the
283/// pre-#1597 [`CrossEncoder::rerank`] ran one full BERT forward pass
284/// per (query, candidate) pair, sequentially, over the entire
285/// post-blend candidate pool (up to 50 rows from the recall SQL cap).
286/// Only the strongest `RERANK_POOL_MAX` candidates by incoming blended
287/// score are cross-encoded (in ONE batched forward pass); the
288/// remainder keep their blended scores and sort below the reranked
289/// head. 20 keeps the cross-encoder's precision win where it matters
290/// (the head the caller actually reads) while bounding the worst-case
291/// forward-pass cost at ~40% of the pre-fix pool.
292pub const RERANK_POOL_MAX: usize = 20;
293
294const CROSS_ENCODER_MODEL_ID: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
295/// Bare configured-model spelling for the default reranker — shared with
296/// the `ai-memory config migrate` template (#1558 batch 6).
297pub(crate) const DEFAULT_RERANKER_MODEL: &str = "ms-marco-MiniLM-L-6-v2";
298/// Model-architecture ceiling on the cross-encoder input sequence.
299/// Per-consumer truncation (e.g. the #1604 rerank cap below) may go
300/// tighter, never looser — the resolver clamps against this value.
301pub const CROSS_ENCODER_MAX_SEQ: usize = 512;
302const CROSS_ENCODER_HIDDEN_DIM: usize = 384;
303
304/// #1604 — compiled default for the tokenized length of **rerank**
305/// inputs, applied in [`CrossEncoder::neural_score_pairs`] (the #1597
306/// batched-forward path) instead of the architecture-ceiling
307/// [`CROSS_ENCODER_MAX_SEQ`].
308///
309/// The #1588 dogfood RE-RUN measured the residual #1597 latency:
310/// warm autonomous-tier recall was ~4,013 ms on a real (long-content)
311/// corpus vs ~533 ms on short-content rows — the [batch=20, seq=512]
312/// candle CPU forward, not pool size or batching, was the cost. BERT
313/// attention is O(n²) in sequence length, so halving the cap to 256
314/// cuts the forward ~4× while keeping the title + lead content that
315/// carries the relevance signal for memory rows. Other cross-encoder
316/// consumers (the single-pair [`CrossEncoder::score`]) keep the full
317/// [`CROSS_ENCODER_MAX_SEQ`].
318///
319/// Operator override ladder (resolved by
320/// `AppConfig::resolve_reranker()` at boot and seeded here via
321/// [`set_rerank_max_seq`]): `AI_MEMORY_RERANK_MAX_SEQ` env >
322/// `[reranker].max_seq_tokens` config > this compiled default. Values
323/// that are zero, unparseable, or above [`CROSS_ENCODER_MAX_SEQ`]
324/// fall through to the next ladder layer.
325pub const RERANK_MAX_SEQ_DEFAULT: usize = 256;
326
327/// Process-wide resolved rerank sequence cap, seeded once at boot from
328/// `AppConfig::resolve_reranker()` (the `crate::storage::set_db_mmap_size`
329/// OnceLock precedent — the scoring paths run deep in the recall
330/// pipeline where no `AppConfig` is in scope). Unseeded processes
331/// (unit tests, library embedders that bypass the CLI boot path) fall
332/// through to [`RERANK_MAX_SEQ_DEFAULT`].
333static RERANK_MAX_SEQ: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
334
335/// Seed the process-wide rerank sequence cap for every subsequent
336/// batched rerank forward. Idempotent — first writer wins; later calls
337/// are no-ops (matches `crate::storage::set_db_mmap_size`).
338pub fn set_rerank_max_seq(tokens: usize) {
339 let _ = RERANK_MAX_SEQ.set(tokens);
340}
341
342/// The effective rerank sequence cap for this process.
343fn rerank_max_seq() -> usize {
344 *RERANK_MAX_SEQ.get().unwrap_or(&RERANK_MAX_SEQ_DEFAULT)
345}
346
347/// v0.7.0 L2-8 — default multiplicative boost applied to `Reflection`-kind
348/// memories AFTER cross-encoder reranking. Reflections summarise multiple
349/// observations, so abstraction-shaped queries ("what patterns...",
350/// "what are recurring themes...") should preferentially surface them.
351/// Default value `1.2` sits in the band where a reflection with a base
352/// score equal to its source observations consistently lifts into the
353/// top-5 without dragging mediocre reflections above well-matched
354/// observations.
355pub const DEFAULT_REFLECTION_BOOST: f32 = 1.2;
356
357/// v0.7.0 L2-8 — default per-depth additional multiplier increment.
358/// `per_depth_factor = 1.0 + per_depth_increment * reflection_depth`.
359/// Deeper reflections (reflections-on-reflections) compress more
360/// observations, so a small per-depth bump is justified.
361pub const DEFAULT_REFLECTION_PER_DEPTH_INCREMENT: f32 = 0.05;
362
363/// v0.7.0 L2-8 — default depth cap mirrored from
364/// [`GovernancePolicy::effective_max_reflection_depth`]. Past this depth
365/// the per-depth multiplier stops growing; reflections deeper than the
366/// cap still receive the cap-evaluated boost (operator policy may refuse
367/// the write entirely, but the reranker side never produces an unbounded
368/// multiplier).
369pub const DEFAULT_REFLECTION_MAX_DEPTH_CAP: u32 = 3;
370
371/// v0.7.0 L2-8 — configuration for the reflection-aware reranker boost.
372///
373/// The boost is applied AFTER the cross-encoder blend (i.e. it does NOT
374/// participate in the `0.6 * original + 0.4 * cross_encoder` scoring
375/// formula). Boost shape:
376///
377/// ```text
378/// per_depth_factor = 1.0 + per_depth_increment * min(reflection_depth, max_depth_cap)
379/// final_score = base_score * (kind == Reflection ? boost * per_depth_factor : 1.0)
380/// ```
381///
382/// Default factor = `1.2` (see [`DEFAULT_REFLECTION_BOOST`]). Setting
383/// `boost = 1.0` makes the reranker reproduce its pre-L2-8 behavior
384/// exactly — a deliberate kill-switch for the recall regression suite.
385#[derive(Debug, Clone, Copy, PartialEq)]
386pub struct ReflectionBoostConfig {
387 /// Multiplicative boost applied to `Reflection`-kind memories.
388 /// Default `1.2`. `1.0` disables the boost.
389 pub boost: f32,
390 /// Per-depth additional multiplier increment. Default `0.05`.
391 pub per_depth_increment: f32,
392 /// Depth cap for the per-depth multiplier. Default `3` (mirrors
393 /// the compiled-in default of
394 /// `GovernancePolicy::effective_max_reflection_depth`). Larger
395 /// `reflection_depth` values are clamped to this cap so the
396 /// reranker never produces an unbounded multiplier.
397 pub max_depth_cap: u32,
398}
399
400impl Default for ReflectionBoostConfig {
401 fn default() -> Self {
402 Self {
403 boost: DEFAULT_REFLECTION_BOOST,
404 per_depth_increment: DEFAULT_REFLECTION_PER_DEPTH_INCREMENT,
405 max_depth_cap: DEFAULT_REFLECTION_MAX_DEPTH_CAP,
406 }
407 }
408}
409
410impl ReflectionBoostConfig {
411 /// Pin to pre-L2-8 behavior: `boost = 1.0` ⇒ multiplier is always
412 /// `1.0` regardless of memory kind or depth. Used by the regression
413 /// test that proves the new pathway is a *pure addition* over the RC
414 /// behavior.
415 #[must_use]
416 pub const fn disabled() -> Self {
417 Self {
418 boost: 1.0,
419 per_depth_increment: 0.0,
420 max_depth_cap: 0,
421 }
422 }
423
424 /// Compute the multiplicative factor for a given memory. Returns
425 /// `1.0` for non-reflections; `boost * per_depth_factor` for
426 /// reflections (with `reflection_depth` clamped to `max_depth_cap`).
427 ///
428 /// Pulled out so the same arithmetic is shared by both the per-query
429 /// `rerank` and the G9 batched `rerank_batch` codepaths — there is
430 /// exactly one place to audit the multiplier shape.
431 #[must_use]
432 pub fn factor_for(&self, mem: &Memory) -> f64 {
433 if !matches!(mem.memory_kind, crate::models::MemoryKind::Reflection) {
434 return 1.0;
435 }
436 // `reflection_depth` is stored as i32 (SQL signed) but the
437 // governance accessor returns u32; the column DEFAULT is 0 and
438 // negative values would already have been rejected by the
439 // `memory_reflect` write path. Clamp to non-negative defensively
440 // so a bad write upstream can't produce a negative multiplier.
441 let depth = u32::try_from(mem.reflection_depth.max(0)).unwrap_or(0);
442 let depth_clamped = depth.min(self.max_depth_cap);
443 let per_depth_factor =
444 f64::from(self.per_depth_increment).mul_add(f64::from(depth_clamped), 1.0);
445 f64::from(self.boost) * per_depth_factor
446 }
447}
448
449/// Cross-encoder for (query, document) relevance scoring.
450pub enum CrossEncoder {
451 /// Lightweight lexical cross-encoder using term overlap signals.
452 ///
453 /// `degraded` is `true` when this variant exists because a
454 /// configured neural cross-encoder failed to initialise (HF Hub
455 /// unreachable, model checksum mismatch, etc.) and the runtime
456 /// fell back. `false` is the originally-configured lexical tier
457 /// (operator opted in to keyword-tier or smart-tier without
458 /// cross-encoder reranking).
459 ///
460 /// v0.7.0 R3-S2 — the distinction surfaces in the recall
461 /// response's `meta.reranker_used` field as
462 /// `"degraded_lexical"` vs `"lexical"`, so an in-band signal
463 /// tells clients (MCP + HTTP) when their reranker downgraded.
464 /// The original G8 fix landed `tracing::warn!` only; G8 closure
465 /// per the playbook required an in-response field, which the
466 /// prior implementation overstated.
467 Lexical { degraded: bool },
468 /// Neural BERT-based cross-encoder (ms-marco-MiniLM-L-6-v2).
469 ///
470 /// v0.7.0 #1084 — `model` is `Arc<BertModel>` (no mutex), same
471 /// pattern as `Embedder::Local`. The pre-#1084 design held an
472 /// `Arc<Mutex<BertModel>>` and locked across the full neural
473 /// rerank forward pass, serialising every rerank-tier recall on
474 /// a single global mutex. Candle's `BertModel::forward` takes
475 /// `&self` (inference-only; weights are read-only) so the
476 /// mutex was unnecessary.
477 Neural {
478 model: Arc<BertModel>,
479 tokenizer: Arc<Tokenizer>,
480 classifier_weight: Tensor,
481 classifier_bias: Tensor,
482 device: Device,
483 },
484}
485
486impl CrossEncoder {
487 /// Create a new lexical cross-encoder (no model download required).
488 ///
489 /// This is the "originally lexical" path — the operator either
490 /// chose keyword-/semantic-tier (no cross-encoder reranking) or
491 /// explicitly opted into the lexical variant. Use
492 /// [`Self::new_neural`] to attempt the neural path with
493 /// fall-back-to-lexical semantics.
494 pub fn new() -> Self {
495 Self::Lexical { degraded: false }
496 }
497
498 /// Create a neural cross-encoder by downloading ms-marco-MiniLM-L-6-v2.
499 ///
500 /// Falls back to lexical if download or loading fails. The
501 /// fallback is marked `degraded: true` so the recall response
502 /// surfaces `reranker_used = "degraded_lexical"` per R3-S2 — an
503 /// in-band signal that v0.7.0 promises but pre-R3 only emitted
504 /// as a `tracing::warn!` (a tracing-event-only fallback is not
505 /// the same as a per-response field operators can branch on).
506 ///
507 /// v0.6.3.1 (P3, G8): when the neural path fails (e.g. HF Hub
508 /// unreachable, model checksum mismatch), emit a structured tracing
509 /// event `reranker.fallback` so operators see the silent
510 /// neural→lexical degrade. The eprintln remains for backward-compat
511 /// startup logs.
512 pub fn new_neural() -> Self {
513 match Self::load_neural() {
514 Ok(ce) => ce,
515 Err(e) => {
516 tracing::warn!(
517 target: "reranker.fallback",
518 from = "neural",
519 to = "lexical",
520 reason = %e,
521 "cross-encoder fell back to lexical: neural init failed"
522 );
523 eprintln!("ai-memory: neural cross-encoder failed ({e}), using lexical fallback");
524 Self::Lexical { degraded: true }
525 }
526 }
527 }
528
529 fn load_neural() -> Result<Self> {
530 let device = Device::Cpu;
531
532 let api = Api::new().context("failed to init HuggingFace Hub API")?;
533 let repo = api.repo(Repo::new(
534 CROSS_ENCODER_MODEL_ID.to_string(),
535 RepoType::Model,
536 ));
537
538 let config_path = repo
539 .get(crate::embeddings::HF_CONFIG_FILE)
540 .context("failed to download config.json")?;
541 let tokenizer_path = repo
542 .get(crate::embeddings::HF_TOKENIZER_FILE)
543 .context("failed to download tokenizer.json")?;
544 let weights_path = repo
545 .get(crate::embeddings::HF_WEIGHTS_FILE)
546 .context("failed to download model.safetensors")?;
547
548 // Load BERT config
549 let config_data = std::fs::read_to_string(&config_path)
550 .context("failed to read cross-encoder config.json")?;
551 let config: BertConfig = serde_json::from_str(&config_data)
552 .context("failed to parse cross-encoder config.json")?;
553
554 // Load tokenizer
555 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
556 .map_err(|e| anyhow::anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
557 let truncation = tokenizers::TruncationParams {
558 max_length: CROSS_ENCODER_MAX_SEQ,
559 ..Default::default()
560 };
561 tokenizer
562 .with_truncation(Some(truncation))
563 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
564 tokenizer.with_padding(None);
565
566 // Load model weights.
567 //
568 // SAFETY (#1456): `from_mmaped_safetensors` memory-maps the
569 // weights file. The mmap is unsound only if the backing file is
570 // mutated or truncated by another process while it is mapped.
571 // `weights_path` resolves to a trusted, immutable safetensors
572 // artifact in the daemon-owned HuggingFace cache (downloaded and
573 // not subsequently written by us); it is never a caller-supplied
574 // path at request time. The mapping lives only for the duration
575 // of weight loading below.
576 let vb = unsafe {
577 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
578 .context("failed to load cross-encoder weights")?
579 };
580
581 let model = BertModel::load(vb.clone(), &config)
582 .context("failed to build cross-encoder BertModel")?;
583
584 // Load the classification head: classifier.weight [1, hidden_dim] and classifier.bias [1]
585 let classifier_weight = vb
586 .get((1, CROSS_ENCODER_HIDDEN_DIM), "classifier.weight")
587 .context("failed to load classifier.weight")?;
588 let classifier_bias = vb
589 .get(1, "classifier.bias")
590 .context("failed to load classifier.bias")?;
591
592 Ok(Self::Neural {
593 model: Arc::new(model),
594 tokenizer: Arc::new(tokenizer),
595 classifier_weight,
596 classifier_bias,
597 device,
598 })
599 }
600
601 /// Score a single (query, document) pair.
602 ///
603 /// Returns a relevance score in `0.0..=1.0`.
604 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
605 match self {
606 Self::Lexical { .. } => lexical_score(query, title, content),
607 Self::Neural {
608 model,
609 tokenizer,
610 classifier_weight,
611 classifier_bias,
612 device,
613 } => {
614 // v0.7.0 #1084 — no mutex acquisition: `Arc<BertModel>`
615 // shared across threads; `BertModel::forward(&self, ...)`
616 // is inference-only and safe to call concurrently.
617 match Self::neural_score(
618 model,
619 tokenizer,
620 classifier_weight,
621 classifier_bias,
622 device,
623 query,
624 title,
625 content,
626 ) {
627 Ok(s) => s,
628 Err(e) => {
629 tracing::warn!(
630 "neural cross-encoder score failed: {e}, using lexical fallback"
631 );
632 lexical_score(query, title, content)
633 }
634 }
635 }
636 }
637 }
638
639 #[allow(clippy::too_many_arguments)]
640 fn neural_score(
641 model: &BertModel,
642 tokenizer: &Tokenizer,
643 classifier_weight: &Tensor,
644 classifier_bias: &Tensor,
645 device: &Device,
646 query: &str,
647 title: &str,
648 content: &str,
649 ) -> Result<f32> {
650 // Cross-encoder input: "[CLS] query [SEP] title content [SEP]"
651 let document = crate::embeddings::embedding_document(title, content);
652
653 let encoding = tokenizer
654 .encode((query, document.as_str()), true)
655 .map_err(|e| anyhow::anyhow!("cross-encoder tokenization failed: {e}"))?;
656
657 let input_ids = encoding.get_ids();
658 let attention_mask = encoding.get_attention_mask();
659 let token_type_ids = encoding.get_type_ids();
660 let seq_len = input_ids.len();
661
662 let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
663 let attention_mask = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
664 let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
665
666 // Forward pass through BERT → [1, seq_len, 384]
667 let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
668
669 // Take [CLS] token (first token) → [1, 384]
670 let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
671
672 // Classification head: logit = cls @ weight^T + bias → [1, 1]
673 let logit = cls
674 .matmul(&classifier_weight.t()?)?
675 .broadcast_add(classifier_bias)?;
676
677 // Extract scalar logit and apply sigmoid to get [0, 1] score
678 let logit_val: f32 = logit.squeeze(0)?.squeeze(0)?.to_scalar()?;
679 let score = 1.0 / (1.0 + (-logit_val).exp());
680
681 Ok(score)
682 }
683
684 /// Whether this is a neural cross-encoder.
685 pub fn is_neural(&self) -> bool {
686 matches!(self, Self::Neural { .. })
687 }
688
689 /// v0.7.0 R3-S2 — whether this cross-encoder is a *degraded*
690 /// lexical fallback (i.e., a neural variant was attempted at
691 /// startup or mid-flight and the runtime fell back). `false` for
692 /// `Neural` and for the originally-configured `Lexical` (operator
693 /// opted into keyword-/semantic-tier without cross-encoder
694 /// reranking). The recall response surfaces this distinction as
695 /// `meta.reranker_used = "degraded_lexical"` so clients can
696 /// detect the silent downgrade in-band — closing the G8 closure
697 /// claim that tracing-event-only signalling had overstated.
698 #[must_use]
699 pub fn is_degraded_lexical(&self) -> bool {
700 matches!(self, Self::Lexical { degraded: true })
701 }
702
703 /// Rerank a set of candidates by blending their original scores with
704 /// cross-encoder scores.
705 ///
706 /// **Blend formula:** `final = 0.6 * original + 0.4 * cross_encoder`
707 ///
708 /// **#1597 pool cap:** only the strongest [`RERANK_POOL_MAX`]
709 /// candidates by incoming blended score are cross-encoded; the
710 /// remainder keep their blended scores and rank below the reranked
711 /// head (head sorted by `final_score` descending, tail sorted by
712 /// blended score descending — no candidate is dropped). A pool at
713 /// or under the cap is fully reranked and returned sorted by
714 /// `final_score` descending, as before.
715 ///
716 /// **v0.7.0 L2-8 contract:** the bare `rerank` is the *pre-L2-8*
717 /// behavior — no reflection boost is applied. Daemons that want
718 /// the reflection-aware boost must call
719 /// [`Self::rerank_with_reflection_boost`] (which is what
720 /// [`BatchedReranker`] does by default with
721 /// [`ReflectionBoostConfig::default`]). Keeping the bare method
722 /// boost-free is a deliberate regression-pin discipline: the L2-8
723 /// recall test for `boost = 1.0` uses
724 /// `rerank_with_reflection_boost(.., &ReflectionBoostConfig::disabled())`
725 /// and asserts byte-identical output to `rerank(..)`.
726 pub fn rerank(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
727 // #1597 — delegate so the pool cap + batched forward pass live in
728 // exactly one place. `ReflectionBoostConfig::disabled()` yields a
729 // multiplier of exactly 1.0 for every candidate, so the output is
730 // byte-identical to the historical boost-free blend (the L2-8
731 // regression pin below asserts this equivalence directly).
732 self.rerank_with_reflection_boost(query, candidates, &ReflectionBoostConfig::disabled())
733 }
734
735 /// v0.7.0 L2-8 — rerank with a post-step reflection-aware boost.
736 ///
737 /// 1. Same blend as [`Self::rerank`] (`0.6 * original + 0.4 * ce`).
738 /// 2. **After** the blend, multiply each candidate's `final_score`
739 /// by [`ReflectionBoostConfig::factor_for`]. Observations get a
740 /// multiplier of `1.0` (unchanged); reflections get
741 /// `boost * (1.0 + per_depth_increment * clamp(depth, 0..=cap))`.
742 /// 3. Sort descending after the boost so the output ordering
743 /// reflects the post-boost ranking.
744 ///
745 /// Operationally this means: a reflection that the cross-encoder
746 /// scored at parity with its source observations *moves up*; the
747 /// movement is bounded (capped per-depth multiplier, single global
748 /// `boost` factor) so a mediocre reflection cannot leapfrog a
749 /// well-matched observation — the boost is a thumb-on-the-scale,
750 /// not a free pass.
751 /// **#1597 pool cap + batched forward pass.** Only the strongest
752 /// [`RERANK_POOL_MAX`] candidates by incoming blended score receive a
753 /// cross-encoder score (in one batched forward pass on the Neural
754 /// variant); the remainder keep their blended scores, internally
755 /// sorted descending, appended after the reranked head. No candidate
756 /// is ever dropped. A pool at or under the cap degenerates to the
757 /// historical full rerank.
758 pub fn rerank_with_reflection_boost(
759 &self,
760 query: &str,
761 candidates: Vec<(Memory, f64)>,
762 boost_config: &ReflectionBoostConfig,
763 ) -> Vec<(Memory, f64)> {
764 let (head, tail) = split_rerank_pool(candidates);
765
766 let ce_scores = self.pair_scores(query, &head);
767 let mut scored: Vec<(Memory, f64)> = head
768 .into_iter()
769 .zip(ce_scores)
770 .map(|((mem, original_score), ce_score)| {
771 let blended =
772 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * f64::from(ce_score);
773 let factor = boost_config.factor_for(&mem);
774 (mem, finite_or_floor(blended * factor))
775 })
776 .collect();
777
778 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779 // #1597 — uncapped remainder: blended scores untouched, already
780 // sorted descending by `split_rerank_pool`, ranked below the
781 // cross-encoded head.
782 scored.extend(tail);
783 scored
784 }
785
786 /// #1597 — cross-encoder scores for an (already capped) candidate
787 /// slice, one score per candidate in input order.
788 ///
789 /// Neural variant: ONE batched tokenize + forward pass via
790 /// [`Self::neural_score_pairs`] (the same machinery the G9
791 /// [`Self::rerank_batch`] path uses) instead of a sequential
792 /// per-pair forward — the second half of the #1597 fix. Falls back
793 /// to per-pair lexical scoring if the batched forward fails.
794 fn pair_scores(&self, query: &str, candidates: &[(Memory, f64)]) -> Vec<f32> {
795 let lexical_fallback = |candidates: &[(Memory, f64)]| -> Vec<f32> {
796 candidates
797 .iter()
798 .map(|(mem, _)| lexical_score(query, &mem.title, &mem.content))
799 .collect()
800 };
801 match self {
802 Self::Lexical { .. } => lexical_fallback(candidates),
803 Self::Neural {
804 model,
805 tokenizer,
806 classifier_weight,
807 classifier_bias,
808 device,
809 } => {
810 let pairs: Vec<(&str, String)> = candidates
811 .iter()
812 .map(|(mem, _)| {
813 (
814 query,
815 crate::embeddings::embedding_document(&mem.title, &mem.content),
816 )
817 })
818 .collect();
819 match Self::neural_score_pairs(
820 model,
821 tokenizer,
822 classifier_weight,
823 classifier_bias,
824 device,
825 pairs,
826 ) {
827 Ok(scores) => scores,
828 Err(e) => {
829 tracing::warn!(
830 "neural cross-encoder batch score failed: {e}, using lexical fallback"
831 );
832 lexical_fallback(candidates)
833 }
834 }
835 }
836 }
837 }
838
839 /// v0.7 G9 — batched rerank for concurrent recall.
840 ///
841 /// Process all `(query, candidates)` jobs in a single tokenize + single
842 /// forward pass on the Neural variant, holding the BERT mutex once for
843 /// the whole batch instead of once per (query, candidate) pair.
844 ///
845 /// **Throughput target**: ~3× for parallel recall vs. per-query
846 /// `rerank()` calls.
847 ///
848 /// Output ordering: `result[i]` corresponds to `queries[i]`. Each
849 /// inner vector is sorted by descending blended score, identical to
850 /// `rerank()`. Lexical variant delegates per-query (no batching win
851 /// since lexical scoring is already CPU-trivial).
852 pub fn rerank_batch(
853 &self,
854 queries: Vec<(String, Vec<(Memory, f64)>)>,
855 ) -> Vec<Vec<(Memory, f64)>> {
856 // Boost-free legacy entry point — preserves the pre-L2-8 wire
857 // shape for callers that haven't migrated to the boost-aware
858 // variant. See `rerank_batch_with_reflection_boost` for the
859 // L2-8 path; here we delegate to it with the `disabled()`
860 // config so the implementation lives in one place.
861 self.rerank_batch_with_reflection_boost(queries, &ReflectionBoostConfig::disabled())
862 }
863
864 /// v0.7.0 L2-8 — batched rerank with a post-step reflection-aware
865 /// boost applied per candidate. Same boost arithmetic as
866 /// [`Self::rerank_with_reflection_boost`], factored so the boost
867 /// shape lives in a single helper.
868 pub fn rerank_batch_with_reflection_boost(
869 &self,
870 queries: Vec<(String, Vec<(Memory, f64)>)>,
871 boost_config: &ReflectionBoostConfig,
872 ) -> Vec<Vec<(Memory, f64)>> {
873 // Single-query short-circuit: avoid any batching overhead.
874 if queries.len() == 1 {
875 let mut iter = queries.into_iter();
876 let (q, cands) = iter.next().expect("len == 1");
877 return vec![self.rerank_with_reflection_boost(&q, cands, boost_config)];
878 }
879
880 match self {
881 Self::Lexical { .. } => queries
882 .into_iter()
883 .map(|(q, cands)| self.rerank_with_reflection_boost(&q, cands, boost_config))
884 .collect(),
885 Self::Neural {
886 model,
887 tokenizer,
888 classifier_weight,
889 classifier_bias,
890 device,
891 } => {
892 // #1597 — apply the per-query pool cap BEFORE the batched
893 // forward pass so a coalesced flush pays for at most
894 // `RERANK_POOL_MAX` forwards per job; each tail is
895 // reattached below its reranked head afterwards.
896 let mut tails: Vec<Vec<(Memory, f64)>> = Vec::with_capacity(queries.len());
897 let queries: Vec<(String, Vec<(Memory, f64)>)> = queries
898 .into_iter()
899 .map(|(q, cands)| {
900 let (head, tail) = split_rerank_pool(cands);
901 tails.push(tail);
902 (q, head)
903 })
904 .collect();
905 // v0.7.0 #1084 — no mutex acquisition: `Arc<BertModel>`
906 // shared across threads; `BertModel::forward(&self, ...)`
907 // is inference-only and safe to call concurrently. The
908 // pre-#1084 poisoned-lock fallback is now unreachable
909 // (no lock to poison); a runtime error in
910 // `neural_rerank_batch` still falls through to the
911 // lexical degrade via the `Err(_)` arm below.
912 match Self::neural_rerank_batch(
913 model,
914 tokenizer,
915 classifier_weight,
916 classifier_bias,
917 device,
918 &queries,
919 ) {
920 Ok(scores) => {
921 // scores is a flat Vec<f32>, one per (query_idx,
922 // candidate_idx) in row-major order matching
923 // queries.iter().flat_map(|(_, cs)| cs).
924 let mut out = Vec::with_capacity(queries.len());
925 let mut cursor = 0usize;
926 for ((_query, cands), tail) in queries.into_iter().zip(tails) {
927 let n = cands.len();
928 let mut scored: Vec<(Memory, f64)> = cands
929 .into_iter()
930 .enumerate()
931 .map(|(i, (mem, original))| {
932 let ce = f64::from(scores[cursor + i]);
933 let blended =
934 ORIGINAL_WEIGHT * original + CROSS_ENCODER_WEIGHT * ce;
935 let factor = boost_config.factor_for(&mem);
936 (mem, finite_or_floor(blended * factor))
937 })
938 .collect();
939 cursor += n;
940 scored.sort_by(|a, b| {
941 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
942 });
943 // #1597 — uncapped remainder ranks below the
944 // cross-encoded head, blended scores untouched.
945 scored.extend(tail);
946 out.push(scored);
947 }
948 out
949 }
950 Err(e) => {
951 tracing::warn!(
952 "neural rerank_batch failed: {e}, falling back to lexical per-query"
953 );
954 queries
955 .into_iter()
956 .zip(tails)
957 .map(|((q, cands), tail)| {
958 // Runtime degrade (forward-pass failure) —
959 // mark the variant degraded so the recall
960 // response can surface `degraded_lexical`.
961 let lex = Self::Lexical { degraded: true };
962 let mut scored =
963 lex.rerank_with_reflection_boost(&q, cands, boost_config);
964 scored.extend(tail);
965 scored
966 })
967 .collect()
968 }
969 }
970 }
971 }
972 }
973
974 /// One tokenize + one forward pass over a flat batch of (query, doc)
975 /// pairs. Returns a flat `Vec<f32>` of sigmoided logits in the same
976 /// row-major order the candidates appear in `queries`.
977 fn neural_rerank_batch(
978 model: &BertModel,
979 tokenizer: &Tokenizer,
980 classifier_weight: &Tensor,
981 classifier_bias: &Tensor,
982 device: &Device,
983 queries: &[(String, Vec<(Memory, f64)>)],
984 ) -> Result<Vec<f32>> {
985 // Build the flat (query, document) pair list.
986 let mut pairs: Vec<(&str, String)> = Vec::new();
987 for (q, cands) in queries {
988 for (mem, _) in cands {
989 let document = crate::embeddings::embedding_document(&mem.title, &mem.content);
990 pairs.push((q.as_str(), document));
991 }
992 }
993 Self::neural_score_pairs(
994 model,
995 tokenizer,
996 classifier_weight,
997 classifier_bias,
998 device,
999 pairs,
1000 )
1001 }
1002
1003 /// One tokenize + one forward pass over a flat list of
1004 /// (query, document) pairs — the shared batched-inference chokepoint
1005 /// (#1597) used by BOTH the G9 multi-query [`Self::neural_rerank_batch`]
1006 /// path and the per-call [`Self::pair_scores`] path. Returns one
1007 /// sigmoided logit per pair, in input order.
1008 fn neural_score_pairs(
1009 model: &BertModel,
1010 tokenizer: &Tokenizer,
1011 classifier_weight: &Tensor,
1012 classifier_bias: &Tensor,
1013 device: &Device,
1014 pairs: Vec<(&str, String)>,
1015 ) -> Result<Vec<f32>> {
1016 if pairs.is_empty() {
1017 return Ok(Vec::new());
1018 }
1019
1020 // Variable-length pairs require padding for a single forward pass.
1021 // Clone the tokenizer so we can mutate padding settings without
1022 // racing other threads on the shared `Arc<Tokenizer>`.
1023 let mut batch_tokenizer = tokenizer.clone();
1024 let padding = tokenizers::PaddingParams {
1025 strategy: tokenizers::PaddingStrategy::BatchLongest,
1026 direction: tokenizers::PaddingDirection::Right,
1027 pad_id: 0,
1028 pad_type_id: 0,
1029 pad_token: "[PAD]".to_string(),
1030 ..Default::default()
1031 };
1032 batch_tokenizer.with_padding(Some(padding));
1033 // #1604 — rerank inputs truncate at the resolved rerank cap
1034 // (default RERANK_MAX_SEQ_DEFAULT), tighter than the
1035 // architecture-ceiling CROSS_ENCODER_MAX_SEQ the shared
1036 // tokenizer carries: long-content rows otherwise pad the whole
1037 // batch to 512 tokens and the candle CPU forward dominates
1038 // recall latency (~3.2 s/recall measured on the #1588 re-run).
1039 let truncation = tokenizers::TruncationParams {
1040 max_length: rerank_max_seq(),
1041 ..Default::default()
1042 };
1043 batch_tokenizer
1044 .with_truncation(Some(truncation))
1045 .map_err(|e| anyhow::anyhow!("failed to set rerank truncation: {e}"))?;
1046
1047 let encodings = batch_tokenizer
1048 .encode_batch(
1049 pairs
1050 .into_iter()
1051 .map(|(q, d)| tokenizers::EncodeInput::Dual(q.into(), d.into()))
1052 .collect::<Vec<_>>(),
1053 true,
1054 )
1055 .map_err(|e| anyhow::anyhow!("cross-encoder batch tokenization failed: {e}"))?;
1056
1057 let batch_size = encodings.len();
1058 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
1059
1060 let mut input_ids: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
1061 let mut attn_mask: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
1062 let mut token_types: Vec<u32> = Vec::with_capacity(batch_size * seq_len);
1063 for enc in &encodings {
1064 input_ids.extend_from_slice(enc.get_ids());
1065 attn_mask.extend_from_slice(enc.get_attention_mask());
1066 token_types.extend_from_slice(enc.get_type_ids());
1067 }
1068
1069 let input_ids = Tensor::from_vec(input_ids, (batch_size, seq_len), device)?;
1070 let attention_mask = Tensor::from_vec(attn_mask, (batch_size, seq_len), device)?;
1071 let token_type_ids = Tensor::from_vec(token_types, (batch_size, seq_len), device)?;
1072
1073 // Forward pass → [batch, seq, 384]
1074 let hidden = model.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
1075
1076 // [CLS] token per row → [batch, 384]
1077 let cls = hidden.narrow(1, 0, 1)?.squeeze(1)?;
1078
1079 // Classification head per row → [batch, 1]
1080 let logits = cls
1081 .matmul(&classifier_weight.t()?)?
1082 .broadcast_add(classifier_bias)?;
1083
1084 let logits_vec: Vec<f32> = logits.squeeze(1)?.to_vec1()?;
1085 Ok(logits_vec
1086 .into_iter()
1087 .map(|l| 1.0 / (1.0 + (-l).exp()))
1088 .collect())
1089 }
1090}
1091
1092impl Default for CrossEncoder {
1093 fn default() -> Self {
1094 Self::new()
1095 }
1096}
1097
1098// ---------------------------------------------------------------------------
1099// Lexical cross-encoder (original implementation)
1100// ---------------------------------------------------------------------------
1101
1102fn lexical_score(query: &str, title: &str, content: &str) -> f32 {
1103 let query_terms = tokenize(query);
1104 if query_terms.is_empty() {
1105 return 0.0;
1106 }
1107
1108 let title_terms = tokenize(title);
1109 let content_terms = tokenize(content);
1110
1111 let doc_terms: HashSet<&str> = title_terms
1112 .iter()
1113 .chain(content_terms.iter())
1114 .copied()
1115 .collect();
1116 let query_set: HashSet<&str> = query_terms.iter().copied().collect();
1117
1118 // 1. Jaccard term overlap
1119 #[allow(clippy::cast_precision_loss)]
1120 let intersection = query_set.intersection(&doc_terms).count() as f32;
1121 #[allow(clippy::cast_precision_loss)]
1122 let union = query_set.union(&doc_terms).count() as f32;
1123 let jaccard = if union > 0.0 {
1124 intersection / union
1125 } else {
1126 0.0
1127 };
1128
1129 // 2. TF-IDF-like term weighting
1130 let doc_all: Vec<&str> = title_terms
1131 .iter()
1132 .chain(content_terms.iter())
1133 .copied()
1134 .collect();
1135 let tf_idf = tfidf_score(&query_terms, &doc_all);
1136
1137 // 3. Bigram overlap bonus
1138 let query_bigrams = bigrams(&query_terms);
1139 let doc_bigrams = bigrams(&doc_all);
1140 let bigram_overlap = if query_bigrams.is_empty() {
1141 0.0
1142 } else {
1143 let doc_bigram_set: HashSet<(&str, &str)> = doc_bigrams.into_iter().collect();
1144 #[allow(clippy::cast_precision_loss)]
1145 let hits = query_bigrams
1146 .iter()
1147 .filter(|b| doc_bigram_set.contains(b))
1148 .count() as f32;
1149 #[allow(clippy::cast_precision_loss)]
1150 let query_bigrams_len = query_bigrams.len() as f32;
1151 hits / query_bigrams_len
1152 };
1153
1154 // 4. Title match bonus
1155 let title_set: HashSet<&str> = title_terms.iter().copied().collect();
1156 #[allow(clippy::cast_precision_loss)]
1157 let title_hits = query_set.intersection(&title_set).count() as f32;
1158 #[allow(clippy::cast_precision_loss)]
1159 let title_bonus = if query_set.is_empty() {
1160 0.0
1161 } else {
1162 title_hits / query_set.len() as f32
1163 };
1164
1165 let raw = 0.30 * jaccard + 0.30 * tf_idf + 0.20 * bigram_overlap + 0.20 * title_bonus;
1166 raw.clamp(0.0, 1.0)
1167}
1168
1169// ---------------------------------------------------------------------------
1170// Internal helpers
1171// ---------------------------------------------------------------------------
1172
1173fn tokenize(text: &str) -> Vec<&str> {
1174 text.split(|c: char| !c.is_alphanumeric() && c != '\'')
1175 .filter(|w| !w.is_empty())
1176 .collect()
1177}
1178
1179fn tfidf_score(query_terms: &[&str], doc_tokens: &[&str]) -> f32 {
1180 if doc_tokens.is_empty() || query_terms.is_empty() {
1181 return 0.0;
1182 }
1183
1184 let mut tf_map: HashMap<&str, usize> = HashMap::new();
1185 for &tok in doc_tokens {
1186 *tf_map.entry(tok).or_insert(0) += 1;
1187 }
1188
1189 #[allow(clippy::cast_precision_loss)]
1190 let total = doc_tokens.len() as f32;
1191 #[allow(clippy::cast_precision_loss)]
1192 let unique = tf_map.len() as f32;
1193
1194 let mut score_sum: f32 = 0.0;
1195 let query_lower: Vec<String> = query_terms.iter().map(|t| t.to_lowercase()).collect();
1196
1197 for qt in &query_lower {
1198 #[allow(clippy::cast_precision_loss)]
1199 let tf = tf_map
1200 .iter()
1201 .filter(|(k, _)| k.to_lowercase() == *qt)
1202 .map(|(_, &v)| v)
1203 .sum::<usize>() as f32;
1204
1205 if tf == 0.0 {
1206 continue;
1207 }
1208
1209 let tf_norm = tf / total;
1210 #[allow(clippy::cast_precision_loss)]
1211 let doc_freq = tf_map.keys().filter(|k| k.to_lowercase() == *qt).count() as f32;
1212 let idf = (unique / (1.0 + doc_freq)).ln() + 1.0;
1213
1214 score_sum += tf_norm * idf;
1215 }
1216
1217 #[allow(clippy::cast_precision_loss)]
1218 let max_possible = query_lower.len() as f32;
1219 (score_sum / max_possible).clamp(0.0, 1.0)
1220}
1221
1222fn bigrams<'a>(tokens: &'a [&str]) -> Vec<(&'a str, &'a str)> {
1223 tokens.windows(2).map(|w| (w[0], w[1])).collect()
1224}
1225
1226// ---------------------------------------------------------------------------
1227// v0.7 G9 — concurrent rerank coalescer
1228// ---------------------------------------------------------------------------
1229
1230/// Default upper bound on how many requests we coalesce per BERT call.
1231pub const DEFAULT_MAX_BATCH: usize = 32;
1232
1233/// Default flush latency (ms) — how long the worker waits for more requests
1234/// before processing a non-full batch. 5ms keeps single-request latency
1235/// negligible while still benefiting parallel callers.
1236pub const DEFAULT_MAX_WAIT_MS: u64 = 5;
1237
1238/// #1579 B10 — minimum number of in-flight rerank requests (including
1239/// the current one) before [`BatchedReranker::rerank`] routes through
1240/// the coalescing worker on a *neural* encoder. Below this threshold
1241/// there is nothing to coalesce WITH: the lone caller pays the worker
1242/// channel round-trip plus up to [`DEFAULT_MAX_WAIT_MS`] of flush-window
1243/// wait for zero amortisation gain.
1244///
1245/// **Criterion evidence (perf-audit P1, 2026-06, `cargo bench --bench
1246/// reranker_throughput`, lexical default):** at N=8 concurrent queries
1247/// × 10 candidates the batched path measured ~7.6 ms vs ~0.65 ms direct
1248/// — 12× SLOWER, because the per-batch flush window (5 ms) dwarfs the
1249/// sub-millisecond lexical compute. The lexical variant therefore NEVER
1250/// routes through the worker (it holds no shared-model mutex, so
1251/// coalescing has nothing to amortise at ANY N — see
1252/// [`BatchedReranker::rerank`]); the neural variant keeps the batched
1253/// path at concurrency ≥ this threshold, where the G9 measurement
1254/// showed ~3× throughput gain from holding the BERT mutex once per
1255/// batch instead of once per (query, candidate).
1256pub const BATCHED_RERANK_MIN_CONCURRENCY: usize = 2;
1257
1258/// #1579 B10 — the auto-select predicate, extracted as a free function
1259/// so the threshold arithmetic is unit-testable without standing up a
1260/// worker thread or downloading model weights. `true` ⇒ route through
1261/// the coalescing worker; `false` ⇒ direct encoder call.
1262#[must_use]
1263pub const fn use_batched_rerank_path(encoder_is_neural: bool, inflight_now: usize) -> bool {
1264 encoder_is_neural && inflight_now >= BATCHED_RERANK_MIN_CONCURRENCY
1265}
1266
1267/// Job submitted to the coalescer worker.
1268struct RerankJob {
1269 query: String,
1270 candidates: Vec<(Memory, f64)>,
1271 reply: std::sync::mpsc::SyncSender<Vec<(Memory, f64)>>,
1272}
1273
1274/// Concurrent rerank coalescer.
1275///
1276/// Wraps a `CrossEncoder` and serializes concurrent recall reranks through
1277/// a single worker thread. The worker buffers up to `max_batch` requests
1278/// or waits up to `max_wait_ms` (whichever first), then issues one
1279/// `rerank_batch` call. The Mutex around the BERT model is held for the
1280/// whole batch instead of once per (query, candidate) — the throughput
1281/// fix mandated by G9.
1282///
1283/// **Single-request latency**: the worker flushes immediately when the
1284/// queue is empty after pulling the first job, so a lone request only
1285/// pays one `recv_timeout(0)` round-trip — no artificial waiting.
1286pub struct BatchedReranker {
1287 sender: Option<Sender<RerankJob>>,
1288 /// H2 (v0.7.0 round-2) — explicit one-shot shutdown signal. The
1289 /// worker thread selects on BOTH the work channel and this
1290 /// shutdown channel; receiving on the shutdown channel makes the
1291 /// worker exit its loop deterministically, even if a holder of
1292 /// `sender` happens to outlive `Drop` (e.g. the test harness
1293 /// stashed a `Sender` clone). `Drop` triggers this BEFORE dropping
1294 /// `sender`, so a worker that is currently blocked in
1295 /// `rx.recv()` wakes up via the shutdown channel without waiting
1296 /// for the work-channel disconnect.
1297 shutdown: Option<std::sync::mpsc::Sender<()>>,
1298 worker: Option<JoinHandle<()>>,
1299 /// Direct handle to the underlying encoder, used for the single-query
1300 /// short-circuit and for callers that explicitly want non-batched
1301 /// behavior (tests, benchmarks).
1302 encoder: Arc<CrossEncoder>,
1303 /// v0.7.0 L2-8 — reflection-aware boost config the worker hands
1304 /// down to every batched `rerank` call. Defaults to
1305 /// [`ReflectionBoostConfig::default`] (boost = 1.2) so the daemon
1306 /// flow ships the boost; explicit configuration goes through
1307 /// [`Self::with_reflection_boost`] before the worker starts taking
1308 /// jobs.
1309 reflection_boost: ReflectionBoostConfig,
1310 /// v0.7.0 #1319 — opt-in noise floor applied AFTER the blend
1311 /// (`0.6 * original + 0.4 * ce_score`) and AFTER the reflection
1312 /// boost. Default is [`RerankerScoreFloor::Off`] so existing
1313 /// callers see byte-identical output to pre-#1319. Operators that
1314 /// observed the cross-encoder false-positive ordering on
1315 /// disjoint-vocab paraphrase queries (the v1 P5 probe — an Apollo
1316 /// 11 row at 0.479 surfacing above a substantively-relevant hit at
1317 /// 0.363) opt in via [`Self::with_score_floor`] to drop the
1318 /// low-confidence tail entirely.
1319 score_floor: RerankerScoreFloor,
1320 /// #1579 B10 — number of rerank requests currently inside
1321 /// [`Self::rerank`] (incremented on entry, decremented on exit).
1322 /// Drives the auto-select between the direct encoder call and the
1323 /// coalescing worker; see [`use_batched_rerank_path`].
1324 inflight: std::sync::atomic::AtomicUsize,
1325 /// #1579 B10 — observability counter: how many jobs this wrapper
1326 /// has submitted to the coalescing worker over its lifetime. The
1327 /// auto-select regression tests pin "lexical / lone-caller traffic
1328 /// never reaches the worker" on this counter.
1329 worker_submissions: std::sync::atomic::AtomicUsize,
1330}
1331
1332/// v0.7.0 #1319 — post-blend score floor applied by [`BatchedReranker`].
1333///
1334/// **Default is [`Self::Off`]** — every existing caller observes
1335/// byte-identical pre-#1319 output. Operators who hit the
1336/// paraphrase / disjoint-vocab noise band turn it on via
1337/// [`BatchedReranker::with_score_floor`] (constructor knob) or
1338/// through the resolver-side `[reranker].score_floor*` config fields
1339/// once they land.
1340///
1341/// **Why two shapes.** [`Self::Absolute`] is the literal "drop
1342/// anything below 0.5" handle the recall caller's documentation
1343/// suggests. [`Self::RelativeToTop`] keeps the top-of-list always
1344/// available — useful when the corpus is small (a 3-row recall
1345/// shouldn't return zero results just because every row scored
1346/// `0.42`) and the operator just wants a "tail cleaner".
1347///
1348/// Both variants compare the **final blended score** (after the L2-8
1349/// reflection boost), not the raw cross-encoder logit, so the floor
1350/// is comparable to the values an operator reads off `recall.memories[].score`.
1351#[derive(Debug, Clone, Copy, PartialEq)]
1352pub enum RerankerScoreFloor {
1353 /// No floor — pre-#1319 behavior. Every blended candidate is kept,
1354 /// regardless of score. Default.
1355 Off,
1356 /// Drop every candidate whose final blended score falls strictly
1357 /// below the supplied absolute value. Clamped at runtime to
1358 /// `[0.0, 1.0]`.
1359 Absolute(f64),
1360 /// Drop every candidate whose final blended score falls strictly
1361 /// below `top_score * ratio` (where `top_score` is the first row
1362 /// after sorting). Clamped at runtime to `[0.0, 1.0]`. The top
1363 /// row itself is never dropped — operators get at least one
1364 /// result even when the entire ranked set is in the noise band.
1365 RelativeToTop(f64),
1366}
1367
1368impl Default for RerankerScoreFloor {
1369 fn default() -> Self {
1370 Self::Off
1371 }
1372}
1373
1374impl RerankerScoreFloor {
1375 /// Apply the floor in-place to a pre-sorted (descending) vector
1376 /// of `(Memory, blended_score)` candidates. The implementation is
1377 /// extracted as a free helper so unit tests can pin the cutoff
1378 /// arithmetic without spinning up a [`BatchedReranker`].
1379 ///
1380 /// The top row is always preserved (so a tiny corpus never
1381 /// returns zero results) — see [`RerankerScoreFloor::RelativeToTop`]
1382 /// documentation for the rationale.
1383 fn apply(&self, scored: &mut Vec<(Memory, f64)>) {
1384 if scored.is_empty() {
1385 return;
1386 }
1387 let cutoff: f64 = match *self {
1388 Self::Off => return,
1389 Self::Absolute(v) => v.clamp(0.0, 1.0),
1390 Self::RelativeToTop(ratio) => {
1391 let top = scored.first().map(|(_, s)| *s).unwrap_or(0.0);
1392 top * ratio.clamp(0.0, 1.0)
1393 }
1394 };
1395 // Walk index-first so we can preserve the top row even when
1396 // its score sits below `cutoff` (small-corpus invariant: the
1397 // floor is a tail cleaner, not a "return nothing" knob).
1398 let mut keep = Vec::with_capacity(scored.len());
1399 for (idx, (_, score)) in scored.iter().enumerate() {
1400 if idx == 0 || *score >= cutoff {
1401 keep.push(idx);
1402 }
1403 }
1404 // `keep` is monotonically increasing; iterate in reverse and
1405 // remove dropped indices so the Vec retains the descending
1406 // sort order from the upstream rerank.
1407 let mut next_keep = keep.iter().rev().copied();
1408 let mut want = next_keep.next();
1409 let mut idx = scored.len();
1410 while idx > 0 {
1411 idx -= 1;
1412 match want {
1413 Some(k) if k == idx => {
1414 want = next_keep.next();
1415 }
1416 _ => {
1417 scored.remove(idx);
1418 }
1419 }
1420 }
1421 }
1422}
1423
1424impl BatchedReranker {
1425 /// Wrap an existing `CrossEncoder` with the default batching parameters
1426 /// (`max_batch = 32`, `max_wait_ms = 5`).
1427 pub fn new(encoder: CrossEncoder) -> Self {
1428 Self::with_params(encoder, DEFAULT_MAX_BATCH, DEFAULT_MAX_WAIT_MS)
1429 }
1430
1431 /// Wrap an existing `CrossEncoder` with custom batching parameters.
1432 pub fn with_params(encoder: CrossEncoder, max_batch: usize, max_wait_ms: u64) -> Self {
1433 Self::with_full_params(
1434 encoder,
1435 max_batch,
1436 max_wait_ms,
1437 ReflectionBoostConfig::default(),
1438 RerankerScoreFloor::Off,
1439 )
1440 }
1441
1442 /// v0.7.0 L2-8 — wrap an existing `CrossEncoder` with a custom
1443 /// reflection-boost config alongside default batching parameters.
1444 /// Used by the recall integration tests to pin specific boost shapes
1445 /// (e.g. `disabled()` for the regression test).
1446 pub fn with_reflection_boost(encoder: CrossEncoder, boost: ReflectionBoostConfig) -> Self {
1447 Self::with_full_params(
1448 encoder,
1449 DEFAULT_MAX_BATCH,
1450 DEFAULT_MAX_WAIT_MS,
1451 boost,
1452 RerankerScoreFloor::Off,
1453 )
1454 }
1455
1456 /// v0.7.0 #1319 — wrap a `CrossEncoder` with a post-blend score
1457 /// floor. The reflection-boost knob is left at the daemon default
1458 /// (`1.2`); use [`Self::with_full_params`] to set both at once.
1459 /// **Default constructors leave the floor `Off`** — flipping it on
1460 /// here is an explicit operator-opt-in.
1461 #[must_use]
1462 pub fn with_score_floor(encoder: CrossEncoder, floor: RerankerScoreFloor) -> Self {
1463 Self::with_full_params(
1464 encoder,
1465 DEFAULT_MAX_BATCH,
1466 DEFAULT_MAX_WAIT_MS,
1467 ReflectionBoostConfig::default(),
1468 floor,
1469 )
1470 }
1471
1472 /// Internal constructor — all knobs visible.
1473 fn with_full_params(
1474 encoder: CrossEncoder,
1475 max_batch: usize,
1476 max_wait_ms: u64,
1477 reflection_boost: ReflectionBoostConfig,
1478 score_floor: RerankerScoreFloor,
1479 ) -> Self {
1480 let encoder = Arc::new(encoder);
1481 let (tx, rx) = std::sync::mpsc::channel::<RerankJob>();
1482 // H2 (v0.7.0 round-2) — one-shot shutdown channel. The std
1483 // mpsc channel is used as a "oneshot": we never send more
1484 // than one value, and the worker exits on the first
1485 // `try_recv()` success OR on disconnect (Drop of the holder
1486 // closes the sender side, which also surfaces as a recv
1487 // outcome the worker can branch on).
1488 let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>();
1489 let worker_encoder = Arc::clone(&encoder);
1490 let worker_boost = reflection_boost;
1491 let max_wait = Duration::from_millis(max_wait_ms);
1492
1493 let worker = thread::Builder::new()
1494 .name("ai-memory-reranker-batcher".into())
1495 .spawn(move || {
1496 // H2 polling cadence: when waiting for the first job
1497 // of a batch, fall back to `recv_timeout` so the worker
1498 // wakes up periodically to check the shutdown signal.
1499 // 100ms keeps the test in `test_drop_terminates_worker`
1500 // comfortably inside its 500ms budget while staying
1501 // well below the 5ms intra-batch coalescing window
1502 // (no cost to the hot path).
1503 const SHUTDOWN_POLL: Duration = Duration::from_millis(100);
1504 'outer: loop {
1505 // Block until the first job arrives OR the
1506 // shutdown signal fires OR the sender drops.
1507 let first = loop {
1508 // Cheap non-blocking shutdown check first so a
1509 // signal that arrived between iterations is
1510 // observed even if the work channel had a job
1511 // queued before the signal landed.
1512 match shutdown_rx.try_recv() {
1513 Ok(()) | Err(std::sync::mpsc::TryRecvError::Disconnected) => {
1514 break 'outer;
1515 }
1516 Err(std::sync::mpsc::TryRecvError::Empty) => {}
1517 }
1518 match rx.recv_timeout(SHUTDOWN_POLL) {
1519 Ok(job) => break job,
1520 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => continue,
1521 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
1522 break 'outer;
1523 }
1524 }
1525 };
1526
1527 let mut batch: Vec<RerankJob> = Vec::with_capacity(max_batch);
1528 batch.push(first);
1529
1530 // Coalesce additional jobs that arrive within the
1531 // window, up to the batch cap.
1532 let deadline = Instant::now() + max_wait;
1533 while batch.len() < max_batch {
1534 let now = Instant::now();
1535 if now >= deadline {
1536 break;
1537 }
1538 match rx.recv_timeout(deadline - now) {
1539 Ok(j) => batch.push(j),
1540 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => break,
1541 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
1542 // Drain the current batch then exit.
1543 process_batch(&worker_encoder, batch, &worker_boost);
1544 break 'outer;
1545 }
1546 }
1547 }
1548
1549 process_batch(&worker_encoder, batch, &worker_boost);
1550 }
1551 })
1552 .expect("failed to spawn rerank batcher worker");
1553
1554 Self {
1555 sender: Some(tx),
1556 shutdown: Some(shutdown_tx),
1557 worker: Some(worker),
1558 encoder,
1559 reflection_boost,
1560 score_floor,
1561 inflight: std::sync::atomic::AtomicUsize::new(0),
1562 worker_submissions: std::sync::atomic::AtomicUsize::new(0),
1563 }
1564 }
1565
1566 /// Submit a single rerank request. Blocks until the result is
1567 /// available.
1568 ///
1569 /// #1579 B10 — **auto-select.** The wrapper keeps BOTH execution
1570 /// paths and picks per call via [`use_batched_rerank_path`]:
1571 ///
1572 /// - **Direct** (no worker round-trip) when the encoder is
1573 /// lexical / degraded-lexical (no shared-model mutex to
1574 /// amortise — criterion proved the coalescing flush window made
1575 /// the batched path 12× slower at N=8: ~7.6 ms vs ~0.65 ms), or
1576 /// when fewer than [`BATCHED_RERANK_MIN_CONCURRENCY`] requests
1577 /// are in flight (nothing to coalesce with).
1578 /// - **Coalesced** (worker thread, one `rerank_batch` per flush)
1579 /// for neural encoders under real concurrency — the G9 win
1580 /// (~3× at N=8 neural) is preserved.
1581 ///
1582 /// If the worker is unavailable for any reason (channel closed),
1583 /// falls back to a direct `rerank` call on the underlying encoder
1584 /// (with the wrapper's configured reflection boost applied).
1585 pub fn rerank(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
1586 let mut scored = self.rerank_unfloored(query, candidates);
1587 // v0.7.0 #1319 — post-blend score floor (default Off; opt-in
1588 // via `with_score_floor`). Applies to the already-sorted
1589 // descending vector returned by the encoder/worker.
1590 self.score_floor.apply(&mut scored);
1591 scored
1592 }
1593
1594 /// #1579 B10 — force the COALESCED (worker) path regardless of the
1595 /// auto-select. Kept public so the throughput bench
1596 /// (`benches/reranker_throughput.rs`) and regression tests can keep
1597 /// measuring the raw batched machinery after `rerank` started
1598 /// auto-selecting away from it at small N. Applies the same
1599 /// post-blend score floor as [`Self::rerank`].
1600 #[must_use]
1601 pub fn rerank_coalesced(
1602 &self,
1603 query: &str,
1604 candidates: Vec<(Memory, f64)>,
1605 ) -> Vec<(Memory, f64)> {
1606 let mut scored = self.rerank_coalesced_unfloored(query, candidates);
1607 self.score_floor.apply(&mut scored);
1608 scored
1609 }
1610
1611 /// Internal — same shape as [`Self::rerank`] but skips the
1612 /// post-blend score floor. Pre-#1319 callsites that explicitly
1613 /// want the raw blended output (regression tests, the byte-equal
1614 /// pin in `g9_batched_reranker_serial_calls_match_rerank`) call
1615 /// this directly.
1616 fn rerank_unfloored(&self, query: &str, candidates: Vec<(Memory, f64)>) -> Vec<(Memory, f64)> {
1617 use std::sync::atomic::Ordering;
1618 // #1579 B10 — RAII in-flight guard so a panicking encoder call
1619 // can't leak the counter and wedge the auto-select high.
1620 struct InflightGuard<'a>(&'a std::sync::atomic::AtomicUsize);
1621 impl Drop for InflightGuard<'_> {
1622 fn drop(&mut self) {
1623 self.0.fetch_sub(1, Ordering::Relaxed);
1624 }
1625 }
1626 let inflight_now = self.inflight.fetch_add(1, Ordering::Relaxed) + 1;
1627 let _guard = InflightGuard(&self.inflight);
1628
1629 if use_batched_rerank_path(self.encoder.is_neural(), inflight_now) {
1630 self.rerank_coalesced_unfloored(query, candidates)
1631 } else {
1632 self.rerank_direct_unfloored(query, candidates)
1633 }
1634 }
1635
1636 /// #1579 B10 — the DIRECT path: one synchronous encoder call on the
1637 /// caller's thread, no worker round-trip, no flush-window wait.
1638 fn rerank_direct_unfloored(
1639 &self,
1640 query: &str,
1641 candidates: Vec<(Memory, f64)>,
1642 ) -> Vec<(Memory, f64)> {
1643 self.encoder
1644 .rerank_with_reflection_boost(query, candidates, &self.reflection_boost)
1645 }
1646
1647 /// The COALESCED path: submit to the worker thread and block for
1648 /// the reply. Concurrent callers are coalesced into a single
1649 /// `rerank_batch` call inside the worker. (Pre-#1579-B10 this was
1650 /// the body of `rerank_unfloored`.)
1651 fn rerank_coalesced_unfloored(
1652 &self,
1653 query: &str,
1654 candidates: Vec<(Memory, f64)>,
1655 ) -> Vec<(Memory, f64)> {
1656 let Some(sender) = self.sender.as_ref() else {
1657 return self.rerank_direct_unfloored(query, candidates);
1658 };
1659 let (reply_tx, reply_rx) = sync_channel::<Vec<(Memory, f64)>>(1);
1660 let job = RerankJob {
1661 query: query.to_string(),
1662 candidates,
1663 reply: reply_tx,
1664 };
1665 if sender.send(job).is_err() {
1666 return self.encoder.rerank_with_reflection_boost(
1667 query,
1668 Vec::new(),
1669 &self.reflection_boost,
1670 );
1671 }
1672 self.worker_submissions
1673 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1674 reply_rx.recv().unwrap_or_else(|_| {
1675 self.encoder
1676 .rerank_with_reflection_boost(query, Vec::new(), &self.reflection_boost)
1677 })
1678 }
1679
1680 /// #1579 B10 — lifetime count of jobs submitted to the coalescing
1681 /// worker. Observability hook for the auto-select regression tests
1682 /// ("lexical traffic never reaches the worker") and operator
1683 /// diagnostics.
1684 #[must_use]
1685 pub fn worker_submissions(&self) -> usize {
1686 self.worker_submissions
1687 .load(std::sync::atomic::Ordering::Relaxed)
1688 }
1689
1690 /// v0.7.0 #1319 — expose the configured score floor for the
1691 /// `memory_capabilities` reporter and for operator-facing
1692 /// diagnostics.
1693 #[must_use]
1694 pub fn score_floor(&self) -> RerankerScoreFloor {
1695 self.score_floor
1696 }
1697
1698 /// v0.7.0 L2-8 — expose the configured boost for the
1699 /// `memory_capabilities` reporter.
1700 #[must_use]
1701 pub fn reflection_boost(&self) -> &ReflectionBoostConfig {
1702 &self.reflection_boost
1703 }
1704
1705 /// Direct access to the wrapped encoder. Useful for callers that
1706 /// want to bypass the coalescer (tests, benchmarks).
1707 pub fn encoder(&self) -> &CrossEncoder {
1708 &self.encoder
1709 }
1710
1711 /// Convenience shortcut for `self.encoder().is_neural()`. Most
1712 /// callers in the recall pipeline only need to check the variant
1713 /// for capability reporting.
1714 pub fn is_neural(&self) -> bool {
1715 self.encoder.is_neural()
1716 }
1717
1718 /// v0.7.0 R3-S2 — shortcut for `self.encoder().is_degraded_lexical()`.
1719 /// The recall path reads this to drive the in-band `reranker_used`
1720 /// signal exposed via `RecallMeta`.
1721 #[must_use]
1722 pub fn is_degraded_lexical(&self) -> bool {
1723 self.encoder.is_degraded_lexical()
1724 }
1725}
1726
1727impl Drop for BatchedReranker {
1728 fn drop(&mut self) {
1729 // H2 (v0.7.0 round-2): two-step termination.
1730 //
1731 // 1. Fire the explicit shutdown signal FIRST so the worker
1732 // observes it even when another holder of `Sender`
1733 // (e.g. a test that cloned the work channel) would
1734 // otherwise keep the work channel alive.
1735 // 2. Then drop the work-channel sender — a worker that was
1736 // blocked in `rx.recv_timeout(...)` wakes up either via
1737 // the shutdown poll OR the disconnect, whichever
1738 // happens first.
1739 //
1740 // Joining the worker after BOTH signals fire bounds shutdown
1741 // by the SHUTDOWN_POLL cadence (100ms) in the absolute worst
1742 // case, well inside the 500ms budget exercised by
1743 // `test_drop_terminates_worker`.
1744 if let Some(shutdown) = self.shutdown.take() {
1745 let _ = shutdown.send(());
1746 }
1747 self.sender.take();
1748 if let Some(handle) = self.worker.take() {
1749 let _ = handle.join();
1750 }
1751 }
1752}
1753
1754fn process_batch(
1755 encoder: &CrossEncoder,
1756 batch: Vec<RerankJob>,
1757 boost_config: &ReflectionBoostConfig,
1758) {
1759 if batch.is_empty() {
1760 return;
1761 }
1762
1763 // Single-request fast path: bypass the batched API to avoid the
1764 // padding overhead and any latency regression on lone callers.
1765 if batch.len() == 1 {
1766 let mut iter = batch.into_iter();
1767 let job = iter.next().expect("len == 1");
1768 let result = encoder.rerank_with_reflection_boost(&job.query, job.candidates, boost_config);
1769 let _ = job.reply.send(result);
1770 return;
1771 }
1772
1773 // Build the input vector for the batched call. Use placeholder
1774 // `Memory` clones via `take` to avoid copying — we move out.
1775 let mut queries: Vec<(String, Vec<(Memory, f64)>)> = Vec::with_capacity(batch.len());
1776 let mut replies: Vec<std::sync::mpsc::SyncSender<Vec<(Memory, f64)>>> =
1777 Vec::with_capacity(batch.len());
1778 for job in batch {
1779 queries.push((job.query, job.candidates));
1780 replies.push(job.reply);
1781 }
1782
1783 let outputs = encoder.rerank_batch_with_reflection_boost(queries, boost_config);
1784 for (out, reply) in outputs.into_iter().zip(replies.into_iter()) {
1785 let _ = reply.send(out);
1786 }
1787}
1788
1789// ---------------------------------------------------------------------------
1790// Tests
1791// ---------------------------------------------------------------------------
1792
1793#[cfg(test)]
1794mod tests {
1795 use super::*;
1796 use crate::models::{Memory, Tier};
1797
1798 /// #1604 — process-wide rerank sequence-cap seeding: the first
1799 /// [`set_rerank_max_seq`] writer wins and later writes are no-ops.
1800 ///
1801 /// Order-independent by construction: other tests in this binary
1802 /// may legitimately seed the process-wide `OnceLock` first (any
1803 /// test that walks the `daemon_runtime` boot ladder does), so this
1804 /// test asserts only the post-seed immutability contract — it
1805 /// seeds (or observes the earlier seed), then proves a second
1806 /// write cannot change the value. The unseeded-default fallback
1807 /// is pinned by `resolve_reranker_1604_max_seq_ladder` (resolver
1808 /// layer, no OnceLock) instead. The pre-fix form asserted the
1809 /// unseeded default first and was order-dependent — green locally,
1810 /// red under CI's impact-aware test ordering.
1811 #[test]
1812 fn rerank_max_seq_1604_seed_once_semantics() {
1813 set_rerank_max_seq(192);
1814 let settled = rerank_max_seq();
1815 assert!(
1816 settled > 0,
1817 "settled value must be a real cap (ours or an earlier boot seed), got {settled}"
1818 );
1819 set_rerank_max_seq(64);
1820 assert_eq!(
1821 rerank_max_seq(),
1822 settled,
1823 "first writer must win — a later set_rerank_max_seq call must be a no-op"
1824 );
1825 }
1826
1827 fn make_memory(title: &str, content: &str) -> Memory {
1828 Memory {
1829 id: "test-id".to_string(),
1830 tier: Tier::Mid,
1831 namespace: "test".to_string(),
1832 title: title.to_string(),
1833 content: content.to_string(),
1834 tags: vec![],
1835 priority: 5,
1836 confidence: 1.0,
1837 source: "test".to_string(),
1838 access_count: 0,
1839 created_at: "2026-01-01T00:00:00Z".to_string(),
1840 updated_at: "2026-01-01T00:00:00Z".to_string(),
1841 last_accessed_at: None,
1842 expires_at: None,
1843 metadata: serde_json::json!({}),
1844 reflection_depth: 0,
1845 memory_kind: crate::models::MemoryKind::Observation,
1846 entity_id: None,
1847 persona_version: None,
1848 citations: Vec::new(),
1849 source_uri: None,
1850 source_span: None,
1851 confidence_source: crate::models::ConfidenceSource::CallerProvided,
1852 confidence_signals: None,
1853 confidence_decayed_at: None,
1854 version: 1,
1855 }
1856 }
1857
1858 /// #1531 M13 — a NaN original score must not nondeterministically
1859 /// hold the top rank. Pre-fix, the blended NaN compared `Equal` to
1860 /// every finite score under `partial_cmp(..).unwrap_or(Equal)`, so
1861 /// the stable sort left the poisoned candidate in its input
1862 /// position (here: first). Post-fix non-finite scores clamp to
1863 /// `f64::MIN` and sink to the bottom.
1864 #[test]
1865 fn nan_scored_candidate_sinks_to_bottom_m13() {
1866 let ce = CrossEncoder::Lexical { degraded: false };
1867 let poisoned = make_memory("poisoned", "irrelevant body");
1868 let good = make_memory("network configuration", "network configuration body");
1869 let out = ce.rerank(
1870 "network configuration",
1871 vec![(poisoned, f64::NAN), (good, 0.9)],
1872 );
1873 assert_eq!(
1874 out[0].0.title, "network configuration",
1875 "finite-scored candidate must outrank the NaN-poisoned one"
1876 );
1877 assert_eq!(out[1].0.title, "poisoned");
1878 assert_eq!(
1879 out[1].1,
1880 f64::MIN,
1881 "non-finite blended score must clamp to the ranking floor"
1882 );
1883
1884 // Boost-aware path takes the same clamp.
1885 let poisoned = make_memory("poisoned", "irrelevant body");
1886 let good = make_memory("network configuration", "network configuration body");
1887 let out = ce.rerank_with_reflection_boost(
1888 "network configuration",
1889 vec![(poisoned, f64::NAN), (good, 0.9)],
1890 &ReflectionBoostConfig::disabled(),
1891 );
1892 assert_eq!(out[0].0.title, "network configuration");
1893 assert_eq!(out[1].1, f64::MIN);
1894 }
1895
1896 #[test]
1897 fn lexical_score_returns_zero_for_empty_query() {
1898 assert_eq!(lexical_score("", "some title", "some content"), 0.0);
1899 }
1900
1901 #[test]
1902 fn lexical_score_returns_zero_for_no_overlap() {
1903 let s = lexical_score("quantum physics", "grocery list", "milk eggs bread butter");
1904 assert!(s < 0.05, "expected near-zero, got {s}");
1905 }
1906
1907 #[test]
1908 fn lexical_score_rewards_title_match() {
1909 let content = "This document discusses network configuration for LAN setups.";
1910 let s_title_match = lexical_score(
1911 "network configuration",
1912 "Network Configuration Guide",
1913 content,
1914 );
1915 let s_no_title = lexical_score("network configuration", "Unrelated Title", content);
1916 assert!(
1917 s_title_match > s_no_title,
1918 "title match ({s_title_match}) should beat no title match ({s_no_title})"
1919 );
1920 }
1921
1922 #[test]
1923 fn lexical_score_is_bounded_zero_one() {
1924 let s = lexical_score(
1925 "the quick brown fox jumps over the lazy dog",
1926 "the quick brown fox",
1927 "the quick brown fox jumps over the lazy dog and more words",
1928 );
1929 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
1930 }
1931
1932 #[test]
1933 fn rerank_reorders_candidates() {
1934 let ce = CrossEncoder::new();
1935 let a = make_memory("Rust cross-encoder", "cross-encoder reranking for search");
1936 let b = make_memory("Grocery list", "milk eggs bread butter cheese");
1937 let candidates = vec![(b.clone(), 0.55), (a.clone(), 0.45)];
1938 let reranked = ce.rerank("cross-encoder reranking", candidates);
1939 assert_eq!(reranked[0].0.title, "Rust cross-encoder");
1940 }
1941
1942 #[test]
1943 fn rerank_preserves_candidate_count() {
1944 let ce = CrossEncoder::new();
1945 let candidates = vec![
1946 (make_memory("A", "alpha"), 0.5),
1947 (make_memory("B", "beta"), 0.6),
1948 (make_memory("C", "gamma"), 0.7),
1949 ];
1950 let reranked = ce.rerank("alpha", candidates);
1951 assert_eq!(reranked.len(), 3);
1952 }
1953
1954 #[test]
1955 fn bigram_overlap_boosts_phrase_match() {
1956 let s_phrase = lexical_score(
1957 "network adapter",
1958 "title",
1959 "the network adapter is connected to the LAN",
1960 );
1961 let s_scattered = lexical_score(
1962 "network adapter",
1963 "title",
1964 "the adapter handles the network traffic independently",
1965 );
1966 assert!(
1967 s_phrase > s_scattered,
1968 "phrase match ({s_phrase}) should beat scattered ({s_scattered})"
1969 );
1970 }
1971
1972 // -----------------------------------------------------------------
1973 // W11/S11b — input-count invariants for the rerank() API
1974 // -----------------------------------------------------------------
1975
1976 #[test]
1977 fn test_rerank_preserves_input_count_heuristic() {
1978 let ce = CrossEncoder::new();
1979 // Build 5 distinct candidates with varied original scores.
1980 let candidates: Vec<(Memory, f64)> = (0..5)
1981 .map(|i| {
1982 (
1983 make_memory(
1984 &format!("title {i}"),
1985 &format!("content body number {i} with some words"),
1986 ),
1987 f64::from(i) * 0.1,
1988 )
1989 })
1990 .collect();
1991 let query = "title content body";
1992 let reranked = ce.rerank(query, candidates);
1993 assert_eq!(
1994 reranked.len(),
1995 5,
1996 "heuristic rerank must preserve candidate count, got {} = {:?}",
1997 reranked.len(),
1998 reranked
1999 .iter()
2000 .map(|(m, s)| (&m.title, *s))
2001 .collect::<Vec<_>>()
2002 );
2003 // Sorted descending by final score (rerank contract).
2004 for w in reranked.windows(2) {
2005 assert!(
2006 w[0].1 >= w[1].1,
2007 "rerank output must be descending by score: {} < {}",
2008 w[0].1,
2009 w[1].1
2010 );
2011 }
2012 }
2013
2014 #[test]
2015 fn test_rerank_zero_candidates_returns_empty_heuristic() {
2016 let ce = CrossEncoder::new();
2017 let reranked = ce.rerank("query", Vec::new());
2018 assert!(reranked.is_empty());
2019 }
2020
2021 // Neural variant: gated to avoid pulling 80MB BERT weights at test time.
2022 // Run with `--features test-with-models` once the cross-encoder feature
2023 // exists upstream.
2024 #[cfg(feature = "test-with-models")]
2025 #[test]
2026 fn test_rerank_preserves_input_count_neural_if_available() {
2027 let ce = CrossEncoder::new_neural();
2028 let candidates: Vec<(Memory, f64)> = (0..5)
2029 .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
2030 .collect();
2031 let reranked = ce.rerank("body", candidates);
2032 assert_eq!(reranked.len(), 5);
2033 }
2034
2035 // -----------------------------------------------------------------
2036 // W12-E — heuristic-path branch coverage for reranker.rs
2037 //
2038 // Targets the Lexical variant only. The Neural variant requires
2039 // downloading 80+ MB of BERT weights from HuggingFace Hub and is
2040 // gated behind `feature = "test-with-models"`.
2041 // -----------------------------------------------------------------
2042
2043 #[test]
2044 fn w12e_default_is_lexical() {
2045 let ce = CrossEncoder::default();
2046 assert!(!ce.is_neural(), "Default::default() must return Lexical");
2047 }
2048
2049 #[test]
2050 fn w12e_new_returns_lexical() {
2051 let ce = CrossEncoder::new();
2052 assert!(!ce.is_neural());
2053 }
2054
2055 #[test]
2056 fn w12e_score_dispatch_lexical_matches_helper() {
2057 // The CrossEncoder::score() dispatcher must delegate to lexical_score()
2058 // for the Lexical variant. Compute both and assert exact equality.
2059 let ce = CrossEncoder::new();
2060 let q = "rust async runtime";
2061 let title = "Tokio: Rust async runtime";
2062 let content = "Tokio is an async runtime for the Rust programming language.";
2063 let via_dispatcher = ce.score(q, title, content);
2064 let direct = lexical_score(q, title, content);
2065 assert!((via_dispatcher - direct).abs() < f32::EPSILON);
2066 }
2067
2068 #[test]
2069 fn w12e_score_empty_inputs_safe() {
2070 let ce = CrossEncoder::new();
2071 // Empty query → 0.0 by short-circuit in lexical_score
2072 assert_eq!(ce.score("", "title", "content"), 0.0);
2073 // Empty title and content with non-empty query — must not panic
2074 let s = ce.score("query", "", "");
2075 assert!((0.0..=1.0).contains(&s));
2076 // Whitespace-only query treated as empty after tokenization
2077 let s_ws = ce.score(" \t\n", "title", "content");
2078 assert_eq!(s_ws, 0.0);
2079 // Punctuation-only query also yields no tokens
2080 let s_punct = ce.score("!?.,;:", "title", "content");
2081 assert_eq!(s_punct, 0.0);
2082 }
2083
2084 #[test]
2085 fn w12e_lexical_score_is_bounded_for_unicode_and_long() {
2086 // Mixed Unicode tokens with apostrophes, accents, emoji boundaries.
2087 let s_unicode = lexical_score(
2088 "café résumé d'oeuvre",
2089 "Le Café d'Oeuvre",
2090 "résumé du café avec d'oeuvre noté",
2091 );
2092 assert!(
2093 (0.0..=1.0).contains(&s_unicode),
2094 "unicode score {s_unicode} out of bounds"
2095 );
2096
2097 // Very long content stresses the length-normalization branches.
2098 let huge = "alpha beta gamma delta ".repeat(2_500);
2099 let s_long = lexical_score("alpha gamma", "headline", &huge);
2100 assert!(
2101 (0.0..=1.0).contains(&s_long),
2102 "long score {s_long} out of bounds"
2103 );
2104 }
2105
2106 #[test]
2107 fn w12e_lexical_score_perfect_overlap_high() {
2108 // 100% query overlap with title and content should produce a high
2109 // (but bounded) score.
2110 let s = lexical_score(
2111 "alpha beta gamma",
2112 "alpha beta gamma",
2113 "alpha beta gamma alpha beta gamma",
2114 );
2115 assert!(s > 0.5, "expected high score for perfect overlap, got {s}");
2116 assert!(s <= 1.0);
2117 }
2118
2119 #[test]
2120 fn w12e_tfidf_score_empty_doc_returns_zero() {
2121 // Branch: doc_tokens.is_empty() → 0.0 short-circuit.
2122 let q = vec!["alpha", "beta"];
2123 let doc: Vec<&str> = Vec::new();
2124 assert_eq!(tfidf_score(&q, &doc), 0.0);
2125 }
2126
2127 #[test]
2128 fn w12e_tfidf_score_empty_query_returns_zero() {
2129 // Branch: query_terms.is_empty() → 0.0 short-circuit.
2130 let q: Vec<&str> = Vec::new();
2131 let doc = vec!["alpha", "beta", "gamma"];
2132 assert_eq!(tfidf_score(&q, &doc), 0.0);
2133 }
2134
2135 #[test]
2136 fn w12e_tfidf_score_no_matching_terms() {
2137 // Query terms entirely absent from doc → tf == 0 continue branch.
2138 let q = vec!["xenon", "kryptonite"];
2139 let doc = vec!["alpha", "beta", "gamma"];
2140 let s = tfidf_score(&q, &doc);
2141 assert_eq!(s, 0.0);
2142 }
2143
2144 #[test]
2145 fn w12e_tfidf_score_partial_match_bounded() {
2146 // Mixed presence/absence; clamp branch reachable.
2147 let q = vec!["alpha", "missing"];
2148 let doc = vec!["alpha", "alpha", "beta", "gamma"];
2149 let s = tfidf_score(&q, &doc);
2150 assert!((0.0..=1.0).contains(&s));
2151 assert!(s > 0.0);
2152 }
2153
2154 #[test]
2155 fn w12e_bigrams_empty_and_single_and_multi() {
2156 // Empty input → empty bigram list.
2157 let empty: Vec<&str> = Vec::new();
2158 assert!(bigrams(&empty).is_empty());
2159
2160 // Single token → no bigrams (windows(2) yields nothing).
2161 let one = vec!["solo"];
2162 assert!(bigrams(&one).is_empty());
2163
2164 // Multi-token → N-1 bigrams.
2165 let three = vec!["a", "b", "c"];
2166 let bg = bigrams(&three);
2167 assert_eq!(bg, vec![("a", "b"), ("b", "c")]);
2168 }
2169
2170 #[test]
2171 fn w12e_tokenize_handles_apostrophe_and_unicode() {
2172 // Apostrophes are preserved (e.g., "don't"), other punctuation splits.
2173 let toks = tokenize("don't stop, I won't!");
2174 assert!(toks.contains(&"don't"));
2175 assert!(toks.contains(&"won't"));
2176 assert!(toks.contains(&"stop"));
2177 assert!(toks.contains(&"I"));
2178
2179 // Pure-punctuation yields no tokens.
2180 let none = tokenize("!!!,,,;;;");
2181 assert!(none.is_empty());
2182
2183 // Empty string yields no tokens.
2184 let empty = tokenize("");
2185 assert!(empty.is_empty());
2186
2187 // Unicode alphanumerics survive (café = 4 alphanumeric chars).
2188 let unicode = tokenize("café résumé");
2189 assert_eq!(unicode.len(), 2);
2190 }
2191
2192 #[test]
2193 fn w12e_rerank_single_candidate_keeps_it() {
2194 let ce = CrossEncoder::new();
2195 let only = make_memory("solo title", "solo content body");
2196 let out = ce.rerank("solo", vec![(only.clone(), 0.42)]);
2197 assert_eq!(out.len(), 1);
2198 assert_eq!(out[0].0.title, "solo title");
2199 // Final score is a blend of original and CE score, both nonneg.
2200 assert!(out[0].1 >= 0.0);
2201 }
2202
2203 #[test]
2204 fn w12e_rerank_identical_originals_stable_under_score() {
2205 // When original scores are identical, ordering is determined by the
2206 // CE score. The candidate whose title/content overlaps the query
2207 // should rank first.
2208 let ce = CrossEncoder::new();
2209 let on_topic = make_memory("rust async runtime", "rust async runtime tokio");
2210 let off_topic = make_memory("grocery", "milk eggs bread");
2211 let out = ce.rerank(
2212 "rust async",
2213 vec![(off_topic.clone(), 0.5), (on_topic.clone(), 0.5)],
2214 );
2215 assert_eq!(out.len(), 2);
2216 assert_eq!(out[0].0.title, "rust async runtime");
2217 }
2218
2219 #[test]
2220 fn w12e_rerank_descending_invariant_holds_across_shapes() {
2221 // Property-style: irrespective of input shape, output is sorted desc.
2222 let ce = CrossEncoder::new();
2223 let cands: Vec<(Memory, f64)> = vec![
2224 (make_memory("a", "alpha words"), 0.10),
2225 (make_memory("b", "beta words"), 0.95),
2226 (make_memory("c", "gamma alpha"), 0.55),
2227 (make_memory("d", ""), 0.0),
2228 (make_memory("", "empty title doc"), 0.30),
2229 ];
2230 let out = ce.rerank("alpha", cands);
2231 assert_eq!(out.len(), 5);
2232 for w in out.windows(2) {
2233 assert!(
2234 w[0].1 >= w[1].1,
2235 "non-descending pair: {} then {}",
2236 w[0].1,
2237 w[1].1
2238 );
2239 }
2240 }
2241
2242 #[test]
2243 fn w12e_lexical_score_no_title_branch_via_empty_title() {
2244 // Empty title means title_set is empty; title_bonus == 0.0.
2245 // query_set non-empty so the else branch (title_hits / |Q|) runs.
2246 let s_empty_title = lexical_score("alpha beta", "", "alpha beta gamma");
2247 let s_with_title = lexical_score("alpha beta", "alpha beta", "alpha beta gamma");
2248 assert!(s_with_title >= s_empty_title);
2249 assert!((0.0..=1.0).contains(&s_empty_title));
2250 }
2251
2252 #[test]
2253 fn w12e_lexical_score_query_terms_only_in_title() {
2254 // Title contains all query terms; content has none.
2255 let s = lexical_score("rust crate", "Rust Crate Index", "unrelated body text");
2256 assert!(s > 0.0);
2257 assert!(s <= 1.0);
2258 }
2259
2260 // PR-9i — buffer coverage uplift.
2261
2262 #[test]
2263 fn pr9i_new_neural_dual_outcome() {
2264 // Exercises CrossEncoder::new_neural() (lines 65-79). Behavior is
2265 // environment-dependent: with an HF cache or network the call
2266 // succeeds and returns Self::Neural; without either it falls back
2267 // to Self::Lexical via the documented eprintln + tracing warn
2268 // pathway. Both outcomes are acceptable — what matters is the
2269 // dispatch is hit. Functionally, both variants score within
2270 // [0.0, 1.0].
2271 let ce = CrossEncoder::new_neural();
2272 let s = ce.score("query", "title", "content");
2273 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
2274 }
2275
2276 // -----------------------------------------------------------------
2277 // v0.7 G9 — batched rerank parity + coalescer smoke tests
2278 // -----------------------------------------------------------------
2279
2280 #[test]
2281 fn g9_rerank_batch_matches_per_query_rerank_lexical() {
2282 // Spec: 3 queries × 5 candidates. Batched output must match
2283 // per-query rerank() output exactly for the deterministic Lexical
2284 // path. (Neural parity is gated behind `test-with-models`; the
2285 // implementation is symmetric — same blend, same sort.)
2286 let ce = CrossEncoder::new();
2287 let queries = vec!["alpha gamma", "beta words", "rust async"];
2288 let mut jobs: Vec<(String, Vec<(Memory, f64)>)> = Vec::new();
2289 let mut expected: Vec<Vec<(Memory, f64)>> = Vec::new();
2290 for q in &queries {
2291 let cands: Vec<(Memory, f64)> = (0..5)
2292 .map(|i| {
2293 (
2294 make_memory(
2295 &format!("title-{i}-{q}"),
2296 &format!("alpha beta gamma rust async body {i} {q}"),
2297 ),
2298 f64::from(i) * 0.1,
2299 )
2300 })
2301 .collect();
2302 expected.push(ce.rerank(q, cands.clone()));
2303 jobs.push(((*q).to_string(), cands));
2304 }
2305
2306 let batched = ce.rerank_batch(jobs);
2307 assert_eq!(batched.len(), expected.len());
2308 for (b, e) in batched.iter().zip(expected.iter()) {
2309 assert_eq!(b.len(), e.len());
2310 for (bi, ei) in b.iter().zip(e.iter()) {
2311 assert_eq!(bi.0.id, ei.0.id);
2312 assert_eq!(bi.0.title, ei.0.title);
2313 assert!(
2314 (bi.1 - ei.1).abs() < 1e-12,
2315 "blended score mismatch: batched={} per-query={}",
2316 bi.1,
2317 ei.1
2318 );
2319 }
2320 }
2321 }
2322
2323 #[test]
2324 fn g9_rerank_batch_single_query_short_circuits() {
2325 // Single-query batches must not regress vs rerank() — use the
2326 // single-query short-circuit path.
2327 let ce = CrossEncoder::new();
2328 let cands: Vec<(Memory, f64)> = (0..5)
2329 .map(|i| (make_memory(&format!("t{i}"), &format!("body {i}")), 0.5))
2330 .collect();
2331 let direct = ce.rerank("body", cands.clone());
2332 let batched = ce.rerank_batch(vec![("body".to_string(), cands)]);
2333 assert_eq!(batched.len(), 1);
2334 assert_eq!(batched[0].len(), direct.len());
2335 for (a, b) in batched[0].iter().zip(direct.iter()) {
2336 assert_eq!(a.0.id, b.0.id);
2337 assert!((a.1 - b.1).abs() < 1e-12);
2338 }
2339 }
2340
2341 #[test]
2342 fn g9_rerank_batch_empty_inputs() {
2343 let ce = CrossEncoder::new();
2344 let out = ce.rerank_batch(Vec::new());
2345 assert!(out.is_empty());
2346
2347 // Multi-query but each has zero candidates.
2348 let out2 = ce.rerank_batch(vec![
2349 ("q1".to_string(), Vec::new()),
2350 ("q2".to_string(), Vec::new()),
2351 ]);
2352 assert_eq!(out2.len(), 2);
2353 assert!(out2.iter().all(std::vec::Vec::is_empty));
2354 }
2355
2356 #[test]
2357 fn g9_batched_reranker_serial_calls_match_rerank() {
2358 use super::BatchedReranker;
2359 let batched = BatchedReranker::new(CrossEncoder::new());
2360 let cands: Vec<(Memory, f64)> = (0..4)
2361 .map(|i| {
2362 (
2363 make_memory(
2364 &format!("t{i}"),
2365 &format!("alpha gamma body {i} content words"),
2366 ),
2367 f64::from(i) * 0.1,
2368 )
2369 })
2370 .collect();
2371 let direct = CrossEncoder::new().rerank("alpha", cands.clone());
2372 let via_batcher = batched.rerank("alpha", cands);
2373 assert_eq!(via_batcher.len(), direct.len());
2374 for (a, b) in via_batcher.iter().zip(direct.iter()) {
2375 assert_eq!(a.0.id, b.0.id);
2376 assert!((a.1 - b.1).abs() < 1e-12);
2377 }
2378 }
2379
2380 #[test]
2381 fn g9_batched_reranker_concurrent_calls_all_succeed() {
2382 use super::BatchedReranker;
2383 use std::sync::Arc;
2384 let batched = Arc::new(BatchedReranker::new(CrossEncoder::new()));
2385 let mut handles = Vec::new();
2386 for i in 0..8 {
2387 let b = Arc::clone(&batched);
2388 handles.push(std::thread::spawn(move || {
2389 let cands: Vec<(Memory, f64)> = (0..5)
2390 .map(|j| {
2391 (
2392 make_memory(
2393 &format!("t{i}-{j}"),
2394 &format!("body {j} alpha gamma rust"),
2395 ),
2396 0.5,
2397 )
2398 })
2399 .collect();
2400 let q = format!("alpha {i}");
2401 let out = b.rerank(&q, cands);
2402 assert_eq!(out.len(), 5);
2403 // Output is sorted descending.
2404 for w in out.windows(2) {
2405 assert!(w[0].1 >= w[1].1);
2406 }
2407 }));
2408 }
2409 for h in handles {
2410 h.join().expect("worker thread panicked");
2411 }
2412 }
2413
2414 /// #1579 B10 — the auto-select predicate: lexical NEVER batches
2415 /// (criterion: batched 7.6 ms vs direct 0.65 ms at N=8 — 12×
2416 /// inversion from the flush window); neural batches only at
2417 /// concurrency ≥ `BATCHED_RERANK_MIN_CONCURRENCY`.
2418 #[test]
2419 fn issue_1579_b10_auto_select_predicate() {
2420 use super::{BATCHED_RERANK_MIN_CONCURRENCY, use_batched_rerank_path};
2421 // Lexical: direct at every concurrency level.
2422 assert!(!use_batched_rerank_path(false, 1));
2423 assert!(!use_batched_rerank_path(false, 8));
2424 assert!(!use_batched_rerank_path(false, 1024));
2425 // Neural: lone caller goes direct (nothing to coalesce with)…
2426 assert!(!use_batched_rerank_path(true, 1));
2427 // …real concurrency keeps the G9 batched win.
2428 assert!(use_batched_rerank_path(
2429 true,
2430 BATCHED_RERANK_MIN_CONCURRENCY
2431 ));
2432 assert!(use_batched_rerank_path(true, 8));
2433 }
2434
2435 /// #1579 B10 — behavioral pin: a lexical `BatchedReranker` routes
2436 /// every call (serial AND concurrent) down the DIRECT path; the
2437 /// coalescing worker never sees a job. Pre-fix, all 8 concurrent
2438 /// lexical calls funneled through the worker and paid the 5 ms
2439 /// flush window per batch.
2440 #[test]
2441 fn issue_1579_b10_lexical_rerank_never_reaches_worker() {
2442 use super::BatchedReranker;
2443 use std::sync::Arc;
2444 let batched = Arc::new(BatchedReranker::new(CrossEncoder::new()));
2445 let mut handles = Vec::new();
2446 for i in 0..8 {
2447 let b = Arc::clone(&batched);
2448 handles.push(std::thread::spawn(move || {
2449 let cands: Vec<(Memory, f64)> = (0..5)
2450 .map(|j| {
2451 (
2452 make_memory(&format!("b10-{i}-{j}"), &format!("body {j} alpha gamma")),
2453 0.5,
2454 )
2455 })
2456 .collect();
2457 let out = b.rerank(&format!("alpha {i}"), cands);
2458 assert_eq!(out.len(), 5);
2459 }));
2460 }
2461 for h in handles {
2462 h.join().expect("worker thread panicked");
2463 }
2464 assert_eq!(
2465 batched.worker_submissions(),
2466 0,
2467 "lexical rerank must auto-select the direct path (no worker jobs)"
2468 );
2469 }
2470
2471 /// #1579 B10 — the forced coalesced path stays alive (both paths
2472 /// are kept per the remediation contract) and produces output
2473 /// byte-equal to the direct path on a lexical encoder.
2474 #[test]
2475 fn issue_1579_b10_forced_coalesced_path_matches_direct() {
2476 use super::BatchedReranker;
2477 let batched = BatchedReranker::new(CrossEncoder::new());
2478 let cands: Vec<(Memory, f64)> = (0..4)
2479 .map(|i| {
2480 (
2481 make_memory(
2482 &format!("b10-forced-{i}"),
2483 &format!("alpha gamma body {i} content words"),
2484 ),
2485 f64::from(i) * 0.1,
2486 )
2487 })
2488 .collect();
2489 let direct = batched.rerank("alpha", cands.clone());
2490 let coalesced = batched.rerank_coalesced("alpha", cands);
2491 assert_eq!(
2492 batched.worker_submissions(),
2493 1,
2494 "rerank_coalesced must route through the worker"
2495 );
2496 assert_eq!(coalesced.len(), direct.len());
2497 for (a, b) in coalesced.iter().zip(direct.iter()) {
2498 assert_eq!(a.0.id, b.0.id);
2499 assert!((a.1 - b.1).abs() < 1e-12);
2500 }
2501 }
2502
2503 #[test]
2504 fn pr9i_rerank_via_score_returns_blend() {
2505 // Even when new_neural() falls back to lexical, rerank() must
2506 // still produce a deterministic [0..1] blend. Pins the contract
2507 // for both branches of CrossEncoder::score().
2508 let ce = CrossEncoder::new_neural();
2509 let cands = vec![
2510 (
2511 Memory {
2512 id: "a".to_string(),
2513 tier: Tier::Mid,
2514 namespace: "ns".to_string(),
2515 title: "rust async runtime".to_string(),
2516 content: "tokio rust async".to_string(),
2517 tags: vec![],
2518 priority: 5,
2519 confidence: 1.0,
2520 source: "test".to_string(),
2521 access_count: 0,
2522 created_at: "2026-01-01T00:00:00Z".to_string(),
2523 updated_at: "2026-01-01T00:00:00Z".to_string(),
2524 last_accessed_at: None,
2525 expires_at: None,
2526 metadata: serde_json::json!({}),
2527 reflection_depth: 0,
2528 memory_kind: crate::models::MemoryKind::Observation,
2529 entity_id: None,
2530 persona_version: None,
2531 citations: Vec::new(),
2532 source_uri: None,
2533 source_span: None,
2534 confidence_source: crate::models::ConfidenceSource::CallerProvided,
2535 confidence_signals: None,
2536 confidence_decayed_at: None,
2537 version: 1,
2538 },
2539 0.6,
2540 ),
2541 (
2542 Memory {
2543 id: "b".to_string(),
2544 tier: Tier::Mid,
2545 namespace: "ns".to_string(),
2546 title: "grocery list".to_string(),
2547 content: "milk eggs".to_string(),
2548 tags: vec![],
2549 priority: 5,
2550 confidence: 1.0,
2551 source: "test".to_string(),
2552 access_count: 0,
2553 created_at: "2026-01-01T00:00:00Z".to_string(),
2554 updated_at: "2026-01-01T00:00:00Z".to_string(),
2555 last_accessed_at: None,
2556 expires_at: None,
2557 metadata: serde_json::json!({}),
2558 reflection_depth: 0,
2559 memory_kind: crate::models::MemoryKind::Observation,
2560 entity_id: None,
2561 persona_version: None,
2562 citations: Vec::new(),
2563 source_uri: None,
2564 source_span: None,
2565 confidence_source: crate::models::ConfidenceSource::CallerProvided,
2566 confidence_signals: None,
2567 confidence_decayed_at: None,
2568 version: 1,
2569 },
2570 0.4,
2571 ),
2572 ];
2573 let out = ce.rerank("rust async", cands);
2574 assert_eq!(out.len(), 2);
2575 for (_, score) in &out {
2576 assert!(score.is_finite());
2577 }
2578 // First entry's blended score >= second by sort contract.
2579 assert!(out[0].1 >= out[1].1);
2580 }
2581
2582 // ---------- Issue #1319 — reranker score floor (calibration) -----------
2583
2584 /// Issue #1319 — `RerankerScoreFloor::Off` is the default and a
2585 /// no-op. Pre-#1319 callers see byte-identical output through the
2586 /// new `apply` helper.
2587 #[test]
2588 fn reranker_score_floor_default_is_off_1319() {
2589 let floor = RerankerScoreFloor::default();
2590 assert_eq!(floor, RerankerScoreFloor::Off);
2591 let mut scored = vec![
2592 (make_memory("a", "x"), 0.9_f64),
2593 (make_memory("b", "y"), 0.4_f64),
2594 (make_memory("c", "z"), 0.1_f64),
2595 ];
2596 let before = scored.clone();
2597 floor.apply(&mut scored);
2598 assert_eq!(scored.len(), before.len());
2599 for (i, (mem, s)) in scored.iter().enumerate() {
2600 assert_eq!(mem.title, before[i].0.title);
2601 assert!((s - before[i].1).abs() < f64::EPSILON);
2602 }
2603 }
2604
2605 /// Issue #1319 — absolute floor drops the tail. Top row is
2606 /// preserved even when its score happens to fall below the floor
2607 /// (small-corpus safety so a 1-row recall never returns nothing).
2608 #[test]
2609 fn reranker_score_floor_absolute_drops_tail_1319() {
2610 let floor = RerankerScoreFloor::Absolute(0.5);
2611 let mut scored = vec![
2612 (make_memory("top", "x"), 0.90_f64),
2613 (make_memory("mid", "y"), 0.60_f64),
2614 (make_memory("low", "z"), 0.30_f64),
2615 (make_memory("noise", "n"), 0.10_f64),
2616 ];
2617 floor.apply(&mut scored);
2618 // top + mid kept; low + noise dropped.
2619 let titles: Vec<&str> = scored.iter().map(|(m, _)| m.title.as_str()).collect();
2620 assert_eq!(titles, vec!["top", "mid"]);
2621 }
2622
2623 /// Issue #1319 — relative floor preserves the head and drops
2624 /// candidates below `top_score * ratio`.
2625 #[test]
2626 fn reranker_score_floor_relative_drops_tail_1319() {
2627 let floor = RerankerScoreFloor::RelativeToTop(0.5);
2628 // top_score = 0.80, cutoff = 0.40.
2629 let mut scored = vec![
2630 (make_memory("top", "x"), 0.80_f64),
2631 (make_memory("kept", "y"), 0.50_f64),
2632 (make_memory("dropped_1", "z"), 0.35_f64),
2633 (make_memory("dropped_2", "z"), 0.20_f64),
2634 ];
2635 floor.apply(&mut scored);
2636 let titles: Vec<&str> = scored.iter().map(|(m, _)| m.title.as_str()).collect();
2637 assert_eq!(titles, vec!["top", "kept"]);
2638 }
2639
2640 /// Issue #1319 — top row is preserved even when the absolute
2641 /// floor sits above every blended score. A tiny corpus that all
2642 /// scored at 0.20 must still surface its top hit, not return
2643 /// empty.
2644 #[test]
2645 fn reranker_score_floor_preserves_top_row_when_everything_below_1319() {
2646 let floor = RerankerScoreFloor::Absolute(0.5);
2647 let mut scored = vec![
2648 (make_memory("apollo", "moon landing"), 0.20_f64),
2649 (make_memory("recall", "blends fts and semantic"), 0.10_f64),
2650 ];
2651 floor.apply(&mut scored);
2652 assert_eq!(scored.len(), 1);
2653 assert_eq!(scored[0].0.title, "apollo");
2654 }
2655
2656 /// Issue #1319 — empty input is a no-op (no panic on `.first()`).
2657 #[test]
2658 fn reranker_score_floor_handles_empty_1319() {
2659 let floor = RerankerScoreFloor::Absolute(0.5);
2660 let mut scored: Vec<(Memory, f64)> = vec![];
2661 floor.apply(&mut scored);
2662 assert!(scored.is_empty());
2663 }
2664
2665 /// Issue #1319 — v1 P5 probe surfaced a paraphrase-aware corpus
2666 /// where an Apollo-11 row scored 0.479 above a
2667 /// substantively-relevant recall-mechanics row at 0.363 with
2668 /// nothing visible to the operator that would have explained the
2669 /// ordering. This regression test reconstructs the empirical
2670 /// situation (disjoint-vocab paraphrase query — query terms appear
2671 /// in neither candidate's title or content) and asserts that, with
2672 /// an operator-opt-in `RerankerScoreFloor::Absolute(0.40)`, the
2673 /// Apollo-11 false positive is dropped while the head ranking is
2674 /// preserved.
2675 ///
2676 /// **Why the floor matters here.** With the lexical CE, both
2677 /// candidates score 0.0 on the paraphrase query (disjoint vocab).
2678 /// The blend `0.6 * original + 0.4 * 0.0` reduces to `0.6 * original`,
2679 /// so the empirical ordering is set entirely by the upstream
2680 /// `original` score. The substrate cannot reorder them away from
2681 /// the noise — but it CAN expose an operator handle that drops
2682 /// the entire tail below a threshold the operator chose. That's
2683 /// what `RerankerScoreFloor` provides.
2684 #[test]
2685 fn reranker_v1_p5_paraphrase_noise_dropped_by_floor_1319() {
2686 let ce = CrossEncoder::new(); // lexical, deterministic.
2687 let apollo = make_memory(
2688 "Apollo 11 moon landing",
2689 "Neil Armstrong walked on the moon in 1969.",
2690 );
2691 let recall_b = make_memory(
2692 "Recall blends FTS and semantic scores",
2693 "The hybrid pipeline weighs cosine vs BM25 then reranks the top-k.",
2694 );
2695
2696 // Empirical pre-#1319 shape: upstream hybrid retrieval scored
2697 // Apollo above recall_b. The exact numbers mirror the v1 P5
2698 // probe (Apollo 0.479, recall_b 0.363) so the test reads as
2699 // the operator-observed evidence on the issue.
2700 let candidates = vec![(apollo.clone(), 0.479_f64), (recall_b.clone(), 0.363_f64)];
2701
2702 // Operator query: a paraphrase that lexically misses both
2703 // candidates ("what makes a recall implementation good?").
2704 // Lexical CE produces 0 for both, so the blend reduces to
2705 // `0.6 * original`.
2706 let query = "what makes a recall implementation good?";
2707
2708 // Sanity: pre-floor, Apollo still sits on top — the
2709 // substrate has no way to reorder paraphrase-disjoint
2710 // candidates without semantic input from upstream.
2711 let pre = ce.rerank(query, candidates.clone());
2712 assert_eq!(pre[0].0.title, "Apollo 11 moon landing");
2713 // Blended top score = 0.6 * 0.479 = 0.2874 (paraphrase noise band).
2714 assert!(pre[0].1 < 0.30, "top score in noise band: {}", pre[0].1);
2715
2716 // Post-#1319 with absolute floor at 0.40: the entire tail is
2717 // dropped EXCEPT the top row (preserved per the small-corpus
2718 // safety rule). The operator now sees a single result and can
2719 // judge "noise" vs "this is genuinely the best the substrate
2720 // has" without an Apollo-11 false positive sitting beneath it
2721 // at 0.218.
2722 let mut post = pre.clone();
2723 RerankerScoreFloor::Absolute(0.40).apply(&mut post);
2724 assert_eq!(
2725 post.len(),
2726 1,
2727 "floor at 0.40 must drop tail when blended scores in noise band: {post:?}"
2728 );
2729 // Top preserved.
2730 assert_eq!(post[0].0.title, "Apollo 11 moon landing");
2731 }
2732
2733 /// Issue #1319 — `BatchedReranker::with_score_floor` plumbs the
2734 /// operator-opt-in floor end-to-end through the batched worker.
2735 /// Pinned via the wrapper so future refactors of the worker
2736 /// pipeline can't silently bypass the floor.
2737 #[test]
2738 fn batched_reranker_score_floor_plumbed_end_to_end_1319() {
2739 use super::BatchedReranker;
2740 let batched = BatchedReranker::with_score_floor(
2741 CrossEncoder::new(),
2742 RerankerScoreFloor::Absolute(0.40),
2743 );
2744 assert_eq!(batched.score_floor(), RerankerScoreFloor::Absolute(0.40));
2745
2746 let apollo = make_memory("Apollo 11 moon landing", "Armstrong, 1969");
2747 let recall_b = make_memory(
2748 "Recall blends FTS and semantic scores",
2749 "hybrid pipeline weighs cosine vs BM25",
2750 );
2751 let candidates = vec![(apollo, 0.479_f64), (recall_b, 0.363_f64)];
2752 let out = batched.rerank("paraphrase miss query", candidates);
2753 // Default daemon path uses `BatchedReranker::new` (floor Off),
2754 // so existing behavior is preserved — only the opt-in
2755 // constructor plumbs the floor.
2756 assert_eq!(out.len(), 1, "score floor must drop tail: {out:?}");
2757 }
2758
2759 /// Issue #1319 — the existing `BatchedReranker::new` path leaves
2760 /// the floor at `Off`, preserving pre-#1319 byte-equality for
2761 /// every daemon that has not opted in.
2762 #[test]
2763 fn batched_reranker_default_constructor_leaves_floor_off_1319() {
2764 use super::BatchedReranker;
2765 let batched = BatchedReranker::new(CrossEncoder::new());
2766 assert_eq!(batched.score_floor(), RerankerScoreFloor::Off);
2767 }
2768}
2769
2770#[cfg(test)]
2771#[allow(
2772 clippy::unused_self,
2773 clippy::unnecessary_wraps,
2774 clippy::needless_pass_by_value,
2775 clippy::wildcard_imports
2776)]
2777pub mod test_support {
2778 use super::*;
2779
2780 /// Mock neural cross-encoder for testing. Returns deterministic scores
2781 /// based on (query, title, content) without loading BERT.
2782 pub struct MockCrossEncoder {
2783 pub use_neural: bool,
2784 }
2785
2786 impl MockCrossEncoder {
2787 /// Create a mock lexical encoder (like CrossEncoder::new()).
2788 pub fn new() -> Self {
2789 Self { use_neural: false }
2790 }
2791
2792 /// Create a mock neural encoder (like CrossEncoder::new_neural()).
2793 pub fn new_neural() -> Self {
2794 Self { use_neural: true }
2795 }
2796
2797 /// Mock score: deterministic hash-based score in [0, 1].
2798 /// Neural path uses a different formula than lexical for testing.
2799 pub fn score(&self, query: &str, title: &str, content: &str) -> f32 {
2800 if self.use_neural {
2801 // Neural mock: combine query+title hash
2802 let combined = format!("{}{}", query, title);
2803 let hash = combined.bytes().fold(0u32, |acc, b| {
2804 acc.wrapping_mul(31).wrapping_add(u32::from(b))
2805 });
2806 let base = ((hash % 1000) as f32) / 1000.0;
2807 // Boost for exact title matches
2808 if title.contains(query) {
2809 (base * 0.5 + 0.5).min(1.0)
2810 } else {
2811 base
2812 }
2813 } else {
2814 // Lexical path uses the real lexical_score
2815 lexical_score(query, title, content)
2816 }
2817 }
2818
2819 /// Whether this is a neural mock.
2820 pub fn is_neural(&self) -> bool {
2821 self.use_neural
2822 }
2823
2824 /// Rerank candidates (same blending formula as real CrossEncoder).
2825 pub fn rerank(
2826 &self,
2827 query: &str,
2828 mut candidates: Vec<(Memory, f64)>,
2829 ) -> Vec<(Memory, f64)> {
2830 let mut scored: Vec<(Memory, f64)> = candidates
2831 .drain(..)
2832 .map(|(mem, original_score)| {
2833 let ce_score = f64::from(self.score(query, &mem.title, &mem.content));
2834 let final_score =
2835 ORIGINAL_WEIGHT * original_score + CROSS_ENCODER_WEIGHT * ce_score;
2836 (mem, final_score)
2837 })
2838 .collect();
2839
2840 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2841 scored
2842 }
2843 }
2844
2845 impl Default for MockCrossEncoder {
2846 fn default() -> Self {
2847 Self::new()
2848 }
2849 }
2850}
2851
2852#[cfg(test)]
2853mod mock_tests {
2854 use super::test_support::*;
2855 use super::{BatchedReranker, CrossEncoder};
2856 use crate::models::{Memory, Tier};
2857 use std::time::Duration;
2858
2859 fn make_memory(title: &str, content: &str) -> Memory {
2860 Memory {
2861 id: "test-id".to_string(),
2862 tier: Tier::Mid,
2863 namespace: "test".to_string(),
2864 title: title.to_string(),
2865 content: content.to_string(),
2866 tags: vec![],
2867 priority: 5,
2868 confidence: 1.0,
2869 source: "test".to_string(),
2870 access_count: 0,
2871 created_at: "2026-01-01T00:00:00Z".to_string(),
2872 updated_at: "2026-01-01T00:00:00Z".to_string(),
2873 last_accessed_at: None,
2874 expires_at: None,
2875 metadata: serde_json::json!({}),
2876 reflection_depth: 0,
2877 memory_kind: crate::models::MemoryKind::Observation,
2878 entity_id: None,
2879 persona_version: None,
2880 citations: Vec::new(),
2881 source_uri: None,
2882 source_span: None,
2883 confidence_source: crate::models::ConfidenceSource::CallerProvided,
2884 confidence_signals: None,
2885 confidence_decayed_at: None,
2886 version: 1,
2887 }
2888 }
2889
2890 #[test]
2891 fn mock_lexical_new() {
2892 let ce = MockCrossEncoder::new();
2893 assert!(!ce.is_neural());
2894 }
2895
2896 #[test]
2897 fn mock_neural_new() {
2898 let ce = MockCrossEncoder::new_neural();
2899 assert!(ce.is_neural());
2900 }
2901
2902 #[test]
2903 fn mock_neural_score_deterministic() {
2904 let ce = MockCrossEncoder::new_neural();
2905 let s1 = ce.score("query", "title", "content");
2906 let s2 = ce.score("query", "title", "content");
2907 assert_eq!(s1, s2);
2908 }
2909
2910 #[test]
2911 fn mock_neural_score_title_match_boost() {
2912 let ce = MockCrossEncoder::new_neural();
2913 let s_title_contains = ce.score("apple", "apple pie recipe", "delicious dessert");
2914 let s_no_match = ce.score("apple", "unrelated", "delicious dessert");
2915 assert!(
2916 s_title_contains > s_no_match,
2917 "title match ({s_title_contains}) should beat no match ({s_no_match})"
2918 );
2919 }
2920
2921 #[test]
2922 fn mock_neural_score_bounded() {
2923 let ce = MockCrossEncoder::new_neural();
2924 for query in &["test", "neural", "reranker", "machine learning"] {
2925 for title in &["a", "b", "the quick brown"] {
2926 let s = ce.score(query, title, "content");
2927 assert!((0.0..=1.0).contains(&s), "score {s} out of bounds");
2928 }
2929 }
2930 }
2931
2932 #[test]
2933 fn mock_neural_rerank_reorders() {
2934 let ce = MockCrossEncoder::new_neural();
2935 let a = make_memory("neural network", "deep learning with transformers");
2936 let b = make_memory("grocery list", "milk eggs bread butter");
2937 let candidates = vec![(b.clone(), 0.3), (a.clone(), 0.2)];
2938 let reranked = ce.rerank("neural network", candidates);
2939 // Neural encoder should boost the neural-network-titled memory
2940 assert_eq!(reranked[0].0.title, "neural network");
2941 }
2942
2943 #[test]
2944 fn mock_neural_rerank_preserves_count() {
2945 let ce = MockCrossEncoder::new_neural();
2946 let candidates = vec![
2947 (make_memory("A", "content a"), 0.5),
2948 (make_memory("B", "content b"), 0.4),
2949 (make_memory("C", "content c"), 0.6),
2950 ];
2951 let reranked = ce.rerank("test", candidates);
2952 assert_eq!(reranked.len(), 3);
2953 }
2954
2955 #[test]
2956 fn mock_lexical_path_via_mock() {
2957 let ce = MockCrossEncoder::new();
2958 let s = ce.score(
2959 "network adapter",
2960 "Network Configuration",
2961 "the network adapter is connected",
2962 );
2963 assert!((0.0..=1.0).contains(&s));
2964 }
2965
2966 #[test]
2967 fn mock_neural_different_from_lexical() {
2968 let lexical = MockCrossEncoder::new();
2969 let neural = MockCrossEncoder::new_neural();
2970 let s_lex = lexical.score("machine learning", "ML title", "neural networks");
2971 let s_neu = neural.score("machine learning", "ML title", "neural networks");
2972 // They should use different scoring formulas
2973 assert_ne!(s_lex, s_neu);
2974 }
2975
2976 // -----------------------------------------------------------------
2977 // H2 (v0.7.0 round-2) — worker-thread shutdown discipline.
2978 //
2979 // Contract: spawning a `BatchedReranker` and dropping it
2980 // immediately must terminate the worker thread within a bounded
2981 // wall-clock window. Without an explicit shutdown channel, a
2982 // worker that was blocked in `rx.recv()` would only exit on
2983 // sender disconnect; the explicit signal closes the worst-case
2984 // (e.g. a stashed `Sender` clone) and bounds the shutdown
2985 // latency by the worker's SHUTDOWN_POLL cadence.
2986 // -----------------------------------------------------------------
2987 #[test]
2988 fn h2_drop_terminates_worker_within_500ms() {
2989 use std::time::Instant;
2990 let reranker = BatchedReranker::new(CrossEncoder::new());
2991 // Capture the JoinHandle by exfiltrating it BEFORE drop so we
2992 // can observe thread termination from the outside. We
2993 // re-implement the Drop body inline for the assertion: fire
2994 // shutdown, drop sender, join with a wall-clock budget.
2995 let mut r = reranker;
2996 let shutdown = r.shutdown.take().expect("shutdown sender present");
2997 let worker = r.worker.take().expect("worker handle present");
2998 // Drop the work-channel sender first to mimic the same
2999 // disconnect semantics the production Drop sequence
3000 // produces.
3001 r.sender.take();
3002 let start = Instant::now();
3003 let _ = shutdown.send(());
3004 // Spawn the join on a side thread so we can apply a hard
3005 // wall-clock budget. `JoinHandle::join` does not take a
3006 // timeout, so the side-thread + park-with-deadline form is
3007 // the idiomatic Rust pattern.
3008 let (done_tx, done_rx) = std::sync::mpsc::channel::<()>();
3009 std::thread::spawn(move || {
3010 let _ = worker.join();
3011 let _ = done_tx.send(());
3012 });
3013 let observed = done_rx
3014 .recv_timeout(Duration::from_millis(500))
3015 .map(|()| Instant::now().duration_since(start));
3016 assert!(
3017 observed.is_ok(),
3018 "BatchedReranker worker did not terminate within 500ms after \
3019 explicit shutdown — observed: {observed:?}"
3020 );
3021 }
3022}
3023
3024#[test]
3025fn score_handles_empty_query_string() {
3026 let s = lexical_score("", "Document Title", "This is document content");
3027 assert_eq!(s, 0.0, "empty query must return 0.0");
3028}
3029
3030#[test]
3031fn score_handles_unicode_normalization() {
3032 // Query with accented characters, document with decomposed/composed variants
3033 let s1 = lexical_score("café", "café", "the café is open");
3034 let s2 = lexical_score("cafe", "cafe", "the cafe is open");
3035 // Both should score positively; exact equality not required due to normalization
3036 assert!(s1 > 0.0);
3037 assert!(s2 > 0.0);
3038}
3039
3040#[test]
3041fn score_handles_very_long_content_truncation() {
3042 // Query and document with extreme length (lexical tokenizer should handle it)
3043 let long_content = "word ".repeat(10000); // 50k+ chars
3044 let s = lexical_score("word", "title", &long_content);
3045 assert!((0.0..=1.0).contains(&s), "score must be bounded [0, 1]");
3046}
3047
3048#[test]
3049fn bigram_score_with_single_token_query() {
3050 // Query with only one token — bigrams should be empty, no crash
3051 let s = lexical_score("query", "Single Token Title", "single token content");
3052 assert!((0.0..=1.0).contains(&s));
3053}
3054
3055#[cfg(test)]
3056mod issue_1597_tests {
3057 //! #1597 — rerank pool cap + batched cross-encoder forward pass.
3058 //!
3059 //! The counting-mock route is unavailable: `MockCrossEncoder` is a
3060 //! standalone test struct, not a pluggable `CrossEncoder` variant,
3061 //! so call counts cannot be observed through the production enum.
3062 //! Instead the cap is pinned via score mutation: with a query that
3063 //! shares zero tokens with every candidate, the lexical
3064 //! cross-encoder scores every scored pair `0.0`, so a cross-encoded
3065 //! candidate's final score becomes EXACTLY `ORIGINAL_WEIGHT * orig`
3066 //! while an uncapped candidate keeps `orig` bit-for-bit — making
3067 //! "exactly RERANK_POOL_MAX candidates were cross-encoded"
3068 //! observable from the output alone.
3069
3070 use super::*;
3071 use crate::models::Memory;
3072
3073 /// Query with zero token overlap against [`pool_memory`] docs —
3074 /// lexical cross-encoder score is exactly 0.0 for every pair.
3075 const NO_OVERLAP_QUERY: &str = "zzz qqq www";
3076
3077 fn pool_memory(i: i32) -> Memory {
3078 Memory {
3079 id: format!("cand-{i}"),
3080 title: format!("alpha {i}"),
3081 content: format!("beta gamma {i}"),
3082 ..Memory::default()
3083 }
3084 }
3085
3086 /// `n` candidates with distinct ascending original scores
3087 /// `0.01 * (i + 1)`, supplied in ASCENDING order so the cap's
3088 /// pre-sort is load-bearing (not a pass-through of input order).
3089 fn pool(n: i32) -> Vec<(Memory, f64)> {
3090 (0..n)
3091 .map(|i| (pool_memory(i), f64::from(i + 1) * 0.01))
3092 .collect()
3093 }
3094
3095 fn orig_score(i: i32) -> f64 {
3096 f64::from(i + 1) * 0.01
3097 }
3098
3099 /// Pool of 50 → exactly [`RERANK_POOL_MAX`] candidates get
3100 /// cross-encoder scores (their final scores move to
3101 /// `ORIGINAL_WEIGHT * orig`); the other 30 keep their blended
3102 /// scores bit-for-bit and sort below the reranked head. No
3103 /// candidate is lost.
3104 #[test]
3105 fn rerank_pool_cap_honored_1597() {
3106 let ce = CrossEncoder::Lexical { degraded: false };
3107 let n = 50;
3108 let out = ce.rerank(NO_OVERLAP_QUERY, pool(n));
3109
3110 assert_eq!(out.len(), 50, "no candidate may be lost");
3111 let ids: std::collections::HashSet<&str> = out.iter().map(|(m, _)| m.id.as_str()).collect();
3112 assert_eq!(ids.len(), 50, "no duplicate / dropped ids");
3113
3114 // Head: the top RERANK_POOL_MAX by original score (i = 30..49,
3115 // descending), each cross-encoded → ORIGINAL_WEIGHT * orig.
3116 for (rank, (mem, score)) in out.iter().take(RERANK_POOL_MAX).enumerate() {
3117 let i = 49 - i32::try_from(rank).expect("rank fits i32");
3118 assert_eq!(mem.id, format!("cand-{i}"), "head rank {rank}");
3119 assert!(
3120 (score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
3121 "head rank {rank} must carry the cross-encoded blend"
3122 );
3123 }
3124
3125 // Tail: the remaining 30 (i = 29..0, descending), blended
3126 // scores untouched (bit-for-bit the input score).
3127 for (off, (mem, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
3128 let i = 29 - i32::try_from(off).expect("offset fits i32");
3129 assert_eq!(mem.id, format!("cand-{i}"), "tail offset {off}");
3130 assert_eq!(
3131 *score,
3132 orig_score(i),
3133 "tail offset {off} must keep its blended score untouched"
3134 );
3135 }
3136 }
3137
3138 /// Order correctness: reranked head internally sorted descending,
3139 /// tail internally sorted descending, tail strictly after the head.
3140 #[test]
3141 fn rerank_pool_cap_order_correctness_1597() {
3142 let ce = CrossEncoder::Lexical { degraded: false };
3143 let out = ce.rerank(NO_OVERLAP_QUERY, pool(50));
3144 let head = &out[..RERANK_POOL_MAX];
3145 let tail = &out[RERANK_POOL_MAX..];
3146 assert!(
3147 head.windows(2).all(|w| w[0].1 >= w[1].1),
3148 "reranked head must be sorted descending"
3149 );
3150 assert!(
3151 tail.windows(2).all(|w| w[0].1 >= w[1].1),
3152 "uncapped tail must be sorted descending"
3153 );
3154 // Every tail member's ORIGINAL score is below every head
3155 // member's original score (the cap kept the strongest pool).
3156 let min_head_orig = orig_score(30);
3157 assert!(
3158 tail.iter().all(|(_, s)| *s < min_head_orig),
3159 "tail must hold only candidates the cap excluded"
3160 );
3161 }
3162
3163 /// Pool exactly at the cap → full rerank (tail empty): every
3164 /// candidate is cross-encoded.
3165 #[test]
3166 fn rerank_pool_at_cap_fully_cross_encoded_1597() {
3167 let ce = CrossEncoder::Lexical { degraded: false };
3168 let n = i32::try_from(RERANK_POOL_MAX).expect("cap fits i32");
3169 let out = ce.rerank(NO_OVERLAP_QUERY, pool(n));
3170 assert_eq!(out.len(), RERANK_POOL_MAX);
3171 for (rank, (_, score)) in out.iter().enumerate() {
3172 let i = n - 1 - i32::try_from(rank).expect("rank fits i32");
3173 assert!(
3174 (score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
3175 "at-cap pool: rank {rank} must be cross-encoded"
3176 );
3177 }
3178 }
3179
3180 /// Cap > pool size degenerates to the historical full rerank.
3181 #[test]
3182 fn rerank_cap_gt_pool_degenerates_to_full_rerank_1597() {
3183 let ce = CrossEncoder::Lexical { degraded: false };
3184 let out = ce.rerank(NO_OVERLAP_QUERY, pool(5));
3185 assert_eq!(out.len(), 5);
3186 for (rank, (_, score)) in out.iter().enumerate() {
3187 let i = 4 - i32::try_from(rank).expect("rank fits i32");
3188 assert!(
3189 (score - ORIGINAL_WEIGHT * orig_score(i)).abs() < f64::EPSILON,
3190 "small pool: rank {rank} must be cross-encoded (no tail)"
3191 );
3192 }
3193 }
3194
3195 /// The G9 multi-query batch path applies the cap per query job.
3196 #[test]
3197 fn rerank_batch_applies_pool_cap_per_query_1597() {
3198 let ce = CrossEncoder::Lexical { degraded: false };
3199 let jobs = vec![
3200 (NO_OVERLAP_QUERY.to_string(), pool(50)),
3201 (NO_OVERLAP_QUERY.to_string(), pool(50)),
3202 ];
3203 let outs = ce.rerank_batch(jobs);
3204 assert_eq!(outs.len(), 2);
3205 for out in &outs {
3206 assert_eq!(out.len(), 50, "per-job candidate count preserved");
3207 for (off, (_, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
3208 let i = 29 - i32::try_from(off).expect("offset fits i32");
3209 assert_eq!(
3210 *score,
3211 orig_score(i),
3212 "per-job tail must keep blended scores untouched"
3213 );
3214 }
3215 }
3216 }
3217
3218 /// The `BatchedReranker` production wrapper inherits the cap via
3219 /// the direct encoder path (lexical traffic never reaches the
3220 /// coalescing worker per #1579 B10).
3221 #[test]
3222 fn batched_reranker_inherits_pool_cap_1597() {
3223 let br = BatchedReranker::with_reflection_boost(
3224 CrossEncoder::Lexical { degraded: false },
3225 ReflectionBoostConfig::disabled(),
3226 );
3227 let out = br.rerank(NO_OVERLAP_QUERY, pool(50));
3228 assert_eq!(out.len(), 50);
3229 for (off, (_, score)) in out.iter().skip(RERANK_POOL_MAX).enumerate() {
3230 let i = 29 - i32::try_from(off).expect("offset fits i32");
3231 assert_eq!(*score, orig_score(i), "wrapper tail untouched");
3232 }
3233 }
3234
3235 /// #1597 bench evidence — manual run against the REAL neural
3236 /// cross-encoder (resolves from the local HF cache; downloads
3237 /// ~80 MB on a cold host):
3238 ///
3239 /// ```bash
3240 /// AI_MEMORY_NO_CONFIG=1 cargo test --release --lib \
3241 /// issue_1597_neural_rerank_timing_evidence -- --ignored --nocapture
3242 /// ```
3243 ///
3244 /// Prints BEFORE (sequential per-pair forward over the full
3245 /// 50-candidate pool — the pre-#1597 `rerank` shape) vs AFTER
3246 /// (capped pool + one batched forward — the shipped path).
3247 #[test]
3248 #[ignore = "#1597 manual bench evidence: loads the real neural cross-encoder"]
3249 fn issue_1597_neural_rerank_timing_evidence() {
3250 let ce = CrossEncoder::new_neural();
3251 assert!(
3252 ce.is_neural(),
3253 "neural encoder failed to load; timing evidence invalid"
3254 );
3255 let bench_pool: Vec<(Memory, f64)> = (0..50)
3256 .map(|i| {
3257 let m = Memory {
3258 id: format!("bench-{i}"),
3259 title: format!("benchmark candidate number {i} recall pipeline"),
3260 content: format!(
3261 "long-form benchmark document body number {i} with enough \
3262 material to exercise the cross-encoder, covering recall \
3263 pipeline reranking, cross encoder scoring, candidate \
3264 blending and ordering semantics for run {i}"
3265 ),
3266 ..Memory::default()
3267 };
3268 (m, f64::from(i) * 0.01)
3269 })
3270 .collect();
3271 let query = "how does the recall pipeline rerank candidates";
3272
3273 // Warm-up (first forward pays one-time allocation cost).
3274 let _ = ce.score(query, "warmup", "warmup body");
3275
3276 // BEFORE shape: one full forward per (query, candidate) pair,
3277 // sequentially, over the entire 50-candidate pool.
3278 let t0 = Instant::now();
3279 for (m, _) in &bench_pool {
3280 let _ = ce.score(query, &m.title, &m.content);
3281 }
3282 let before = t0.elapsed();
3283
3284 // AFTER: shipped path — cap at RERANK_POOL_MAX + single
3285 // batched forward.
3286 let t1 = Instant::now();
3287 let out = ce.rerank(query, bench_pool.clone());
3288 let after = t1.elapsed();
3289
3290 assert_eq!(out.len(), 50, "no candidate lost on the neural path");
3291 eprintln!(
3292 "#1597 timing (50-candidate pool, CPU): BEFORE sequential-full = {before:?}; \
3293 AFTER capped+batched = {after:?}"
3294 );
3295 }
3296}