Skip to main content

mnem_core/
rerank.rs

1//! Reranker trait: post-fusion rescoring by a model that jointly
2//! encodes `(query, candidate)` pairs.
3//!
4//! # Why
5//!
6//! The retrieve pipeline's base rankers do not read `(query, candidate)`
7//! jointly:
8//!
9//! - The dense vector ranker is a bi-encoder. It embeds the whole query
10//!   into one vector and each doc summary into one vector, then compares
11//!   cosine similarity. It sees phrases but produces only one score per
12//!   doc; the embeddings are encoded independently and never read
13//!   together.
14//! - The learned-sparse ranker scores via a sparse dot product over a
15//!   shared vocabulary. It too is a bi-encoder.
16//!
17//! For compositional paraphrase like "father's sister == aunt", you
18//! need a model that reads the query and a candidate side-by-side and
19//! scores their relevance as a pair. That is a cross-encoder.
20//!
21//! # What this module provides
22//!
23//! A [`Reranker`] trait that adapter crates implement. Industry
24//! cross-encoder providers (Cohere rerank, Voyage rerank, Jina rerank,
25//! local BGE-reranker ONNX) all fit this shape.
26//!
27//! Mnem-core stays tokio-free; adapter crates live next to
28//! [`mnem-embed-providers`](https://github.com/Uranid/mnem) and
29//! do the HTTP work. This file contains only the trait, the error
30//! type, and a deterministic mock for tests.
31//!
32//! # How it plugs in
33//!
34//! [`crate::retrieve::Retriever::with_reranker`] takes a
35//! `Arc<dyn Reranker>`. If set, the retriever re-scores the top-K of
36//! the fused list before budget packing. Failures fall back to the
37//! original fused order (same graceful-degrade policy as the embedder
38//! auto-fuse in the CLI).
39//!
40//!
41use std::fmt::Debug;
42
43use thiserror::Error;
44
45/// Error surface for cross-encoder reranker adapters.
46///
47/// Marked `#[non_exhaustive]` so provider crates can grow their own
48/// failure modes without a breaking change here.
49#[derive(Debug, Error)]
50#[non_exhaustive]
51pub enum RerankError {
52    /// TLS / TCP / DNS / timeout failure reaching the provider.
53    #[error("network error: {0}")]
54    Network(String),
55    /// Provider rejected credentials.
56    #[error("authentication failed: {0}")]
57    Auth(String),
58    /// Provider rate-limited the request.
59    #[error("rate limited: {0}")]
60    RateLimited(String),
61    /// 4xx from the provider.
62    #[error("bad request ({status}): {body}")]
63    BadRequest {
64        /// HTTP status code.
65        status: u16,
66        /// Response body or best-effort error string.
67        body: String,
68    },
69    /// 5xx from the provider.
70    #[error("server error ({status}): {body}")]
71    Server {
72        /// HTTP status code.
73        status: u16,
74        /// Response body or best-effort error string.
75        body: String,
76    },
77    /// Response decoder failed (malformed JSON, missing score field, ...).
78    #[error("decode error: {0}")]
79    Decode(String),
80    /// Adapter config invalid (bad URL, missing env var, etc.).
81    #[error("config error: {0}")]
82    Config(String),
83    /// Model / tokenizer / ONNX session runtime failure, distinct from
84    /// config-time validation. Mirrors [`crate::sparse::SparseError::Inference`]
85    /// so sibling provider traits surface runtime failures with a
86    /// consistent shape.
87    #[error("inference error: {0}")]
88    Inference(String),
89    /// Provider returned a different number of scores than candidates.
90    /// Implementations MUST reject this up front; the retriever would
91    /// otherwise zip mismatched pairs.
92    #[error("score count mismatch: expected {expected}, got {got}")]
93    ScoreCountMismatch {
94        /// Number of candidates sent.
95        expected: usize,
96        /// Number of scores returned.
97        got: usize,
98    },
99}
100
101/// Cross-encoder-style reranker: given a query and a list of
102/// candidate texts, return one relevance score per candidate
103/// (higher is better).
104///
105/// The returned `Vec<f32>` MUST be in the SAME order and same length
106/// as the input `candidates`; the retriever sorts by score and zips
107/// back to node ids. Score range is implementation-defined (Cohere
108/// returns logits; Voyage returns [0, 1]; local ONNX depends on the
109/// head). Callers who need to mix scores from different rerankers
110/// should normalise.
111///
112/// Implementations handle internal batching if the provider has a
113/// per-request cap; the caller passes the full candidate slice.
114pub trait Reranker: Send + Sync + Debug {
115    /// Provider + model identifier. Lowercase, colon-separated by
116    /// convention (e.g. `"cohere:rerank-v3.5"`,
117    /// `"local:bge-reranker-v2-m3"`). Used for logging and cache keys.
118    fn model(&self) -> &str;
119
120    /// Re-score `candidates` against `query`.
121    ///
122    /// # Errors
123    ///
124    /// Any [`RerankError`] the adapter surfaces. The retriever
125    /// gracefully falls back to the fused order on error; it does
126    /// not propagate the failure to the user.
127    fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError>;
128}
129
130/// Deterministic test-only reranker that scores candidates by their
131/// token-overlap Jaccard similarity to the query.
132///
133/// Useful for Retriever-integration tests where a real cross-encoder
134/// is not available; the score is meaningful (shared-tokens ratio)
135/// so rerank-changes-top-1 tests can rely on predictable behaviour.
136/// It is **not** a substitute for a real cross-encoder on
137/// compositional-paraphrase queries - Jaccard is a keyword metric,
138/// so "father's sister" will still not match "aunt."
139#[derive(Debug, Clone, Default)]
140pub struct MockJaccardReranker;
141
142impl Reranker for MockJaccardReranker {
143    fn model(&self) -> &str {
144        "mock:jaccard"
145    }
146
147    fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
148        let q = token_set(query);
149        Ok(candidates
150            .iter()
151            .map(|c| {
152                let c_tokens = token_set(c);
153                let inter = q.intersection(&c_tokens).count() as f32;
154                let union_ = q.union(&c_tokens).count() as f32;
155                if union_ == 0.0 { 0.0 } else { inter / union_ }
156            })
157            .collect())
158    }
159}
160
161/// Test-only reranker that always errors. Proves the graceful
162/// fallback path in [`crate::retrieve::Retriever::execute`].
163#[derive(Debug, Clone, Default)]
164pub struct AlwaysFailReranker;
165
166impl Reranker for AlwaysFailReranker {
167    fn model(&self) -> &str {
168        "mock:always-fail"
169    }
170
171    fn rerank(&self, _query: &str, _candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
172        Err(RerankError::Network(
173            "intentional failure for test".to_string(),
174        ))
175    }
176}
177
178/// Simple ASCII/unicode alphanumeric tokenizer used only by the
179/// deterministic mock reranker. Lowercases, splits on non-alphanumeric
180/// runs, drops empty tokens. Not exposed publicly; the mock is the
181/// only caller.
182fn token_set(text: &str) -> std::collections::HashSet<String> {
183    let mut out: std::collections::HashSet<String> = std::collections::HashSet::new();
184    let mut buf = String::new();
185    for ch in text.chars() {
186        if ch.is_alphanumeric() {
187            for lc in ch.to_lowercase() {
188                buf.push(lc);
189            }
190        } else if !buf.is_empty() {
191            out.insert(std::mem::take(&mut buf));
192        }
193    }
194    if !buf.is_empty() {
195        out.insert(buf);
196    }
197    out
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn mock_jaccard_identical_query_and_candidate_scores_one() {
206        let r = MockJaccardReranker;
207        let s = r.rerank("alice bob", &["alice bob"]).unwrap();
208        assert!((s[0] - 1.0).abs() < 1e-6);
209    }
210
211    #[test]
212    fn mock_jaccard_disjoint_scores_zero() {
213        let r = MockJaccardReranker;
214        let s = r.rerank("alice", &["zed"]).unwrap();
215        assert_eq!(s[0], 0.0);
216    }
217
218    #[test]
219    fn mock_jaccard_partial_overlap() {
220        let r = MockJaccardReranker;
221        // Query tokens: {alice, bob}. Candidate: {alice, carol}.
222        // Inter = 1, Union = 3, score = 1/3.
223        let s = r.rerank("alice bob", &["alice carol"]).unwrap();
224        assert!((s[0] - (1.0 / 3.0)).abs() < 1e-6);
225    }
226
227    #[test]
228    fn mock_jaccard_length_matches_candidates_len() {
229        let r = MockJaccardReranker;
230        let s = r
231            .rerank("alpha", &["alpha beta", "alpha gamma", "delta"])
232            .unwrap();
233        assert_eq!(s.len(), 3);
234    }
235
236    #[test]
237    fn mock_jaccard_empty_candidates_empty_output() {
238        let r = MockJaccardReranker;
239        let s = r.rerank("alpha", &[]).unwrap();
240        assert!(s.is_empty());
241    }
242
243    #[test]
244    fn always_fail_reranker_returns_err() {
245        let r = AlwaysFailReranker;
246        assert!(r.rerank("q", &["c"]).is_err());
247    }
248
249    #[test]
250    fn model_id_contains_provider_prefix() {
251        assert_eq!(MockJaccardReranker.model(), "mock:jaccard");
252        assert_eq!(AlwaysFailReranker.model(), "mock:always-fail");
253    }
254}