Skip to main content

inference/
reranker.rs

1//! Cross-encoder reranker for improving recall precision.
2//!
3//! Uses BAAI/bge-reranker-base (Xenova ONNX INT8 quantized) to score
4//! (query, passage) pairs for relevance. More accurate than bi-encoder
5//! vector similarity but slower — used as a second-stage reranker after
6//! ANN candidate retrieval.
7//!
8//! # Architecture
9//!
10//! ```text
11//! query + passage → [CLS] query [SEP] passage [SEP]
12//!                       ↓ BERT forward pass
13//!                   logits [batch, 1]
14//!                       ↓ sigmoid
15//!                   relevance scores ∈ [0, 1]
16//! ```
17
18use crate::engine::EmbeddingEngine;
19use crate::error::{InferenceError, Result};
20use ort::inputs;
21use ort::session::builder::GraphOptimizationLevel;
22use ort::session::Session;
23use ort::value::Tensor;
24use parking_lot::Mutex;
25use std::path::PathBuf;
26use std::sync::Arc;
27use tokenizers::{
28    EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
29};
30use tracing::{info, instrument, warn};
31
32/// The reranker model Xenova HuggingFace repo ID (ONNX INT8).
33const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
34/// ONNX quantized model filename within the repo.
35const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
36/// Maximum token length for cross-encoder input (query + passage combined).
37const MAX_SEQ_LENGTH: usize = 512;
38
39/// Cross-encoder reranking engine.
40///
41/// Thread-safe — can be wrapped in `Arc` and shared across tasks.
42pub struct CrossEncoderEngine {
43    session: Arc<Mutex<Session>>,
44    tokenizer: Arc<Tokenizer>,
45    /// Whether the loaded ONNX model expects a `token_type_ids` input tensor.
46    /// bge-reranker-base only has `input_ids` + `attention_mask`; some other
47    /// cross-encoders include `token_type_ids`. Determined at load time.
48    has_token_type_ids: bool,
49}
50
51impl CrossEncoderEngine {
52    /// Load or download the reranker model.
53    ///
54    /// Downloads `Xenova/bge-reranker-base` ONNX INT8 model from HuggingFace Hub
55    /// if not already cached.
56    #[instrument(skip_all)]
57    pub async fn new(cache_dir: Option<String>) -> Result<Self> {
58        info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
59
60        let (tokenizer_path, onnx_path) =
61            tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
62                .await
63                .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
64                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
65
66        info!("Loading reranker tokenizer from {:?}", tokenizer_path);
67        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
68            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
69
70        // Configure padding + truncation for uniform batch shapes
71        let padding = PaddingParams {
72            strategy: PaddingStrategy::BatchLongest,
73            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
74            pad_token: tokenizer
75                .get_padding()
76                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
77            ..Default::default()
78        };
79        tokenizer.with_padding(Some(padding));
80        let truncation = TruncationParams {
81            max_length: MAX_SEQ_LENGTH,
82            ..Default::default()
83        };
84        let _ = tokenizer.with_truncation(Some(truncation));
85
86        info!("Loading reranker ONNX model from {:?}", onnx_path);
87        let session = Session::builder()
88            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
89            .with_optimization_level(GraphOptimizationLevel::Level3)
90            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
91            .with_intra_threads(4)
92            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
93            .commit_from_file(&onnx_path)
94            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
95
96        // Inspect model inputs to determine if token_type_ids is required.
97        // bge-reranker-base (Xenova ONNX) only has input_ids + attention_mask.
98        let has_token_type_ids = session
99            .inputs()
100            .iter()
101            .any(|i| i.name() == "token_type_ids");
102        info!(
103            has_token_type_ids,
104            "Cross-encoder reranker loaded successfully"
105        );
106        Ok(Self {
107            session: Arc::new(Mutex::new(session)),
108            tokenizer: Arc::new(tokenizer),
109            has_token_type_ids,
110        })
111    }
112
113    /// Score a batch of (query, passage) pairs.
114    ///
115    /// Returns a relevance score in `[0, 1]` for each passage.
116    /// Higher scores indicate greater relevance to the query.
117    ///
118    /// Each pair is tokenized as `[CLS] query [SEP] passage [SEP]`.
119    #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
120    pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
121        if passages.is_empty() {
122            return Ok(Vec::new());
123        }
124
125        let query = query.to_string();
126        let passages = passages.to_vec();
127        let tokenizer = Arc::clone(&self.tokenizer);
128        let session = Arc::clone(&self.session);
129        let has_token_type_ids = self.has_token_type_ids;
130
131        tokio::task::spawn_blocking(move || {
132            score_pairs_blocking(&session, &tokenizer, &query, &passages, has_token_type_ids)
133        })
134        .await
135        .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))?
136    }
137}
138
139/// Blocking cross-encoder inference — runs inside `spawn_blocking`.
140fn score_pairs_blocking(
141    session: &Arc<Mutex<Session>>,
142    tokenizer: &Tokenizer,
143    query: &str,
144    passages: &[String],
145    has_token_type_ids: bool,
146) -> Result<Vec<f32>> {
147    let batch_size = passages.len();
148
149    // Build EncodeInput pairs: [CLS] query [SEP] passage [SEP]
150    let inputs: Vec<EncodeInput> = passages
151        .iter()
152        .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
153        .collect();
154
155    let encodings = tokenizer
156        .encode_batch(inputs, true)
157        .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
158
159    let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
160    if seq_len == 0 {
161        return Ok(vec![0.5; batch_size]);
162    }
163
164    // Flatten to i64 arrays (ORT BERT models expect int64)
165    let mut input_ids = Vec::with_capacity(batch_size * seq_len);
166    let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
167    let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
168
169    for enc in &encodings {
170        input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
171        attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
172        let type_ids = enc.get_type_ids();
173        if type_ids.is_empty() {
174            token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
175        } else {
176            token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
177        }
178    }
179
180    // Build ORT tensors
181    let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
182        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
183    let attention_mask_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
184        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
185    let token_type_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
186        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
187
188    // Run inference and extract scores in one scoped block so `sess` and `outputs`
189    // are dropped before we return (avoids session borrow escaping the mutex guard).
190    let scores: Vec<f32> = {
191        let mut sess = session.lock();
192        let outputs = if has_token_type_ids {
193            sess.run(inputs![
194                "input_ids" => input_ids_tensor,
195                "attention_mask" => attention_mask_tensor,
196                "token_type_ids" => token_type_ids_tensor
197            ])
198            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
199        } else {
200            sess.run(inputs![
201                "input_ids" => input_ids_tensor,
202                "attention_mask" => attention_mask_tensor
203            ])
204            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
205        };
206
207        // Extract logits — bge-reranker-base output shape is [batch_size, 1]
208        let (out_shape, logits_slice) = outputs[0]
209            .try_extract_tensor::<f32>()
210            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
211
212        if out_shape.is_empty() || out_shape[0] as usize != batch_size {
213            warn!(
214                "Reranker output shape mismatch: expected [{}, 1], got {:?}",
215                batch_size, out_shape
216            );
217        }
218
219        // Apply sigmoid → owned Vec<f32> so the borrow on outputs/sess ends here
220        logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
221        // outputs and sess drop here in the correct order
222    };
223
224    if scores.len() != batch_size {
225        warn!(
226            "Reranker score count mismatch: expected {}, got {}",
227            batch_size,
228            scores.len()
229        );
230        let mut padded = scores;
231        padded.resize(batch_size, 0.5);
232        return Ok(padded);
233    }
234
235    Ok(scores)
236}
237
238/// Sigmoid activation: 1 / (1 + exp(-x))
239#[inline]
240fn sigmoid(x: f32) -> f32 {
241    1.0 / (1.0 + (-x).exp())
242}
243
244/// Download tokenizer and ONNX model files for the reranker.
245/// Reuses `EmbeddingEngine::download_hf_file_pub` for redirect-aware caching.
246fn download_reranker_files(
247    cache_dir: Option<String>,
248) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
249    let cache = match cache_dir {
250        Some(dir) => {
251            let p = PathBuf::from(dir);
252            std::fs::create_dir_all(&p)
253                .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
254            p
255        }
256        None => {
257            let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
258            PathBuf::from(home)
259                .join(".cache")
260                .join("huggingface")
261                .join("dakera")
262                .join(RERANKER_REPO_ID.replace('/', "--"))
263        }
264    };
265
266    std::fs::create_dir_all(&cache)
267        .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
268
269    let files = [
270        "tokenizer.json",
271        "tokenizer_config.json",
272        "special_tokens_map.json",
273        RERANKER_ONNX_FILE,
274    ];
275
276    for filename in &files {
277        EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
278            .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
279    }
280
281    let tokenizer_path = cache.join("tokenizer.json");
282    let onnx_path = cache.join(RERANKER_ONNX_FILE);
283    Ok((tokenizer_path, onnx_path))
284}
285
286impl std::fmt::Debug for CrossEncoderEngine {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("CrossEncoderEngine")
289            .field("model", &RERANKER_REPO_ID)
290            .finish()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_sigmoid() {
300        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
301        assert!(sigmoid(10.0) > 0.99);
302        assert!(sigmoid(-10.0) < 0.01);
303    }
304}