Skip to main content

mag/memory_core/
reranker.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4
5use crate::app_paths;
6
7#[cfg(feature = "real-embeddings")]
8const CROSS_ENCODER_MODEL_NAME: &str = "ms-marco-MiniLM-L-6-v2";
9#[cfg(feature = "real-embeddings")]
10const CROSS_ENCODER_MODEL_URL: &str =
11    "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2/resolve/main/onnx/model.onnx";
12#[cfg(feature = "real-embeddings")]
13const CROSS_ENCODER_TOKENIZER_URL: &str =
14    "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer.json";
15
16#[cfg(feature = "real-embeddings")]
17const IDLE_TIMEOUT_SECS: u64 = 600; // 10 minutes
18
19#[cfg(feature = "real-embeddings")]
20#[derive(Debug)]
21pub struct CrossEncoderReranker {
22    model_dir: PathBuf,
23    runtime: std::sync::Mutex<Option<CrossEncoderRuntime>>,
24    last_used: std::sync::atomic::AtomicU64,
25}
26
27#[cfg(feature = "real-embeddings")]
28#[derive(Debug)]
29struct CrossEncoderRuntime {
30    session: ort::session::Session,
31    tokenizer: tokenizers::Tokenizer,
32}
33
34#[cfg(feature = "real-embeddings")]
35#[derive(Debug, Clone)]
36struct ModelFiles {
37    directory: PathBuf,
38    model_path: PathBuf,
39    tokenizer_path: PathBuf,
40}
41
42#[cfg(feature = "real-embeddings")]
43impl CrossEncoderReranker {
44    pub fn new() -> Result<Self> {
45        Ok(Self {
46            model_dir: default_cross_encoder_dir()?,
47            runtime: std::sync::Mutex::new(None),
48            last_used: std::sync::atomic::AtomicU64::new(0),
49        })
50    }
51
52    fn epoch_secs() -> u64 {
53        std::time::SystemTime::now()
54            .duration_since(std::time::UNIX_EPOCH)
55            .map(|d| d.as_secs())
56            .unwrap_or(0)
57    }
58
59    fn touch_last_used(&self) {
60        self.last_used
61            .store(Self::epoch_secs(), std::sync::atomic::Ordering::Relaxed);
62    }
63
64    /// Eagerly load the ONNX session for warm start.
65    pub async fn warmup(self: &std::sync::Arc<Self>) -> Result<()> {
66        {
67            let guard = self
68                .runtime
69                .lock()
70                .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
71            if guard.is_some() {
72                return Ok(());
73            }
74        }
75        let this = std::sync::Arc::clone(self);
76        tokio::task::spawn_blocking(move || {
77            let mut guard = this
78                .runtime
79                .lock()
80                .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
81            if guard.is_none() {
82                let rt = this.init_runtime()?;
83                *guard = Some(rt);
84                this.touch_last_used();
85            }
86            Ok::<_, anyhow::Error>(())
87        })
88        .await
89        .context("spawn_blocking join error")?
90    }
91
92    /// Drops the ONNX session if idle for longer than the timeout.
93    pub fn try_unload_if_idle(&self) -> bool {
94        // Quick pre-check without lock to avoid contention in the common case.
95        let last = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
96        if last == 0 {
97            return false;
98        }
99        if Self::epoch_secs().saturating_sub(last) < IDLE_TIMEOUT_SECS {
100            return false;
101        }
102        // Re-check after acquiring the mutex — score_batch() may have updated
103        // last_used while we were waiting for the lock.
104        if let Ok(mut guard) = self.runtime.lock()
105            && guard.is_some()
106        {
107            let fresh = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
108            if Self::epoch_secs().saturating_sub(fresh) < IDLE_TIMEOUT_SECS {
109                return false;
110            }
111            *guard = None;
112            tracing::info!("unloaded idle cross-encoder session after {IDLE_TIMEOUT_SECS}s");
113            return true;
114        }
115        false
116    }
117
118    /// Periodic maintenance entry-point.
119    pub async fn maintenance_tick(self: &std::sync::Arc<Self>) {
120        let this = std::sync::Arc::clone(self);
121        let _ = tokio::task::spawn_blocking(move || {
122            this.try_unload_if_idle();
123        })
124        .await;
125    }
126
127    fn init_runtime(&self) -> Result<CrossEncoderRuntime> {
128        let files = ensure_cross_encoder_files_blocking(self.model_dir.clone())?;
129        let cpu_ep = ort::ep::CPU::default().with_arena_allocator(false).build();
130        let session = ort::session::Session::builder()?
131            .with_execution_providers([cpu_ep])?
132            .with_intra_threads(num_cpus::get())?
133            .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
134            .commit_from_file(&files.model_path)
135            .with_context(|| {
136                format!(
137                    "failed to create cross-encoder ONNX session from {}",
138                    files.model_path.display()
139                )
140            })?;
141        let mut tokenizer = tokenizers::Tokenizer::from_file(&files.tokenizer_path)
142            .map_err(|e| anyhow!("failed to load cross-encoder tokenizer: {e}"))?;
143        tokenizer
144            .with_truncation(Some(tokenizers::TruncationParams {
145                max_length: 512,
146                ..Default::default()
147            }))
148            .map_err(|e| anyhow!("failed to configure cross-encoder tokenizer truncation: {e}"))?;
149        Ok(CrossEncoderRuntime { session, tokenizer })
150    }
151
152    /// Score a batch of query-passage pairs. Returns a relevance score (0-1) for each.
153    pub fn score_batch(&self, query: &str, passages: &[&str]) -> Result<Vec<f32>> {
154        if passages.is_empty() {
155            return Ok(Vec::new());
156        }
157
158        let mut rt_guard = self
159            .runtime
160            .lock()
161            .map_err(|_| anyhow!("cross-encoder runtime mutex poisoned"))?;
162        if rt_guard.is_none() {
163            *rt_guard = Some(self.init_runtime()?);
164            self.touch_last_used();
165        }
166        let runtime = rt_guard
167            .as_mut()
168            .ok_or_else(|| anyhow!("cross-encoder runtime missing after init"))?;
169
170        // Tokenize all query-passage pairs
171        let encodings: Vec<tokenizers::Encoding> = passages
172            .iter()
173            .map(|passage| {
174                runtime
175                    .tokenizer
176                    .encode((query, *passage), true)
177                    .map_err(|e| anyhow!("cross-encoder tokenization failed: {e}"))
178            })
179            .collect::<Result<Vec<_>>>()?;
180
181        let max_len = encodings
182            .iter()
183            .map(|enc| enc.get_ids().len())
184            .max()
185            .ok_or_else(|| anyhow!("empty encodings in cross-encoder batch"))?;
186        if max_len == 0 {
187            return Err(anyhow!(
188                "all cross-encoder tokenizations produced zero-length sequences"
189            ));
190        }
191
192        let batch_size = encodings.len();
193
194        // Build padded flat tensors
195        let mut flat_input_ids = vec![0_i64; batch_size * max_len];
196        let mut flat_attention_mask = vec![0_i64; batch_size * max_len];
197        let mut flat_token_type_ids = vec![0_i64; batch_size * max_len];
198
199        for (b, enc) in encodings.iter().enumerate() {
200            let ids = enc.get_ids();
201            let mask = enc.get_attention_mask();
202            let type_ids = enc.get_type_ids();
203            let seq_len = ids.len();
204            let offset = b * max_len;
205            for j in 0..seq_len {
206                flat_input_ids[offset + j] = ids[j] as i64;
207                flat_attention_mask[offset + j] = mask[j] as i64;
208                flat_token_type_ids[offset + j] = type_ids[j] as i64;
209            }
210        }
211
212        let input_ids_value =
213            ort::value::Value::from_array(([batch_size, max_len], flat_input_ids))
214                .context("failed to create cross-encoder input_ids value")?;
215        let attention_mask_value =
216            ort::value::Value::from_array(([batch_size, max_len], flat_attention_mask))
217                .context("failed to create cross-encoder attention_mask value")?;
218        let token_type_ids_value =
219            ort::value::Value::from_array(([batch_size, max_len], flat_token_type_ids))
220                .context("failed to create cross-encoder token_type_ids value")?;
221
222        let outputs = runtime
223            .session
224            .run(ort::inputs![
225                input_ids_value,
226                attention_mask_value,
227                token_type_ids_value
228            ])
229            .context("cross-encoder ONNX inference failed")?;
230
231        // Cross-encoder output: logits of shape [batch_size, 1] or [batch_size]
232        let logits_output = outputs
233            .get("logits")
234            .ok_or_else(|| anyhow!("missing cross-encoder output tensor 'logits'"))?;
235        let (shape, logits) = logits_output
236            .try_extract_tensor::<f32>()
237            .context("failed to extract cross-encoder logits tensor")?;
238
239        let shape_dims = shape.as_ref();
240        let bs = batch_size as i64;
241        let valid = match shape_dims {
242            [n] => *n == bs,
243            [n, 1] => *n == bs,
244            _ => false,
245        };
246        if !valid {
247            return Err(anyhow!(
248                "unexpected cross-encoder logits shape {shape_dims:?}, expected [{batch_size}] or [{batch_size}, 1]"
249            ));
250        }
251
252        // Apply sigmoid to each logit — tensor is contiguous so flat indexing works for both shapes
253        let scores: Vec<f32> = (0..batch_size).map(|i| sigmoid(logits[i])).collect();
254
255        self.touch_last_used();
256        Ok(scores)
257    }
258}
259
260fn sigmoid(x: f32) -> f32 {
261    1.0 / (1.0 + (-x).exp())
262}
263
264#[cfg(feature = "real-embeddings")]
265pub fn cross_encoder_model_dir() -> Result<PathBuf> {
266    Ok(app_paths::resolve_app_paths()?
267        .model_root
268        .join(CROSS_ENCODER_MODEL_NAME))
269}
270
271#[cfg(feature = "real-embeddings")]
272fn default_cross_encoder_dir() -> Result<PathBuf> {
273    cross_encoder_model_dir()
274}
275
276#[cfg(feature = "real-embeddings")]
277fn ensure_cross_encoder_files_blocking(model_dir: PathBuf) -> Result<ModelFiles> {
278    if model_files_exist(&model_dir) {
279        return Ok(model_files_for_dir(model_dir));
280    }
281
282    let runtime = tokio::runtime::Builder::new_current_thread()
283        .enable_all()
284        .build()
285        .context("failed to create temporary tokio runtime for cross-encoder model download")?;
286    runtime.block_on(ensure_cross_encoder_files_async(model_dir))
287}
288
289#[cfg(feature = "real-embeddings")]
290pub async fn download_cross_encoder_model() -> Result<PathBuf> {
291    let model_dir = default_cross_encoder_dir()?;
292    let files = ensure_cross_encoder_files_async(model_dir).await?;
293    Ok(files.directory)
294}
295
296#[cfg(feature = "real-embeddings")]
297async fn ensure_cross_encoder_files_async(model_dir: PathBuf) -> Result<ModelFiles> {
298    let files = model_files_for_dir(model_dir);
299    if model_files_exist(&files.directory) {
300        return Ok(files);
301    }
302
303    tokio::fs::create_dir_all(&files.directory)
304        .await
305        .with_context(|| {
306            format!(
307                "failed to create cross-encoder model directory {}",
308                files.directory.display()
309            )
310        })?;
311
312    if !tokio::fs::try_exists(&files.model_path)
313        .await
314        .context("failed to check cross-encoder model.onnx path")?
315    {
316        super::embedder::download_file(CROSS_ENCODER_MODEL_URL, &files.model_path).await?;
317    }
318    if !tokio::fs::try_exists(&files.tokenizer_path)
319        .await
320        .context("failed to check cross-encoder tokenizer.json path")?
321    {
322        super::embedder::download_file(CROSS_ENCODER_TOKENIZER_URL, &files.tokenizer_path).await?;
323    }
324
325    Ok(files)
326}
327
328#[cfg(feature = "real-embeddings")]
329fn model_files_exist(model_dir: &Path) -> bool {
330    let files = model_files_for_dir(model_dir.to_path_buf());
331    files.model_path.exists() && files.tokenizer_path.exists()
332}
333
334#[cfg(feature = "real-embeddings")]
335fn model_files_for_dir(model_dir: PathBuf) -> ModelFiles {
336    ModelFiles {
337        model_path: model_dir.join("model.onnx"),
338        tokenizer_path: model_dir.join("tokenizer.json"),
339        directory: model_dir,
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[cfg(feature = "real-embeddings")]
348    #[test]
349    fn cross_encoder_model_dir_returns_expected_path() {
350        crate::test_helpers::with_temp_home(|home| {
351            let expected = home
352                .join(".mag")
353                .join("models")
354                .join("ms-marco-MiniLM-L-6-v2");
355            let actual = cross_encoder_model_dir()
356                .expect("cross_encoder_model_dir() should succeed with a valid HOME");
357            assert_eq!(actual, expected);
358        });
359    }
360
361    #[test]
362    fn test_sigmoid_zero() {
363        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
364    }
365
366    #[test]
367    fn test_sigmoid_large_positive() {
368        assert!((sigmoid(10.0) - 1.0).abs() < 1e-4);
369    }
370
371    #[test]
372    fn test_sigmoid_large_negative() {
373        assert!(sigmoid(-10.0) < 1e-4);
374    }
375
376    #[test]
377    fn test_sigmoid_monotonic() {
378        let a = sigmoid(-1.0);
379        let b = sigmoid(0.0);
380        let c = sigmoid(1.0);
381        assert!(a < b);
382        assert!(b < c);
383    }
384}