Skip to main content

ripvec_core/
rerank.rs

1//! Cross-encoder reranker for top-K refinement.
2//!
3//! ## Why this module exists
4//!
5//! ripvec's bi-encoder retrieval (BERT or semble) embeds query and
6//! documents into a shared vector space and ranks by cosine. That's
7//! cheap to scale, but the model can't express cross-token
8//! interactions between query and document — each side is encoded
9//! independently. On natural-language and prose corpora this caps
10//! quality.
11//!
12//! A cross-encoder concatenates the pair `[CLS] query [SEP] doc [SEP]`
13//! and runs full attention across both, producing a single relevance
14//! score. Quality is meaningfully higher but cost is O(candidates),
15//! so it's used only as a reranker on the bi-encoder's top-K.
16//!
17//! ## Architecture
18//!
19//! This module is a thin orchestrator: tokenize `(query, doc)` pairs,
20//! delegate scoring to a [`RerankBackend`](crate::backend::RerankBackend)
21//! (currently [`crate::backend::cpu::CpuRerankBackend`] — same BERT
22//! trunk as the bi-encoder, plus a `Linear(hidden -> 1)` classifier
23//! head + sigmoid).
24//!
25//! Adding GPU rerankers later is mechanical: implement
26//! `RerankBackend` for Metal/CUDA/MLX, mirror `load_reranker_cpu` in
27//! `backend/mod.rs`, route through `Reranker::from_pretrained`.
28
29use anyhow::anyhow;
30use tokenizers::{Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy};
31
32use crate::backend::{Encoding, RerankBackend};
33
34/// Default cross-encoder model.
35/// `cross-encoder/ms-marco-MiniLM-L-12-v2` is 33MB, ~10ms per
36/// query/doc pair on CPU, NDCG@10 = 74.5 on MS MARCO dev. Picked over
37/// the smaller L-6 (22MB, NDCG 74.3) because the 4-corpus benchmark
38/// matrix showed L-12 added meaningful target-hit lift across both
39/// prose (Gutenberg) and code (Tokio) — and the ~5ms/pair extra is
40/// invisible against the indexing budget on any non-trivial corpus.
41pub const DEFAULT_RERANK_MODEL: &str = "cross-encoder/ms-marco-MiniLM-L-12-v2";
42
43/// Default cap on candidates passed to the reranker.
44///
45/// Cost is linear in candidates; 100 is the standard top-K in the
46/// retrieve-then-rerank literature. At ~5ms/pair on MiniLM-L-6 this
47/// is ~500ms total, the upper edge of interactive.
48pub const DEFAULT_RERANK_CANDIDATES: usize = 100;
49
50/// Cross-encoder reranker orchestrator.
51///
52/// Owns a `RerankBackend` (model trunk + classifier head) and the
53/// tokenizer that produced the encodings the backend expects.
54///
55/// Construct via [`Self::from_pretrained`]. Use [`score_pairs`] to
56/// rank candidate `(query, doc)` text pairs.
57///
58/// [`score_pairs`]: Self::score_pairs
59pub struct Reranker {
60    backend: Box<dyn RerankBackend>,
61    tokenizer: Tokenizer,
62}
63
64impl Reranker {
65    /// Load a cross-encoder by `HuggingFace` repo ID.
66    ///
67    /// Routes through [`crate::backend::load_reranker_cpu`] for now;
68    /// GPU paths slot in here as feature-gated branches when added.
69    /// The tokenizer is downloaded via the same `hf-hub` cache, so
70    /// multiple sub-agent MCP processes share weights through
71    /// `~/.cache/huggingface/hub/`.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the model can't be downloaded, lacks a
76    /// classifier head (i.e., a bi-encoder was supplied by mistake),
77    /// or fails to load.
78    pub fn from_pretrained(model_repo: &str) -> crate::Result<Self> {
79        let backend = crate::backend::load_reranker_cpu(model_repo)?;
80        let mut tokenizer = crate::tokenize::load_tokenizer(model_repo)?;
81        // Configure `LongestFirst` truncation against the model's
82        // declared max sequence length. Without this the tokenizer
83        // returns full-length encodings and ripvec used to head-truncate
84        // the already-joined `[CLS] q [SEP] d [SEP]` sequence, which
85        // can drop the trailing `[SEP]` and let the doc tail overflow
86        // into garbage on long inputs. With `LongestFirst` the
87        // tokenizer trims whichever of (query, doc) is longer until
88        // the joined sequence fits, preserving special tokens.
89        let max_tokens = backend.max_tokens();
90        tokenizer
91            .with_truncation(Some(TruncationParams {
92                max_length: max_tokens,
93                strategy: TruncationStrategy::LongestFirst,
94                stride: 0,
95                direction: TruncationDirection::Right,
96            }))
97            .map_err(|e| crate::Error::Other(anyhow!("rerank tokenizer truncation: {e}")))?;
98        Ok(Self { backend, tokenizer })
99    }
100
101    /// Score a batch of `(query, document)` pairs.
102    ///
103    /// Returns raw logits (sentence-transformers `Identity` activation —
104    /// the canonical public score for ms-marco cross-encoders), one
105    /// per input pair, in input order. Tokenizes with a `(query, doc)`
106    /// tuple so `token_type_ids` are 0 for the query side, 1 for the
107    /// doc side — the convention BERT cross-encoders are trained on.
108    /// The tokenizer is pre-configured with `LongestFirst` truncation
109    /// at the model's `max_position_embeddings`, so callers don't need
110    /// to clip outputs.
111    ///
112    /// # Errors
113    ///
114    /// Propagates tokenization or forward-pass errors.
115    pub fn score_pairs(&self, pairs: &[(&str, &str)]) -> crate::Result<Vec<f32>> {
116        if pairs.is_empty() {
117            return Ok(Vec::new());
118        }
119        let encodings: crate::Result<Vec<Encoding>> = pairs
120            .iter()
121            .map(|(q, d)| {
122                // The tokenizer is configured with LongestFirst
123                // truncation in from_pretrained; the returned encoding
124                // already fits within max_position_embeddings and
125                // preserves [CLS] / [SEP] tokens at the correct
126                // positions.
127                let enc = self
128                    .tokenizer
129                    .encode((*q, *d), true)
130                    .map_err(|e| crate::Error::Other(anyhow!("rerank tokenize failed: {e}")))?;
131                Ok(Encoding {
132                    input_ids: enc.get_ids().iter().map(|&x| i64::from(x)).collect(),
133                    attention_mask: enc
134                        .get_attention_mask()
135                        .iter()
136                        .map(|&x| i64::from(x))
137                        .collect(),
138                    token_type_ids: enc.get_type_ids().iter().map(|&x| i64::from(x)).collect(),
139                })
140            })
141            .collect();
142        let encodings = encodings?;
143        self.backend.score_batch(&encodings)
144    }
145
146    /// Max sequence length supported by the underlying model.
147    #[must_use]
148    pub fn max_tokens(&self) -> usize {
149        self.backend.max_tokens()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    /// `Reranker::from_pretrained` works end-to-end on the default model.
158    /// Gated `--ignored` since it downloads weights from `HuggingFace`.
159    ///
160    /// Verifies the two structural claims:
161    /// 1. The cross-encoder ranks a relevant doc higher than an
162    ///    irrelevant one for the same query.
163    /// 2. Scores span a meaningful range (raw logits — the reference
164    ///    spread for this model is roughly [-11, +5]).
165    #[test]
166    #[ignore = "requires network + model download (~22MB)"]
167    fn loads_and_ranks_default_cross_encoder() {
168        let rr = Reranker::from_pretrained(DEFAULT_RERANK_MODEL)
169            .expect("default cross-encoder should load");
170        let scores = rr
171            .score_pairs(&[
172                (
173                    "how to make pasta",
174                    "Boil water, add salt, cook pasta for 8 minutes.",
175                ),
176                (
177                    "how to make pasta",
178                    "The mitochondria is the powerhouse of the cell.",
179                ),
180            ])
181            .expect("scoring should succeed");
182        assert_eq!(scores.len(), 2);
183        assert!(
184            scores[0] > scores[1] + 1.0,
185            "relevant doc ({}) should beat irrelevant ({}) by a clear logit margin",
186            scores[0],
187            scores[1]
188        );
189    }
190}