Skip to main content

inference/
reranker.rs

1//! Cross-encoder reranker for improving recall precision.
2//!
3//! Uses BAAI/bge-reranker-base (Xenova ONNX INT8 quantized) to score
4//! (query, passage) pairs for relevance. More accurate than bi-encoder
5//! vector similarity but slower — used as a second-stage reranker after
6//! ANN candidate retrieval.
7//!
8//! # Architecture
9//!
10//! ```text
11//! query + passage → [CLS] query [SEP] passage [SEP]
12//!                       ↓ BERT forward pass
13//!                   logits [batch, 1]
14//!                       ↓ sigmoid
15//!                   relevance scores ∈ [0, 1]
16//! ```
17//!
18//! # Session Pool
19//!
20//! The engine maintains `RERANKER_POOL_SIZE` independent ONNX sessions.
21//! Large candidate lists are split into chunks of `RERANKER_CHUNK_SIZE` and
22//! dispatched in parallel across the pool, eliminating head-of-line blocking
23//! when multiple recall calls arrive concurrently (DAK-5873).
24//!
25//! # ONNX Mini-Batching
26//!
27//! Within each dispatched chunk, pairs are further split into mini-batches of
28//! `RERANKER_ONNX_BATCH_SIZE` for a single `session.run()` call. Smaller mini-batches
29//! reduce sequence-padding overhead: each mini-batch pads to its own maximum
30//! sequence length rather than the chunk maximum, cutting wasted compute when
31//! passage lengths vary widely (DAK-5883).
32
33use crate::engine::EmbeddingEngine;
34use crate::error::{InferenceError, Result};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::path::PathBuf;
41use std::sync::atomic::{AtomicUsize, Ordering};
42use std::sync::Arc;
43use tokenizers::{
44    EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
45};
46use tracing::{info, instrument, warn};
47
48/// The reranker model Xenova HuggingFace repo ID (ONNX INT8).
49const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
50/// ONNX quantized model filename within the repo.
51const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
52/// Maximum token length for cross-encoder input (query + passage combined).
53const MAX_SEQ_LENGTH: usize = 512;
54/// Number of independent ONNX sessions in the reranker pool.
55///
56/// Two sessions allow concurrent recall requests to rerank in parallel without
57/// head-of-line mutex blocking. Each session uses `intra_threads=4`; two sessions
58/// occupy all 8 vCPUs on the production CPX32 server (DAK-5873).
59const RERANKER_POOL_SIZE: usize = 2;
60/// Maximum candidates per session sub-batch (parallel dispatch unit).
61///
62/// Large candidate lists (e.g. temporal `fetch_n = top_k × 8 = 160`) are split
63/// into chunks of this size and dispatched concurrently across the pool. With
64/// `RERANKER_POOL_SIZE=2` and `RERANKER_CHUNK_SIZE=32`, a 160-candidate list
65/// produces 5 chunks: sessions[0] handles chunks 0/2/4, sessions[1] handles
66/// chunks 1/3 — 3 serial chunks on the busier session vs. 5 serial before.
67const RERANKER_CHUNK_SIZE: usize = 32;
68/// Maximum candidate pairs per single ONNX `session.run()` call (inner batch).
69///
70/// Within each dispatched chunk, pairs are further split into mini-batches of
71/// this size. Each mini-batch is padded to its own maximum sequence length,
72/// reducing wasted computation when passage lengths vary (padding overhead on
73/// shorter passages is bounded by the max within the mini-batch, not the chunk).
74///
75/// Tuning guide:
76/// - Smaller (8): less padding waste, more `session.run()` calls per chunk.
77/// - Larger (32, equal to CHUNK_SIZE): behaves like the pre-DAK-5883 single-call
78///   mode, effectively disabling inner batching.
79/// - Default 16: halves padding overhead on mixed-length candidate lists while
80///   keeping `session.run()` call count to 2 per full chunk.
81const RERANKER_ONNX_BATCH_SIZE: usize = 16;
82
83/// Cross-encoder reranking engine.
84///
85/// Thread-safe — shared via `Arc`. Maintains a pool of independent ONNX sessions
86/// so concurrent rerank calls never contend on a single mutex.
87pub struct CrossEncoderEngine {
88    /// Pool of independent ONNX sessions (round-robin dispatch).
89    sessions: Vec<Arc<Mutex<Session>>>,
90    tokenizer: Arc<Tokenizer>,
91    /// Whether the loaded ONNX model expects a `token_type_ids` input tensor.
92    /// bge-reranker-base only has `input_ids` + `attention_mask`; some other
93    /// cross-encoders include `token_type_ids`. Determined at load time.
94    has_token_type_ids: bool,
95    /// Round-robin counter for session assignment.
96    next_session: AtomicUsize,
97}
98
99impl CrossEncoderEngine {
100    /// Load or download the reranker model.
101    ///
102    /// Downloads `Xenova/bge-reranker-base` ONNX INT8 model from HuggingFace Hub
103    /// if not already cached. Builds `RERANKER_POOL_SIZE` independent sessions.
104    #[instrument(skip_all)]
105    pub async fn new(cache_dir: Option<String>) -> Result<Self> {
106        info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
107
108        let (tokenizer_path, onnx_path) =
109            tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
110                .await
111                .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
112                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
113
114        info!("Loading reranker tokenizer from {:?}", tokenizer_path);
115        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
116            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
117
118        // Configure padding + truncation for uniform batch shapes
119        let padding = PaddingParams {
120            strategy: PaddingStrategy::BatchLongest,
121            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
122            pad_token: tokenizer
123                .get_padding()
124                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
125            ..Default::default()
126        };
127        tokenizer.with_padding(Some(padding));
128        let truncation = TruncationParams {
129            max_length: MAX_SEQ_LENGTH,
130            ..Default::default()
131        };
132        let _ = tokenizer.with_truncation(Some(truncation));
133
134        info!(
135            "Loading reranker ONNX model from {:?} (pool_size={}, onnx_batch_size={})",
136            onnx_path, RERANKER_POOL_SIZE, RERANKER_ONNX_BATCH_SIZE
137        );
138
139        // Build pool of independent ONNX sessions — each has its own ORT context
140        // so pool members never block each other under concurrent rerank calls.
141        let (sessions, has_token_type_ids) =
142            tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
143                let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
144                    .map(|_| {
145                        Session::builder()
146                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
147                            .with_optimization_level(GraphOptimizationLevel::Level3)
148                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
149                            .with_intra_threads(4)
150                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
151                            .commit_from_file(&onnx_path)
152                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))
153                    })
154                    .collect();
155                let raw = raw?;
156                // Inspect first session to detect optional token_type_ids input.
157                let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
158                let sessions: Vec<Arc<Mutex<Session>>> =
159                    raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
160                Ok((sessions, has_tti))
161            })
162            .await
163            .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
164            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
165
166        info!(
167            has_token_type_ids,
168            pool_size = sessions.len(),
169            onnx_batch_size = RERANKER_ONNX_BATCH_SIZE,
170            "Cross-encoder reranker loaded successfully"
171        );
172
173        Ok(Self {
174            sessions,
175            tokenizer: Arc::new(tokenizer),
176            has_token_type_ids,
177            next_session: AtomicUsize::new(0),
178        })
179    }
180
181    /// Score a batch of (query, passage) pairs.
182    ///
183    /// Passages are split into chunks of [`RERANKER_CHUNK_SIZE`] and dispatched
184    /// in parallel across the session pool (round-robin). Within each chunk,
185    /// pairs are processed in mini-batches of [`RERANKER_ONNX_BATCH_SIZE`] to
186    /// reduce sequence-padding overhead. Chunk results are reassembled in input
187    /// order.
188    ///
189    /// Returns a relevance score in `[0, 1]` for each passage.
190    /// Higher scores indicate greater relevance to the query.
191    #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
192    pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
193        if passages.is_empty() {
194            return Ok(Vec::new());
195        }
196
197        let pool_len = self.sessions.len();
198        // Round-robin start: each concurrent caller gets a different initial slot
199        // so concurrent requests don't all contend on sessions[0].
200        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
201        let tokenizer = Arc::clone(&self.tokenizer);
202        let has_token_type_ids = self.has_token_type_ids;
203        let query_str = query.to_string();
204
205        // Split candidates into RERANKER_CHUNK_SIZE sub-batches.
206        let chunks: Vec<Vec<String>> = passages
207            .chunks(RERANKER_CHUNK_SIZE)
208            .map(<[String]>::to_vec)
209            .collect();
210
211        // Spawn all chunks concurrently; each acquires its own session slot.
212        let mut handles = Vec::with_capacity(chunks.len());
213        for (i, chunk) in chunks.into_iter().enumerate() {
214            let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
215            let tok = Arc::clone(&tokenizer);
216            let q = query_str.clone();
217            handles.push(tokio::task::spawn_blocking(move || {
218                score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
219            }));
220        }
221
222        // Collect results in chunk order to preserve passage ordering.
223        let mut scores = Vec::with_capacity(passages.len());
224        for handle in handles {
225            let chunk_scores = handle
226                .await
227                .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
228            scores.extend(chunk_scores);
229        }
230
231        Ok(scores)
232    }
233
234    /// Number of parallel ONNX sessions in the pool.
235    pub fn pool_size(&self) -> usize {
236        self.sessions.len()
237    }
238
239    /// Configured ONNX mini-batch size (pairs per `session.run()` call).
240    pub fn onnx_batch_size(&self) -> usize {
241        RERANKER_ONNX_BATCH_SIZE
242    }
243}
244
245/// Blocking cross-encoder inference for one chunk — runs inside `spawn_blocking`.
246///
247/// The chunk is processed as a sequence of mini-batches of [`RERANKER_ONNX_BATCH_SIZE`]
248/// pairs. Each mini-batch issues one `session.run()` call and pads to its own
249/// maximum sequence length, reducing waste compared to padding the full chunk.
250/// The session mutex is held for all mini-batches in the chunk to avoid per-mini-batch
251/// acquire/release overhead.
252fn score_pairs_blocking(
253    session: &Arc<Mutex<Session>>,
254    tokenizer: &Tokenizer,
255    query: &str,
256    passages: &[String],
257    has_token_type_ids: bool,
258) -> Result<Vec<f32>> {
259    let total = passages.len();
260    if total == 0 {
261        return Ok(Vec::new());
262    }
263
264    let mut all_scores = Vec::with_capacity(total);
265    // Hold the lock for the entire chunk to eliminate per-mini-batch
266    // acquire/release cost. Total lock duration is unchanged vs. the
267    // pre-DAK-5883 single-call approach since total compute is the same.
268    let mut sess = session.lock();
269
270    for mini_batch in passages.chunks(RERANKER_ONNX_BATCH_SIZE) {
271        let batch_size = mini_batch.len();
272
273        // Build EncodeInput pairs: [CLS] query [SEP] passage [SEP]
274        let inputs: Vec<EncodeInput> = mini_batch
275            .iter()
276            .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
277            .collect();
278
279        let encodings = tokenizer
280            .encode_batch(inputs, true)
281            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
282
283        let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
284        if seq_len == 0 {
285            all_scores.extend(std::iter::repeat_n(0.5f32, batch_size));
286            continue;
287        }
288
289        // Flatten to i64 arrays (ORT BERT models expect int64)
290        let mut input_ids = Vec::with_capacity(batch_size * seq_len);
291        let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
292        let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
293
294        for enc in &encodings {
295            input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
296            attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
297            let type_ids = enc.get_type_ids();
298            if type_ids.is_empty() {
299                token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
300            } else {
301                token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
302            }
303        }
304
305        // Build ORT tensors
306        let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
307            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
308        let attention_mask_tensor =
309            Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
310                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
311        let token_type_ids_tensor =
312            Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
313                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
314
315        // Run inference in a scoped block so `outputs` drops before the next
316        // mini-batch iteration reuses the session — the borrow on `sess` from
317        // `SessionOutputs` must end before the next `sess.run()` call.
318        let mini_scores: Vec<f32> = {
319            let outputs = if has_token_type_ids {
320                sess.run(inputs![
321                    "input_ids" => input_ids_tensor,
322                    "attention_mask" => attention_mask_tensor,
323                    "token_type_ids" => token_type_ids_tensor
324                ])
325                .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
326            } else {
327                sess.run(inputs![
328                    "input_ids" => input_ids_tensor,
329                    "attention_mask" => attention_mask_tensor
330                ])
331                .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
332            };
333
334            // Extract logits — bge-reranker-base output shape is [batch_size, 1]
335            let (out_shape, logits_slice) = outputs[0]
336                .try_extract_tensor::<f32>()
337                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
338
339            if out_shape.is_empty() || out_shape[0] as usize != batch_size {
340                warn!(
341                    "Reranker output shape mismatch: expected [{}, 1], got {:?}",
342                    batch_size, out_shape
343                );
344            }
345
346            // Apply sigmoid → owned Vec<f32>; borrow on outputs/sess ends here.
347            logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
348            // outputs drops here (in reverse declaration order: logits_slice, outputs)
349        };
350
351        let n_scores = mini_scores.len();
352        if n_scores != batch_size {
353            warn!(
354                "Reranker score count mismatch: expected {}, got {}",
355                batch_size, n_scores
356            );
357            let mut padded = mini_scores;
358            padded.resize(batch_size, 0.5);
359            all_scores.extend(padded);
360        } else {
361            all_scores.extend(mini_scores);
362        }
363    }
364    // sess drops here, releasing the mutex
365
366    Ok(all_scores)
367}
368
369/// Sigmoid activation: 1 / (1 + exp(-x))
370#[inline]
371fn sigmoid(x: f32) -> f32 {
372    1.0 / (1.0 + (-x).exp())
373}
374
375/// Download tokenizer and ONNX model files for the reranker.
376/// Reuses `EmbeddingEngine::download_hf_file_pub` for redirect-aware caching.
377fn download_reranker_files(
378    cache_dir: Option<String>,
379) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
380    let cache = match cache_dir {
381        Some(dir) => {
382            let p = PathBuf::from(dir);
383            std::fs::create_dir_all(&p)
384                .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
385            p
386        }
387        None => {
388            let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
389            PathBuf::from(home)
390                .join(".cache")
391                .join("huggingface")
392                .join("dakera")
393                .join(RERANKER_REPO_ID.replace('/', "--"))
394        }
395    };
396
397    std::fs::create_dir_all(&cache)
398        .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
399
400    let files = [
401        "tokenizer.json",
402        "tokenizer_config.json",
403        "special_tokens_map.json",
404        RERANKER_ONNX_FILE,
405    ];
406
407    for filename in &files {
408        EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
409            .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
410    }
411
412    let tokenizer_path = cache.join("tokenizer.json");
413    let onnx_path = cache.join(RERANKER_ONNX_FILE);
414    Ok((tokenizer_path, onnx_path))
415}
416
417impl std::fmt::Debug for CrossEncoderEngine {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        f.debug_struct("CrossEncoderEngine")
420            .field("model", &RERANKER_REPO_ID)
421            .field("pool_size", &self.sessions.len())
422            .field("onnx_batch_size", &RERANKER_ONNX_BATCH_SIZE)
423            .finish()
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_sigmoid() {
433        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
434        assert!(sigmoid(10.0) > 0.99);
435        assert!(sigmoid(-10.0) < 0.01);
436    }
437
438    #[test]
439    fn test_chunk_count_exact() {
440        // 64 passages / chunk_size=32 → exactly 2 chunks
441        let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
442        let chunks: Vec<Vec<String>> = passages
443            .chunks(RERANKER_CHUNK_SIZE)
444            .map(<[String]>::to_vec)
445            .collect();
446        assert_eq!(chunks.len(), 2);
447        assert_eq!(chunks[0].len(), 32);
448        assert_eq!(chunks[1].len(), 32);
449    }
450
451    #[test]
452    fn test_chunk_count_remainder() {
453        // 50 passages / chunk_size=32 → 2 chunks (32 + 18)
454        let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
455        let chunks: Vec<Vec<String>> = passages
456            .chunks(RERANKER_CHUNK_SIZE)
457            .map(<[String]>::to_vec)
458            .collect();
459        assert_eq!(chunks.len(), 2);
460        assert_eq!(chunks[0].len(), 32);
461        assert_eq!(chunks[1].len(), 18);
462    }
463
464    #[test]
465    fn test_chunk_count_small_batch() {
466        // 10 passages → single chunk, no splitting overhead
467        let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
468        let chunks: Vec<Vec<String>> = passages
469            .chunks(RERANKER_CHUNK_SIZE)
470            .map(<[String]>::to_vec)
471            .collect();
472        assert_eq!(chunks.len(), 1);
473        assert_eq!(chunks[0].len(), 10);
474    }
475
476    #[test]
477    fn test_chunk_order_preserved() {
478        // Chunk splitting must preserve passage order for score reassembly.
479        let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
480        let reassembled: Vec<String> = passages
481            .chunks(RERANKER_CHUNK_SIZE)
482            .flat_map(<[String]>::to_vec)
483            .collect();
484        assert_eq!(passages, reassembled);
485    }
486
487    #[test]
488    fn test_pool_size_constant() {
489        const { assert!(RERANKER_POOL_SIZE >= 1) };
490        const { assert!(RERANKER_CHUNK_SIZE >= 1) };
491    }
492
493    #[test]
494    fn test_round_robin_wraps() {
495        let pool_len = RERANKER_POOL_SIZE;
496        // Simulate 10 concurrent callers; each gets a different start_idx.
497        // Verify no start_idx exceeds pool_len when used with modulo.
498        for start in 0usize..10 {
499            let idx = start % pool_len;
500            assert!(idx < pool_len);
501        }
502    }
503
504    // ── ONNX mini-batch tests (DAK-5883) ────────────────────────────────────
505
506    #[test]
507    fn test_onnx_batch_size_constant_invariants() {
508        // ONNX batch size must be positive and no larger than the chunk size.
509        // If ONNX_BATCH_SIZE > CHUNK_SIZE the inner loop always produces one
510        // mini-batch (identical to pre-DAK-5883 behaviour), which is allowed
511        // but defeats the purpose of the constant.
512        const { assert!(RERANKER_ONNX_BATCH_SIZE >= 1) };
513        const { assert!(RERANKER_ONNX_BATCH_SIZE <= RERANKER_CHUNK_SIZE) };
514    }
515
516    #[test]
517    fn test_onnx_mini_batch_count_full_chunk() {
518        // A full chunk (32 passages) with ONNX_BATCH_SIZE=16 → exactly 2 mini-batches.
519        let passages: Vec<String> = (0..RERANKER_CHUNK_SIZE).map(|i| format!("p{i}")).collect();
520        let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
521        let expected = RERANKER_CHUNK_SIZE.div_ceil(RERANKER_ONNX_BATCH_SIZE);
522        assert_eq!(mini_batches.len(), expected);
523        // Each full mini-batch has exactly ONNX_BATCH_SIZE items.
524        for mb in &mini_batches[..mini_batches.len() - 1] {
525            assert_eq!(mb.len(), RERANKER_ONNX_BATCH_SIZE);
526        }
527    }
528
529    #[test]
530    fn test_onnx_mini_batch_count_partial_chunk() {
531        // ONNX_BATCH_SIZE + 1 passages → 2 mini-batches (full + remainder of 1).
532        let n = RERANKER_ONNX_BATCH_SIZE + 1;
533        let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
534        let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
535        assert_eq!(mini_batches.len(), 2);
536        assert_eq!(mini_batches[0].len(), RERANKER_ONNX_BATCH_SIZE);
537        assert_eq!(mini_batches[1].len(), 1);
538    }
539
540    #[test]
541    fn test_onnx_mini_batch_count_smaller_than_batch_size() {
542        // Fewer passages than ONNX_BATCH_SIZE → single mini-batch (no overhead).
543        let n = RERANKER_ONNX_BATCH_SIZE / 2;
544        let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
545        let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
546        assert_eq!(mini_batches.len(), 1);
547        assert_eq!(mini_batches[0].len(), n);
548    }
549
550    #[test]
551    fn test_onnx_mini_batch_order_preserved() {
552        // Mini-batch splitting and reassembly must preserve input order exactly.
553        // This guards against score[i] ↔ passage[j] mismatches in reassembly.
554        let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
555        let reassembled: Vec<String> = passages
556            .chunks(RERANKER_ONNX_BATCH_SIZE)
557            .flat_map(|mb| mb.to_vec())
558            .collect();
559        assert_eq!(passages, reassembled);
560    }
561
562    #[test]
563    fn test_onnx_mini_batch_total_score_count_matches_input() {
564        // Whatever the input size, total score count after reassembly must equal
565        // the number of input passages (covers exact multiples and remainders).
566        for n in [1, 8, 15, 16, 17, 32, 33, 47, 64] {
567            let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
568            let total: usize = passages
569                .chunks(RERANKER_ONNX_BATCH_SIZE)
570                .map(|mb| mb.len())
571                .sum();
572            assert_eq!(total, n, "score count mismatch for n={n}");
573        }
574    }
575
576    #[test]
577    fn test_onnx_batch_size_accessor() {
578        // Verify that onnx_batch_size() returns the compile-time constant.
579        // Requires constructing a CrossEncoderEngine — only possible in integration
580        // tests with a real model. Instead, verify the constant directly via the
581        // public constant's value relationship.
582        assert_eq!(RERANKER_ONNX_BATCH_SIZE, 16);
583    }
584}