Skip to main content

inference/
reranker.rs

1//! Cross-encoder reranker for improving recall precision.
2//!
3//! Uses BAAI/bge-reranker-base (Xenova ONNX INT8 quantized) to score
4//! (query, passage) pairs for relevance. More accurate than bi-encoder
5//! vector similarity but slower — used as a second-stage reranker after
6//! ANN candidate retrieval.
7//!
8//! # Architecture
9//!
10//! ```text
11//! query + passage → [CLS] query [SEP] passage [SEP]
12//!                       ↓ BERT forward pass
13//!                   logits [batch, 1]
14//!                       ↓ sigmoid
15//!                   relevance scores ∈ [0, 1]
16//! ```
17//!
18//! # Session Pool
19//!
20//! The engine maintains `RERANKER_POOL_SIZE` independent ONNX sessions.
21//! Large candidate lists are split into chunks of `RERANKER_CHUNK_SIZE` and
22//! dispatched in parallel across the pool, eliminating head-of-line blocking
23//! when multiple recall calls arrive concurrently (DAK-5873).
24
25use crate::engine::EmbeddingEngine;
26use crate::error::{InferenceError, Result};
27use ort::inputs;
28use ort::session::builder::GraphOptimizationLevel;
29use ort::session::Session;
30use ort::value::Tensor;
31use parking_lot::Mutex;
32use std::path::PathBuf;
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::Arc;
35use tokenizers::{
36    EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
37};
38use tracing::{info, instrument, warn};
39
40/// The reranker model Xenova HuggingFace repo ID (ONNX INT8).
41const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
42/// ONNX quantized model filename within the repo.
43const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
44/// Maximum token length for cross-encoder input (query + passage combined).
45const MAX_SEQ_LENGTH: usize = 512;
46/// Number of independent ONNX sessions in the reranker pool.
47///
48/// Two sessions allow concurrent recall requests to rerank in parallel without
49/// head-of-line mutex blocking. Each session uses `intra_threads=4`; two sessions
50/// occupy all 8 vCPUs on the production CPX32 server (DAK-5873).
51const RERANKER_POOL_SIZE: usize = 2;
52/// Maximum candidates per session sub-batch.
53///
54/// Large candidate lists (e.g. temporal `fetch_n = top_k × 8 = 160`) are split
55/// into chunks of this size and dispatched concurrently across the pool. With
56/// `RERANKER_POOL_SIZE=2` and `RERANKER_CHUNK_SIZE=32`, a 160-candidate list
57/// produces 5 chunks: sessions[0] handles chunks 0/2/4, sessions[1] handles
58/// chunks 1/3 — 3 serial chunks on the busier session vs. 5 serial before.
59const RERANKER_CHUNK_SIZE: usize = 32;
60
61/// Cross-encoder reranking engine.
62///
63/// Thread-safe — shared via `Arc`. Maintains a pool of independent ONNX sessions
64/// so concurrent rerank calls never contend on a single mutex.
65pub struct CrossEncoderEngine {
66    /// Pool of independent ONNX sessions (round-robin dispatch).
67    sessions: Vec<Arc<Mutex<Session>>>,
68    tokenizer: Arc<Tokenizer>,
69    /// Whether the loaded ONNX model expects a `token_type_ids` input tensor.
70    /// bge-reranker-base only has `input_ids` + `attention_mask`; some other
71    /// cross-encoders include `token_type_ids`. Determined at load time.
72    has_token_type_ids: bool,
73    /// Round-robin counter for session assignment.
74    next_session: AtomicUsize,
75}
76
77impl CrossEncoderEngine {
78    /// Load or download the reranker model.
79    ///
80    /// Downloads `Xenova/bge-reranker-base` ONNX INT8 model from HuggingFace Hub
81    /// if not already cached. Builds `RERANKER_POOL_SIZE` independent sessions.
82    #[instrument(skip_all)]
83    pub async fn new(cache_dir: Option<String>) -> Result<Self> {
84        info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
85
86        let (tokenizer_path, onnx_path) =
87            tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
88                .await
89                .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
90                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
91
92        info!("Loading reranker tokenizer from {:?}", tokenizer_path);
93        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
94            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
95
96        // Configure padding + truncation for uniform batch shapes
97        let padding = PaddingParams {
98            strategy: PaddingStrategy::BatchLongest,
99            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
100            pad_token: tokenizer
101                .get_padding()
102                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
103            ..Default::default()
104        };
105        tokenizer.with_padding(Some(padding));
106        let truncation = TruncationParams {
107            max_length: MAX_SEQ_LENGTH,
108            ..Default::default()
109        };
110        let _ = tokenizer.with_truncation(Some(truncation));
111
112        info!(
113            "Loading reranker ONNX model from {:?} (pool_size={})",
114            onnx_path, RERANKER_POOL_SIZE
115        );
116
117        // Build pool of independent ONNX sessions — each has its own ORT context
118        // so pool members never block each other under concurrent rerank calls.
119        let (sessions, has_token_type_ids) =
120            tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
121                let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
122                    .map(|_| {
123                        Session::builder()
124                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
125                            .with_optimization_level(GraphOptimizationLevel::Level3)
126                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
127                            .with_intra_threads(4)
128                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
129                            .commit_from_file(&onnx_path)
130                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))
131                    })
132                    .collect();
133                let raw = raw?;
134                // Inspect first session to detect optional token_type_ids input.
135                let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
136                let sessions: Vec<Arc<Mutex<Session>>> =
137                    raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
138                Ok((sessions, has_tti))
139            })
140            .await
141            .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
142            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
143
144        info!(
145            has_token_type_ids,
146            pool_size = sessions.len(),
147            "Cross-encoder reranker loaded successfully"
148        );
149
150        Ok(Self {
151            sessions,
152            tokenizer: Arc::new(tokenizer),
153            has_token_type_ids,
154            next_session: AtomicUsize::new(0),
155        })
156    }
157
158    /// Score a batch of (query, passage) pairs.
159    ///
160    /// Passages are split into chunks of [`RERANKER_CHUNK_SIZE`] and dispatched
161    /// in parallel across the session pool (round-robin). Chunk results are
162    /// reassembled in input order.
163    ///
164    /// Returns a relevance score in `[0, 1]` for each passage.
165    /// Higher scores indicate greater relevance to the query.
166    #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
167    pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
168        if passages.is_empty() {
169            return Ok(Vec::new());
170        }
171
172        let pool_len = self.sessions.len();
173        // Round-robin start: each concurrent caller gets a different initial slot
174        // so concurrent requests don't all contend on sessions[0].
175        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
176        let tokenizer = Arc::clone(&self.tokenizer);
177        let has_token_type_ids = self.has_token_type_ids;
178        let query_str = query.to_string();
179
180        // Split candidates into RERANKER_CHUNK_SIZE sub-batches.
181        let chunks: Vec<Vec<String>> = passages
182            .chunks(RERANKER_CHUNK_SIZE)
183            .map(<[String]>::to_vec)
184            .collect();
185
186        // Spawn all chunks concurrently; each acquires its own session slot.
187        let mut handles = Vec::with_capacity(chunks.len());
188        for (i, chunk) in chunks.into_iter().enumerate() {
189            let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
190            let tok = Arc::clone(&tokenizer);
191            let q = query_str.clone();
192            handles.push(tokio::task::spawn_blocking(move || {
193                score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
194            }));
195        }
196
197        // Collect results in chunk order to preserve passage ordering.
198        let mut scores = Vec::with_capacity(passages.len());
199        for handle in handles {
200            let chunk_scores = handle
201                .await
202                .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
203            scores.extend(chunk_scores);
204        }
205
206        Ok(scores)
207    }
208
209    /// Number of parallel ONNX sessions in the pool.
210    pub fn pool_size(&self) -> usize {
211        self.sessions.len()
212    }
213}
214
215/// Blocking cross-encoder inference for one sub-batch — runs inside `spawn_blocking`.
216fn score_pairs_blocking(
217    session: &Arc<Mutex<Session>>,
218    tokenizer: &Tokenizer,
219    query: &str,
220    passages: &[String],
221    has_token_type_ids: bool,
222) -> Result<Vec<f32>> {
223    let batch_size = passages.len();
224
225    // Build EncodeInput pairs: [CLS] query [SEP] passage [SEP]
226    let inputs: Vec<EncodeInput> = passages
227        .iter()
228        .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
229        .collect();
230
231    let encodings = tokenizer
232        .encode_batch(inputs, true)
233        .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
234
235    let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
236    if seq_len == 0 {
237        return Ok(vec![0.5; batch_size]);
238    }
239
240    // Flatten to i64 arrays (ORT BERT models expect int64)
241    let mut input_ids = Vec::with_capacity(batch_size * seq_len);
242    let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
243    let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
244
245    for enc in &encodings {
246        input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
247        attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
248        let type_ids = enc.get_type_ids();
249        if type_ids.is_empty() {
250            token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
251        } else {
252            token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
253        }
254    }
255
256    // Build ORT tensors
257    let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
258        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
259    let attention_mask_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
260        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
261    let token_type_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
262        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
263
264    // Run inference and extract scores in one scoped block so `sess` and `outputs`
265    // are dropped before we return (avoids session borrow escaping the mutex guard).
266    let scores: Vec<f32> = {
267        let mut sess = session.lock();
268        let outputs = if has_token_type_ids {
269            sess.run(inputs![
270                "input_ids" => input_ids_tensor,
271                "attention_mask" => attention_mask_tensor,
272                "token_type_ids" => token_type_ids_tensor
273            ])
274            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
275        } else {
276            sess.run(inputs![
277                "input_ids" => input_ids_tensor,
278                "attention_mask" => attention_mask_tensor
279            ])
280            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
281        };
282
283        // Extract logits — bge-reranker-base output shape is [batch_size, 1]
284        let (out_shape, logits_slice) = outputs[0]
285            .try_extract_tensor::<f32>()
286            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
287
288        if out_shape.is_empty() || out_shape[0] as usize != batch_size {
289            warn!(
290                "Reranker output shape mismatch: expected [{}, 1], got {:?}",
291                batch_size, out_shape
292            );
293        }
294
295        // Apply sigmoid → owned Vec<f32> so the borrow on outputs/sess ends here
296        logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
297        // outputs and sess drop here in the correct order
298    };
299
300    if scores.len() != batch_size {
301        warn!(
302            "Reranker score count mismatch: expected {}, got {}",
303            batch_size,
304            scores.len()
305        );
306        let mut padded = scores;
307        padded.resize(batch_size, 0.5);
308        return Ok(padded);
309    }
310
311    Ok(scores)
312}
313
314/// Sigmoid activation: 1 / (1 + exp(-x))
315#[inline]
316fn sigmoid(x: f32) -> f32 {
317    1.0 / (1.0 + (-x).exp())
318}
319
320/// Download tokenizer and ONNX model files for the reranker.
321/// Reuses `EmbeddingEngine::download_hf_file_pub` for redirect-aware caching.
322fn download_reranker_files(
323    cache_dir: Option<String>,
324) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
325    let cache = match cache_dir {
326        Some(dir) => {
327            let p = PathBuf::from(dir);
328            std::fs::create_dir_all(&p)
329                .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
330            p
331        }
332        None => {
333            let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
334            PathBuf::from(home)
335                .join(".cache")
336                .join("huggingface")
337                .join("dakera")
338                .join(RERANKER_REPO_ID.replace('/', "--"))
339        }
340    };
341
342    std::fs::create_dir_all(&cache)
343        .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
344
345    let files = [
346        "tokenizer.json",
347        "tokenizer_config.json",
348        "special_tokens_map.json",
349        RERANKER_ONNX_FILE,
350    ];
351
352    for filename in &files {
353        EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
354            .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
355    }
356
357    let tokenizer_path = cache.join("tokenizer.json");
358    let onnx_path = cache.join(RERANKER_ONNX_FILE);
359    Ok((tokenizer_path, onnx_path))
360}
361
362impl std::fmt::Debug for CrossEncoderEngine {
363    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        f.debug_struct("CrossEncoderEngine")
365            .field("model", &RERANKER_REPO_ID)
366            .field("pool_size", &self.sessions.len())
367            .finish()
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_sigmoid() {
377        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
378        assert!(sigmoid(10.0) > 0.99);
379        assert!(sigmoid(-10.0) < 0.01);
380    }
381
382    #[test]
383    fn test_chunk_count_exact() {
384        // 64 passages / chunk_size=32 → exactly 2 chunks
385        let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
386        let chunks: Vec<Vec<String>> = passages
387            .chunks(RERANKER_CHUNK_SIZE)
388            .map(<[String]>::to_vec)
389            .collect();
390        assert_eq!(chunks.len(), 2);
391        assert_eq!(chunks[0].len(), 32);
392        assert_eq!(chunks[1].len(), 32);
393    }
394
395    #[test]
396    fn test_chunk_count_remainder() {
397        // 50 passages / chunk_size=32 → 2 chunks (32 + 18)
398        let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
399        let chunks: Vec<Vec<String>> = passages
400            .chunks(RERANKER_CHUNK_SIZE)
401            .map(<[String]>::to_vec)
402            .collect();
403        assert_eq!(chunks.len(), 2);
404        assert_eq!(chunks[0].len(), 32);
405        assert_eq!(chunks[1].len(), 18);
406    }
407
408    #[test]
409    fn test_chunk_count_small_batch() {
410        // 10 passages → single chunk, no splitting overhead
411        let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
412        let chunks: Vec<Vec<String>> = passages
413            .chunks(RERANKER_CHUNK_SIZE)
414            .map(<[String]>::to_vec)
415            .collect();
416        assert_eq!(chunks.len(), 1);
417        assert_eq!(chunks[0].len(), 10);
418    }
419
420    #[test]
421    fn test_chunk_order_preserved() {
422        // Chunk splitting must preserve passage order for score reassembly.
423        let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
424        let reassembled: Vec<String> = passages
425            .chunks(RERANKER_CHUNK_SIZE)
426            .flat_map(<[String]>::to_vec)
427            .collect();
428        assert_eq!(passages, reassembled);
429    }
430
431    #[test]
432    fn test_pool_size_constant() {
433        const { assert!(RERANKER_POOL_SIZE >= 1) };
434        const { assert!(RERANKER_CHUNK_SIZE >= 1) };
435    }
436
437    #[test]
438    fn test_round_robin_wraps() {
439        let pool_len = RERANKER_POOL_SIZE;
440        // Simulate 10 concurrent callers; each gets a different start_idx.
441        // Verify no start_idx exceeds pool_len when used with modulo.
442        for start in 0usize..10 {
443            let idx = start % pool_len;
444            assert!(idx < pool_len);
445        }
446    }
447}