Skip to main content

inference/
reranker.rs

1//! Cross-encoder reranker for improving recall precision.
2//!
3//! Uses BAAI/bge-reranker-base (Xenova ONNX INT8 quantized) to score
4//! (query, passage) pairs for relevance. More accurate than bi-encoder
5//! vector similarity but slower — used as a second-stage reranker after
6//! ANN candidate retrieval.
7//!
8//! # Architecture
9//!
10//! ```text
11//! query + passage → [CLS] query [SEP] passage [SEP]
12//!                       ↓ BERT forward pass
13//!                   logits [batch, 1]
14//!                       ↓ sigmoid
15//!                   relevance scores ∈ [0, 1]
16//! ```
17//!
18//! # Session Pool
19//!
20//! The engine maintains `RERANKER_POOL_SIZE` independent ONNX sessions.
21//! Large candidate lists are split into chunks of `RERANKER_CHUNK_SIZE` and
22//! dispatched in parallel across the pool, eliminating head-of-line blocking
23//! when multiple recall calls arrive concurrently (DAK-5873).
24//!
25//! # ONNX Mini-Batching
26//!
27//! Within each dispatched chunk, pairs are further split into mini-batches of
28//! `RERANKER_ONNX_BATCH_SIZE` for a single `session.run()` call. Smaller mini-batches
29//! reduce sequence-padding overhead: each mini-batch pads to its own maximum
30//! sequence length rather than the chunk maximum, cutting wasted compute when
31//! passage lengths vary widely (DAK-5883).
32
33use crate::engine::EmbeddingEngine;
34use crate::error::{InferenceError, Result};
35use ort::execution_providers::CUDAExecutionProvider;
36use ort::inputs;
37use ort::session::builder::GraphOptimizationLevel;
38use ort::session::Session;
39use ort::value::Tensor;
40use parking_lot::Mutex;
41use std::path::PathBuf;
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::{
45    EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
46};
47use tracing::{info, instrument, warn};
48
49/// The reranker model Xenova HuggingFace repo ID (ONNX INT8).
50const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
51/// ONNX quantized model filename within the repo.
52const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
53/// Maximum token length for cross-encoder input (query + passage combined).
54const MAX_SEQ_LENGTH: usize = 512;
55/// Number of independent ONNX sessions in the reranker pool.
56///
57/// Two sessions allow concurrent recall requests to rerank in parallel without
58/// head-of-line mutex blocking. Each session uses `intra_threads=4`; two sessions
59/// occupy all 8 vCPUs on the production CPX32 server (DAK-5873).
60const RERANKER_POOL_SIZE: usize = 2;
61
62/// Maximum concurrent `score_pairs` callers allowed before returning `Overloaded`.
63///
64/// Root cause of DAK-5893: with `RERANKER_POOL_SIZE=2` and 8 concurrent bench recall
65/// requests, the 7th/8th caller waited >120s for a Mutex slot → client timeout →
66/// 8-attempt retry loop → 19-minute stall. Capping at `POOL_SIZE * 3 = 6` lets
67/// 6 requests queue shallowly (each waits at most 2 ahead on its session) while
68/// the 7th+ returns immediately so the API falls back to unranked results.
69const RERANKER_MAX_CONCURRENT: usize = RERANKER_POOL_SIZE * 3;
70
71/// Maximum candidates per session sub-batch (parallel dispatch unit).
72///
73/// Large candidate lists (e.g. temporal `fetch_n = top_k × 8 = 160`) are split
74/// into chunks of this size and dispatched concurrently across the pool. With
75/// `RERANKER_POOL_SIZE=2` and `RERANKER_CHUNK_SIZE=32`, a 160-candidate list
76/// produces 5 chunks: sessions[0] handles chunks 0/2/4, sessions[1] handles
77/// chunks 1/3 — 3 serial chunks on the busier session vs. 5 serial before.
78const RERANKER_CHUNK_SIZE: usize = 32;
79/// Maximum candidate pairs per single ONNX `session.run()` call (inner batch).
80///
81/// Within each dispatched chunk, pairs are further split into mini-batches of
82/// this size. Each mini-batch is padded to its own maximum sequence length,
83/// reducing wasted computation when passage lengths vary (padding overhead on
84/// shorter passages is bounded by the max within the mini-batch, not the chunk).
85///
86/// Tuning guide:
87/// - Smaller (8): less padding waste, more `session.run()` calls per chunk.
88/// - Larger (32, equal to CHUNK_SIZE): behaves like the pre-DAK-5883 single-call
89///   mode, effectively disabling inner batching.
90/// - Default 16: halves padding overhead on mixed-length candidate lists while
91///   keeping `session.run()` call count to 2 per full chunk.
92const RERANKER_ONNX_BATCH_SIZE: usize = 16;
93
94/// RAII guard that decrements the `active_requests` counter when dropped.
95struct ActiveGuard(Arc<AtomicUsize>);
96
97impl Drop for ActiveGuard {
98    fn drop(&mut self) {
99        self.0.fetch_sub(1, Ordering::SeqCst);
100    }
101}
102
103/// Cross-encoder reranking engine.
104///
105/// Thread-safe — shared via `Arc`. Maintains a pool of independent ONNX sessions
106/// so concurrent rerank calls never contend on a single mutex.
107pub struct CrossEncoderEngine {
108    /// Pool of independent ONNX sessions (round-robin dispatch).
109    sessions: Vec<Arc<Mutex<Session>>>,
110    tokenizer: Arc<Tokenizer>,
111    /// Whether the loaded ONNX model expects a `token_type_ids` input tensor.
112    /// bge-reranker-base only has `input_ids` + `attention_mask`; some other
113    /// cross-encoders include `token_type_ids`. Determined at load time.
114    has_token_type_ids: bool,
115    /// Round-robin counter for session assignment.
116    next_session: AtomicUsize,
117    /// Active concurrent callers of `score_pairs`. When this reaches
118    /// `RERANKER_MAX_CONCURRENT`, new callers return `Overloaded` immediately
119    /// so the API can fall back to unranked results rather than queuing
120    /// indefinitely — root cause fix for DAK-5893 SIGTERM stall.
121    active_requests: Arc<AtomicUsize>,
122}
123
124impl CrossEncoderEngine {
125    /// Load or download the reranker model.
126    ///
127    /// Downloads `Xenova/bge-reranker-base` ONNX INT8 model from HuggingFace Hub
128    /// if not already cached. Builds `RERANKER_POOL_SIZE` independent sessions.
129    #[instrument(skip_all)]
130    pub async fn new(cache_dir: Option<String>) -> Result<Self> {
131        info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
132
133        let (tokenizer_path, onnx_path) =
134            tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
135                .await
136                .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
137                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
138
139        info!("Loading reranker tokenizer from {:?}", tokenizer_path);
140        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
141            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
142
143        // Configure padding + truncation for uniform batch shapes
144        let padding = PaddingParams {
145            strategy: PaddingStrategy::BatchLongest,
146            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
147            pad_token: tokenizer
148                .get_padding()
149                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
150            ..Default::default()
151        };
152        tokenizer.with_padding(Some(padding));
153        let truncation = TruncationParams {
154            max_length: MAX_SEQ_LENGTH,
155            ..Default::default()
156        };
157        let _ = tokenizer.with_truncation(Some(truncation));
158
159        info!(
160            "Loading reranker ONNX model from {:?} (pool_size={}, onnx_batch_size={})",
161            onnx_path, RERANKER_POOL_SIZE, RERANKER_ONNX_BATCH_SIZE
162        );
163
164        let use_gpu = std::env::var("DAKERA_USE_GPU")
165            .map(|v| v == "1")
166            .unwrap_or(false);
167        if use_gpu {
168            info!("CUDA execution provider enabled for reranker (DAKERA_USE_GPU=1)");
169        }
170
171        // Build pool of independent ONNX sessions — each has its own ORT context
172        // so pool members never block each other under concurrent rerank calls.
173        let (sessions, has_token_type_ids) =
174            tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
175                let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
176                    .map(|_| {
177                        let builder = Session::builder()
178                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
179                            .with_optimization_level(GraphOptimizationLevel::Level3)
180                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
181                            .with_intra_threads(4)
182                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
183
184                        let mut builder = if use_gpu {
185                            builder
186                                .with_execution_providers(
187                                    [CUDAExecutionProvider::default().build()],
188                                )
189                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
190                        } else {
191                            builder
192                        };
193
194                        builder
195                            .commit_from_file(&onnx_path)
196                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))
197                    })
198                    .collect();
199                let raw = raw?;
200                // Inspect first session to detect optional token_type_ids input.
201                let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
202                let sessions: Vec<Arc<Mutex<Session>>> =
203                    raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
204                Ok((sessions, has_tti))
205            })
206            .await
207            .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
208            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
209
210        info!(
211            has_token_type_ids,
212            pool_size = sessions.len(),
213            onnx_batch_size = RERANKER_ONNX_BATCH_SIZE,
214            "Cross-encoder reranker loaded successfully"
215        );
216
217        Ok(Self {
218            sessions,
219            tokenizer: Arc::new(tokenizer),
220            has_token_type_ids,
221            next_session: AtomicUsize::new(0),
222            active_requests: Arc::new(AtomicUsize::new(0)),
223        })
224    }
225
226    /// Score a batch of (query, passage) pairs.
227    ///
228    /// Passages are split into chunks of [`RERANKER_CHUNK_SIZE`] and dispatched
229    /// in parallel across the session pool (round-robin). Within each chunk,
230    /// pairs are processed in mini-batches of [`RERANKER_ONNX_BATCH_SIZE`] to
231    /// reduce sequence-padding overhead. Chunk results are reassembled in input
232    /// order.
233    ///
234    /// Returns `Err(InferenceError::Overloaded)` immediately when more than
235    /// `RERANKER_MAX_CONCURRENT` callers are active — the API layer falls back
236    /// to unranked results rather than queuing indefinitely (DAK-5893 fix).
237    ///
238    /// Returns a relevance score in `[0, 1]` for each passage.
239    /// Higher scores indicate greater relevance to the query.
240    #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
241    pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
242        if passages.is_empty() {
243            return Ok(Vec::new());
244        }
245
246        // Concurrency gate: if already at capacity, return Overloaded so the API
247        // falls back to unranked results instead of queuing for >120s (DAK-5893).
248        let prev = self.active_requests.fetch_add(1, Ordering::SeqCst);
249        if prev >= RERANKER_MAX_CONCURRENT {
250            self.active_requests.fetch_sub(1, Ordering::SeqCst);
251            warn!(
252                active = prev,
253                max = RERANKER_MAX_CONCURRENT,
254                "Cross-encoder at capacity — returning Overloaded (API will use unranked results)"
255            );
256            return Err(InferenceError::Overloaded {
257                active: prev,
258                max: RERANKER_MAX_CONCURRENT,
259            });
260        }
261        // RAII decrement: always release the slot on return, including errors.
262        let _guard = ActiveGuard(Arc::clone(&self.active_requests));
263
264        let pool_len = self.sessions.len();
265        // Round-robin start: each concurrent caller gets a different initial slot
266        // so concurrent requests don't all contend on sessions[0].
267        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
268        let tokenizer = Arc::clone(&self.tokenizer);
269        let has_token_type_ids = self.has_token_type_ids;
270        let query_str = query.to_string();
271
272        // Split candidates into RERANKER_CHUNK_SIZE sub-batches.
273        let chunks: Vec<Vec<String>> = passages
274            .chunks(RERANKER_CHUNK_SIZE)
275            .map(<[String]>::to_vec)
276            .collect();
277
278        // Spawn all chunks concurrently; each acquires its own session slot.
279        let mut handles = Vec::with_capacity(chunks.len());
280        for (i, chunk) in chunks.into_iter().enumerate() {
281            let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
282            let tok = Arc::clone(&tokenizer);
283            let q = query_str.clone();
284            handles.push(tokio::task::spawn_blocking(move || {
285                score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
286            }));
287        }
288
289        // Collect results in chunk order to preserve passage ordering.
290        let mut scores = Vec::with_capacity(passages.len());
291        for handle in handles {
292            let chunk_scores = handle
293                .await
294                .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
295            scores.extend(chunk_scores);
296        }
297
298        Ok(scores)
299    }
300
301    /// Number of parallel ONNX sessions in the pool.
302    pub fn pool_size(&self) -> usize {
303        self.sessions.len()
304    }
305
306    /// Configured ONNX mini-batch size (pairs per `session.run()` call).
307    pub fn onnx_batch_size(&self) -> usize {
308        RERANKER_ONNX_BATCH_SIZE
309    }
310
311    /// Current number of active concurrent `score_pairs` calls.
312    /// Used by metrics and health checks (DAK-5893).
313    pub fn active_requests_count(&self) -> usize {
314        self.active_requests.load(Ordering::Relaxed)
315    }
316
317    /// Maximum concurrent `score_pairs` calls before `Overloaded` is returned.
318    pub fn max_concurrent(&self) -> usize {
319        RERANKER_MAX_CONCURRENT
320    }
321}
322
323/// Blocking cross-encoder inference for one chunk — runs inside `spawn_blocking`.
324///
325/// The chunk is processed as a sequence of mini-batches of [`RERANKER_ONNX_BATCH_SIZE`]
326/// pairs. Each mini-batch issues one `session.run()` call and pads to its own
327/// maximum sequence length, reducing waste compared to padding the full chunk.
328/// The session mutex is held for all mini-batches in the chunk to avoid per-mini-batch
329/// acquire/release overhead.
330fn score_pairs_blocking(
331    session: &Arc<Mutex<Session>>,
332    tokenizer: &Tokenizer,
333    query: &str,
334    passages: &[String],
335    has_token_type_ids: bool,
336) -> Result<Vec<f32>> {
337    let total = passages.len();
338    if total == 0 {
339        return Ok(Vec::new());
340    }
341
342    let mut all_scores = Vec::with_capacity(total);
343    // Hold the lock for the entire chunk to eliminate per-mini-batch
344    // acquire/release cost. Total lock duration is unchanged vs. the
345    // pre-DAK-5883 single-call approach since total compute is the same.
346    let mut sess = session.lock();
347
348    for mini_batch in passages.chunks(RERANKER_ONNX_BATCH_SIZE) {
349        let batch_size = mini_batch.len();
350
351        // Build EncodeInput pairs: [CLS] query [SEP] passage [SEP]
352        let inputs: Vec<EncodeInput> = mini_batch
353            .iter()
354            .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
355            .collect();
356
357        let encodings = tokenizer
358            .encode_batch(inputs, true)
359            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
360
361        let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
362        if seq_len == 0 {
363            all_scores.extend(std::iter::repeat_n(0.5f32, batch_size));
364            continue;
365        }
366
367        // Flatten to i64 arrays (ORT BERT models expect int64)
368        let mut input_ids = Vec::with_capacity(batch_size * seq_len);
369        let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
370        let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
371
372        for enc in &encodings {
373            input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
374            attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
375            let type_ids = enc.get_type_ids();
376            if type_ids.is_empty() {
377                token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
378            } else {
379                token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
380            }
381        }
382
383        // Build ORT tensors
384        let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
385            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
386        let attention_mask_tensor =
387            Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
388                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
389        let token_type_ids_tensor =
390            Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
391                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
392
393        // Run inference in a scoped block so `outputs` drops before the next
394        // mini-batch iteration reuses the session — the borrow on `sess` from
395        // `SessionOutputs` must end before the next `sess.run()` call.
396        let mini_scores: Vec<f32> = {
397            let outputs = if has_token_type_ids {
398                sess.run(inputs![
399                    "input_ids" => input_ids_tensor,
400                    "attention_mask" => attention_mask_tensor,
401                    "token_type_ids" => token_type_ids_tensor
402                ])
403                .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
404            } else {
405                sess.run(inputs![
406                    "input_ids" => input_ids_tensor,
407                    "attention_mask" => attention_mask_tensor
408                ])
409                .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
410            };
411
412            // Extract logits — bge-reranker-base output shape is [batch_size, 1]
413            let (out_shape, logits_slice) = outputs[0]
414                .try_extract_tensor::<f32>()
415                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
416
417            if out_shape.is_empty() || out_shape[0] as usize != batch_size {
418                warn!(
419                    "Reranker output shape mismatch: expected [{}, 1], got {:?}",
420                    batch_size, out_shape
421                );
422            }
423
424            // Apply sigmoid → owned Vec<f32>; borrow on outputs/sess ends here.
425            logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
426            // outputs drops here (in reverse declaration order: logits_slice, outputs)
427        };
428
429        let n_scores = mini_scores.len();
430        if n_scores != batch_size {
431            warn!(
432                "Reranker score count mismatch: expected {}, got {}",
433                batch_size, n_scores
434            );
435            let mut padded = mini_scores;
436            padded.resize(batch_size, 0.5);
437            all_scores.extend(padded);
438        } else {
439            all_scores.extend(mini_scores);
440        }
441    }
442    // sess drops here, releasing the mutex
443
444    Ok(all_scores)
445}
446
447/// Sigmoid activation: 1 / (1 + exp(-x))
448#[inline]
449fn sigmoid(x: f32) -> f32 {
450    1.0 / (1.0 + (-x).exp())
451}
452
453/// Download tokenizer and ONNX model files for the reranker.
454/// Reuses `EmbeddingEngine::download_hf_file_pub` for redirect-aware caching.
455fn download_reranker_files(
456    cache_dir: Option<String>,
457) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
458    let cache = match cache_dir {
459        Some(dir) => {
460            let p = PathBuf::from(dir);
461            std::fs::create_dir_all(&p)
462                .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
463            p
464        }
465        None => {
466            let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
467            PathBuf::from(home)
468                .join(".cache")
469                .join("huggingface")
470                .join("dakera")
471                .join(RERANKER_REPO_ID.replace('/', "--"))
472        }
473    };
474
475    std::fs::create_dir_all(&cache)
476        .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
477
478    let files = [
479        "tokenizer.json",
480        "tokenizer_config.json",
481        "special_tokens_map.json",
482        RERANKER_ONNX_FILE,
483    ];
484
485    for filename in &files {
486        EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
487            .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
488    }
489
490    let tokenizer_path = cache.join("tokenizer.json");
491    let onnx_path = cache.join(RERANKER_ONNX_FILE);
492    Ok((tokenizer_path, onnx_path))
493}
494
495impl std::fmt::Debug for CrossEncoderEngine {
496    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497        f.debug_struct("CrossEncoderEngine")
498            .field("model", &RERANKER_REPO_ID)
499            .field("pool_size", &self.sessions.len())
500            .field("onnx_batch_size", &RERANKER_ONNX_BATCH_SIZE)
501            .field(
502                "active_requests",
503                &self.active_requests.load(Ordering::Relaxed),
504            )
505            .field("max_concurrent", &RERANKER_MAX_CONCURRENT)
506            .finish()
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    // ── GPU env-var read path ────────────────────────────────────────────────
515
516    #[test]
517    fn test_use_gpu_default_is_false() {
518        // Without DAKERA_USE_GPU set, use_gpu must resolve to false (CPU default).
519        use std::sync::Mutex;
520        static ENV_LOCK: Mutex<()> = Mutex::new(());
521        let _guard = ENV_LOCK.lock().unwrap();
522        unsafe { std::env::remove_var("DAKERA_USE_GPU") };
523        let use_gpu = std::env::var("DAKERA_USE_GPU")
524            .map(|v| v == "1")
525            .unwrap_or(false);
526        assert!(
527            !use_gpu,
528            "expected CPU default when DAKERA_USE_GPU is unset"
529        );
530    }
531
532    #[test]
533    fn test_use_gpu_enabled_when_env_var_is_1() {
534        use std::sync::Mutex;
535        static ENV_LOCK: Mutex<()> = Mutex::new(());
536        let _guard = ENV_LOCK.lock().unwrap();
537        unsafe { std::env::set_var("DAKERA_USE_GPU", "1") };
538        let use_gpu = std::env::var("DAKERA_USE_GPU")
539            .map(|v| v == "1")
540            .unwrap_or(false);
541        unsafe { std::env::remove_var("DAKERA_USE_GPU") };
542        assert!(use_gpu, "expected GPU mode when DAKERA_USE_GPU=1");
543    }
544
545    #[test]
546    fn test_use_gpu_not_enabled_for_other_values() {
547        use std::sync::Mutex;
548        static ENV_LOCK: Mutex<()> = Mutex::new(());
549        let _guard = ENV_LOCK.lock().unwrap();
550        for val in ["0", "true", "yes", "gpu", ""] {
551            unsafe { std::env::set_var("DAKERA_USE_GPU", val) };
552            let use_gpu = std::env::var("DAKERA_USE_GPU")
553                .map(|v| v == "1")
554                .unwrap_or(false);
555            unsafe { std::env::remove_var("DAKERA_USE_GPU") };
556            assert!(
557                !use_gpu,
558                "expected CPU when DAKERA_USE_GPU={val:?} (only '1' enables GPU)"
559            );
560        }
561    }
562
563    #[test]
564    fn test_sigmoid() {
565        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
566        assert!(sigmoid(10.0) > 0.99);
567        assert!(sigmoid(-10.0) < 0.01);
568    }
569
570    #[test]
571    fn test_chunk_count_exact() {
572        // 64 passages / chunk_size=32 → exactly 2 chunks
573        let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
574        let chunks: Vec<Vec<String>> = passages
575            .chunks(RERANKER_CHUNK_SIZE)
576            .map(<[String]>::to_vec)
577            .collect();
578        assert_eq!(chunks.len(), 2);
579        assert_eq!(chunks[0].len(), 32);
580        assert_eq!(chunks[1].len(), 32);
581    }
582
583    #[test]
584    fn test_chunk_count_remainder() {
585        // 50 passages / chunk_size=32 → 2 chunks (32 + 18)
586        let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
587        let chunks: Vec<Vec<String>> = passages
588            .chunks(RERANKER_CHUNK_SIZE)
589            .map(<[String]>::to_vec)
590            .collect();
591        assert_eq!(chunks.len(), 2);
592        assert_eq!(chunks[0].len(), 32);
593        assert_eq!(chunks[1].len(), 18);
594    }
595
596    #[test]
597    fn test_chunk_count_small_batch() {
598        // 10 passages → single chunk, no splitting overhead
599        let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
600        let chunks: Vec<Vec<String>> = passages
601            .chunks(RERANKER_CHUNK_SIZE)
602            .map(<[String]>::to_vec)
603            .collect();
604        assert_eq!(chunks.len(), 1);
605        assert_eq!(chunks[0].len(), 10);
606    }
607
608    #[test]
609    fn test_chunk_order_preserved() {
610        // Chunk splitting must preserve passage order for score reassembly.
611        let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
612        let reassembled: Vec<String> = passages
613            .chunks(RERANKER_CHUNK_SIZE)
614            .flat_map(<[String]>::to_vec)
615            .collect();
616        assert_eq!(passages, reassembled);
617    }
618
619    #[test]
620    fn test_pool_size_constant() {
621        const { assert!(RERANKER_POOL_SIZE >= 1) };
622        const { assert!(RERANKER_CHUNK_SIZE >= 1) };
623    }
624
625    #[test]
626    fn test_max_concurrent_exceeds_pool_size() {
627        // RERANKER_MAX_CONCURRENT must be strictly greater than RERANKER_POOL_SIZE
628        // so that the pool can be utilised at all before the gate fires (DAK-5893).
629        const { assert!(RERANKER_MAX_CONCURRENT > RERANKER_POOL_SIZE) };
630        // Must also be reasonable — less than 20 so a 20-request burst gets shed.
631        const { assert!(RERANKER_MAX_CONCURRENT < 20) };
632    }
633
634    #[test]
635    fn test_active_guard_decrements() {
636        let counter = Arc::new(AtomicUsize::new(1));
637        {
638            let _g = ActiveGuard(Arc::clone(&counter));
639            assert_eq!(counter.load(Ordering::SeqCst), 1);
640        }
641        assert_eq!(counter.load(Ordering::SeqCst), 0);
642    }
643
644    #[test]
645    fn test_round_robin_wraps() {
646        let pool_len = RERANKER_POOL_SIZE;
647        // Simulate 10 concurrent callers; each gets a different start_idx.
648        // Verify no start_idx exceeds pool_len when used with modulo.
649        for start in 0usize..10 {
650            let idx = start % pool_len;
651            assert!(idx < pool_len);
652        }
653    }
654
655    // ── ONNX mini-batch tests (DAK-5883) ────────────────────────────────────
656
657    #[test]
658    fn test_onnx_batch_size_constant_invariants() {
659        // ONNX batch size must be positive and no larger than the chunk size.
660        // If ONNX_BATCH_SIZE > CHUNK_SIZE the inner loop always produces one
661        // mini-batch (identical to pre-DAK-5883 behaviour), which is allowed
662        // but defeats the purpose of the constant.
663        const { assert!(RERANKER_ONNX_BATCH_SIZE >= 1) };
664        const { assert!(RERANKER_ONNX_BATCH_SIZE <= RERANKER_CHUNK_SIZE) };
665    }
666
667    #[test]
668    fn test_onnx_mini_batch_count_full_chunk() {
669        // A full chunk (32 passages) with ONNX_BATCH_SIZE=16 → exactly 2 mini-batches.
670        let passages: Vec<String> = (0..RERANKER_CHUNK_SIZE).map(|i| format!("p{i}")).collect();
671        let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
672        let expected = RERANKER_CHUNK_SIZE.div_ceil(RERANKER_ONNX_BATCH_SIZE);
673        assert_eq!(mini_batches.len(), expected);
674        // Each full mini-batch has exactly ONNX_BATCH_SIZE items.
675        for mb in &mini_batches[..mini_batches.len() - 1] {
676            assert_eq!(mb.len(), RERANKER_ONNX_BATCH_SIZE);
677        }
678    }
679
680    #[test]
681    fn test_onnx_mini_batch_count_partial_chunk() {
682        // ONNX_BATCH_SIZE + 1 passages → 2 mini-batches (full + remainder of 1).
683        let n = RERANKER_ONNX_BATCH_SIZE + 1;
684        let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
685        let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
686        assert_eq!(mini_batches.len(), 2);
687        assert_eq!(mini_batches[0].len(), RERANKER_ONNX_BATCH_SIZE);
688        assert_eq!(mini_batches[1].len(), 1);
689    }
690
691    #[test]
692    fn test_onnx_mini_batch_count_smaller_than_batch_size() {
693        // Fewer passages than ONNX_BATCH_SIZE → single mini-batch (no overhead).
694        let n = RERANKER_ONNX_BATCH_SIZE / 2;
695        let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
696        let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
697        assert_eq!(mini_batches.len(), 1);
698        assert_eq!(mini_batches[0].len(), n);
699    }
700
701    #[test]
702    fn test_onnx_mini_batch_order_preserved() {
703        // Mini-batch splitting and reassembly must preserve input order exactly.
704        // This guards against score[i] ↔ passage[j] mismatches in reassembly.
705        let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
706        let reassembled: Vec<String> = passages
707            .chunks(RERANKER_ONNX_BATCH_SIZE)
708            .flat_map(|mb| mb.to_vec())
709            .collect();
710        assert_eq!(passages, reassembled);
711    }
712
713    #[test]
714    fn test_onnx_mini_batch_total_score_count_matches_input() {
715        // Whatever the input size, total score count after reassembly must equal
716        // the number of input passages (covers exact multiples and remainders).
717        for n in [1, 8, 15, 16, 17, 32, 33, 47, 64] {
718            let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
719            let total: usize = passages
720                .chunks(RERANKER_ONNX_BATCH_SIZE)
721                .map(|mb| mb.len())
722                .sum();
723            assert_eq!(total, n, "score count mismatch for n={n}");
724        }
725    }
726
727    #[test]
728    fn test_onnx_batch_size_accessor() {
729        // Verify that onnx_batch_size() returns the compile-time constant.
730        // Requires constructing a CrossEncoderEngine — only possible in integration
731        // tests with a real model. Instead, verify the constant directly via the
732        // public constant's value relationship.
733        assert_eq!(RERANKER_ONNX_BATCH_SIZE, 16);
734    }
735}