1use crate::engine::EmbeddingEngine;
19use crate::error::{InferenceError, Result};
20use ort::inputs;
21use ort::session::builder::GraphOptimizationLevel;
22use ort::session::Session;
23use ort::value::Tensor;
24use parking_lot::Mutex;
25use std::path::PathBuf;
26use std::sync::Arc;
27use tokenizers::{
28 EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
29};
30use tracing::{info, instrument, warn};
31
32const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
34const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
36const MAX_SEQ_LENGTH: usize = 512;
38
39pub struct CrossEncoderEngine {
43 session: Arc<Mutex<Session>>,
44 tokenizer: Arc<Tokenizer>,
45 has_token_type_ids: bool,
49}
50
51impl CrossEncoderEngine {
52 #[instrument(skip_all)]
57 pub async fn new(cache_dir: Option<String>) -> Result<Self> {
58 info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
59
60 let (tokenizer_path, onnx_path) =
61 tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
62 .await
63 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
64 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
65
66 info!("Loading reranker tokenizer from {:?}", tokenizer_path);
67 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
68 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
69
70 let padding = PaddingParams {
72 strategy: PaddingStrategy::BatchLongest,
73 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
74 pad_token: tokenizer
75 .get_padding()
76 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
77 ..Default::default()
78 };
79 tokenizer.with_padding(Some(padding));
80 let truncation = TruncationParams {
81 max_length: MAX_SEQ_LENGTH,
82 ..Default::default()
83 };
84 let _ = tokenizer.with_truncation(Some(truncation));
85
86 info!("Loading reranker ONNX model from {:?}", onnx_path);
87 let session = Session::builder()
88 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
89 .with_optimization_level(GraphOptimizationLevel::Level3)
90 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
91 .with_intra_threads(4)
92 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
93 .commit_from_file(&onnx_path)
94 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
95
96 let has_token_type_ids = session
99 .inputs()
100 .iter()
101 .any(|i| i.name() == "token_type_ids");
102 info!(
103 has_token_type_ids,
104 "Cross-encoder reranker loaded successfully"
105 );
106 Ok(Self {
107 session: Arc::new(Mutex::new(session)),
108 tokenizer: Arc::new(tokenizer),
109 has_token_type_ids,
110 })
111 }
112
113 #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
120 pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
121 if passages.is_empty() {
122 return Ok(Vec::new());
123 }
124
125 let query = query.to_string();
126 let passages = passages.to_vec();
127 let tokenizer = Arc::clone(&self.tokenizer);
128 let session = Arc::clone(&self.session);
129 let has_token_type_ids = self.has_token_type_ids;
130
131 tokio::task::spawn_blocking(move || {
132 score_pairs_blocking(&session, &tokenizer, &query, &passages, has_token_type_ids)
133 })
134 .await
135 .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))?
136 }
137}
138
139fn score_pairs_blocking(
141 session: &Arc<Mutex<Session>>,
142 tokenizer: &Tokenizer,
143 query: &str,
144 passages: &[String],
145 has_token_type_ids: bool,
146) -> Result<Vec<f32>> {
147 let batch_size = passages.len();
148
149 let inputs: Vec<EncodeInput> = passages
151 .iter()
152 .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
153 .collect();
154
155 let encodings = tokenizer
156 .encode_batch(inputs, true)
157 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
158
159 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
160 if seq_len == 0 {
161 return Ok(vec![0.5; batch_size]);
162 }
163
164 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
166 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
167 let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
168
169 for enc in &encodings {
170 input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
171 attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
172 let type_ids = enc.get_type_ids();
173 if type_ids.is_empty() {
174 token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
175 } else {
176 token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
177 }
178 }
179
180 let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
182 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
183 let attention_mask_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
184 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
185 let token_type_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
186 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
187
188 let scores: Vec<f32> = {
191 let mut sess = session.lock();
192 let outputs = if has_token_type_ids {
193 sess.run(inputs![
194 "input_ids" => input_ids_tensor,
195 "attention_mask" => attention_mask_tensor,
196 "token_type_ids" => token_type_ids_tensor
197 ])
198 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
199 } else {
200 sess.run(inputs![
201 "input_ids" => input_ids_tensor,
202 "attention_mask" => attention_mask_tensor
203 ])
204 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
205 };
206
207 let (out_shape, logits_slice) = outputs[0]
209 .try_extract_tensor::<f32>()
210 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
211
212 if out_shape.is_empty() || out_shape[0] as usize != batch_size {
213 warn!(
214 "Reranker output shape mismatch: expected [{}, 1], got {:?}",
215 batch_size, out_shape
216 );
217 }
218
219 logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
221 };
223
224 if scores.len() != batch_size {
225 warn!(
226 "Reranker score count mismatch: expected {}, got {}",
227 batch_size,
228 scores.len()
229 );
230 let mut padded = scores;
231 padded.resize(batch_size, 0.5);
232 return Ok(padded);
233 }
234
235 Ok(scores)
236}
237
238#[inline]
240fn sigmoid(x: f32) -> f32 {
241 1.0 / (1.0 + (-x).exp())
242}
243
244fn download_reranker_files(
247 cache_dir: Option<String>,
248) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
249 let cache = match cache_dir {
250 Some(dir) => {
251 let p = PathBuf::from(dir);
252 std::fs::create_dir_all(&p)
253 .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
254 p
255 }
256 None => {
257 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
258 PathBuf::from(home)
259 .join(".cache")
260 .join("huggingface")
261 .join("dakera")
262 .join(RERANKER_REPO_ID.replace('/', "--"))
263 }
264 };
265
266 std::fs::create_dir_all(&cache)
267 .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
268
269 let files = [
270 "tokenizer.json",
271 "tokenizer_config.json",
272 "special_tokens_map.json",
273 RERANKER_ONNX_FILE,
274 ];
275
276 for filename in &files {
277 EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
278 .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
279 }
280
281 let tokenizer_path = cache.join("tokenizer.json");
282 let onnx_path = cache.join(RERANKER_ONNX_FILE);
283 Ok((tokenizer_path, onnx_path))
284}
285
286impl std::fmt::Debug for CrossEncoderEngine {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("CrossEncoderEngine")
289 .field("model", &RERANKER_REPO_ID)
290 .finish()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_sigmoid() {
300 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
301 assert!(sigmoid(10.0) > 0.99);
302 assert!(sigmoid(-10.0) < 0.01);
303 }
304}