Skip to main content

ripvec_core/
rerank.rs

1//! Cross-encoder reranker for top-K refinement.
2//!
3//! ## Why this module exists
4//!
5//! ripvec's bi-encoder retrieval (BERT or semble) embeds query and
6//! documents into a shared vector space and ranks by cosine. That's
7//! cheap to scale, but the model can't express cross-token
8//! interactions between query and document — each side is encoded
9//! independently. On natural-language and prose corpora this caps
10//! quality.
11//!
12//! A cross-encoder concatenates the pair `[CLS] query [SEP] doc [SEP]`
13//! and runs full attention across both, producing a single relevance
14//! score. Quality is meaningfully higher but cost is O(candidates),
15//! so it's used only as a reranker on the bi-encoder's top-K.
16//!
17//! ## Architecture
18//!
19//! This module is a thin orchestrator: tokenize `(query, doc)` pairs,
20//! delegate scoring to a [`RerankBackend`](crate::backend::RerankBackend)
21//! (currently [`crate::backend::cpu::CpuRerankBackend`] — same BERT
22//! trunk as the bi-encoder, plus a `Linear(hidden -> 1)` classifier
23//! head + sigmoid).
24//!
25//! Adding GPU rerankers later is mechanical: implement
26//! `RerankBackend` for Metal/CUDA/MLX, mirror `load_reranker_cpu` in
27//! `backend/mod.rs`, route through `Reranker::from_pretrained`.
28
29use anyhow::anyhow;
30use tokenizers::{Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy};
31
32use crate::backend::{Encoding, RerankBackend};
33
34/// Default cross-encoder model.
35/// `cross-encoder/ms-marco-TinyBERT-L-2-v2` (~5 MB, 2-layer
36/// distilled-from-BERT-base) replaced the prior MiniLM-L-12-v2
37/// default after a model sweep on the gutenberg prose benchmark
38/// (15 NL queries) showed it bit-identical on NDCG@10 / recall@10
39/// while running 20x faster at the warm-query path:
40///
41/// ```text
42///   model                              NDCG@10  recall@10  p50
43///   ms-marco-MiniLM-L-12-v2 (old)      1.0000   1.000      671 ms
44///   ms-marco-MiniLM-L-6-v2             1.0000   1.000      344 ms
45///   ms-marco-MiniLM-L-2-v2             0.9508   1.000      125 ms  <- quality drop
46///   ms-marco-TinyBERT-L-2-v2 (new)     1.0000   1.000       33 ms
47/// ```
48///
49/// The distinction is distillation: TinyBERT-L-2 was trained with
50/// teacher-distillation to preserve the larger model's behavior at
51/// 2 layers, whereas plain MiniLM-L-2 sheds layers without that
52/// regularization and loses precision. Two layers vs twelve cuts
53/// inference cost ~6x; combined with smaller embedding dim it lands
54/// at 20x in practice. Override via the CLI flag or
55/// `Reranker::from_pretrained` directly when a corpus needs more
56/// capacity (e.g. fine-grained domain reranking).
57pub const DEFAULT_RERANK_MODEL: &str = "cross-encoder/ms-marco-TinyBERT-L-2-v2";
58
59/// Default cap on candidates passed to the reranker.
60///
61/// Cost is linear in candidates. The retrieve-then-rerank literature
62/// suggests 100 as a safe upper bound, but empirically — on the
63/// gutenberg prose benchmark with the L-12 ms-marco cross-encoder —
64/// NDCG@10 is bit-identical from K=100 all the way down to K=20
65/// (recall stays at 1.000, the bi-encoder + ranking layer already
66/// puts the relevant doc at rank 1 in every test query, so the
67/// rerank's job is confirmation rather than reordering). 50 is a
68/// 2x speedup over the literature default with enough headroom for
69/// corpora where the bi-encoder is less confident; users on
70/// high-confidence corpora can drop further (CLI: `--candidates 30`).
71///
72/// Bench (gutenberg, 15 NL queries, scope=docs, NDCG=1.000 throughout):
73///
74/// ```text
75/// K=100  p50 1335 ms
76/// K=50   p50  676 ms
77/// K=30   p50  418 ms
78/// K=20   p50  275 ms
79/// ```
80pub const DEFAULT_RERANK_CANDIDATES: usize = 50;
81
82/// Cross-encoder reranker orchestrator.
83///
84/// Owns a `RerankBackend` (model trunk + classifier head) and the
85/// tokenizer that produced the encodings the backend expects.
86///
87/// Construct via [`Self::from_pretrained`]. Use [`score_pairs`] to
88/// rank candidate `(query, doc)` text pairs.
89///
90/// [`score_pairs`]: Self::score_pairs
91pub struct Reranker {
92    backend: Box<dyn RerankBackend>,
93    tokenizer: Tokenizer,
94}
95
96impl Reranker {
97    /// Load a cross-encoder by `HuggingFace` repo ID.
98    ///
99    /// Routes through [`crate::backend::load_reranker_cpu`] for now;
100    /// GPU paths slot in here as feature-gated branches when added.
101    /// The tokenizer is downloaded via the same `hf-hub` cache, so
102    /// multiple sub-agent MCP processes share weights through
103    /// `~/.cache/huggingface/hub/`.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the model can't be downloaded, lacks a
108    /// classifier head (i.e., a bi-encoder was supplied by mistake),
109    /// or fails to load.
110    pub fn from_pretrained(model_repo: &str) -> crate::Result<Self> {
111        let backend = crate::backend::load_reranker_cpu(model_repo)?;
112        let mut tokenizer = crate::tokenize::load_tokenizer(model_repo)?;
113        // Configure `LongestFirst` truncation against the model's
114        // declared max sequence length. Without this the tokenizer
115        // returns full-length encodings and ripvec used to head-truncate
116        // the already-joined `[CLS] q [SEP] d [SEP]` sequence, which
117        // can drop the trailing `[SEP]` and let the doc tail overflow
118        // into garbage on long inputs. With `LongestFirst` the
119        // tokenizer trims whichever of (query, doc) is longer until
120        // the joined sequence fits, preserving special tokens.
121        let max_tokens = backend.max_tokens();
122        tokenizer
123            .with_truncation(Some(TruncationParams {
124                max_length: max_tokens,
125                strategy: TruncationStrategy::LongestFirst,
126                stride: 0,
127                direction: TruncationDirection::Right,
128            }))
129            .map_err(|e| crate::Error::Other(anyhow!("rerank tokenizer truncation: {e}")))?;
130        Ok(Self { backend, tokenizer })
131    }
132
133    /// Score a batch of `(query, document)` pairs.
134    ///
135    /// Returns raw logits (sentence-transformers `Identity` activation —
136    /// the canonical public score for ms-marco cross-encoders), one
137    /// per input pair, in input order. Tokenizes with a `(query, doc)`
138    /// tuple so `token_type_ids` are 0 for the query side, 1 for the
139    /// doc side — the convention BERT cross-encoders are trained on.
140    /// The tokenizer is pre-configured with `LongestFirst` truncation
141    /// at the model's `max_position_embeddings`, so callers don't need
142    /// to clip outputs.
143    ///
144    /// # Errors
145    ///
146    /// Propagates tokenization or forward-pass errors.
147    pub fn score_pairs(&self, pairs: &[(&str, &str)]) -> crate::Result<Vec<f32>> {
148        if pairs.is_empty() {
149            return Ok(Vec::new());
150        }
151        let encodings: crate::Result<Vec<Encoding>> = pairs
152            .iter()
153            .map(|(q, d)| {
154                // The tokenizer is configured with LongestFirst
155                // truncation in from_pretrained; the returned encoding
156                // already fits within max_position_embeddings and
157                // preserves [CLS] / [SEP] tokens at the correct
158                // positions.
159                let enc = self
160                    .tokenizer
161                    .encode((*q, *d), true)
162                    .map_err(|e| crate::Error::Other(anyhow!("rerank tokenize failed: {e}")))?;
163                Ok(Encoding {
164                    input_ids: enc.get_ids().iter().map(|&x| i64::from(x)).collect(),
165                    attention_mask: enc
166                        .get_attention_mask()
167                        .iter()
168                        .map(|&x| i64::from(x))
169                        .collect(),
170                    token_type_ids: enc.get_type_ids().iter().map(|&x| i64::from(x)).collect(),
171                })
172            })
173            .collect();
174        let encodings = encodings?;
175        self.backend.score_batch(&encodings)
176    }
177
178    /// Max sequence length supported by the underlying model.
179    #[must_use]
180    pub fn max_tokens(&self) -> usize {
181        self.backend.max_tokens()
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    /// `Reranker::from_pretrained` works end-to-end on the default model.
190    /// Gated `--ignored` since it downloads weights from `HuggingFace`.
191    ///
192    /// Verifies the two structural claims:
193    /// 1. The cross-encoder ranks a relevant doc higher than an
194    ///    irrelevant one for the same query.
195    /// 2. Scores span a meaningful range (raw logits — the reference
196    ///    spread for this model is roughly [-11, +5]).
197    #[test]
198    #[ignore = "requires network + model download (~22MB)"]
199    fn loads_and_ranks_default_cross_encoder() {
200        let rr = Reranker::from_pretrained(DEFAULT_RERANK_MODEL)
201            .expect("default cross-encoder should load");
202        let scores = rr
203            .score_pairs(&[
204                (
205                    "how to make pasta",
206                    "Boil water, add salt, cook pasta for 8 minutes.",
207                ),
208                (
209                    "how to make pasta",
210                    "The mitochondria is the powerhouse of the cell.",
211                ),
212            ])
213            .expect("scoring should succeed");
214        assert_eq!(scores.len(), 2);
215        assert!(
216            scores[0] > scores[1] + 1.0,
217            "relevant doc ({}) should beat irrelevant ({}) by a clear logit margin",
218            scores[0],
219            scores[1]
220        );
221    }
222}