mnem_core/rerank.rs
1//! Reranker trait: post-fusion rescoring by a model that jointly
2//! encodes `(query, candidate)` pairs.
3//!
4//! # Why
5//!
6//! The retrieve pipeline's base rankers do not read `(query, candidate)`
7//! jointly:
8//!
9//! - The dense vector ranker is a bi-encoder. It embeds the whole query
10//! into one vector and each doc summary into one vector, then compares
11//! cosine similarity. It sees phrases but produces only one score per
12//! doc; the embeddings are encoded independently and never read
13//! together.
14//! - The learned-sparse ranker scores via a sparse dot product over a
15//! shared vocabulary. It too is a bi-encoder.
16//!
17//! For compositional paraphrase like "father's sister == aunt", you
18//! need a model that reads the query and a candidate side-by-side and
19//! scores their relevance as a pair. That is a cross-encoder.
20//!
21//! # What this module provides
22//!
23//! A [`Reranker`] trait that adapter crates implement. Industry
24//! cross-encoder providers (Cohere rerank, Voyage rerank, Jina rerank,
25//! local BGE-reranker ONNX) all fit this shape.
26//!
27//! Mnem-core stays tokio-free; adapter crates live next to
28//! [`mnem-embed-providers`](https://github.com/Uranid/mnem) and
29//! do the HTTP work. This file contains only the trait, the error
30//! type, and a deterministic mock for tests.
31//!
32//! # How it plugs in
33//!
34//! [`crate::retrieve::Retriever::with_reranker`] takes a
35//! `Arc<dyn Reranker>`. If set, the retriever re-scores the top-K of
36//! the fused list before budget packing. Failures fall back to the
37//! original fused order (same graceful-degrade policy as the embedder
38//! auto-fuse in the CLI).
39//!
40//!
41use std::fmt::Debug;
42
43use thiserror::Error;
44
45/// Error surface for cross-encoder reranker adapters.
46///
47/// Marked `#[non_exhaustive]` so provider crates can grow their own
48/// failure modes without a breaking change here.
49#[derive(Debug, Error)]
50#[non_exhaustive]
51pub enum RerankError {
52 /// TLS / TCP / DNS / timeout failure reaching the provider.
53 #[error("network error: {0}")]
54 Network(String),
55 /// Provider rejected credentials.
56 #[error("authentication failed: {0}")]
57 Auth(String),
58 /// Provider rate-limited the request.
59 #[error("rate limited: {0}")]
60 RateLimited(String),
61 /// 4xx from the provider.
62 #[error("bad request ({status}): {body}")]
63 BadRequest {
64 /// HTTP status code.
65 status: u16,
66 /// Response body or best-effort error string.
67 body: String,
68 },
69 /// 5xx from the provider.
70 #[error("server error ({status}): {body}")]
71 Server {
72 /// HTTP status code.
73 status: u16,
74 /// Response body or best-effort error string.
75 body: String,
76 },
77 /// Response decoder failed (malformed JSON, missing score field, ...).
78 #[error("decode error: {0}")]
79 Decode(String),
80 /// Adapter config invalid (bad URL, missing env var, etc.).
81 #[error("config error: {0}")]
82 Config(String),
83 /// Model / tokenizer / ONNX session runtime failure, distinct from
84 /// config-time validation. Mirrors [`crate::sparse::SparseError::Inference`]
85 /// so sibling provider traits surface runtime failures with a
86 /// consistent shape.
87 #[error("inference error: {0}")]
88 Inference(String),
89 /// Provider returned a different number of scores than candidates.
90 /// Implementations MUST reject this up front; the retriever would
91 /// otherwise zip mismatched pairs.
92 #[error("score count mismatch: expected {expected}, got {got}")]
93 ScoreCountMismatch {
94 /// Number of candidates sent.
95 expected: usize,
96 /// Number of scores returned.
97 got: usize,
98 },
99}
100
101/// Cross-encoder-style reranker: given a query and a list of
102/// candidate texts, return one relevance score per candidate
103/// (higher is better).
104///
105/// The returned `Vec<f32>` MUST be in the SAME order and same length
106/// as the input `candidates`; the retriever sorts by score and zips
107/// back to node ids. Score range is implementation-defined (Cohere
108/// returns logits; Voyage returns [0, 1]; local ONNX depends on the
109/// head). Callers who need to mix scores from different rerankers
110/// should normalise.
111///
112/// Implementations handle internal batching if the provider has a
113/// per-request cap; the caller passes the full candidate slice.
114pub trait Reranker: Send + Sync + Debug {
115 /// Provider + model identifier. Lowercase, colon-separated by
116 /// convention (e.g. `"cohere:rerank-v3.5"`,
117 /// `"local:bge-reranker-v2-m3"`). Used for logging and cache keys.
118 fn model(&self) -> &str;
119
120 /// Re-score `candidates` against `query`.
121 ///
122 /// # Errors
123 ///
124 /// Any [`RerankError`] the adapter surfaces. The retriever
125 /// gracefully falls back to the fused order on error; it does
126 /// not propagate the failure to the user.
127 fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError>;
128}
129
130/// Deterministic test-only reranker that scores candidates by their
131/// token-overlap Jaccard similarity to the query.
132///
133/// Useful for Retriever-integration tests where a real cross-encoder
134/// is not available; the score is meaningful (shared-tokens ratio)
135/// so rerank-changes-top-1 tests can rely on predictable behaviour.
136/// It is **not** a substitute for a real cross-encoder on
137/// compositional-paraphrase queries - Jaccard is a keyword metric,
138/// so "father's sister" will still not match "aunt."
139#[derive(Debug, Clone, Default)]
140pub struct MockJaccardReranker;
141
142impl Reranker for MockJaccardReranker {
143 fn model(&self) -> &str {
144 "mock:jaccard"
145 }
146
147 fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
148 let q = token_set(query);
149 Ok(candidates
150 .iter()
151 .map(|c| {
152 let c_tokens = token_set(c);
153 let inter = q.intersection(&c_tokens).count() as f32;
154 let union_ = q.union(&c_tokens).count() as f32;
155 if union_ == 0.0 { 0.0 } else { inter / union_ }
156 })
157 .collect())
158 }
159}
160
161/// Test-only reranker that always errors. Proves the graceful
162/// fallback path in [`crate::retrieve::Retriever::execute`].
163#[derive(Debug, Clone, Default)]
164pub struct AlwaysFailReranker;
165
166impl Reranker for AlwaysFailReranker {
167 fn model(&self) -> &str {
168 "mock:always-fail"
169 }
170
171 fn rerank(&self, _query: &str, _candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
172 Err(RerankError::Network(
173 "intentional failure for test".to_string(),
174 ))
175 }
176}
177
178/// Simple ASCII/unicode alphanumeric tokenizer used only by the
179/// deterministic mock reranker. Lowercases, splits on non-alphanumeric
180/// runs, drops empty tokens. Not exposed publicly; the mock is the
181/// only caller.
182fn token_set(text: &str) -> std::collections::HashSet<String> {
183 let mut out: std::collections::HashSet<String> = std::collections::HashSet::new();
184 let mut buf = String::new();
185 for ch in text.chars() {
186 if ch.is_alphanumeric() {
187 for lc in ch.to_lowercase() {
188 buf.push(lc);
189 }
190 } else if !buf.is_empty() {
191 out.insert(std::mem::take(&mut buf));
192 }
193 }
194 if !buf.is_empty() {
195 out.insert(buf);
196 }
197 out
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn mock_jaccard_identical_query_and_candidate_scores_one() {
206 let r = MockJaccardReranker;
207 let s = r.rerank("alice bob", &["alice bob"]).unwrap();
208 assert!((s[0] - 1.0).abs() < 1e-6);
209 }
210
211 #[test]
212 fn mock_jaccard_disjoint_scores_zero() {
213 let r = MockJaccardReranker;
214 let s = r.rerank("alice", &["zed"]).unwrap();
215 assert_eq!(s[0], 0.0);
216 }
217
218 #[test]
219 fn mock_jaccard_partial_overlap() {
220 let r = MockJaccardReranker;
221 // Query tokens: {alice, bob}. Candidate: {alice, carol}.
222 // Inter = 1, Union = 3, score = 1/3.
223 let s = r.rerank("alice bob", &["alice carol"]).unwrap();
224 assert!((s[0] - (1.0 / 3.0)).abs() < 1e-6);
225 }
226
227 #[test]
228 fn mock_jaccard_length_matches_candidates_len() {
229 let r = MockJaccardReranker;
230 let s = r
231 .rerank("alpha", &["alpha beta", "alpha gamma", "delta"])
232 .unwrap();
233 assert_eq!(s.len(), 3);
234 }
235
236 #[test]
237 fn mock_jaccard_empty_candidates_empty_output() {
238 let r = MockJaccardReranker;
239 let s = r.rerank("alpha", &[]).unwrap();
240 assert!(s.is_empty());
241 }
242
243 #[test]
244 fn always_fail_reranker_returns_err() {
245 let r = AlwaysFailReranker;
246 assert!(r.rerank("q", &["c"]).is_err());
247 }
248
249 #[test]
250 fn model_id_contains_provider_prefix() {
251 assert_eq!(MockJaccardReranker.model(), "mock:jaccard");
252 assert_eq!(AlwaysFailReranker.model(), "mock:always-fail");
253 }
254}