mag/memory_core/
reranker.rs1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4
5use crate::app_paths;
6
7#[cfg(feature = "real-embeddings")]
8const CROSS_ENCODER_MODEL_NAME: &str = "ms-marco-MiniLM-L-6-v2";
9#[cfg(feature = "real-embeddings")]
10const CROSS_ENCODER_MODEL_URL: &str =
11 "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2/resolve/main/onnx/model.onnx";
12#[cfg(feature = "real-embeddings")]
13const CROSS_ENCODER_TOKENIZER_URL: &str =
14 "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer.json";
15
16#[cfg(feature = "real-embeddings")]
17const IDLE_TIMEOUT_SECS: u64 = 600; #[cfg(feature = "real-embeddings")]
20#[derive(Debug)]
21pub struct CrossEncoderReranker {
22 model_dir: PathBuf,
23 runtime: std::sync::Mutex<Option<CrossEncoderRuntime>>,
24 last_used: std::sync::atomic::AtomicU64,
25}
26
27#[cfg(feature = "real-embeddings")]
28#[derive(Debug)]
29struct CrossEncoderRuntime {
30 session: ort::session::Session,
31 tokenizer: tokenizers::Tokenizer,
32}
33
34#[cfg(feature = "real-embeddings")]
35#[derive(Debug, Clone)]
36struct ModelFiles {
37 directory: PathBuf,
38 model_path: PathBuf,
39 tokenizer_path: PathBuf,
40}
41
42#[cfg(feature = "real-embeddings")]
43impl CrossEncoderReranker {
44 pub fn new() -> Result<Self> {
45 Ok(Self {
46 model_dir: default_cross_encoder_dir()?,
47 runtime: std::sync::Mutex::new(None),
48 last_used: std::sync::atomic::AtomicU64::new(0),
49 })
50 }
51
52 fn epoch_secs() -> u64 {
53 std::time::SystemTime::now()
54 .duration_since(std::time::UNIX_EPOCH)
55 .map(|d| d.as_secs())
56 .unwrap_or(0)
57 }
58
59 fn touch_last_used(&self) {
60 self.last_used
61 .store(Self::epoch_secs(), std::sync::atomic::Ordering::Relaxed);
62 }
63
64 pub async fn warmup(self: &std::sync::Arc<Self>) -> Result<()> {
66 {
67 let guard = self
68 .runtime
69 .lock()
70 .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
71 if guard.is_some() {
72 return Ok(());
73 }
74 }
75 let this = std::sync::Arc::clone(self);
76 tokio::task::spawn_blocking(move || {
77 let mut guard = this
78 .runtime
79 .lock()
80 .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
81 if guard.is_none() {
82 let rt = this.init_runtime()?;
83 *guard = Some(rt);
84 this.touch_last_used();
85 }
86 Ok::<_, anyhow::Error>(())
87 })
88 .await
89 .context("spawn_blocking join error")?
90 }
91
92 pub fn try_unload_if_idle(&self) -> bool {
94 let last = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
96 if last == 0 {
97 return false;
98 }
99 if Self::epoch_secs().saturating_sub(last) < IDLE_TIMEOUT_SECS {
100 return false;
101 }
102 if let Ok(mut guard) = self.runtime.lock()
105 && guard.is_some()
106 {
107 let fresh = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
108 if Self::epoch_secs().saturating_sub(fresh) < IDLE_TIMEOUT_SECS {
109 return false;
110 }
111 *guard = None;
112 tracing::info!("unloaded idle cross-encoder session after {IDLE_TIMEOUT_SECS}s");
113 return true;
114 }
115 false
116 }
117
118 pub async fn maintenance_tick(self: &std::sync::Arc<Self>) {
120 let this = std::sync::Arc::clone(self);
121 let _ = tokio::task::spawn_blocking(move || {
122 this.try_unload_if_idle();
123 })
124 .await;
125 }
126
127 fn init_runtime(&self) -> Result<CrossEncoderRuntime> {
128 let files = ensure_cross_encoder_files_blocking(self.model_dir.clone())?;
129 let cpu_ep = ort::ep::CPU::default().with_arena_allocator(false).build();
130 let session = ort::session::Session::builder()?
131 .with_execution_providers([cpu_ep])?
132 .with_intra_threads(num_cpus::get())?
133 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
134 .commit_from_file(&files.model_path)
135 .with_context(|| {
136 format!(
137 "failed to create cross-encoder ONNX session from {}",
138 files.model_path.display()
139 )
140 })?;
141 let mut tokenizer = tokenizers::Tokenizer::from_file(&files.tokenizer_path)
142 .map_err(|e| anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
143 tokenizer
144 .with_truncation(Some(tokenizers::TruncationParams {
145 max_length: 512,
146 ..Default::default()
147 }))
148 .map_err(|e| anyhow!("failed to configure cross-encoder tokenizer truncation: {e}"))?;
149 Ok(CrossEncoderRuntime { session, tokenizer })
150 }
151
152 pub fn score_batch(&self, query: &str, passages: &[&str]) -> Result<Vec<f32>> {
154 if passages.is_empty() {
155 return Ok(Vec::new());
156 }
157
158 let mut rt_guard = self
159 .runtime
160 .lock()
161 .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
162 if rt_guard.is_none() {
163 *rt_guard = Some(self.init_runtime()?);
164 self.touch_last_used();
165 }
166 let runtime = rt_guard
167 .as_mut()
168 .ok_or_else(|| anyhow!("cross-encoder runtime missing after init"))?;
169
170 let encodings: Vec<tokenizers::Encoding> = passages
172 .iter()
173 .map(|passage| {
174 runtime
175 .tokenizer
176 .encode((query, *passage), true)
177 .map_err(|e| anyhow!("cross-encoder tokenization failed: {e}"))
178 })
179 .collect::<Result<Vec<_>>>()?;
180
181 let max_len = encodings
182 .iter()
183 .map(|enc| enc.get_ids().len())
184 .max()
185 .ok_or_else(|| anyhow!("empty encodings in cross-encoder batch"))?;
186 if max_len == 0 {
187 return Err(anyhow!(
188 "all cross-encoder tokenizations produced zero-length sequences"
189 ));
190 }
191
192 let batch_size = encodings.len();
193
194 let mut flat_input_ids = vec![0_i64; batch_size * max_len];
196 let mut flat_attention_mask = vec![0_i64; batch_size * max_len];
197 let mut flat_token_type_ids = vec![0_i64; batch_size * max_len];
198
199 for (b, enc) in encodings.iter().enumerate() {
200 let ids = enc.get_ids();
201 let mask = enc.get_attention_mask();
202 let type_ids = enc.get_type_ids();
203 let seq_len = ids.len();
204 let offset = b * max_len;
205 for j in 0..seq_len {
206 flat_input_ids[offset + j] = ids[j] as i64;
207 flat_attention_mask[offset + j] = mask[j] as i64;
208 flat_token_type_ids[offset + j] = type_ids[j] as i64;
209 }
210 }
211
212 let input_ids_value =
213 ort::value::Value::from_array(([batch_size, max_len], flat_input_ids))
214 .context("failed to create cross-encoder input_ids value")?;
215 let attention_mask_value =
216 ort::value::Value::from_array(([batch_size, max_len], flat_attention_mask))
217 .context("failed to create cross-encoder attention_mask value")?;
218 let token_type_ids_value =
219 ort::value::Value::from_array(([batch_size, max_len], flat_token_type_ids))
220 .context("failed to create cross-encoder token_type_ids value")?;
221
222 let outputs = runtime
223 .session
224 .run(ort::inputs![
225 input_ids_value,
226 attention_mask_value,
227 token_type_ids_value
228 ])
229 .context("cross-encoder ONNX inference failed")?;
230
231 let logits_output = outputs
233 .get("logits")
234 .ok_or_else(|| anyhow!("missing cross-encoder output tensor 'logits'"))?;
235 let (shape, logits) = logits_output
236 .try_extract_tensor::<f32>()
237 .context("failed to extract cross-encoder logits tensor")?;
238
239 let shape_dims = shape.as_ref();
240 let bs = batch_size as i64;
241 let valid = match shape_dims {
242 [n] => *n == bs,
243 [n, 1] => *n == bs,
244 _ => false,
245 };
246 if !valid {
247 return Err(anyhow!(
248 "unexpected cross-encoder logits shape {shape_dims:?}, expected [{batch_size}] or [{batch_size}, 1]"
249 ));
250 }
251
252 let scores: Vec<f32> = (0..batch_size).map(|i| sigmoid(logits[i])).collect();
254
255 self.touch_last_used();
256 Ok(scores)
257 }
258}
259
260fn sigmoid(x: f32) -> f32 {
261 1.0 / (1.0 + (-x).exp())
262}
263
264#[cfg(feature = "real-embeddings")]
265pub fn cross_encoder_model_dir() -> Result<PathBuf> {
266 Ok(app_paths::resolve_app_paths()?
267 .model_root
268 .join(CROSS_ENCODER_MODEL_NAME))
269}
270
271#[cfg(feature = "real-embeddings")]
272fn default_cross_encoder_dir() -> Result<PathBuf> {
273 cross_encoder_model_dir()
274}
275
276#[cfg(feature = "real-embeddings")]
277fn ensure_cross_encoder_files_blocking(model_dir: PathBuf) -> Result<ModelFiles> {
278 if model_files_exist(&model_dir) {
279 return Ok(model_files_for_dir(model_dir));
280 }
281
282 let runtime = tokio::runtime::Builder::new_current_thread()
283 .enable_all()
284 .build()
285 .context("failed to create temporary tokio runtime for cross-encoder model download")?;
286 runtime.block_on(ensure_cross_encoder_files_async(model_dir))
287}
288
289#[cfg(feature = "real-embeddings")]
290pub async fn download_cross_encoder_model() -> Result<PathBuf> {
291 let model_dir = default_cross_encoder_dir()?;
292 let files = ensure_cross_encoder_files_async(model_dir).await?;
293 Ok(files.directory)
294}
295
296#[cfg(feature = "real-embeddings")]
297async fn ensure_cross_encoder_files_async(model_dir: PathBuf) -> Result<ModelFiles> {
298 let files = model_files_for_dir(model_dir);
299 if model_files_exist(&files.directory) {
300 return Ok(files);
301 }
302
303 tokio::fs::create_dir_all(&files.directory)
304 .await
305 .with_context(|| {
306 format!(
307 "failed to create cross-encoder model directory {}",
308 files.directory.display()
309 )
310 })?;
311
312 if !tokio::fs::try_exists(&files.model_path)
313 .await
314 .context("failed to check cross-encoder model.onnx path")?
315 {
316 super::embedder::download_file(CROSS_ENCODER_MODEL_URL, &files.model_path).await?;
317 }
318 if !tokio::fs::try_exists(&files.tokenizer_path)
319 .await
320 .context("failed to check cross-encoder tokenizer.json path")?
321 {
322 super::embedder::download_file(CROSS_ENCODER_TOKENIZER_URL, &files.tokenizer_path).await?;
323 }
324
325 Ok(files)
326}
327
328#[cfg(feature = "real-embeddings")]
329fn model_files_exist(model_dir: &Path) -> bool {
330 let files = model_files_for_dir(model_dir.to_path_buf());
331 files.model_path.exists() && files.tokenizer_path.exists()
332}
333
334#[cfg(feature = "real-embeddings")]
335fn model_files_for_dir(model_dir: PathBuf) -> ModelFiles {
336 ModelFiles {
337 model_path: model_dir.join("model.onnx"),
338 tokenizer_path: model_dir.join("tokenizer.json"),
339 directory: model_dir,
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[cfg(feature = "real-embeddings")]
348 #[test]
349 fn cross_encoder_model_dir_returns_expected_path() {
350 crate::test_helpers::with_temp_home(|home| {
351 let expected = home
352 .join(".mag")
353 .join("models")
354 .join("ms-marco-MiniLM-L-6-v2");
355 let actual = cross_encoder_model_dir()
356 .expect("cross_encoder_model_dir() should succeed with a valid HOME");
357 assert_eq!(actual, expected);
358 });
359 }
360
361 #[test]
362 fn test_sigmoid_zero() {
363 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
364 }
365
366 #[test]
367 fn test_sigmoid_large_positive() {
368 assert!((sigmoid(10.0) - 1.0).abs() < 1e-4);
369 }
370
371 #[test]
372 fn test_sigmoid_large_negative() {
373 assert!(sigmoid(-10.0) < 1e-4);
374 }
375
376 #[test]
377 fn test_sigmoid_monotonic() {
378 let a = sigmoid(-1.0);
379 let b = sigmoid(0.0);
380 let c = sigmoid(1.0);
381 assert!(a < b);
382 assert!(b < c);
383 }
384}