reflex/embedding/reranker/
mod.rs

1//! Cross-encoder reranker used for L3 verification.
2
3/// Reranker configuration.
4pub mod config;
5/// Reranker error types.
6pub mod error;
7
8#[cfg(test)]
9mod tests;
10
11pub use config::{DEFAULT_THRESHOLD, MAX_SEQ_LEN, RerankerConfig};
12pub use error::RerankerError;
13
14use crate::embedding::bert::BertClassifier;
15use candle_core::Tensor;
16use tokenizers::Tokenizer;
17use tracing::{debug, info};
18
19use crate::embedding::device::select_device;
20use crate::embedding::utils::load_tokenizer_with_truncation;
21
22/// Cross-encoder model used to rerank candidates (stub mode supported).
23pub struct Reranker {
24    device: candle_core::Device,
25    config: RerankerConfig,
26    model_loaded: bool,
27    model: Option<BertClassifier>,
28    tokenizer: Option<Tokenizer>,
29}
30
31impl std::fmt::Debug for Reranker {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("Reranker")
34            .field("device", &format!("{:?}", self.device))
35            .field("config", &self.config)
36            .field("model_loaded", &self.model_loaded)
37            .finish()
38    }
39}
40
41impl Reranker {
42    /// Loads a reranker from config (or creates a stub if no model is configured).
43    pub fn load(config: RerankerConfig) -> Result<Self, RerankerError> {
44        if let Err(msg) = config.validate() {
45            return Err(RerankerError::InvalidConfig { reason: msg });
46        }
47
48        let device = select_device()?;
49        debug!(?device, "Selected compute device for reranker");
50
51        if let Some(ref model_path) = config.model_path {
52            if !model_path.exists() {
53                return Err(RerankerError::ModelLoadFailed {
54                    reason: format!("Reranker model path not found: {}", model_path.display()),
55                });
56            }
57
58            let config_path = model_path.join("config.json");
59            if !config_path.exists() {
60                return Err(RerankerError::ModelLoadFailed {
61                    reason: format!("Missing config.json in {}", model_path.display()),
62                });
63            }
64
65            let weights_path = model_path.join("model.safetensors");
66            if !weights_path.exists() {
67                return Err(RerankerError::ModelLoadFailed {
68                    reason: format!("Missing model.safetensors in {}", model_path.display()),
69                });
70            }
71
72            info!(
73                model_path = %model_path.display(),
74                threshold = config.threshold,
75                "Loading reranker model"
76            );
77
78            let model = BertClassifier::load(model_path, &device).map_err(|e| {
79                RerankerError::ModelLoadFailed {
80                    reason: format!("Failed to load BERT model: {}", e),
81                }
82            })?;
83
84            let tokenizer =
85                load_tokenizer_with_truncation(model_path, MAX_SEQ_LEN).map_err(|e| {
86                    RerankerError::ModelLoadFailed {
87                        reason: format!("Failed to load tokenizer: {}", e),
88                    }
89                })?;
90
91            info!(
92                threshold = config.threshold,
93                "Reranker model loaded successfully"
94            );
95
96            Ok(Self {
97                device,
98                config,
99                model_loaded: true,
100                model: Some(model),
101                tokenizer: Some(tokenizer),
102            })
103        } else {
104            info!("No reranker model path configured, operating in stub mode");
105            Ok(Self::create_stub(device, config))
106        }
107    }
108
109    /// Creates a stub reranker.
110    pub fn stub() -> Result<Self, RerankerError> {
111        Self::load(RerankerConfig::stub())
112    }
113
114    fn create_stub(device: candle_core::Device, config: RerankerConfig) -> Self {
115        Self {
116            device,
117            config,
118            model_loaded: false,
119            model: None,
120            tokenizer: None,
121        }
122    }
123
124    /// Scores a single query/candidate pair.
125    pub fn score(&self, query: &str, candidate: &str) -> Result<f32, RerankerError> {
126        debug!(
127            query_len = query.len(),
128            candidate_len = candidate.len(),
129            model_loaded = self.model_loaded,
130            "Scoring query-candidate pair"
131        );
132
133        if let (Some(model), Some(tokenizer)) = (&self.model, &self.tokenizer) {
134            let tokens = tokenizer.encode((query, candidate), true).map_err(|e| {
135                RerankerError::TokenizationFailed {
136                    reason: e.to_string(),
137                }
138            })?;
139
140            let token_ids = tokens.get_ids();
141            let token_ids = Tensor::new(token_ids, &self.device)
142                .map_err(RerankerError::from)?
143                .unsqueeze(0)
144                .map_err(RerankerError::from)?;
145
146            let type_ids = tokens.get_type_ids();
147            let type_ids = Tensor::new(type_ids, &self.device)
148                .map_err(RerankerError::from)?
149                .unsqueeze(0)
150                .map_err(RerankerError::from)?;
151
152            // Use the tokenizer's attention mask to properly handle padding tokens.
153            // Previously used ones_like() which is incorrect when padding is present.
154            let attention_mask_data = tokens.get_attention_mask();
155            let attention_mask = Tensor::new(attention_mask_data, &self.device)
156                .map_err(RerankerError::from)?
157                .unsqueeze(0)
158                .map_err(RerankerError::from)?;
159
160            let logits = model
161                .forward(&token_ids, &type_ids, Some(&attention_mask))
162                .map_err(|e| RerankerError::InferenceFailed {
163                    reason: e.to_string(),
164                })?;
165
166            let score = logits
167                .flatten_all()
168                .map_err(RerankerError::from)?
169                .to_vec1::<f32>()
170                .map_err(RerankerError::from)?[0];
171            return Ok(score);
172        }
173
174        let score = self.compute_placeholder_score(query, candidate);
175
176        debug!(score = score, "Computed score (stub)");
177
178        Ok(score)
179    }
180
181    /// Scores and sorts candidates best-first.
182    pub fn rerank(
183        &self,
184        query: &str,
185        candidates: &[&str],
186    ) -> Result<Vec<(usize, f32)>, RerankerError> {
187        debug!(
188            query_len = query.len(),
189            num_candidates = candidates.len(),
190            "Reranking candidates"
191        );
192
193        let mut scored: Vec<(usize, f32)> = candidates
194            .iter()
195            .enumerate()
196            .map(|(idx, candidate)| {
197                let score = self.score(query, candidate)?;
198                Ok((idx, score))
199            })
200            .collect::<Result<Vec<_>, RerankerError>>()?;
201
202        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
203
204        debug!(
205            top_score = scored.first().map(|(_, s)| *s),
206            "Reranking complete"
207        );
208
209        Ok(scored)
210    }
211
212    /// Reranks and filters to scores above the configured threshold.
213    pub fn rerank_with_threshold(
214        &self,
215        query: &str,
216        candidates: &[&str],
217    ) -> Result<Vec<(usize, f32)>, RerankerError> {
218        let ranked = self.rerank(query, candidates)?;
219        let threshold = self.config.threshold;
220
221        let filtered: Vec<_> = ranked
222            .into_iter()
223            .filter(|(_, score)| *score > threshold)
224            .collect();
225
226        debug!(
227            threshold = threshold,
228            hits = filtered.len(),
229            total = candidates.len(),
230            "Filtered by threshold"
231        );
232
233        Ok(filtered)
234    }
235
236    /// Returns `true` if a model was loaded (vs stub mode).
237    pub fn is_model_loaded(&self) -> bool {
238        self.model_loaded
239    }
240
241    /// Returns the configured verification threshold.
242    pub fn threshold(&self) -> f32 {
243        self.config.threshold
244    }
245
246    /// Returns the active configuration.
247    pub fn config(&self) -> &RerankerConfig {
248        &self.config
249    }
250
251    /// Returns the compute device.
252    pub fn device(&self) -> &candle_core::Device {
253        &self.device
254    }
255
256    /// Returns `true` if `score` exceeds the configured threshold.
257    pub fn is_hit(&self, score: f32) -> bool {
258        score > self.config.threshold
259    }
260
261    fn compute_placeholder_score(&self, query: &str, candidate: &str) -> f32 {
262        use std::collections::HashSet;
263
264        let stop_words: HashSet<&str> = [
265            "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
266            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
267            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
268            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
269            "below", "between", "under", "again", "further", "then", "once", "here", "there",
270            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
271            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
272            "and", "but", "if", "or", "because", "until", "while", "what", "which", "who", "whom",
273            "this", "that", "these", "those", "am", "it", "its",
274        ]
275        .into_iter()
276        .collect();
277
278        let query_lower = query.to_lowercase();
279        let query_words: HashSet<&str> = query_lower
280            .split(|c: char| !c.is_alphanumeric())
281            .filter(|w| !w.is_empty() && !stop_words.contains(w))
282            .collect();
283
284        let candidate_lower = candidate.to_lowercase();
285        let candidate_words: HashSet<&str> = candidate_lower
286            .split(|c: char| !c.is_alphanumeric())
287            .filter(|w| !w.is_empty() && !stop_words.contains(w))
288            .collect();
289
290        if query_words.is_empty() {
291            let len_ratio = (query.len().min(candidate.len()) as f32)
292                / (query.len().max(candidate.len()).max(1) as f32);
293            return len_ratio * 0.3;
294        }
295
296        let matches = query_words.intersection(&candidate_words).count();
297        let recall = matches as f32 / query_words.len() as f32;
298
299        let union = query_words.union(&candidate_words).count();
300        let jaccard = if union > 0 {
301            matches as f32 / union as f32
302        } else {
303            0.0
304        };
305
306        let base_score = 0.6 * recall + 0.4 * jaccard;
307
308        let normalized = 1.0 / (1.0 + (-8.0 * (base_score - 0.5)).exp());
309
310        normalized.clamp(0.0, 1.0)
311    }
312}