1pub mod hierarchy;
50
51use anyhow::{Context, Result};
52use ndarray::Array2;
53use ort::session::builder::GraphOptimizationLevel;
54use ort::session::Session;
55use ort::value::Tensor;
56use serde::{Deserialize, Serialize};
57use std::collections::HashSet;
58use std::fs;
59use std::path::Path;
60use std::sync::Once;
61use std::sync::{Arc, Mutex};
62use tokenizers::Tokenizer;
63
64#[cfg(feature = "cuda")]
66use ort::ep::ExecutionProvider as ExecutionProviderTrait;
67#[cfg(feature = "cuda")]
68use ort::execution_providers::CUDAExecutionProvider;
69#[cfg(feature = "coreml")]
70use ort::execution_providers::CoreMLExecutionProvider;
71#[cfg(feature = "directml")]
72use ort::execution_providers::DirectMLExecutionProvider;
73#[cfg(feature = "tensorrt")]
74use ort::execution_providers::TensorRTExecutionProvider;
75
76use ort::session::builder::SessionBuilder;
77
78static ORT_INIT: Once = Once::new();
83
84fn init_ort_runtime() {
86 ORT_INIT.call_once(|| {
87 if std::env::var("ORT_DYLIB_PATH").is_ok() {
89 return;
90 }
91
92 if let Some(lib_path) = find_onnxruntime_library() {
94 std::env::set_var("ORT_DYLIB_PATH", &lib_path);
95 }
96 });
97}
98
99fn find_onnxruntime_library() -> Option<String> {
101 let home = std::env::var("HOME").ok()?;
102
103 let search_patterns = vec![
104 format!(
106 "{}/.venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*",
107 home
108 ),
109 format!(
110 "{}/venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*",
111 home
112 ),
113 "python/.venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*".to_string(),
114 ".venv/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*".to_string(),
115 format!(
117 "{}/.local/lib/python*/site-packages/onnxruntime/capi/libonnxruntime.so*",
118 home
119 ),
120 format!(
122 "{}/.cache/uv/archive-v*/*/onnxruntime/capi/libonnxruntime.so*",
123 home
124 ),
125 format!("{}/anaconda3/lib/libonnxruntime.so*", home),
127 format!("{}/miniconda3/lib/libonnxruntime.so*", home),
128 "/usr/local/lib/libonnxruntime.so*".to_string(),
130 "/usr/lib/libonnxruntime.so*".to_string(),
131 "/usr/lib/x86_64-linux-gnu/libonnxruntime.so*".to_string(),
132 ];
133
134 for pattern in search_patterns {
135 if let Ok(paths) = glob::glob(&pattern) {
136 for path in paths.flatten() {
137 if path.exists() && path.is_file() {
138 let path_str = path.to_string_lossy();
139 if path_str.contains(".so.") || path_str.ends_with(".so") {
140 return Some(path.to_string_lossy().to_string());
141 }
142 }
143 }
144 }
145 }
146
147 None
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
156pub enum ExecutionProvider {
157 #[default]
160 Auto,
161 Cpu,
163 Cuda,
165 TensorRT,
167 CoreML,
169 DirectML,
171}
172
173fn configure_execution_provider(
174 builder: SessionBuilder,
175 provider: ExecutionProvider,
176) -> Result<SessionBuilder> {
177 match provider {
178 ExecutionProvider::Auto => configure_auto_provider(builder),
179 ExecutionProvider::Cpu => Ok(builder),
180 ExecutionProvider::Cuda => configure_cuda(builder),
181 ExecutionProvider::TensorRT => configure_tensorrt(builder),
182 ExecutionProvider::CoreML => configure_coreml(builder),
183 ExecutionProvider::DirectML => configure_directml(builder),
184 }
185}
186
187#[cfg(feature = "cuda")]
189fn get_cuda_device_id() -> i32 {
190 std::env::var("CUDA_VISIBLE_DEVICES")
191 .ok()
192 .and_then(|s| s.split(',').next().and_then(|id| id.parse::<i32>().ok()))
193 .unwrap_or(0)
194}
195
196pub fn is_force_cpu() -> bool {
200 std::env::var("NEXT_PLAID_FORCE_CPU")
201 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
202 .unwrap_or(false)
203}
204
205#[cfg(feature = "cuda")]
215pub fn is_cuda_available() -> bool {
216 if is_force_cpu() {
219 return false;
220 }
221
222 if let Ok(devices) = std::env::var("CUDA_VISIBLE_DEVICES") {
231 if devices.is_empty() || devices == "-1" {
233 return false;
234 }
235 }
236 CUDAExecutionProvider::default()
240 .is_available()
241 .unwrap_or(false)
242}
243
244#[cfg(not(feature = "cuda"))]
247pub fn is_cuda_available() -> bool {
248 false
249}
250
251fn configure_auto_provider(builder: SessionBuilder) -> Result<SessionBuilder> {
252 #[cfg(any(feature = "cuda", feature = "tensorrt", feature = "coreml"))]
254 let force_cpu = is_force_cpu();
255
256 #[cfg(feature = "cuda")]
257 if !force_cpu {
258 let device_id = get_cuda_device_id();
259 if let Ok(b) = builder
260 .clone()
261 .with_execution_providers([CUDAExecutionProvider::default()
262 .with_device_id(device_id)
263 .with_tf32(true)
264 .build()])
265 {
266 return Ok(b);
267 }
268 }
269
270 #[cfg(feature = "tensorrt")]
271 if !force_cpu {
272 if let Ok(b) = builder
273 .clone()
274 .with_execution_providers([TensorRTExecutionProvider::default().build()])
275 {
276 return Ok(b);
277 }
278 }
279
280 #[cfg(feature = "coreml")]
281 {
282 if let Ok(b) = builder
283 .clone()
284 .with_execution_providers([CoreMLExecutionProvider::default().build()])
285 {
286 return Ok(b);
287 }
288 }
289
290 #[cfg(feature = "directml")]
291 if !force_cpu {
292 if let Ok(b) = builder
293 .clone()
294 .with_execution_providers([DirectMLExecutionProvider::default().build()])
295 {
296 return Ok(b);
297 }
298 }
299
300 Ok(builder)
301}
302
303#[cfg(feature = "cuda")]
304fn configure_cuda(builder: SessionBuilder) -> Result<SessionBuilder> {
305 if is_force_cpu() {
307 return Ok(builder);
308 }
309
310 let device_id = get_cuda_device_id();
311 builder
312 .with_execution_providers([
313 CUDAExecutionProvider::default()
314 .with_device_id(device_id)
315 .with_tf32(true)
316 .build()
317 ])
318 .context("Failed to configure CUDA execution provider. Ensure CUDA toolkit and cuDNN are installed.")
319}
320
321#[cfg(not(feature = "cuda"))]
322fn configure_cuda(_builder: SessionBuilder) -> Result<SessionBuilder> {
323 anyhow::bail!("CUDA support not compiled. Enable the 'cuda' feature.")
324}
325
326#[cfg(feature = "tensorrt")]
327fn configure_tensorrt(builder: SessionBuilder) -> Result<SessionBuilder> {
328 builder
329 .with_execution_providers([TensorRTExecutionProvider::default().build()])
330 .context("Failed to configure TensorRT execution provider")
331}
332
333#[cfg(not(feature = "tensorrt"))]
334fn configure_tensorrt(_builder: SessionBuilder) -> Result<SessionBuilder> {
335 anyhow::bail!("TensorRT support not compiled. Enable the 'tensorrt' feature.")
336}
337
338#[cfg(feature = "coreml")]
339fn configure_coreml(builder: SessionBuilder) -> Result<SessionBuilder> {
340 builder
341 .with_execution_providers([CoreMLExecutionProvider::default().build()])
342 .context("Failed to configure CoreML execution provider")
343}
344
345#[cfg(not(feature = "coreml"))]
346fn configure_coreml(_builder: SessionBuilder) -> Result<SessionBuilder> {
347 anyhow::bail!("CoreML support not compiled. Enable the 'coreml' feature.")
348}
349
350#[cfg(feature = "directml")]
351fn configure_directml(builder: SessionBuilder) -> Result<SessionBuilder> {
352 builder
353 .with_execution_providers([DirectMLExecutionProvider::default().build()])
354 .context("Failed to configure DirectML execution provider")
355}
356
357#[cfg(not(feature = "directml"))]
358fn configure_directml(_builder: SessionBuilder) -> Result<SessionBuilder> {
359 anyhow::bail!("DirectML support not compiled. Enable the 'directml' feature.")
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ColbertConfig {
371 #[serde(default = "default_query_prefix")]
373 pub query_prefix: String,
374
375 #[serde(default = "default_document_prefix")]
377 pub document_prefix: String,
378
379 #[serde(default = "default_query_length")]
381 pub query_length: usize,
382
383 #[serde(default = "default_document_length")]
385 pub document_length: usize,
386
387 #[serde(default = "default_do_query_expansion")]
389 pub do_query_expansion: bool,
390
391 #[serde(default = "default_embedding_dim")]
393 pub embedding_dim: usize,
394
395 #[serde(default = "default_uses_token_type_ids")]
397 pub uses_token_type_ids: bool,
398
399 #[serde(default = "default_mask_token_id")]
401 pub mask_token_id: u32,
402
403 #[serde(default = "default_pad_token_id")]
405 pub pad_token_id: u32,
406
407 #[serde(default)]
409 pub skiplist_words: Vec<String>,
410
411 #[serde(default = "default_model_type")]
413 model_type: String,
414 #[serde(default)]
415 model_name: Option<String>,
416 #[serde(default)]
417 model_class: Option<String>,
418 #[serde(default)]
419 attend_to_expansion_tokens: bool,
420 query_prefix_id: Option<u32>,
421 document_prefix_id: Option<u32>,
422 #[serde(default)]
424 pub do_lower_case: bool,
425}
426
427fn default_model_type() -> String {
428 "ColBERT".to_string()
429}
430fn default_uses_token_type_ids() -> bool {
431 true
432}
433fn default_query_prefix() -> String {
434 "[Q] ".to_string()
435}
436fn default_document_prefix() -> String {
437 "[D] ".to_string()
438}
439fn default_query_length() -> usize {
440 48
441}
442fn default_document_length() -> usize {
443 300
444}
445fn default_do_query_expansion() -> bool {
446 true
447}
448fn default_embedding_dim() -> usize {
449 128
450}
451fn default_mask_token_id() -> u32 {
452 103
453}
454fn default_pad_token_id() -> u32 {
455 0
456}
457
458impl Default for ColbertConfig {
459 fn default() -> Self {
460 Self {
461 model_type: default_model_type(),
462 model_name: None,
463 model_class: None,
464 uses_token_type_ids: default_uses_token_type_ids(),
465 query_prefix: default_query_prefix(),
466 document_prefix: default_document_prefix(),
467 query_length: default_query_length(),
468 document_length: default_document_length(),
469 do_query_expansion: default_do_query_expansion(),
470 attend_to_expansion_tokens: false,
471 skiplist_words: Vec::new(),
472 embedding_dim: default_embedding_dim(),
473 mask_token_id: default_mask_token_id(),
474 pad_token_id: default_pad_token_id(),
475 query_prefix_id: None,
476 document_prefix_id: None,
477 do_lower_case: false,
478 }
479 }
480}
481
482impl ColbertConfig {
483 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
485 let content = fs::read_to_string(path.as_ref())
486 .with_context(|| format!("Failed to read config from {:?}", path.as_ref()))?;
487 let config: ColbertConfig =
488 serde_json::from_str(&content).with_context(|| "Failed to parse onnx_config.json")?;
489 Ok(config)
490 }
491
492 fn from_model_dir<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
493 let onnx_config_path = model_dir.as_ref().join("onnx_config.json");
494 if onnx_config_path.exists() {
495 return Self::from_file(&onnx_config_path);
496 }
497
498 anyhow::bail!(
499 "onnx_config.json not found in {:?}. This file is required for ColBERT model configuration.",
500 model_dir.as_ref()
501 )
502 }
503
504 pub fn model_name(&self) -> Option<&str> {
506 self.model_name.as_deref()
507 }
508}
509
510const DEFAULT_CPU_BATCH_SIZE: usize = 32;
516
517const DEFAULT_GPU_BATCH_SIZE: usize = 64;
519
520type BatchEncoding = (Vec<i64>, Vec<i64>, Vec<i64>, Vec<u32>);
522
523pub struct Colbert {
544 sessions: Vec<Mutex<Session>>,
545 tokenizer: Arc<Tokenizer>,
546 config: ColbertConfig,
547 skiplist_ids: HashSet<u32>,
548 batch_size: usize,
549}
550
551pub struct ColbertBuilder {
570 model_dir: std::path::PathBuf,
571 num_sessions: usize,
572 threads_per_session: usize,
573 batch_size: Option<usize>,
574 execution_provider: ExecutionProvider,
575 quantized: bool,
576 query_length: Option<usize>,
577 document_length: Option<usize>,
578}
579
580impl ColbertBuilder {
581 pub fn new<P: AsRef<Path>>(model_dir: P) -> Self {
588 let num_threads = std::thread::available_parallelism()
589 .map(|p| p.get())
590 .unwrap_or(4);
591 Self {
592 model_dir: model_dir.as_ref().to_path_buf(),
593 num_sessions: 1,
594 threads_per_session: num_threads,
595 batch_size: None,
596 execution_provider: ExecutionProvider::Auto,
597 quantized: false,
598 query_length: None,
599 document_length: None,
600 }
601 }
602
603 pub fn with_parallel(mut self, num_sessions: usize) -> Self {
610 self.num_sessions = num_sessions.max(1);
611 self.threads_per_session = 1; self
613 }
614
615 pub fn with_threads(mut self, num_threads: usize) -> Self {
619 self.threads_per_session = num_threads;
620 self
621 }
622
623 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
627 self.batch_size = Some(batch_size);
628 self
629 }
630
631 pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
633 self.execution_provider = provider;
634 self
635 }
636
637 pub fn with_quantized(mut self, quantized: bool) -> Self {
641 self.quantized = quantized;
642 self
643 }
644
645 pub fn with_query_length(mut self, query_length: usize) -> Self {
650 self.query_length = Some(query_length);
651 self
652 }
653
654 pub fn with_document_length(mut self, document_length: usize) -> Self {
659 self.document_length = Some(document_length);
660 self
661 }
662
663 pub fn build(self) -> Result<Colbert> {
665 init_ort_runtime();
666
667 let model_dir = &self.model_dir;
668 let onnx_path = select_onnx_file(model_dir, self.quantized)?;
669 let tokenizer_path = model_dir.join("tokenizer.json");
670
671 let tokenizer = Tokenizer::from_file(&tokenizer_path)
672 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
673
674 let mut config = ColbertConfig::from_model_dir(model_dir)?;
675
676 if let Some(query_length) = self.query_length {
680 config.query_length = query_length;
681 }
682 if let Some(document_length) = self.document_length {
683 config.document_length = document_length;
684 }
685
686 update_token_ids(&mut config, &tokenizer);
687 let skiplist_ids = build_skiplist(&config, &tokenizer);
688
689 let mut sessions = Vec::with_capacity(self.num_sessions);
691 for _i in 0..self.num_sessions {
692 let builder = Session::builder()?
693 .with_optimization_level(GraphOptimizationLevel::Level3)?
694 .with_intra_threads(self.threads_per_session)?
695 .with_inter_threads(if self.num_sessions > 1 { 1 } else { 2 })?
696 .with_memory_pattern(false)?;
699
700 let builder = configure_execution_provider(builder, self.execution_provider)?;
701
702 let session = builder
703 .commit_from_file(&onnx_path)
704 .context("Failed to load ONNX model")?;
705
706 sessions.push(Mutex::new(session));
707 }
708
709 let batch_size = self.batch_size.unwrap_or(if self.num_sessions > 1 {
711 2 } else {
713 match self.execution_provider {
714 ExecutionProvider::Cpu => DEFAULT_CPU_BATCH_SIZE,
715 _ => DEFAULT_GPU_BATCH_SIZE,
716 }
717 });
718
719 Ok(Colbert {
720 sessions,
721 tokenizer: Arc::new(tokenizer),
722 config,
723 skiplist_ids,
724 batch_size,
725 })
726 }
727}
728
729impl Colbert {
730 pub fn new<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
740 ColbertBuilder::new(model_dir).build()
741 }
742
743 pub fn builder<P: AsRef<Path>>(model_dir: P) -> ColbertBuilder {
754 ColbertBuilder::new(model_dir)
755 }
756
757 pub fn encode_documents(
779 &self,
780 documents: &[&str],
781 pool_factor: Option<usize>,
782 ) -> Result<Vec<Array2<f32>>> {
783 if documents.is_empty() {
784 return Ok(Vec::new());
785 }
786
787 let embeddings = if self.sessions.len() == 1 {
788 self.encode_single_session(documents, false, true)?
789 } else {
790 self.encode_parallel(documents, false, true)?
791 };
792
793 match pool_factor {
795 Some(pf) if pf > 1 => {
796 let pooled: Vec<Array2<f32>> = embeddings
797 .into_iter()
798 .map(|emb| pool_embeddings_hierarchical(emb, pf, 1))
799 .collect();
800 Ok(pooled)
801 }
802 _ => Ok(embeddings),
803 }
804 }
805
806 pub fn encode_queries(&self, queries: &[&str]) -> Result<Vec<Array2<f32>>> {
817 if queries.is_empty() {
818 return Ok(Vec::new());
819 }
820
821 if self.sessions.len() == 1 {
822 self.encode_single_session(queries, true, false)
823 } else {
824 self.encode_parallel(queries, true, false)
825 }
826 }
827
828 pub fn config(&self) -> &ColbertConfig {
830 &self.config
831 }
832
833 pub fn embedding_dim(&self) -> usize {
835 self.config.embedding_dim
836 }
837
838 pub fn batch_size(&self) -> usize {
840 self.batch_size
841 }
842
843 pub fn num_sessions(&self) -> usize {
845 self.sessions.len()
846 }
847
848 fn encode_single_session(
853 &self,
854 texts: &[&str],
855 is_query: bool,
856 filter_skiplist: bool,
857 ) -> Result<Vec<Array2<f32>>> {
858 let mut all_embeddings = Vec::with_capacity(texts.len());
859
860 for chunk in texts.chunks(self.batch_size) {
861 let mut session = self.sessions[0].lock().unwrap();
862 let chunk_embeddings = encode_batch_with_session(
863 &mut session,
864 &self.tokenizer,
865 &self.config,
866 &self.skiplist_ids,
867 chunk,
868 is_query,
869 filter_skiplist,
870 )?;
871 all_embeddings.extend(chunk_embeddings);
872 }
873
874 Ok(all_embeddings)
875 }
876
877 fn encode_parallel(
878 &self,
879 texts: &[&str],
880 is_query: bool,
881 filter_skiplist: bool,
882 ) -> Result<Vec<Array2<f32>>> {
883 let num_sessions = self.sessions.len();
884
885 let chunks: Vec<Vec<&str>> = texts
886 .chunks(self.batch_size.max(1))
887 .map(|c| c.to_vec())
888 .collect();
889
890 let results: Vec<Result<Vec<Array2<f32>>>> = std::thread::scope(|s| {
891 let handles: Vec<_> = chunks
892 .iter()
893 .enumerate()
894 .map(|(i, chunk)| {
895 let session_idx = i % num_sessions;
896 let session_mutex = &self.sessions[session_idx];
897 let tokenizer = &self.tokenizer;
898 let config = &self.config;
899 let skiplist_ids = &self.skiplist_ids;
900
901 s.spawn(move || {
902 let mut session = session_mutex.lock().unwrap();
903 encode_batch_with_session(
904 &mut session,
905 tokenizer,
906 config,
907 skiplist_ids,
908 chunk,
909 is_query,
910 filter_skiplist,
911 )
912 })
913 })
914 .collect();
915
916 handles.into_iter().map(|h| h.join().unwrap()).collect()
917 });
918
919 let mut all_embeddings = Vec::with_capacity(texts.len());
920 for result in results {
921 all_embeddings.extend(result?);
922 }
923
924 Ok(all_embeddings)
925 }
926}
927
928fn select_onnx_file<P: AsRef<Path>>(model_dir: P, quantized: bool) -> Result<std::path::PathBuf> {
933 let model_dir = model_dir.as_ref();
934
935 if quantized {
936 let q_path = model_dir.join("model_int8.onnx");
938 if q_path.exists() {
939 Ok(q_path)
940 } else {
941 anyhow::bail!(
942 "INT8 quantized model not found at {:?}. Remove --int8 flag to load model.onnx instead.",
943 q_path
944 )
945 }
946 } else {
947 let model_path = model_dir.join("model.onnx");
950 if model_path.exists() {
951 Ok(model_path)
952 } else {
953 anyhow::bail!(
954 "Model not found at {:?}. Use --int8 flag to load model_int8.onnx instead.",
955 model_path
956 )
957 }
958 }
959}
960
961fn update_token_ids(config: &mut ColbertConfig, tokenizer: &Tokenizer) {
962 if config.mask_token_id == default_mask_token_id() {
963 if let Some(mask_id) = tokenizer.token_to_id("[MASK]") {
964 config.mask_token_id = mask_id;
965 } else if let Some(mask_id) = tokenizer.token_to_id("<mask>") {
966 config.mask_token_id = mask_id;
967 }
968 }
969 if config.pad_token_id == default_pad_token_id() {
970 if let Some(pad_id) = tokenizer.token_to_id("[PAD]") {
971 config.pad_token_id = pad_id;
972 } else if let Some(pad_id) = tokenizer.token_to_id("<pad>") {
973 config.pad_token_id = pad_id;
974 }
975 }
976}
977
978fn build_skiplist(config: &ColbertConfig, tokenizer: &Tokenizer) -> HashSet<u32> {
979 let mut skiplist_ids = HashSet::new();
980 for word in &config.skiplist_words {
981 if let Some(token_id) = tokenizer.token_to_id(word) {
982 skiplist_ids.insert(token_id);
983 }
984 }
985 skiplist_ids
986}
987
988fn encode_batch_with_session(
997 session: &mut Session,
998 tokenizer: &Tokenizer,
999 config: &ColbertConfig,
1000 skiplist_ids: &HashSet<u32>,
1001 texts: &[&str],
1002 is_query: bool,
1003 filter_skiplist: bool,
1004) -> Result<Vec<Array2<f32>>> {
1005 if texts.is_empty() {
1006 return Ok(Vec::new());
1007 }
1008
1009 let (prefix_str, prefix_token_id_opt, max_length) = if is_query {
1010 (
1011 &config.query_prefix,
1012 config.query_prefix_id,
1013 config.query_length,
1014 )
1015 } else {
1016 (
1017 &config.document_prefix,
1018 config.document_prefix_id,
1019 config.document_length,
1020 )
1021 };
1022
1023 let prefix_token_id: u32 = match prefix_token_id_opt {
1025 Some(id) => id,
1026 None => tokenizer.token_to_id(prefix_str).ok_or_else(|| {
1027 anyhow::anyhow!(
1028 "Prefix token '{}' not found in tokenizer vocabulary",
1029 prefix_str
1030 )
1031 })?,
1032 };
1033
1034 let processed_texts: Vec<String> = if config.do_lower_case {
1038 texts.iter().map(|t| t.trim().to_lowercase()).collect()
1039 } else {
1040 texts.iter().map(|t| t.trim().to_string()).collect()
1041 };
1042 let texts_to_encode: Vec<&str> = processed_texts.iter().map(|s| s.as_str()).collect();
1043
1044 let batch_encodings = tokenizer
1047 .encode_batch(texts_to_encode, true)
1048 .map_err(|e| anyhow::anyhow!("Tokenization error: {}", e))?;
1049
1050 let mut encodings: Vec<BatchEncoding> = Vec::with_capacity(texts.len());
1051 let mut batch_max_len = 0usize;
1052
1053 let truncate_limit = max_length - 1;
1055
1056 for encoding in batch_encodings {
1057 let token_ids: Vec<u32> = encoding.get_ids().to_vec();
1058 let mut input_ids: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
1059 let mut attention_mask: Vec<i64> = encoding
1060 .get_attention_mask()
1061 .iter()
1062 .map(|&x| x as i64)
1063 .collect();
1064 let mut token_type_ids: Vec<i64> =
1065 encoding.get_type_ids().iter().map(|&x| x as i64).collect();
1066 let mut token_ids_vec = token_ids;
1067
1068 if input_ids.len() > truncate_limit {
1072 let sep_token = input_ids[input_ids.len() - 1];
1074 let sep_mask = attention_mask[attention_mask.len() - 1];
1075 let sep_type = token_type_ids[token_type_ids.len() - 1];
1076 let sep_token_id = token_ids_vec[token_ids_vec.len() - 1];
1077
1078 input_ids.truncate(truncate_limit - 1);
1080 attention_mask.truncate(truncate_limit - 1);
1081 token_type_ids.truncate(truncate_limit - 1);
1082 token_ids_vec.truncate(truncate_limit - 1);
1083
1084 input_ids.push(sep_token);
1086 attention_mask.push(sep_mask);
1087 token_type_ids.push(sep_type);
1088 token_ids_vec.push(sep_token_id);
1089 }
1090
1091 input_ids.insert(1, prefix_token_id as i64);
1094 attention_mask.insert(1, 1);
1095 token_type_ids.insert(1, 0);
1096 token_ids_vec.insert(1, prefix_token_id);
1097
1098 batch_max_len = batch_max_len.max(input_ids.len());
1099 encodings.push((input_ids, attention_mask, token_type_ids, token_ids_vec));
1100 }
1101
1102 if is_query && config.do_query_expansion {
1103 batch_max_len = max_length;
1104 }
1105
1106 let batch_size = texts.len();
1107 let mut all_input_ids: Vec<i64> = Vec::with_capacity(batch_size * batch_max_len);
1108 let mut all_attention_mask: Vec<i64> = Vec::with_capacity(batch_size * batch_max_len);
1109 let mut all_token_type_ids: Vec<i64> = Vec::with_capacity(batch_size * batch_max_len);
1110 let mut all_token_ids: Vec<Vec<u32>> = Vec::with_capacity(batch_size);
1111 let mut original_lengths: Vec<usize> = Vec::with_capacity(batch_size);
1112
1113 for (mut input_ids, mut attention_mask, mut token_type_ids, mut token_ids) in encodings {
1114 original_lengths.push(input_ids.len());
1115
1116 while input_ids.len() < batch_max_len {
1117 if is_query && config.do_query_expansion {
1118 input_ids.push(config.mask_token_id as i64);
1119 attention_mask.push(1);
1120 token_ids.push(config.mask_token_id);
1121 } else {
1122 input_ids.push(config.pad_token_id as i64);
1123 attention_mask.push(0);
1124 token_ids.push(config.pad_token_id);
1125 }
1126 token_type_ids.push(0);
1127 }
1128
1129 all_input_ids.extend(input_ids);
1130 all_attention_mask.extend(attention_mask);
1131 all_token_type_ids.extend(token_type_ids);
1132 all_token_ids.push(token_ids);
1133 }
1134
1135 let input_ids_tensor = Tensor::from_array(([batch_size, batch_max_len], all_input_ids))?;
1136 let attention_mask_tensor =
1137 Tensor::from_array(([batch_size, batch_max_len], all_attention_mask.clone()))?;
1138
1139 let token_type_ids_tensor = if config.uses_token_type_ids {
1140 Some(Tensor::from_array((
1141 [batch_size, batch_max_len],
1142 all_token_type_ids,
1143 ))?)
1144 } else {
1145 None
1146 };
1147
1148 let outputs = if let Some(token_type_ids_tensor) = token_type_ids_tensor {
1149 session.run(ort::inputs![
1150 "input_ids" => input_ids_tensor,
1151 "attention_mask" => attention_mask_tensor,
1152 "token_type_ids" => token_type_ids_tensor,
1153 ])?
1154 } else {
1155 session.run(ort::inputs![
1156 "input_ids" => input_ids_tensor,
1157 "attention_mask" => attention_mask_tensor,
1158 ])?
1159 };
1160
1161 let (output_shape, output_data) = outputs["output"]
1162 .try_extract_tensor::<f32>()
1163 .context("Failed to extract output tensor")?;
1164
1165 let shape_slice: Vec<i64> = output_shape.iter().copied().collect();
1166 let embedding_dim = shape_slice[2] as usize;
1167
1168 let mut all_embeddings = Vec::with_capacity(batch_size);
1169 for i in 0..batch_size {
1170 let batch_offset = i * batch_max_len * embedding_dim;
1171 let attention_offset = i * batch_max_len;
1172
1173 if is_query && config.do_query_expansion {
1174 let end = batch_offset + batch_max_len * embedding_dim;
1175 let flat: Vec<f32> = output_data[batch_offset..end].to_vec();
1176 let arr = Array2::from_shape_vec((batch_max_len, embedding_dim), flat)?;
1177 all_embeddings.push(arr);
1178 } else {
1179 let orig_len = original_lengths[i];
1180 let token_ids = &all_token_ids[i];
1181
1182 let valid_count = (0..orig_len)
1183 .filter(|&j| {
1184 let mask = all_attention_mask[attention_offset + j];
1185 let token_id = token_ids[j];
1186 mask != 0 && !(filter_skiplist && skiplist_ids.contains(&token_id))
1187 })
1188 .count();
1189
1190 let mut flat: Vec<f32> = Vec::with_capacity(valid_count * embedding_dim);
1191 for j in 0..orig_len {
1192 let mask = all_attention_mask[attention_offset + j];
1193 let token_id = token_ids[j];
1194
1195 if mask == 0 {
1196 continue;
1197 }
1198 if filter_skiplist && skiplist_ids.contains(&token_id) {
1199 continue;
1200 }
1201
1202 let start = batch_offset + j * embedding_dim;
1203 flat.extend_from_slice(&output_data[start..start + embedding_dim]);
1204 }
1205
1206 let arr = Array2::from_shape_vec((valid_count, embedding_dim), flat)?;
1207 all_embeddings.push(arr);
1208 }
1209 }
1210
1211 Ok(all_embeddings)
1212}
1213
1214fn pool_embeddings_hierarchical(
1216 embeddings: Array2<f32>,
1217 pool_factor: usize,
1218 protected_tokens: usize,
1219) -> Array2<f32> {
1220 let n_tokens = embeddings.nrows();
1221 let n_features = embeddings.ncols();
1222
1223 if n_tokens <= protected_tokens + 1 {
1224 return embeddings;
1225 }
1226
1227 let tokens_to_pool = n_tokens - protected_tokens;
1228 let num_clusters = (tokens_to_pool / pool_factor).max(1);
1229
1230 if num_clusters >= tokens_to_pool {
1231 return embeddings;
1232 }
1233
1234 let to_pool = embeddings.slice(ndarray::s![protected_tokens.., ..]);
1235 let flat_embeddings: Vec<f32> = to_pool.iter().copied().collect();
1236
1237 let distances = crate::hierarchy::pdist_cosine(&flat_embeddings, tokens_to_pool, n_features);
1238
1239 let linkage_matrix = crate::hierarchy::linkage(
1240 &distances,
1241 tokens_to_pool,
1242 crate::hierarchy::LinkageMethod::Ward,
1243 );
1244
1245 let labels = crate::hierarchy::fcluster(
1246 &linkage_matrix,
1247 tokens_to_pool,
1248 crate::hierarchy::FclusterCriterion::MaxClust,
1249 num_clusters as f64,
1250 );
1251
1252 let mut pooled_rows: Vec<Vec<f32>> = Vec::with_capacity(num_clusters + protected_tokens);
1253
1254 for i in 0..protected_tokens {
1255 pooled_rows.push(embeddings.row(i).to_vec());
1256 }
1257
1258 for cluster_id in 1..=num_clusters {
1259 let mut sum = vec![0.0f32; n_features];
1260 let mut count = 0usize;
1261
1262 for (idx, &label) in labels.iter().enumerate() {
1263 if label == cluster_id {
1264 let row = to_pool.row(idx);
1265 for (s, &v) in sum.iter_mut().zip(row.iter()) {
1266 *s += v;
1267 }
1268 count += 1;
1269 }
1270 }
1271
1272 if count > 0 {
1273 for s in &mut sum {
1274 *s /= count as f32;
1275 }
1276 pooled_rows.push(sum);
1277 }
1278 }
1279
1280 let n_pooled = pooled_rows.len();
1281 let flat: Vec<f32> = pooled_rows.into_iter().flatten().collect();
1282 Array2::from_shape_vec((n_pooled, n_features), flat)
1283 .expect("Shape mismatch in pooled embeddings")
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288 use super::*;
1289
1290 #[test]
1295 fn test_default_config() {
1296 let config = ColbertConfig::default();
1297 assert_eq!(config.query_length, 48);
1298 assert_eq!(config.document_length, 300);
1299 assert!(config.do_query_expansion);
1300 assert_eq!(config.embedding_dim, 128);
1301 assert_eq!(config.mask_token_id, 103);
1302 assert_eq!(config.pad_token_id, 0);
1303 assert!(config.uses_token_type_ids);
1304 assert_eq!(config.query_prefix, "[Q] ");
1305 assert_eq!(config.document_prefix, "[D] ");
1306 assert!(config.skiplist_words.is_empty());
1307 }
1308
1309 #[test]
1310 fn test_config_serialization_roundtrip() {
1311 let config = ColbertConfig::default();
1312 let json = serde_json::to_string(&config).unwrap();
1313 let parsed: ColbertConfig = serde_json::from_str(&json).unwrap();
1314
1315 assert_eq!(parsed.query_length, config.query_length);
1316 assert_eq!(parsed.document_length, config.document_length);
1317 assert_eq!(parsed.do_query_expansion, config.do_query_expansion);
1318 assert_eq!(parsed.embedding_dim, config.embedding_dim);
1319 assert_eq!(parsed.mask_token_id, config.mask_token_id);
1320 assert_eq!(parsed.pad_token_id, config.pad_token_id);
1321 assert_eq!(parsed.uses_token_type_ids, config.uses_token_type_ids);
1322 }
1323
1324 #[test]
1325 fn test_config_deserialization_with_custom_values() {
1326 let json = r#"{
1327 "query_length": 64,
1328 "document_length": 512,
1329 "do_query_expansion": false,
1330 "embedding_dim": 256,
1331 "mask_token_id": 4,
1332 "pad_token_id": 1,
1333 "uses_token_type_ids": false,
1334 "query_prefix": "[query]",
1335 "document_prefix": "[doc]",
1336 "skiplist_words": ["the", "a", "an"]
1337 }"#;
1338
1339 let config: ColbertConfig = serde_json::from_str(json).unwrap();
1340
1341 assert_eq!(config.query_length, 64);
1342 assert_eq!(config.document_length, 512);
1343 assert!(!config.do_query_expansion);
1344 assert_eq!(config.embedding_dim, 256);
1345 assert_eq!(config.mask_token_id, 4);
1346 assert_eq!(config.pad_token_id, 1);
1347 assert!(!config.uses_token_type_ids);
1348 assert_eq!(config.query_prefix, "[query]");
1349 assert_eq!(config.document_prefix, "[doc]");
1350 assert_eq!(config.skiplist_words, vec!["the", "a", "an"]);
1351 }
1352
1353 #[test]
1354 fn test_config_deserialization_with_defaults() {
1355 let json = "{}";
1357 let config: ColbertConfig = serde_json::from_str(json).unwrap();
1358
1359 assert_eq!(config.query_length, 48);
1360 assert_eq!(config.document_length, 300);
1361 assert!(config.do_query_expansion);
1362 }
1363
1364 #[test]
1369 fn test_builder_defaults() {
1370 let builder = ColbertBuilder::new("test_model");
1371
1372 assert_eq!(builder.num_sessions, 1);
1373 assert!(!builder.quantized);
1374 assert!(builder.batch_size.is_none());
1375 assert_eq!(builder.execution_provider, ExecutionProvider::Auto);
1376 assert!(builder.query_length.is_none());
1377 assert!(builder.document_length.is_none());
1378 }
1379
1380 #[test]
1381 fn test_builder_with_parallel() {
1382 let builder = ColbertBuilder::new("test_model").with_parallel(25);
1383
1384 assert_eq!(builder.num_sessions, 25);
1385 assert_eq!(builder.threads_per_session, 1); }
1387
1388 #[test]
1389 fn test_builder_with_parallel_minimum() {
1390 let builder = ColbertBuilder::new("test_model").with_parallel(0);
1392
1393 assert_eq!(builder.num_sessions, 1);
1394 }
1395
1396 #[test]
1397 fn test_builder_with_threads() {
1398 let builder = ColbertBuilder::new("test_model").with_threads(8);
1399
1400 assert_eq!(builder.threads_per_session, 8);
1401 }
1402
1403 #[test]
1404 fn test_builder_with_batch_size() {
1405 let builder = ColbertBuilder::new("test_model").with_batch_size(64);
1406
1407 assert_eq!(builder.batch_size, Some(64));
1408 }
1409
1410 #[test]
1411 fn test_builder_with_quantized() {
1412 let builder = ColbertBuilder::new("test_model").with_quantized(true);
1413
1414 assert!(builder.quantized);
1415 }
1416
1417 #[test]
1418 fn test_builder_with_execution_provider() {
1419 let builder =
1420 ColbertBuilder::new("test_model").with_execution_provider(ExecutionProvider::Cpu);
1421
1422 assert_eq!(builder.execution_provider, ExecutionProvider::Cpu);
1423 }
1424
1425 #[test]
1426 fn test_builder_with_query_length() {
1427 let builder = ColbertBuilder::new("test_model").with_query_length(64);
1428
1429 assert_eq!(builder.query_length, Some(64));
1430 }
1431
1432 #[test]
1433 fn test_builder_with_document_length() {
1434 let builder = ColbertBuilder::new("test_model").with_document_length(512);
1435
1436 assert_eq!(builder.document_length, Some(512));
1437 }
1438
1439 #[test]
1440 fn test_builder_chained_configuration() {
1441 let builder = ColbertBuilder::new("test_model")
1442 .with_quantized(true)
1443 .with_parallel(16)
1444 .with_batch_size(4)
1445 .with_execution_provider(ExecutionProvider::Cuda)
1446 .with_query_length(64)
1447 .with_document_length(512);
1448
1449 assert!(builder.quantized);
1450 assert_eq!(builder.num_sessions, 16);
1451 assert_eq!(builder.threads_per_session, 1);
1452 assert_eq!(builder.batch_size, Some(4));
1453 assert_eq!(builder.execution_provider, ExecutionProvider::Cuda);
1454 assert_eq!(builder.query_length, Some(64));
1455 assert_eq!(builder.document_length, Some(512));
1456 }
1457
1458 #[test]
1463 fn test_execution_provider_default() {
1464 let provider = ExecutionProvider::default();
1465 assert_eq!(provider, ExecutionProvider::Auto);
1466 }
1467
1468 #[test]
1469 fn test_execution_provider_variants() {
1470 assert_ne!(ExecutionProvider::Auto, ExecutionProvider::Cpu);
1472 assert_ne!(ExecutionProvider::Cpu, ExecutionProvider::Cuda);
1473 assert_ne!(ExecutionProvider::Cuda, ExecutionProvider::TensorRT);
1474 assert_ne!(ExecutionProvider::TensorRT, ExecutionProvider::CoreML);
1475 assert_ne!(ExecutionProvider::CoreML, ExecutionProvider::DirectML);
1476 }
1477
1478 #[test]
1479 fn test_execution_provider_clone() {
1480 let provider = ExecutionProvider::Cuda;
1481 let cloned = provider;
1482 assert_eq!(provider, cloned);
1483 }
1484
1485 #[test]
1486 fn test_execution_provider_debug() {
1487 let provider = ExecutionProvider::Cuda;
1488 let debug_str = format!("{:?}", provider);
1489 assert_eq!(debug_str, "Cuda");
1490 }
1491
1492 #[test]
1497 fn test_pool_embeddings_no_pooling() {
1498 let embeddings = Array2::from_shape_vec(
1500 (5, 4),
1501 vec![
1502 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0, 0.0, ],
1508 )
1509 .unwrap();
1510
1511 let result = pool_embeddings_hierarchical(embeddings.clone(), 1, 1);
1513 assert_eq!(result.dim(), embeddings.dim());
1514 }
1515
1516 #[test]
1517 fn test_pool_embeddings_with_pooling() {
1518 let embeddings = Array2::from_shape_vec(
1520 (5, 4),
1521 vec![
1522 1.0, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.85, 0.15, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.9, 0.1, ],
1528 )
1529 .unwrap();
1530
1531 let result = pool_embeddings_hierarchical(embeddings, 2, 1);
1533
1534 assert!(result.nrows() < 5);
1536 assert!(result.nrows() >= 1);
1538 assert_eq!(result.ncols(), 4);
1540 }
1541
1542 #[test]
1543 fn test_pool_embeddings_too_few_tokens() {
1544 let embeddings = Array2::from_shape_vec(
1546 (2, 4),
1547 vec![
1548 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
1551 )
1552 .unwrap();
1553
1554 let result = pool_embeddings_hierarchical(embeddings.clone(), 2, 1);
1555
1556 assert_eq!(result.dim(), embeddings.dim());
1558 }
1559
1560 #[test]
1561 fn test_pool_embeddings_all_protected() {
1562 let embeddings = Array2::from_shape_vec(
1564 (3, 4),
1565 vec![
1566 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
1570 )
1571 .unwrap();
1572
1573 let result = pool_embeddings_hierarchical(embeddings.clone(), 2, 3);
1575
1576 assert_eq!(result.dim(), embeddings.dim());
1578 }
1579
1580 #[test]
1585 fn test_default_batch_sizes() {
1586 assert_eq!(DEFAULT_CPU_BATCH_SIZE, 32);
1587 assert_eq!(DEFAULT_GPU_BATCH_SIZE, 64);
1588 }
1589}