ripvec_core/rerank.rs
1//! Cross-encoder reranker for top-K refinement.
2//!
3//! ## Why this module exists
4//!
5//! ripvec's bi-encoder retrieval (BERT or semble) embeds query and
6//! documents into a shared vector space and ranks by cosine. That's
7//! cheap to scale, but the model can't express cross-token
8//! interactions between query and document — each side is encoded
9//! independently. On natural-language and prose corpora this caps
10//! quality.
11//!
12//! A cross-encoder concatenates the pair `[CLS] query [SEP] doc [SEP]`
13//! and runs full attention across both, producing a single relevance
14//! score. Quality is meaningfully higher but cost is O(candidates),
15//! so it's used only as a reranker on the bi-encoder's top-K.
16//!
17//! ## Architecture
18//!
19//! This module is a thin orchestrator: tokenize `(query, doc)` pairs,
20//! delegate scoring to a [`RerankBackend`](crate::backend::RerankBackend)
21//! (currently [`crate::backend::cpu::CpuRerankBackend`] — same BERT
22//! trunk as the bi-encoder, plus a `Linear(hidden -> 1)` classifier
23//! head + sigmoid).
24//!
25//! Only the CPU rerank backend is wired today. Adding GPU rerankers
26//! later would require implementing `RerankBackend` for the target
27//! device, mirroring `load_reranker_cpu` in `backend/mod.rs`, and
28//! routing through `Reranker::from_pretrained`.
29
30use anyhow::anyhow;
31use tokenizers::{Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy};
32
33use crate::backend::{Encoding, RerankBackend};
34
35/// Default cross-encoder model.
36/// `cross-encoder/ms-marco-TinyBERT-L-2-v2` (~5 MB, 2-layer
37/// distilled-from-BERT-base) replaced the prior MiniLM-L-12-v2
38/// default after a model sweep on the gutenberg prose benchmark
39/// (15 NL queries) showed it bit-identical on NDCG@10 / recall@10
40/// while running 20x faster at the warm-query path:
41///
42/// ```text
43/// model NDCG@10 recall@10 p50
44/// ms-marco-MiniLM-L-12-v2 (old) 1.0000 1.000 671 ms
45/// ms-marco-MiniLM-L-6-v2 1.0000 1.000 344 ms
46/// ms-marco-MiniLM-L-2-v2 0.9508 1.000 125 ms <- quality drop
47/// ms-marco-TinyBERT-L-2-v2 (new) 1.0000 1.000 33 ms
48/// ```
49///
50/// The distinction is distillation: TinyBERT-L-2 was trained with
51/// teacher-distillation to preserve the larger model's behavior at
52/// 2 layers, whereas plain MiniLM-L-2 sheds layers without that
53/// regularization and loses precision. Two layers vs twelve cuts
54/// inference cost ~6x; combined with smaller embedding dim it lands
55/// at 20x in practice. Override via the CLI flag or
56/// `Reranker::from_pretrained` directly when a corpus needs more
57/// capacity (e.g. fine-grained domain reranking).
58pub const DEFAULT_RERANK_MODEL: &str = "cross-encoder/ms-marco-TinyBERT-L-2-v2";
59
60/// Default cap on candidates passed to the reranker.
61///
62/// Cost is linear in candidates. The retrieve-then-rerank literature
63/// suggests 100 as a safe upper bound, but empirically — on the
64/// gutenberg prose benchmark with the L-12 ms-marco cross-encoder —
65/// NDCG@10 is bit-identical from K=100 all the way down to K=20
66/// (recall stays at 1.000, the bi-encoder + ranking layer already
67/// puts the relevant doc at rank 1 in every test query, so the
68/// rerank's job is confirmation rather than reordering). 50 is a
69/// 2x speedup over the literature default with enough headroom for
70/// corpora where the bi-encoder is less confident; users on
71/// high-confidence corpora can drop further (CLI: `--candidates 30`).
72///
73/// Bench (gutenberg, 15 NL queries, scope=docs, NDCG=1.000 throughout):
74///
75/// ```text
76/// K=100 p50 1335 ms
77/// K=50 p50 676 ms
78/// K=30 p50 418 ms
79/// K=20 p50 275 ms
80/// ```
81pub const DEFAULT_RERANK_CANDIDATES: usize = 50;
82
83/// Cross-encoder reranker orchestrator.
84///
85/// Owns a `RerankBackend` (model trunk + classifier head) and the
86/// tokenizer that produced the encodings the backend expects.
87///
88/// Construct via [`Self::from_pretrained`]. Use [`score_pairs`] to
89/// rank candidate `(query, doc)` text pairs.
90///
91/// ## cfg-gating
92///
93/// The `backend` field type is cfg-gated by the `collapse-rerank-trait`
94/// Cargo feature per the mandated pattern in
95/// `docs/surgery/backend_trait_microbench.md` Section 4
96/// (@Lampson (1983) "Hints for Computer System Design"):
97///
98/// - **default** (`collapse-rerank-trait` off): `Box<dyn RerankBackend>` —
99/// heap-allocated vtable dispatch; future GPU rerankers slot in here.
100/// - **Variant C** (`collapse-rerank-trait` on): [`crate::backend::cpu::CpuRerankBackend`]
101/// held directly — monomorphic static dispatch; LLVM may inline through.
102///
103/// The call site `self.backend.score_batch(...)` is identical in source for
104/// both variants; the compiler generates an indirect vtable call for Variant T
105/// and a direct call for Variant C. This structural difference is what the
106/// microbench measures. Anti-patterns (enum wrapping, type alias, two parallel
107/// structs) are explicitly avoided per Section 4 to prevent LLVM from
108/// collapsing both variants to zero overhead by construction.
109///
110/// [`score_pairs`]: Self::score_pairs
111pub struct Reranker {
112 /// Under default build: heap-allocated trait object for future GPU reranker extensibility.
113 /// Under `collapse-rerank-trait`: concrete `CpuRerankBackend` for direct static dispatch.
114 #[cfg(not(feature = "collapse-rerank-trait"))]
115 backend: Box<dyn RerankBackend>,
116 #[cfg(feature = "collapse-rerank-trait")]
117 backend: crate::backend::cpu::CpuRerankBackend,
118 tokenizer: Tokenizer,
119}
120
121impl Reranker {
122 /// Load a cross-encoder by `HuggingFace` repo ID.
123 ///
124 /// Under default builds routes through [`crate::backend::load_reranker_cpu`]
125 /// which boxes the result as `Box<dyn RerankBackend>`. Under
126 /// `collapse-rerank-trait` calls [`crate::backend::cpu::CpuRerankBackend::load`]
127 /// directly, storing the concrete type without boxing — this is the structural
128 /// difference the `collapse-rerank-trait` microbench exploits
129 /// (@Lampson (1983) "Hints for Computer System Design").
130 ///
131 /// The tokenizer is downloaded via the same `hf-hub` cache, so
132 /// multiple sub-agent MCP processes share weights through
133 /// `~/.cache/huggingface/hub/`.
134 ///
135 /// # Errors
136 ///
137 /// Returns an error if the model can't be downloaded, lacks a
138 /// classifier head (i.e., a bi-encoder was supplied by mistake),
139 /// or fails to load.
140 pub fn from_pretrained(model_repo: &str) -> crate::Result<Self> {
141 #[cfg(not(feature = "collapse-rerank-trait"))]
142 let backend = crate::backend::load_reranker_cpu(model_repo)?;
143 #[cfg(feature = "collapse-rerank-trait")]
144 let backend = crate::backend::cpu::CpuRerankBackend::load(model_repo)?;
145 let mut tokenizer = crate::tokenize::load_tokenizer(model_repo)?;
146 // Configure `LongestFirst` truncation against the model's
147 // declared max sequence length. Without this the tokenizer
148 // returns full-length encodings and ripvec used to head-truncate
149 // the already-joined `[CLS] q [SEP] d [SEP]` sequence, which
150 // can drop the trailing `[SEP]` and let the doc tail overflow
151 // into garbage on long inputs. With `LongestFirst` the
152 // tokenizer trims whichever of (query, doc) is longer until
153 // the joined sequence fits, preserving special tokens.
154 let max_tokens = backend.max_tokens();
155 tokenizer
156 .with_truncation(Some(TruncationParams {
157 max_length: max_tokens,
158 strategy: TruncationStrategy::LongestFirst,
159 stride: 0,
160 direction: TruncationDirection::Right,
161 }))
162 .map_err(|e| crate::Error::Other(anyhow!("rerank tokenizer truncation: {e}")))?;
163 Ok(Self { backend, tokenizer })
164 }
165
166 /// Score a batch of `(query, document)` pairs.
167 ///
168 /// Returns raw logits (sentence-transformers `Identity` activation —
169 /// the canonical public score for ms-marco cross-encoders), one
170 /// per input pair, in input order. Tokenizes with a `(query, doc)`
171 /// tuple so `token_type_ids` are 0 for the query side, 1 for the
172 /// doc side — the convention BERT cross-encoders are trained on.
173 /// The tokenizer is pre-configured with `LongestFirst` truncation
174 /// at the model's `max_position_embeddings`, so callers don't need
175 /// to clip outputs.
176 ///
177 /// # Errors
178 ///
179 /// Propagates tokenization or forward-pass errors.
180 pub fn score_pairs(&self, pairs: &[(&str, &str)]) -> crate::Result<Vec<f32>> {
181 if pairs.is_empty() {
182 return Ok(Vec::new());
183 }
184 let encodings: crate::Result<Vec<Encoding>> = pairs
185 .iter()
186 .map(|(q, d)| {
187 // The tokenizer is configured with LongestFirst
188 // truncation in from_pretrained; the returned encoding
189 // already fits within max_position_embeddings and
190 // preserves [CLS] / [SEP] tokens at the correct
191 // positions.
192 let enc = self
193 .tokenizer
194 .encode((*q, *d), true)
195 .map_err(|e| crate::Error::Other(anyhow!("rerank tokenize failed: {e}")))?;
196 Ok(Encoding {
197 input_ids: enc.get_ids().iter().map(|&x| i64::from(x)).collect(),
198 attention_mask: enc
199 .get_attention_mask()
200 .iter()
201 .map(|&x| i64::from(x))
202 .collect(),
203 token_type_ids: enc.get_type_ids().iter().map(|&x| i64::from(x)).collect(),
204 })
205 })
206 .collect();
207 let encodings = encodings?;
208 self.backend.score_batch(&encodings)
209 }
210
211 /// Max sequence length supported by the underlying model.
212 #[must_use]
213 pub fn max_tokens(&self) -> usize {
214 self.backend.max_tokens()
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 /// `Reranker::from_pretrained` works end-to-end on the default model.
223 /// Gated `--ignored` since it downloads weights from `HuggingFace`.
224 ///
225 /// Verifies the two structural claims:
226 /// 1. The cross-encoder ranks a relevant doc higher than an
227 /// irrelevant one for the same query.
228 /// 2. Scores span a meaningful range (raw logits — the reference
229 /// spread for this model is roughly [-11, +5]).
230 #[test]
231 #[ignore = "requires network + model download (~22MB)"]
232 fn loads_and_ranks_default_cross_encoder() {
233 let rr = Reranker::from_pretrained(DEFAULT_RERANK_MODEL)
234 .expect("default cross-encoder should load");
235 let scores = rr
236 .score_pairs(&[
237 (
238 "how to make pasta",
239 "Boil water, add salt, cook pasta for 8 minutes.",
240 ),
241 (
242 "how to make pasta",
243 "The mitochondria is the powerhouse of the cell.",
244 ),
245 ])
246 .expect("scoring should succeed");
247 assert_eq!(scores.len(), 2);
248 assert!(
249 scores[0] > scores[1] + 1.0,
250 "relevant doc ({}) should beat irrelevant ({}) by a clear logit margin",
251 scores[0],
252 scores[1]
253 );
254 }
255}