Skip to main content

next_plaid_onnx/
lib.rs

1//! # Next-Plaid ONNX
2//!
3//! Fast ColBERT inference using ONNX Runtime with automatic hardware acceleration.
4//!
5//! Also includes hierarchical clustering utilities compatible with scipy.
6//!
7//! ## Quick Start
8//!
9//! ```rust,ignore
10//! use next_plaid_onnx::Colbert;
11//!
12//! // Simple usage with defaults (auto-detects threads and hardware)
13//! let model = Colbert::new("models/GTE-ModernColBERT-v1")?;
14//!
15//! // Encode documents
16//! let doc_embeddings = model.encode_documents(&["Paris is the capital of France."], None)?;
17//!
18//! // Encode queries
19//! let query_embeddings = model.encode_queries(&["What is the capital of France?"])?;
20//! ```
21//!
22//! ## Configuration
23//!
24//! Use the builder pattern for advanced configuration:
25//!
26//! ```rust,ignore
27//! use next_plaid_onnx::{Colbert, ExecutionProvider};
28//!
29//! let model = Colbert::builder("models/GTE-ModernColBERT-v1")
30//!     .with_quantized(true)                              // Use INT8 model for ~2x speedup
31//!     .with_parallel(25)                                 // 25 parallel ONNX sessions
32//!     .with_batch_size(2)                                // Batch size per session
33//!     .with_execution_provider(ExecutionProvider::Cuda)  // Force CUDA
34//!     .build()?;
35//! ```
36//!
37//! ## Hardware Acceleration
38//!
39//! Enable GPU acceleration by adding the appropriate feature:
40//!
41//! - `cuda` - NVIDIA CUDA (Linux/Windows)
42//! - `tensorrt` - NVIDIA TensorRT (optimized CUDA)
43//! - `coreml` - Apple Silicon (macOS)
44//! - `directml` - Windows GPUs (DirectX 12)
45//!
46//! When GPU features are enabled, the library automatically uses GPU if available
47//! and falls back to CPU if not.
48
49pub mod hierarchy;
50
51use anyhow::{Context, Result};
52use ndarray::Array2;
53use ort::session::builder::GraphOptimizationLevel;
54use ort::session::Session;
55use ort::value::Tensor;
56use serde::{Deserialize, Serialize};
57use std::collections::HashSet;
58use std::fs;
59use std::path::Path;
60use std::sync::Once;
61use std::sync::{Arc, Mutex};
62use tokenizers::Tokenizer;
63
64// Conditional imports for execution providers
65#[cfg(feature = "cuda")]
66use ort::ep::ExecutionProvider as ExecutionProviderTrait;
67#[cfg(feature = "cuda")]
68use ort::execution_providers::CUDAExecutionProvider;
69#[cfg(feature = "coreml")]
70use ort::execution_providers::CoreMLExecutionProvider;
71#[cfg(feature = "directml")]
72use ort::execution_providers::DirectMLExecutionProvider;
73#[cfg(feature = "tensorrt")]
74use ort::execution_providers::TensorRTExecutionProvider;
75
76use ort::session::builder::SessionBuilder;
77
78// =============================================================================
79// ONNX Runtime initialization (internal)
80// =============================================================================
81
82static ORT_INIT: Once = Once::new();
83
84/// Initialize ONNX Runtime by finding and loading the dynamic library.
85fn init_ort_runtime() {
86    ORT_INIT.call_once(|| {
87        // If ORT_DYLIB_PATH is already set, ort will use it
88        if std::env::var("ORT_DYLIB_PATH").is_ok() {
89            return;
90        }
91
92        // Try to find ONNX Runtime in common locations
93        if let Some(lib_path) = find_onnxruntime_library() {
94            std::env::set_var("ORT_DYLIB_PATH", &lib_path);
95        }
96    });
97}
98
99/// Find the ONNX Runtime library in common installation locations.
100fn find_onnxruntime_library() -> Option<String> {
101    let home = std::env::var("HOME").ok()?;
102
103    let search_patterns = vec![
104        // Python virtual environments (various Python versions)
105        format!(
106            "{}/.venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*",
107            home
108        ),
109        format!(
110            "{}/venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*",
111            home
112        ),
113        "python/.venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*".to_string(),
114        ".venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*".to_string(),
115        // User site-packages
116        format!(
117            "{}/.local/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*",
118            home
119        ),
120        // UV cache (common with uv package manager)
121        format!(
122            "{}/.cache/uv/archive-v*/*/onnxruntime/capi/libonnxruntime.so*",
123            home
124        ),
125        // Conda environments
126        format!("{}/anaconda3/lib/libonnxruntime.so*", home),
127        format!("{}/miniconda3/lib/libonnxruntime.so*", home),
128        // System locations
129        "/usr/local/lib/libonnxruntime.so*".to_string(),
130        "/usr/lib/libonnxruntime.so*".to_string(),
131        "/usr/lib/x86_64-linux-gnu/libonnxruntime.so*".to_string(),
132    ];
133
134    for pattern in search_patterns {
135        if let Ok(paths) = glob::glob(&pattern) {
136            for path in paths.flatten() {
137                if path.exists() && path.is_file() {
138                    let path_str = path.to_string_lossy();
139                    if path_str.contains(".so.") || path_str.ends_with(".so") {
140                        return Some(path.to_string_lossy().to_string());
141                    }
142                }
143            }
144        }
145    }
146
147    None
148}
149
150// =============================================================================
151// Execution Provider Configuration
152// =============================================================================
153
154/// Hardware acceleration provider for ONNX Runtime.
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
156pub enum ExecutionProvider {
157    /// Automatically detect and use the best available hardware.
158    /// Tries in order: CUDA > TensorRT > CoreML > DirectML > CPU
159    #[default]
160    Auto,
161    /// CPU execution only
162    Cpu,
163    /// CUDA execution (NVIDIA GPUs, requires `cuda` feature)
164    Cuda,
165    /// TensorRT execution (NVIDIA GPUs with TensorRT, requires `tensorrt` feature)
166    TensorRT,
167    /// CoreML execution (Apple Silicon, requires `coreml` feature)
168    CoreML,
169    /// DirectML execution (Windows GPUs, requires `directml` feature)
170    DirectML,
171}
172
173fn configure_execution_provider(
174    builder: SessionBuilder,
175    provider: ExecutionProvider,
176) -> Result<SessionBuilder> {
177    match provider {
178        ExecutionProvider::Auto => configure_auto_provider(builder),
179        ExecutionProvider::Cpu => Ok(builder),
180        ExecutionProvider::Cuda => configure_cuda(builder),
181        ExecutionProvider::TensorRT => configure_tensorrt(builder),
182        ExecutionProvider::CoreML => configure_coreml(builder),
183        ExecutionProvider::DirectML => configure_directml(builder),
184    }
185}
186
187/// Get CUDA device ID from environment or default to 0
188#[cfg(feature = "cuda")]
189fn get_cuda_device_id() -> i32 {
190    std::env::var("CUDA_VISIBLE_DEVICES")
191        .ok()
192        .and_then(|s| s.split(',').next().and_then(|id| id.parse::<i32>().ok()))
193        .unwrap_or(0)
194}
195
196/// Check if CPU-only mode is forced via environment variable.
197/// Set `NEXT_PLAID_FORCE_CPU=1` to completely disable CUDA and avoid
198/// any CUDA library loading overhead.
199pub fn is_force_cpu() -> bool {
200    std::env::var("NEXT_PLAID_FORCE_CPU")
201        .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
202        .unwrap_or(false)
203}
204
205/// Check if CUDA execution provider is available AND a GPU is visible.
206/// Returns true if:
207/// - NEXT_PLAID_FORCE_CPU is NOT set
208/// - CUDA feature is enabled
209/// - At least one GPU is visible (CUDA_VISIBLE_DEVICES is not empty/-1)
210/// - CUDA EP is compiled in ONNX Runtime
211///
212/// IMPORTANT: Check CUDA_VISIBLE_DEVICES FIRST before calling .is_available()
213/// to avoid CUDA driver initialization overhead when GPUs are hidden.
214#[cfg(feature = "cuda")]
215pub fn is_cuda_available() -> bool {
216    // Check if CPU-only mode is forced via environment variable
217    // This completely bypasses all CUDA checks
218    if is_force_cpu() {
219        return false;
220    }
221
222    // Check if GPUs are visible via CUDA_VISIBLE_DEVICES FIRST
223    // This avoids triggering CUDA driver initialization when GPUs are hidden
224    //
225    // Note: When CUDA_VISIBLE_DEVICES is:
226    // - Not set: GPUs are visible (default CUDA behavior)
227    // - Empty string "": GPUs are hidden
228    // - "-1": GPUs are hidden
229    // - Valid device IDs: Only those GPUs are visible
230    if let Ok(devices) = std::env::var("CUDA_VISIBLE_DEVICES") {
231        // Empty string or "-1" means no GPUs visible
232        if devices.is_empty() || devices == "-1" {
233            return false;
234        }
235    }
236    // If CUDA_VISIBLE_DEVICES is not set, GPUs are visible by default
237
238    // Only now check if CUDA EP is compiled in (may trigger CUDA driver init)
239    CUDAExecutionProvider::default()
240        .is_available()
241        .unwrap_or(false)
242}
243
244/// Check if CUDA execution provider is available.
245/// Always returns false when CUDA feature is not enabled.
246#[cfg(not(feature = "cuda"))]
247pub fn is_cuda_available() -> bool {
248    false
249}
250
251fn configure_auto_provider(builder: SessionBuilder) -> Result<SessionBuilder> {
252    // Skip GPU providers entirely if CPU-only mode is forced
253    #[cfg(any(feature = "cuda", feature = "tensorrt", feature = "coreml"))]
254    let force_cpu = is_force_cpu();
255
256    #[cfg(feature = "cuda")]
257    if !force_cpu {
258        let device_id = get_cuda_device_id();
259        if let Ok(b) = builder
260            .clone()
261            .with_execution_providers([CUDAExecutionProvider::default()
262                .with_device_id(device_id)
263                .with_tf32(true)
264                .build()])
265        {
266            return Ok(b);
267        }
268    }
269
270    #[cfg(feature = "tensorrt")]
271    if !force_cpu {
272        if let Ok(b) = builder
273            .clone()
274            .with_execution_providers([TensorRTExecutionProvider::default().build()])
275        {
276            return Ok(b);
277        }
278    }
279
280    #[cfg(feature = "coreml")]
281    {
282        if let Ok(b) = builder
283            .clone()
284            .with_execution_providers([CoreMLExecutionProvider::default().build()])
285        {
286            return Ok(b);
287        }
288    }
289
290    #[cfg(feature = "directml")]
291    if !force_cpu {
292        if let Ok(b) = builder
293            .clone()
294            .with_execution_providers([DirectMLExecutionProvider::default().build()])
295        {
296            return Ok(b);
297        }
298    }
299
300    Ok(builder)
301}
302
303#[cfg(feature = "cuda")]
304fn configure_cuda(builder: SessionBuilder) -> Result<SessionBuilder> {
305    // If CPU-only mode is forced, return CPU provider instead
306    if is_force_cpu() {
307        return Ok(builder);
308    }
309
310    let device_id = get_cuda_device_id();
311    builder
312        .with_execution_providers([
313            CUDAExecutionProvider::default()
314                .with_device_id(device_id)
315                .with_tf32(true)
316                .build()
317        ])
318        .context("Failed to configure CUDA execution provider. Ensure CUDA toolkit and cuDNN are installed.")
319}
320
321#[cfg(not(feature = "cuda"))]
322fn configure_cuda(_builder: SessionBuilder) -> Result<SessionBuilder> {
323    anyhow::bail!("CUDA support not compiled. Enable the 'cuda' feature.")
324}
325
326#[cfg(feature = "tensorrt")]
327fn configure_tensorrt(builder: SessionBuilder) -> Result<SessionBuilder> {
328    builder
329        .with_execution_providers([TensorRTExecutionProvider::default().build()])
330        .context("Failed to configure TensorRT execution provider")
331}
332
333#[cfg(not(feature = "tensorrt"))]
334fn configure_tensorrt(_builder: SessionBuilder) -> Result<SessionBuilder> {
335    anyhow::bail!("TensorRT support not compiled. Enable the 'tensorrt' feature.")
336}
337
338#[cfg(feature = "coreml")]
339fn configure_coreml(builder: SessionBuilder) -> Result<SessionBuilder> {
340    builder
341        .with_execution_providers([CoreMLExecutionProvider::default().build()])
342        .context("Failed to configure CoreML execution provider")
343}
344
345#[cfg(not(feature = "coreml"))]
346fn configure_coreml(_builder: SessionBuilder) -> Result<SessionBuilder> {
347    anyhow::bail!("CoreML support not compiled. Enable the 'coreml' feature.")
348}
349
350#[cfg(feature = "directml")]
351fn configure_directml(builder: SessionBuilder) -> Result<SessionBuilder> {
352    builder
353        .with_execution_providers([DirectMLExecutionProvider::default().build()])
354        .context("Failed to configure DirectML execution provider")
355}
356
357#[cfg(not(feature = "directml"))]
358fn configure_directml(_builder: SessionBuilder) -> Result<SessionBuilder> {
359    anyhow::bail!("DirectML support not compiled. Enable the 'directml' feature.")
360}
361
362// =============================================================================
363// Configuration
364// =============================================================================
365
366/// Configuration for ColBERT model behavior.
367///
368/// This is automatically loaded from `onnx_config.json` when loading a model.
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ColbertConfig {
371    /// Prefix prepended to queries (e.g., "\[Q\] " or "\[unused0\]")
372    #[serde(default = "default_query_prefix")]
373    pub query_prefix: String,
374
375    /// Prefix prepended to documents (e.g., "\[D\] " or "\[unused1\]")
376    #[serde(default = "default_document_prefix")]
377    pub document_prefix: String,
378
379    /// Maximum sequence length for queries (typically 32-48)
380    #[serde(default = "default_query_length")]
381    pub query_length: usize,
382
383    /// Maximum sequence length for documents (typically 180-300)
384    #[serde(default = "default_document_length")]
385    pub document_length: usize,
386
387    /// Whether to expand queries with MASK tokens
388    #[serde(default = "default_do_query_expansion")]
389    pub do_query_expansion: bool,
390
391    /// Output embedding dimension
392    #[serde(default = "default_embedding_dim")]
393    pub embedding_dim: usize,
394
395    /// Whether the model uses token_type_ids (BERT does, ModernBERT doesn't)
396    #[serde(default = "default_uses_token_type_ids")]
397    pub uses_token_type_ids: bool,
398
399    /// MASK token ID for query expansion
400    #[serde(default = "default_mask_token_id")]
401    pub mask_token_id: u32,
402
403    /// PAD token ID
404    #[serde(default = "default_pad_token_id")]
405    pub pad_token_id: u32,
406
407    /// Words/punctuation to filter from document embeddings
408    #[serde(default)]
409    pub skiplist_words: Vec<String>,
410
411    // Internal fields
412    #[serde(default = "default_model_type")]
413    model_type: String,
414    #[serde(default)]
415    model_name: Option<String>,
416    #[serde(default)]
417    model_class: Option<String>,
418    #[serde(default)]
419    attend_to_expansion_tokens: bool,
420    query_prefix_id: Option<u32>,
421    document_prefix_id: Option<u32>,
422    /// Whether to lowercase text before tokenization (matches sentence-transformers preprocessing)
423    #[serde(default)]
424    pub do_lower_case: bool,
425}
426
427fn default_model_type() -> String {
428    "ColBERT".to_string()
429}
430fn default_uses_token_type_ids() -> bool {
431    true
432}
433fn default_query_prefix() -> String {
434    "[Q] ".to_string()
435}
436fn default_document_prefix() -> String {
437    "[D] ".to_string()
438}
439fn default_query_length() -> usize {
440    48
441}
442fn default_document_length() -> usize {
443    300
444}
445fn default_do_query_expansion() -> bool {
446    true
447}
448fn default_embedding_dim() -> usize {
449    128
450}
451fn default_mask_token_id() -> u32 {
452    103
453}
454fn default_pad_token_id() -> u32 {
455    0
456}
457
458impl Default for ColbertConfig {
459    fn default() -> Self {
460        Self {
461            model_type: default_model_type(),
462            model_name: None,
463            model_class: None,
464            uses_token_type_ids: default_uses_token_type_ids(),
465            query_prefix: default_query_prefix(),
466            document_prefix: default_document_prefix(),
467            query_length: default_query_length(),
468            document_length: default_document_length(),
469            do_query_expansion: default_do_query_expansion(),
470            attend_to_expansion_tokens: false,
471            skiplist_words: Vec::new(),
472            embedding_dim: default_embedding_dim(),
473            mask_token_id: default_mask_token_id(),
474            pad_token_id: default_pad_token_id(),
475            query_prefix_id: None,
476            document_prefix_id: None,
477            do_lower_case: false,
478        }
479    }
480}
481
482impl ColbertConfig {
483    /// Load config from a JSON file.
484    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
485        let content = fs::read_to_string(path.as_ref())
486            .with_context(|| format!("Failed to read config from {:?}", path.as_ref()))?;
487        let config: ColbertConfig =
488            serde_json::from_str(&content).with_context(|| "Failed to parse onnx_config.json")?;
489        Ok(config)
490    }
491
492    fn from_model_dir<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
493        let onnx_config_path = model_dir.as_ref().join("onnx_config.json");
494        if onnx_config_path.exists() {
495            return Self::from_file(&onnx_config_path);
496        }
497
498        anyhow::bail!(
499            "onnx_config.json not found in {:?}. This file is required for ColBERT model configuration.",
500            model_dir.as_ref()
501        )
502    }
503
504    /// Get the model name (if specified in config).
505    pub fn model_name(&self) -> Option<&str> {
506        self.model_name.as_deref()
507    }
508}
509
510// =============================================================================
511// Colbert Model
512// =============================================================================
513
514/// Default batch size for CPU encoding.
515const DEFAULT_CPU_BATCH_SIZE: usize = 32;
516
517/// Default batch size for GPU encoding.
518const DEFAULT_GPU_BATCH_SIZE: usize = 64;
519
520/// Type alias for batch encoding data: (input_ids, attention_mask, token_type_ids, token_ids)
521type BatchEncoding = (Vec<i64>, Vec<i64>, Vec<i64>, Vec<u32>);
522
523/// ColBERT model for encoding documents and queries into multi-vector embeddings.
524///
525/// Supports both single-session and parallel multi-session encoding.
526///
527/// # Example
528///
529/// ```rust,ignore
530/// use next_plaid_onnx::Colbert;
531///
532/// // Simple usage
533/// let model = Colbert::new("models/GTE-ModernColBERT-v1")?;
534/// let docs = model.encode_documents(&["Hello world"], None)?;
535/// let queries = model.encode_queries(&["greeting"])?;
536///
537/// // With parallel sessions for high throughput
538/// let model = Colbert::builder("models/GTE-ModernColBERT-v1")
539///     .with_quantized(true)
540///     .with_parallel(25)
541///     .build()?;
542/// ```
543pub struct Colbert {
544    sessions: Vec<Mutex<Session>>,
545    tokenizer: Arc<Tokenizer>,
546    config: ColbertConfig,
547    skiplist_ids: HashSet<u32>,
548    batch_size: usize,
549}
550
551/// Builder for configuring [`Colbert`].
552///
553/// # Example
554///
555/// ```rust,ignore
556/// use next_plaid_onnx::{Colbert, ExecutionProvider};
557///
558/// // Simple usage with defaults
559/// let model = Colbert::builder("models/GTE-ModernColBERT-v1").build()?;
560///
561/// // Full configuration
562/// let model = Colbert::builder("models/GTE-ModernColBERT-v1")
563///     .with_quantized(true)                              // Use INT8 model
564///     .with_parallel(25)                                 // 25 parallel sessions
565///     .with_batch_size(2)                                // Batch size per session
566///     .with_execution_provider(ExecutionProvider::Cuda)  // Force CUDA
567///     .build()?;
568/// ```
569pub struct ColbertBuilder {
570    model_dir: std::path::PathBuf,
571    num_sessions: usize,
572    threads_per_session: usize,
573    batch_size: Option<usize>,
574    execution_provider: ExecutionProvider,
575    quantized: bool,
576    query_length: Option<usize>,
577    document_length: Option<usize>,
578}
579
580impl ColbertBuilder {
581    /// Create a new builder with default settings.
582    ///
583    /// Default configuration:
584    /// - Single session with auto-detected thread count
585    /// - No quantization (FP32 model)
586    /// - Auto execution provider (best available hardware)
587    pub fn new<P: AsRef<Path>>(model_dir: P) -> Self {
588        let num_threads = std::thread::available_parallelism()
589            .map(|p| p.get())
590            .unwrap_or(4);
591        Self {
592            model_dir: model_dir.as_ref().to_path_buf(),
593            num_sessions: 1,
594            threads_per_session: num_threads,
595            batch_size: None,
596            execution_provider: ExecutionProvider::Auto,
597            quantized: false,
598            query_length: None,
599            document_length: None,
600        }
601    }
602
603    /// Enable parallel encoding with multiple ONNX sessions.
604    ///
605    /// More sessions = more parallelism but also more memory.
606    /// When enabled, uses 1 thread per session (optimal for parallel execution).
607    ///
608    /// Recommended: 25 for large models, 8 for small models.
609    pub fn with_parallel(mut self, num_sessions: usize) -> Self {
610        self.num_sessions = num_sessions.max(1);
611        self.threads_per_session = 1; // Optimal for parallel sessions
612        self
613    }
614
615    /// Set the number of threads (for single-session mode).
616    ///
617    /// This is automatically set when using `with_parallel()`.
618    pub fn with_threads(mut self, num_threads: usize) -> Self {
619        self.threads_per_session = num_threads;
620        self
621    }
622
623    /// Set the batch size (documents processed per inference call).
624    ///
625    /// Default: 32 for CPU, 64 for GPU (single session) or 2 (parallel sessions).
626    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
627        self.batch_size = Some(batch_size);
628        self
629    }
630
631    /// Set the hardware acceleration provider.
632    pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
633        self.execution_provider = provider;
634        self
635    }
636
637    /// Use INT8 quantized model (`model_int8.onnx`) for faster inference.
638    ///
639    /// Quantization provides ~2x speedup with minimal quality loss (>99% cosine similarity).
640    pub fn with_quantized(mut self, quantized: bool) -> Self {
641        self.quantized = quantized;
642        self
643    }
644
645    /// Set the maximum query length.
646    ///
647    /// If not set, uses `query_length` from `onnx_config.json` (default: 48).
648    /// Queries longer than this will be truncated.
649    pub fn with_query_length(mut self, query_length: usize) -> Self {
650        self.query_length = Some(query_length);
651        self
652    }
653
654    /// Set the maximum document length.
655    ///
656    /// If not set, uses `document_length` from `onnx_config.json` (default: 300).
657    /// Documents longer than this will be truncated.
658    pub fn with_document_length(mut self, document_length: usize) -> Self {
659        self.document_length = Some(document_length);
660        self
661    }
662
663    /// Build the Colbert model.
664    pub fn build(self) -> Result<Colbert> {
665        init_ort_runtime();
666
667        let model_dir = &self.model_dir;
668        let onnx_path = select_onnx_file(model_dir, self.quantized)?;
669        let tokenizer_path = model_dir.join("tokenizer.json");
670
671        let tokenizer = Tokenizer::from_file(&tokenizer_path)
672            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
673
674        let mut config = ColbertConfig::from_model_dir(model_dir)?;
675
676        // Set query_length and document_length:
677        // - If user provided a value, use it
678        // - Otherwise, use value from onnx_config.json
679        if let Some(query_length) = self.query_length {
680            config.query_length = query_length;
681        }
682        if let Some(document_length) = self.document_length {
683            config.document_length = document_length;
684        }
685
686        update_token_ids(&mut config, &tokenizer);
687        let skiplist_ids = build_skiplist(&config, &tokenizer);
688
689        // Create sessions
690        let mut sessions = Vec::with_capacity(self.num_sessions);
691        for _i in 0..self.num_sessions {
692            let builder = Session::builder()?
693                .with_optimization_level(GraphOptimizationLevel::Level3)?
694                .with_intra_threads(self.threads_per_session)?
695                .with_inter_threads(if self.num_sessions > 1 { 1 } else { 2 })?
696                // Disable memory pattern optimization for ~7% speedup on CPU
697                // (based on benchmarking - helps with variable-length sequences)
698                .with_memory_pattern(false)?;
699
700            let builder = configure_execution_provider(builder, self.execution_provider)?;
701
702            let session = builder
703                .commit_from_file(&onnx_path)
704                .context("Failed to load ONNX model")?;
705
706            sessions.push(Mutex::new(session));
707        }
708
709        // Determine batch size
710        let batch_size = self.batch_size.unwrap_or(if self.num_sessions > 1 {
711            2 // Small batches optimal for parallel sessions
712        } else {
713            match self.execution_provider {
714                ExecutionProvider::Cpu => DEFAULT_CPU_BATCH_SIZE,
715                _ => DEFAULT_GPU_BATCH_SIZE,
716            }
717        });
718
719        Ok(Colbert {
720            sessions,
721            tokenizer: Arc::new(tokenizer),
722            config,
723            skiplist_ids,
724            batch_size,
725        })
726    }
727}
728
729impl Colbert {
730    /// Load a ColBERT model with default settings.
731    ///
732    /// Uses auto-detected thread count and hardware acceleration.
733    ///
734    /// # Example
735    ///
736    /// ```rust,ignore
737    /// let model = Colbert::new("models/GTE-ModernColBERT-v1")?;
738    /// ```
739    pub fn new<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
740        ColbertBuilder::new(model_dir).build()
741    }
742
743    /// Create a builder for advanced configuration.
744    ///
745    /// # Example
746    ///
747    /// ```rust,ignore
748    /// let model = Colbert::builder("models/GTE-ModernColBERT-v1")
749    ///     .with_quantized(true)
750    ///     .with_parallel(25)
751    ///     .build()?;
752    /// ```
753    pub fn builder<P: AsRef<Path>>(model_dir: P) -> ColbertBuilder {
754        ColbertBuilder::new(model_dir)
755    }
756
757    /// Encode documents into ColBERT embeddings.
758    ///
759    /// Each document is encoded into a matrix of shape `[num_tokens, embedding_dim]`,
760    /// where `num_tokens` is the number of non-padding, non-skiplist tokens.
761    ///
762    /// # Arguments
763    /// * `documents` - The documents to encode
764    /// * `pool_factor` - Optional reduction factor for hierarchical pooling.
765    ///   - `None` or `Some(1)`: No pooling, return all token embeddings
766    ///   - `Some(2)`: Keep ~50% of tokens by clustering similar ones
767    ///   - `Some(3)`: Keep ~33% of tokens, etc.
768    ///
769    /// # Example
770    ///
771    /// ```rust,ignore
772    /// // Without pooling
773    /// let embeddings = model.encode_documents(&["Paris is the capital of France."], None)?;
774    ///
775    /// // With pooling (keep ~50% of tokens)
776    /// let embeddings = model.encode_documents(&["Paris is the capital of France."], Some(2))?;
777    /// ```
778    pub fn encode_documents(
779        &self,
780        documents: &[&str],
781        pool_factor: Option<usize>,
782    ) -> Result<Vec<Array2<f32>>> {
783        if documents.is_empty() {
784            return Ok(Vec::new());
785        }
786
787        let embeddings = if self.sessions.len() == 1 {
788            self.encode_single_session(documents, false, true)?
789        } else {
790            self.encode_parallel(documents, false, true)?
791        };
792
793        // Apply pooling if requested
794        match pool_factor {
795            Some(pf) if pf > 1 => {
796                let pooled: Vec<Array2<f32>> = embeddings
797                    .into_iter()
798                    .map(|emb| pool_embeddings_hierarchical(emb, pf, 1))
799                    .collect();
800                Ok(pooled)
801            }
802            _ => Ok(embeddings),
803        }
804    }
805
806    /// Encode queries into ColBERT embeddings.
807    ///
808    /// Each query is encoded into a matrix of shape `[query_length, embedding_dim]`.
809    /// Queries are padded with MASK tokens to enable query expansion.
810    ///
811    /// # Example
812    ///
813    /// ```rust,ignore
814    /// let embeddings = model.encode_queries(&["What is the capital of France?"])?;
815    /// ```
816    pub fn encode_queries(&self, queries: &[&str]) -> Result<Vec<Array2<f32>>> {
817        if queries.is_empty() {
818            return Ok(Vec::new());
819        }
820
821        if self.sessions.len() == 1 {
822            self.encode_single_session(queries, true, false)
823        } else {
824            self.encode_parallel(queries, true, false)
825        }
826    }
827
828    /// Get the model configuration.
829    pub fn config(&self) -> &ColbertConfig {
830        &self.config
831    }
832
833    /// Get the embedding dimension.
834    pub fn embedding_dim(&self) -> usize {
835        self.config.embedding_dim
836    }
837
838    /// Get the batch size used for encoding.
839    pub fn batch_size(&self) -> usize {
840        self.batch_size
841    }
842
843    /// Get the number of parallel sessions.
844    pub fn num_sessions(&self) -> usize {
845        self.sessions.len()
846    }
847
848    // =========================================================================
849    // Internal encoding implementations
850    // =========================================================================
851
852    fn encode_single_session(
853        &self,
854        texts: &[&str],
855        is_query: bool,
856        filter_skiplist: bool,
857    ) -> Result<Vec<Array2<f32>>> {
858        let mut all_embeddings = Vec::with_capacity(texts.len());
859
860        for chunk in texts.chunks(self.batch_size) {
861            let mut session = self.sessions[0].lock().unwrap();
862            let chunk_embeddings = encode_batch_with_session(
863                &mut session,
864                &self.tokenizer,
865                &self.config,
866                &self.skiplist_ids,
867                chunk,
868                is_query,
869                filter_skiplist,
870            )?;
871            all_embeddings.extend(chunk_embeddings);
872        }
873
874        Ok(all_embeddings)
875    }
876
877    fn encode_parallel(
878        &self,
879        texts: &[&str],
880        is_query: bool,
881        filter_skiplist: bool,
882    ) -> Result<Vec<Array2<f32>>> {
883        let num_sessions = self.sessions.len();
884
885        let chunks: Vec<Vec<&str>> = texts
886            .chunks(self.batch_size.max(1))
887            .map(|c| c.to_vec())
888            .collect();
889
890        let results: Vec<Result<Vec<Array2<f32>>>> = std::thread::scope(|s| {
891            let handles: Vec<_> = chunks
892                .iter()
893                .enumerate()
894                .map(|(i, chunk)| {
895                    let session_idx = i % num_sessions;
896                    let session_mutex = &self.sessions[session_idx];
897                    let tokenizer = &self.tokenizer;
898                    let config = &self.config;
899                    let skiplist_ids = &self.skiplist_ids;
900
901                    s.spawn(move || {
902                        let mut session = session_mutex.lock().unwrap();
903                        encode_batch_with_session(
904                            &mut session,
905                            tokenizer,
906                            config,
907                            skiplist_ids,
908                            chunk,
909                            is_query,
910                            filter_skiplist,
911                        )
912                    })
913                })
914                .collect();
915
916            handles.into_iter().map(|h| h.join().unwrap()).collect()
917        });
918
919        let mut all_embeddings = Vec::with_capacity(texts.len());
920        for result in results {
921            all_embeddings.extend(result?);
922        }
923
924        Ok(all_embeddings)
925    }
926}
927
928// =============================================================================
929// Helper functions
930// =============================================================================
931
932fn select_onnx_file<P: AsRef<Path>>(model_dir: P, quantized: bool) -> Result<std::path::PathBuf> {
933    let model_dir = model_dir.as_ref();
934
935    if quantized {
936        // When --int8 IS provided, always load model_int8.onnx specifically.
937        let q_path = model_dir.join("model_int8.onnx");
938        if q_path.exists() {
939            Ok(q_path)
940        } else {
941            anyhow::bail!(
942                "INT8 quantized model not found at {:?}. Remove --int8 flag to load model.onnx instead.",
943                q_path
944            )
945        }
946    } else {
947        // When --int8 is NOT provided, always load model.onnx specifically.
948        // This prevents accidentally loading model_int8.onnx when model.onnx is missing.
949        let model_path = model_dir.join("model.onnx");
950        if model_path.exists() {
951            Ok(model_path)
952        } else {
953            anyhow::bail!(
954                "Model not found at {:?}. Use --int8 flag to load model_int8.onnx instead.",
955                model_path
956            )
957        }
958    }
959}
960
961fn update_token_ids(config: &mut ColbertConfig, tokenizer: &Tokenizer) {
962    if config.mask_token_id == default_mask_token_id() {
963        if let Some(mask_id) = tokenizer.token_to_id("[MASK]") {
964            config.mask_token_id = mask_id;
965        } else if let Some(mask_id) = tokenizer.token_to_id("<mask>") {
966            config.mask_token_id = mask_id;
967        }
968    }
969    if config.pad_token_id == default_pad_token_id() {
970        if let Some(pad_id) = tokenizer.token_to_id("[PAD]") {
971            config.pad_token_id = pad_id;
972        } else if let Some(pad_id) = tokenizer.token_to_id("<pad>") {
973            config.pad_token_id = pad_id;
974        }
975    }
976}
977
978fn build_skiplist(config: &ColbertConfig, tokenizer: &Tokenizer) -> HashSet<u32> {
979    let mut skiplist_ids = HashSet::new();
980    for word in &config.skiplist_words {
981        if let Some(token_id) = tokenizer.token_to_id(word) {
982            skiplist_ids.insert(token_id);
983        }
984    }
985    skiplist_ids
986}
987
988/// Internal function to encode a batch using a specific session.
989///
990/// This function matches PyLate's tokenization approach:
991/// 1. Tokenize text WITHOUT the prefix (max_length - 1 tokens)
992/// 2. Insert the prefix token ID after [CLS] (position 1)
993///
994/// This ensures that long documents get the same number of content tokens
995/// as PyLate, where the prefix is inserted after initial tokenization.
996fn encode_batch_with_session(
997    session: &mut Session,
998    tokenizer: &Tokenizer,
999    config: &ColbertConfig,
1000    skiplist_ids: &HashSet<u32>,
1001    texts: &[&str],
1002    is_query: bool,
1003    filter_skiplist: bool,
1004) -> Result<Vec<Array2<f32>>> {
1005    if texts.is_empty() {
1006        return Ok(Vec::new());
1007    }
1008
1009    let (prefix_str, prefix_token_id_opt, max_length) = if is_query {
1010        (
1011            &config.query_prefix,
1012            config.query_prefix_id,
1013            config.query_length,
1014        )
1015    } else {
1016        (
1017            &config.document_prefix,
1018            config.document_prefix_id,
1019            config.document_length,
1020        )
1021    };
1022
1023    // Get the prefix token ID, either from config or by looking it up in the tokenizer
1024    let prefix_token_id: u32 = match prefix_token_id_opt {
1025        Some(id) => id,
1026        None => tokenizer.token_to_id(prefix_str).ok_or_else(|| {
1027            anyhow::anyhow!(
1028                "Prefix token '{}' not found in tokenizer vocabulary",
1029                prefix_str
1030            )
1031        })?,
1032    };
1033
1034    // Apply text preprocessing to match sentence-transformers behavior:
1035    // 1. Strip leading/trailing whitespace
1036    // 2. Lowercase if configured
1037    let processed_texts: Vec<String> = if config.do_lower_case {
1038        texts.iter().map(|t| t.trim().to_lowercase()).collect()
1039    } else {
1040        texts.iter().map(|t| t.trim().to_string()).collect()
1041    };
1042    let texts_to_encode: Vec<&str> = processed_texts.iter().map(|s| s.as_str()).collect();
1043
1044    // Tokenize texts WITHOUT the prefix first (matching PyLate's approach)
1045    // PyLate tokenizes with max_length - 1 to reserve space for the prefix token
1046    let batch_encodings = tokenizer
1047        .encode_batch(texts_to_encode, true)
1048        .map_err(|e| anyhow::anyhow!("Tokenization error: {}", e))?;
1049
1050    let mut encodings: Vec<BatchEncoding> = Vec::with_capacity(texts.len());
1051    let mut batch_max_len = 0usize;
1052
1053    // Truncate limit is max_length - 1 to leave room for prefix token insertion
1054    let truncate_limit = max_length - 1;
1055
1056    for encoding in batch_encodings {
1057        let token_ids: Vec<u32> = encoding.get_ids().to_vec();
1058        let mut input_ids: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
1059        let mut attention_mask: Vec<i64> = encoding
1060            .get_attention_mask()
1061            .iter()
1062            .map(|&x| x as i64)
1063            .collect();
1064        let mut token_type_ids: Vec<i64> =
1065            encoding.get_type_ids().iter().map(|&x| x as i64).collect();
1066        let mut token_ids_vec = token_ids;
1067
1068        // Truncate to max_length - 1 to leave room for prefix token
1069        // IMPORTANT: Preserve [SEP] token at the end when truncating
1070        // PyLate truncates content but keeps [CLS] at start and [SEP] at end
1071        if input_ids.len() > truncate_limit {
1072            // Save the [SEP] token (last token)
1073            let sep_token = input_ids[input_ids.len() - 1];
1074            let sep_mask = attention_mask[attention_mask.len() - 1];
1075            let sep_type = token_type_ids[token_type_ids.len() - 1];
1076            let sep_token_id = token_ids_vec[token_ids_vec.len() - 1];
1077
1078            // Truncate content (keeping room for [SEP])
1079            input_ids.truncate(truncate_limit - 1);
1080            attention_mask.truncate(truncate_limit - 1);
1081            token_type_ids.truncate(truncate_limit - 1);
1082            token_ids_vec.truncate(truncate_limit - 1);
1083
1084            // Re-add [SEP] at the end
1085            input_ids.push(sep_token);
1086            attention_mask.push(sep_mask);
1087            token_type_ids.push(sep_type);
1088            token_ids_vec.push(sep_token_id);
1089        }
1090
1091        // Insert prefix token after [CLS] (position 1), matching PyLate's insert_prefix_token
1092        // PyLate does: torch.cat([input_ids[:, :1], prefix_tensor, input_ids[:, 1:]], dim=1)
1093        input_ids.insert(1, prefix_token_id as i64);
1094        attention_mask.insert(1, 1);
1095        token_type_ids.insert(1, 0);
1096        token_ids_vec.insert(1, prefix_token_id);
1097
1098        batch_max_len = batch_max_len.max(input_ids.len());
1099        encodings.push((input_ids, attention_mask, token_type_ids, token_ids_vec));
1100    }
1101
1102    if is_query && config.do_query_expansion {
1103        batch_max_len = max_length;
1104    }
1105
1106    let batch_size = texts.len();
1107    let mut all_input_ids: Vec<i64> = Vec::with_capacity(batch_size * batch_max_len);
1108    let mut all_attention_mask: Vec<i64> = Vec::with_capacity(batch_size * batch_max_len);
1109    let mut all_token_type_ids: Vec<i64> = Vec::with_capacity(batch_size * batch_max_len);
1110    let mut all_token_ids: Vec<Vec<u32>> = Vec::with_capacity(batch_size);
1111    let mut original_lengths: Vec<usize> = Vec::with_capacity(batch_size);
1112
1113    for (mut input_ids, mut attention_mask, mut token_type_ids, mut token_ids) in encodings {
1114        original_lengths.push(input_ids.len());
1115
1116        while input_ids.len() < batch_max_len {
1117            if is_query && config.do_query_expansion {
1118                input_ids.push(config.mask_token_id as i64);
1119                attention_mask.push(1);
1120                token_ids.push(config.mask_token_id);
1121            } else {
1122                input_ids.push(config.pad_token_id as i64);
1123                attention_mask.push(0);
1124                token_ids.push(config.pad_token_id);
1125            }
1126            token_type_ids.push(0);
1127        }
1128
1129        all_input_ids.extend(input_ids);
1130        all_attention_mask.extend(attention_mask);
1131        all_token_type_ids.extend(token_type_ids);
1132        all_token_ids.push(token_ids);
1133    }
1134
1135    let input_ids_tensor = Tensor::from_array(([batch_size, batch_max_len], all_input_ids))?;
1136    let attention_mask_tensor =
1137        Tensor::from_array(([batch_size, batch_max_len], all_attention_mask.clone()))?;
1138
1139    let token_type_ids_tensor = if config.uses_token_type_ids {
1140        Some(Tensor::from_array((
1141            [batch_size, batch_max_len],
1142            all_token_type_ids,
1143        ))?)
1144    } else {
1145        None
1146    };
1147
1148    let outputs = if let Some(token_type_ids_tensor) = token_type_ids_tensor {
1149        session.run(ort::inputs![
1150            "input_ids" => input_ids_tensor,
1151            "attention_mask" => attention_mask_tensor,
1152            "token_type_ids" => token_type_ids_tensor,
1153        ])?
1154    } else {
1155        session.run(ort::inputs![
1156            "input_ids" => input_ids_tensor,
1157            "attention_mask" => attention_mask_tensor,
1158        ])?
1159    };
1160
1161    let (output_shape, output_data) = outputs["output"]
1162        .try_extract_tensor::<f32>()
1163        .context("Failed to extract output tensor")?;
1164
1165    let shape_slice: Vec<i64> = output_shape.iter().copied().collect();
1166    let embedding_dim = shape_slice[2] as usize;
1167
1168    let mut all_embeddings = Vec::with_capacity(batch_size);
1169    for i in 0..batch_size {
1170        let batch_offset = i * batch_max_len * embedding_dim;
1171        let attention_offset = i * batch_max_len;
1172
1173        if is_query && config.do_query_expansion {
1174            let end = batch_offset + batch_max_len * embedding_dim;
1175            let flat: Vec<f32> = output_data[batch_offset..end].to_vec();
1176            let arr = Array2::from_shape_vec((batch_max_len, embedding_dim), flat)?;
1177            all_embeddings.push(arr);
1178        } else {
1179            let orig_len = original_lengths[i];
1180            let token_ids = &all_token_ids[i];
1181
1182            let valid_count = (0..orig_len)
1183                .filter(|&j| {
1184                    let mask = all_attention_mask[attention_offset + j];
1185                    let token_id = token_ids[j];
1186                    mask != 0 && !(filter_skiplist && skiplist_ids.contains(&token_id))
1187                })
1188                .count();
1189
1190            let mut flat: Vec<f32> = Vec::with_capacity(valid_count * embedding_dim);
1191            for j in 0..orig_len {
1192                let mask = all_attention_mask[attention_offset + j];
1193                let token_id = token_ids[j];
1194
1195                if mask == 0 {
1196                    continue;
1197                }
1198                if filter_skiplist && skiplist_ids.contains(&token_id) {
1199                    continue;
1200                }
1201
1202                let start = batch_offset + j * embedding_dim;
1203                flat.extend_from_slice(&output_data[start..start + embedding_dim]);
1204            }
1205
1206            let arr = Array2::from_shape_vec((valid_count, embedding_dim), flat)?;
1207            all_embeddings.push(arr);
1208        }
1209    }
1210
1211    Ok(all_embeddings)
1212}
1213
1214/// Pool embeddings using hierarchical clustering with Ward's method.
1215fn pool_embeddings_hierarchical(
1216    embeddings: Array2<f32>,
1217    pool_factor: usize,
1218    protected_tokens: usize,
1219) -> Array2<f32> {
1220    let n_tokens = embeddings.nrows();
1221    let n_features = embeddings.ncols();
1222
1223    if n_tokens <= protected_tokens + 1 {
1224        return embeddings;
1225    }
1226
1227    let tokens_to_pool = n_tokens - protected_tokens;
1228    let num_clusters = (tokens_to_pool / pool_factor).max(1);
1229
1230    if num_clusters >= tokens_to_pool {
1231        return embeddings;
1232    }
1233
1234    let to_pool = embeddings.slice(ndarray::s![protected_tokens.., ..]);
1235    let flat_embeddings: Vec<f32> = to_pool.iter().copied().collect();
1236
1237    let distances = crate::hierarchy::pdist_cosine(&flat_embeddings, tokens_to_pool, n_features);
1238
1239    let linkage_matrix = crate::hierarchy::linkage(
1240        &distances,
1241        tokens_to_pool,
1242        crate::hierarchy::LinkageMethod::Ward,
1243    );
1244
1245    let labels = crate::hierarchy::fcluster(
1246        &linkage_matrix,
1247        tokens_to_pool,
1248        crate::hierarchy::FclusterCriterion::MaxClust,
1249        num_clusters as f64,
1250    );
1251
1252    let mut pooled_rows: Vec<Vec<f32>> = Vec::with_capacity(num_clusters + protected_tokens);
1253
1254    for i in 0..protected_tokens {
1255        pooled_rows.push(embeddings.row(i).to_vec());
1256    }
1257
1258    for cluster_id in 1..=num_clusters {
1259        let mut sum = vec![0.0f32; n_features];
1260        let mut count = 0usize;
1261
1262        for (idx, &label) in labels.iter().enumerate() {
1263            if label == cluster_id {
1264                let row = to_pool.row(idx);
1265                for (s, &v) in sum.iter_mut().zip(row.iter()) {
1266                    *s += v;
1267                }
1268                count += 1;
1269            }
1270        }
1271
1272        if count > 0 {
1273            for s in &mut sum {
1274                *s /= count as f32;
1275            }
1276            pooled_rows.push(sum);
1277        }
1278    }
1279
1280    let n_pooled = pooled_rows.len();
1281    let flat: Vec<f32> = pooled_rows.into_iter().flatten().collect();
1282    Array2::from_shape_vec((n_pooled, n_features), flat)
1283        .expect("Shape mismatch in pooled embeddings")
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288    use super::*;
1289
1290    // =========================================================================
1291    // ColbertConfig tests
1292    // =========================================================================
1293
1294    #[test]
1295    fn test_default_config() {
1296        let config = ColbertConfig::default();
1297        assert_eq!(config.query_length, 48);
1298        assert_eq!(config.document_length, 300);
1299        assert!(config.do_query_expansion);
1300        assert_eq!(config.embedding_dim, 128);
1301        assert_eq!(config.mask_token_id, 103);
1302        assert_eq!(config.pad_token_id, 0);
1303        assert!(config.uses_token_type_ids);
1304        assert_eq!(config.query_prefix, "[Q] ");
1305        assert_eq!(config.document_prefix, "[D] ");
1306        assert!(config.skiplist_words.is_empty());
1307    }
1308
1309    #[test]
1310    fn test_config_serialization_roundtrip() {
1311        let config = ColbertConfig::default();
1312        let json = serde_json::to_string(&config).unwrap();
1313        let parsed: ColbertConfig = serde_json::from_str(&json).unwrap();
1314
1315        assert_eq!(parsed.query_length, config.query_length);
1316        assert_eq!(parsed.document_length, config.document_length);
1317        assert_eq!(parsed.do_query_expansion, config.do_query_expansion);
1318        assert_eq!(parsed.embedding_dim, config.embedding_dim);
1319        assert_eq!(parsed.mask_token_id, config.mask_token_id);
1320        assert_eq!(parsed.pad_token_id, config.pad_token_id);
1321        assert_eq!(parsed.uses_token_type_ids, config.uses_token_type_ids);
1322    }
1323
1324    #[test]
1325    fn test_config_deserialization_with_custom_values() {
1326        let json = r#"{
1327            "query_length": 64,
1328            "document_length": 512,
1329            "do_query_expansion": false,
1330            "embedding_dim": 256,
1331            "mask_token_id": 4,
1332            "pad_token_id": 1,
1333            "uses_token_type_ids": false,
1334            "query_prefix": "[query]",
1335            "document_prefix": "[doc]",
1336            "skiplist_words": ["the", "a", "an"]
1337        }"#;
1338
1339        let config: ColbertConfig = serde_json::from_str(json).unwrap();
1340
1341        assert_eq!(config.query_length, 64);
1342        assert_eq!(config.document_length, 512);
1343        assert!(!config.do_query_expansion);
1344        assert_eq!(config.embedding_dim, 256);
1345        assert_eq!(config.mask_token_id, 4);
1346        assert_eq!(config.pad_token_id, 1);
1347        assert!(!config.uses_token_type_ids);
1348        assert_eq!(config.query_prefix, "[query]");
1349        assert_eq!(config.document_prefix, "[doc]");
1350        assert_eq!(config.skiplist_words, vec!["the", "a", "an"]);
1351    }
1352
1353    #[test]
1354    fn test_config_deserialization_with_defaults() {
1355        // Empty JSON should use all defaults
1356        let json = "{}";
1357        let config: ColbertConfig = serde_json::from_str(json).unwrap();
1358
1359        assert_eq!(config.query_length, 48);
1360        assert_eq!(config.document_length, 300);
1361        assert!(config.do_query_expansion);
1362    }
1363
1364    // =========================================================================
1365    // ColbertBuilder tests
1366    // =========================================================================
1367
1368    #[test]
1369    fn test_builder_defaults() {
1370        let builder = ColbertBuilder::new("test_model");
1371
1372        assert_eq!(builder.num_sessions, 1);
1373        assert!(!builder.quantized);
1374        assert!(builder.batch_size.is_none());
1375        assert_eq!(builder.execution_provider, ExecutionProvider::Auto);
1376        assert!(builder.query_length.is_none());
1377        assert!(builder.document_length.is_none());
1378    }
1379
1380    #[test]
1381    fn test_builder_with_parallel() {
1382        let builder = ColbertBuilder::new("test_model").with_parallel(25);
1383
1384        assert_eq!(builder.num_sessions, 25);
1385        assert_eq!(builder.threads_per_session, 1); // Auto-set to 1 for parallel
1386    }
1387
1388    #[test]
1389    fn test_builder_with_parallel_minimum() {
1390        // with_parallel(0) should be clamped to 1
1391        let builder = ColbertBuilder::new("test_model").with_parallel(0);
1392
1393        assert_eq!(builder.num_sessions, 1);
1394    }
1395
1396    #[test]
1397    fn test_builder_with_threads() {
1398        let builder = ColbertBuilder::new("test_model").with_threads(8);
1399
1400        assert_eq!(builder.threads_per_session, 8);
1401    }
1402
1403    #[test]
1404    fn test_builder_with_batch_size() {
1405        let builder = ColbertBuilder::new("test_model").with_batch_size(64);
1406
1407        assert_eq!(builder.batch_size, Some(64));
1408    }
1409
1410    #[test]
1411    fn test_builder_with_quantized() {
1412        let builder = ColbertBuilder::new("test_model").with_quantized(true);
1413
1414        assert!(builder.quantized);
1415    }
1416
1417    #[test]
1418    fn test_builder_with_execution_provider() {
1419        let builder =
1420            ColbertBuilder::new("test_model").with_execution_provider(ExecutionProvider::Cpu);
1421
1422        assert_eq!(builder.execution_provider, ExecutionProvider::Cpu);
1423    }
1424
1425    #[test]
1426    fn test_builder_with_query_length() {
1427        let builder = ColbertBuilder::new("test_model").with_query_length(64);
1428
1429        assert_eq!(builder.query_length, Some(64));
1430    }
1431
1432    #[test]
1433    fn test_builder_with_document_length() {
1434        let builder = ColbertBuilder::new("test_model").with_document_length(512);
1435
1436        assert_eq!(builder.document_length, Some(512));
1437    }
1438
1439    #[test]
1440    fn test_builder_chained_configuration() {
1441        let builder = ColbertBuilder::new("test_model")
1442            .with_quantized(true)
1443            .with_parallel(16)
1444            .with_batch_size(4)
1445            .with_execution_provider(ExecutionProvider::Cuda)
1446            .with_query_length(64)
1447            .with_document_length(512);
1448
1449        assert!(builder.quantized);
1450        assert_eq!(builder.num_sessions, 16);
1451        assert_eq!(builder.threads_per_session, 1);
1452        assert_eq!(builder.batch_size, Some(4));
1453        assert_eq!(builder.execution_provider, ExecutionProvider::Cuda);
1454        assert_eq!(builder.query_length, Some(64));
1455        assert_eq!(builder.document_length, Some(512));
1456    }
1457
1458    // =========================================================================
1459    // ExecutionProvider tests
1460    // =========================================================================
1461
1462    #[test]
1463    fn test_execution_provider_default() {
1464        let provider = ExecutionProvider::default();
1465        assert_eq!(provider, ExecutionProvider::Auto);
1466    }
1467
1468    #[test]
1469    fn test_execution_provider_variants() {
1470        // Ensure all variants are distinct
1471        assert_ne!(ExecutionProvider::Auto, ExecutionProvider::Cpu);
1472        assert_ne!(ExecutionProvider::Cpu, ExecutionProvider::Cuda);
1473        assert_ne!(ExecutionProvider::Cuda, ExecutionProvider::TensorRT);
1474        assert_ne!(ExecutionProvider::TensorRT, ExecutionProvider::CoreML);
1475        assert_ne!(ExecutionProvider::CoreML, ExecutionProvider::DirectML);
1476    }
1477
1478    #[test]
1479    fn test_execution_provider_clone() {
1480        let provider = ExecutionProvider::Cuda;
1481        let cloned = provider;
1482        assert_eq!(provider, cloned);
1483    }
1484
1485    #[test]
1486    fn test_execution_provider_debug() {
1487        let provider = ExecutionProvider::Cuda;
1488        let debug_str = format!("{:?}", provider);
1489        assert_eq!(debug_str, "Cuda");
1490    }
1491
1492    // =========================================================================
1493    // Pool embeddings tests
1494    // =========================================================================
1495
1496    #[test]
1497    fn test_pool_embeddings_no_pooling() {
1498        // Create a small embedding matrix
1499        let embeddings = Array2::from_shape_vec(
1500            (5, 4),
1501            vec![
1502                1.0, 0.0, 0.0, 0.0, // token 0 (protected)
1503                0.0, 1.0, 0.0, 0.0, // token 1
1504                0.0, 0.0, 1.0, 0.0, // token 2
1505                0.0, 0.0, 0.0, 1.0, // token 3
1506                0.5, 0.5, 0.0, 0.0, // token 4
1507            ],
1508        )
1509        .unwrap();
1510
1511        // pool_factor=1 should not pool
1512        let result = pool_embeddings_hierarchical(embeddings.clone(), 1, 1);
1513        assert_eq!(result.dim(), embeddings.dim());
1514    }
1515
1516    #[test]
1517    fn test_pool_embeddings_with_pooling() {
1518        // Create embeddings that will cluster together
1519        let embeddings = Array2::from_shape_vec(
1520            (5, 4),
1521            vec![
1522                1.0, 0.0, 0.0, 0.0, // token 0 (protected CLS)
1523                0.9, 0.1, 0.0, 0.0, // token 1 - similar to token 2
1524                0.85, 0.15, 0.0, 0.0, // token 2 - similar to token 1
1525                0.0, 0.0, 1.0, 0.0, // token 3 - different
1526                0.0, 0.0, 0.9, 0.1, // token 4 - similar to token 3
1527            ],
1528        )
1529        .unwrap();
1530
1531        // pool_factor=2 should reduce 4 tokens to ~2 clusters + 1 protected
1532        let result = pool_embeddings_hierarchical(embeddings, 2, 1);
1533
1534        // Should have fewer tokens than original
1535        assert!(result.nrows() < 5);
1536        // Protected token should be preserved
1537        assert!(result.nrows() >= 1);
1538        // Feature dimension should be preserved
1539        assert_eq!(result.ncols(), 4);
1540    }
1541
1542    #[test]
1543    fn test_pool_embeddings_too_few_tokens() {
1544        // Only 2 tokens - too few to pool
1545        let embeddings = Array2::from_shape_vec(
1546            (2, 4),
1547            vec![
1548                1.0, 0.0, 0.0, 0.0, // protected
1549                0.0, 1.0, 0.0, 0.0, // single token
1550            ],
1551        )
1552        .unwrap();
1553
1554        let result = pool_embeddings_hierarchical(embeddings.clone(), 2, 1);
1555
1556        // Should return unchanged
1557        assert_eq!(result.dim(), embeddings.dim());
1558    }
1559
1560    #[test]
1561    fn test_pool_embeddings_all_protected() {
1562        // All tokens protected
1563        let embeddings = Array2::from_shape_vec(
1564            (3, 4),
1565            vec![
1566                1.0, 0.0, 0.0, 0.0, //
1567                0.0, 1.0, 0.0, 0.0, //
1568                0.0, 0.0, 1.0, 0.0, //
1569            ],
1570        )
1571        .unwrap();
1572
1573        // With 3 protected tokens, nothing to pool
1574        let result = pool_embeddings_hierarchical(embeddings.clone(), 2, 3);
1575
1576        // Should return unchanged
1577        assert_eq!(result.dim(), embeddings.dim());
1578    }
1579
1580    // =========================================================================
1581    // Batch size defaults tests
1582    // =========================================================================
1583
1584    #[test]
1585    fn test_default_batch_sizes() {
1586        assert_eq!(DEFAULT_CPU_BATCH_SIZE, 32);
1587        assert_eq!(DEFAULT_GPU_BATCH_SIZE, 64);
1588    }
1589}