mag/memory_core/
reranker.rs1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4
5use crate::app_paths;
6
7#[cfg(feature = "real-embeddings")]
8const CROSS_ENCODER_MODEL_NAME: &str = "ms-marco-MiniLM-L-6-v2";
9#[cfg(feature = "real-embeddings")]
10const CROSS_ENCODER_MODEL_URL: &str =
11 "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2/resolve/main/onnx/model.onnx";
12#[cfg(feature = "real-embeddings")]
13const CROSS_ENCODER_TOKENIZER_URL: &str =
14 "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer.json";
15
16#[cfg(feature = "real-embeddings")]
17const IDLE_TIMEOUT_SECS: u64 = 600; #[cfg(feature = "real-embeddings")]
20#[derive(Debug)]
21pub struct CrossEncoderReranker {
22 model_dir: PathBuf,
23 runtime: std::sync::Mutex<Option<CrossEncoderRuntime>>,
24 last_used: std::sync::atomic::AtomicU64,
25}
26
27#[cfg(feature = "real-embeddings")]
28#[derive(Debug)]
29struct CrossEncoderRuntime {
30 session: ort::session::Session,
31 tokenizer: tokenizers::Tokenizer,
32}
33
34#[cfg(feature = "real-embeddings")]
35#[derive(Debug, Clone)]
36struct ModelFiles {
37 directory: PathBuf,
38 model_path: PathBuf,
39 tokenizer_path: PathBuf,
40}
41
42#[cfg(feature = "real-embeddings")]
43impl CrossEncoderReranker {
44 pub fn new() -> Result<Self> {
45 Ok(Self {
46 model_dir: default_cross_encoder_dir()?,
47 runtime: std::sync::Mutex::new(None),
48 last_used: std::sync::atomic::AtomicU64::new(0),
49 })
50 }
51
52 fn epoch_secs() -> u64 {
53 std::time::SystemTime::now()
54 .duration_since(std::time::UNIX_EPOCH)
55 .map(|d| d.as_secs())
56 .unwrap_or(0)
57 }
58
59 fn touch_last_used(&self) {
60 self.last_used
61 .store(Self::epoch_secs(), std::sync::atomic::Ordering::Relaxed);
62 }
63
64 pub async fn warmup(self: &std::sync::Arc<Self>) -> Result<()> {
66 {
67 let guard = self
68 .runtime
69 .lock()
70 .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
71 if guard.is_some() {
72 return Ok(());
73 }
74 }
75 let this = std::sync::Arc::clone(self);
76 tokio::task::spawn_blocking(move || {
77 let mut guard = this
78 .runtime
79 .lock()
80 .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
81 if guard.is_none() {
82 let rt = this.init_runtime()?;
83 *guard = Some(rt);
84 this.touch_last_used();
85 }
86 Ok::<_, anyhow::Error>(())
87 })
88 .await
89 .context("spawn_blocking join error")?
90 }
91
92 pub fn try_unload_if_idle(&self) -> bool {
94 let last = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
96 if last == 0 {
97 return false;
98 }
99 if Self::epoch_secs().saturating_sub(last) < IDLE_TIMEOUT_SECS {
100 return false;
101 }
102 if let Ok(mut guard) = self.runtime.lock()
105 && guard.is_some()
106 {
107 let fresh = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
108 if Self::epoch_secs().saturating_sub(fresh) < IDLE_TIMEOUT_SECS {
109 return false;
110 }
111 *guard = None;
112 tracing::info!("unloaded idle cross-encoder session after {IDLE_TIMEOUT_SECS}s");
113 return true;
114 }
115 false
116 }
117
118 pub async fn maintenance_tick(self: &std::sync::Arc<Self>) {
120 let this = std::sync::Arc::clone(self);
121 let _ = tokio::task::spawn_blocking(move || {
122 this.try_unload_if_idle();
123 })
124 .await;
125 }
126
127 fn init_runtime(&self) -> Result<CrossEncoderRuntime> {
128 let files = ensure_cross_encoder_files_blocking(self.model_dir.clone())?;
129 let cpu_ep = ort::ep::CPU::default().with_arena_allocator(false).build();
130 let session = ort::session::Session::builder()?
131 .with_execution_providers([cpu_ep])?
132 .with_intra_threads(num_cpus::get())?
133 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
134 .commit_from_file(&files.model_path)
135 .with_context(|| {
136 format!(
137 "failed to create cross-encoder ONNX session from {}",
138 files.model_path.display()
139 )
140 })?;
141 let mut tokenizer = tokenizers::Tokenizer::from_file(&files.tokenizer_path)
142 .map_err(|e| anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
143 tokenizer
144 .with_truncation(Some(tokenizers::TruncationParams {
145 max_length: 512,
146 ..Default::default()
147 }))
148 .map_err(|e| anyhow!("failed to configure cross-encoder tokenizer truncation: {e}"))?;
149 Ok(CrossEncoderRuntime { session, tokenizer })
150 }
151
152 pub fn score_batch(&self, query: &str, passages: &[&str]) -> Result<Vec<f32>> {
154 if passages.is_empty() {
155 return Ok(Vec::new());
156 }
157
158 let mut rt_guard = self
159 .runtime
160 .lock()
161 .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
162 if rt_guard.is_none() {
163 *rt_guard = Some(self.init_runtime()?);
164 self.touch_last_used();
165 }
166 let runtime = rt_guard
167 .as_mut()
168 .ok_or_else(|| anyhow!("cross-encoder runtime missing after init"))?;
169
170 let encodings: Vec<tokenizers::Encoding> = passages
172 .iter()
173 .map(|passage| {
174 runtime
175 .tokenizer
176 .encode((query, *passage), true)
177 .map_err(|e| anyhow!("cross-encoder tokenization failed: {e}"))
178 })
179 .collect::<Result<Vec<_>>>()?;
180
181 let max_len = encodings
182 .iter()
183 .map(|enc| enc.get_ids().len())
184 .max()
185 .ok_or_else(|| anyhow!("empty encodings in cross-encoder batch"))?;
186 if max_len == 0 {
187 return Err(anyhow!(
188 "all cross-encoder tokenizations produced zero-length sequences"
189 ));
190 }
191
192 let batch_size = encodings.len();
193
194 let mut flat_input_ids = vec![0_i64; batch_size * max_len];
196 let mut flat_attention_mask = vec![0_i64; batch_size * max_len];
197 let mut flat_token_type_ids = vec![0_i64; batch_size * max_len];
198
199 for (b, enc) in encodings.iter().enumerate() {
200 let ids = enc.get_ids();
201 let mask = enc.get_attention_mask();
202 let type_ids = enc.get_type_ids();
203 let seq_len = ids.len();
204 let offset = b * max_len;
205 for j in 0..seq_len {
206 flat_input_ids[offset + j] = ids[j] as i64;
207 flat_attention_mask[offset + j] = mask[j] as i64;
208 flat_token_type_ids[offset + j] = type_ids[j] as i64;
209 }
210 }
211
212 let input_ids_value =
213 ort::value::Value::from_array(([batch_size, max_len], flat_input_ids))
214 .context("failed to create cross-encoder input_ids value")?;
215 let attention_mask_value =
216 ort::value::Value::from_array(([batch_size, max_len], flat_attention_mask))
217 .context("failed to create cross-encoder attention_mask value")?;
218 let token_type_ids_value =
219 ort::value::Value::from_array(([batch_size, max_len], flat_token_type_ids))
220 .context("failed to create cross-encoder token_type_ids value")?;
221
222 let outputs = runtime
223 .session
224 .run(ort::inputs![
225 input_ids_value,
226 attention_mask_value,
227 token_type_ids_value
228 ])
229 .context("cross-encoder ONNX inference failed")?;
230
231 let logits_output = outputs
233 .get("logits")
234 .ok_or_else(|| anyhow!("missing cross-encoder output tensor 'logits'"))?;
235 let (shape, logits) = logits_output
236 .try_extract_tensor::<f32>()
237 .context("failed to extract cross-encoder logits tensor")?;
238
239 let shape_dims = shape.as_ref();
240 let bs = batch_size as i64;
241 let valid = match shape_dims {
242 [n] => *n == bs,
243 [n, 1] => *n == bs,
244 _ => false,
245 };
246 if !valid {
247 return Err(anyhow!(
248 "unexpected cross-encoder logits shape {shape_dims:?}, expected [{batch_size}] or [{batch_size}, 1]"
249 ));
250 }
251
252 let scores: Vec<f32> = (0..batch_size).map(|i| sigmoid(logits[i])).collect();
254
255 self.touch_last_used();
256 Ok(scores)
257 }
258}
259
260fn sigmoid(x: f32) -> f32 {
261 1.0 / (1.0 + (-x).exp())
262}
263
264#[cfg(feature = "real-embeddings")]
265fn default_cross_encoder_dir() -> Result<PathBuf> {
266 Ok(app_paths::resolve_app_paths()?
267 .model_root
268 .join(CROSS_ENCODER_MODEL_NAME))
269}
270
271#[cfg(feature = "real-embeddings")]
272fn ensure_cross_encoder_files_blocking(model_dir: PathBuf) -> Result<ModelFiles> {
273 if model_files_exist(&model_dir) {
274 return Ok(model_files_for_dir(model_dir));
275 }
276
277 let runtime = tokio::runtime::Builder::new_current_thread()
278 .enable_all()
279 .build()
280 .context("failed to create temporary tokio runtime for cross-encoder model download")?;
281 runtime.block_on(ensure_cross_encoder_files_async(model_dir))
282}
283
284#[cfg(feature = "real-embeddings")]
285pub async fn download_cross_encoder_model() -> Result<PathBuf> {
286 let model_dir = default_cross_encoder_dir()?;
287 let files = ensure_cross_encoder_files_async(model_dir).await?;
288 Ok(files.directory)
289}
290
291#[cfg(feature = "real-embeddings")]
292async fn ensure_cross_encoder_files_async(model_dir: PathBuf) -> Result<ModelFiles> {
293 let files = model_files_for_dir(model_dir);
294 if model_files_exist(&files.directory) {
295 return Ok(files);
296 }
297
298 tokio::fs::create_dir_all(&files.directory)
299 .await
300 .with_context(|| {
301 format!(
302 "failed to create cross-encoder model directory {}",
303 files.directory.display()
304 )
305 })?;
306
307 if !tokio::fs::try_exists(&files.model_path)
308 .await
309 .context("failed to check cross-encoder model.onnx path")?
310 {
311 super::embedder::download_file(CROSS_ENCODER_MODEL_URL, &files.model_path).await?;
312 }
313 if !tokio::fs::try_exists(&files.tokenizer_path)
314 .await
315 .context("failed to check cross-encoder tokenizer.json path")?
316 {
317 super::embedder::download_file(CROSS_ENCODER_TOKENIZER_URL, &files.tokenizer_path).await?;
318 }
319
320 Ok(files)
321}
322
323#[cfg(feature = "real-embeddings")]
324fn model_files_exist(model_dir: &Path) -> bool {
325 let files = model_files_for_dir(model_dir.to_path_buf());
326 files.model_path.exists() && files.tokenizer_path.exists()
327}
328
329#[cfg(feature = "real-embeddings")]
330fn model_files_for_dir(model_dir: PathBuf) -> ModelFiles {
331 ModelFiles {
332 model_path: model_dir.join("model.onnx"),
333 tokenizer_path: model_dir.join("tokenizer.json"),
334 directory: model_dir,
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_sigmoid_zero() {
344 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
345 }
346
347 #[test]
348 fn test_sigmoid_large_positive() {
349 assert!((sigmoid(10.0) - 1.0).abs() < 1e-4);
350 }
351
352 #[test]
353 fn test_sigmoid_large_negative() {
354 assert!(sigmoid(-10.0) < 1e-4);
355 }
356
357 #[test]
358 fn test_sigmoid_monotonic() {
359 let a = sigmoid(-1.0);
360 let b = sigmoid(0.0);
361 let c = sigmoid(1.0);
362 assert!(a < b);
363 assert!(b < c);
364 }
365}