1mod models;
4mod provider;
5
6pub use models::{EmbeddingConfig, ModelConfig, ModelInfo, DEFAULT_DIM, DEFAULT_MODEL_REPO};
7
8use provider::ort_err;
9pub(crate) use provider::{create_session, select_provider};
10
11use lru::LruCache;
12use ndarray::{Array2, Array3, Axis};
13use once_cell::sync::OnceCell;
14use ort::session::Session;
15use std::num::NonZeroUsize;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18use thiserror::Error;
19
20pub fn model_repo() -> String {
24 ModelConfig::resolve(None, None).repo
25}
26
27const MODEL_BLAKE3: &str = "";
29const TOKENIZER_BLAKE3: &str = "";
30
31#[derive(Error, Debug)]
32pub enum EmbedderError {
33 #[error("Model not found: {0}")]
34 ModelNotFound(String),
35 #[error("Tokenizer error: {0}")]
36 Tokenizer(String),
37 #[error("Inference failed: {0}")]
38 InferenceFailed(String),
39 #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
40 ChecksumMismatch {
41 path: String,
42 expected: String,
43 actual: String,
44 },
45 #[error("Query cannot be empty")]
46 EmptyQuery,
47 #[error("HuggingFace Hub error: {0}")]
48 HfHub(String),
49}
50
51#[derive(Debug, Clone)]
58pub struct Embedding(Vec<f32>);
59
60pub use crate::EMBEDDING_DIM;
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct EmbeddingDimensionError {
66 pub actual: usize,
68 pub expected: usize,
70}
71
72impl std::fmt::Display for EmbeddingDimensionError {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(
86 f,
87 "Invalid embedding dimension: expected {}, got {}",
88 self.expected, self.actual
89 )
90 }
91}
92
93impl std::error::Error for EmbeddingDimensionError {}
94
95impl Embedding {
96 pub fn new(data: Vec<f32>) -> Self {
102 Self(data)
103 }
104
105 pub fn try_new(data: Vec<f32>) -> Result<Self, EmbeddingDimensionError> {
124 if data.is_empty() {
125 return Err(EmbeddingDimensionError {
126 actual: 0,
127 expected: 1, });
129 }
130 if !data.iter().all(|v| v.is_finite()) {
131 return Err(EmbeddingDimensionError {
132 actual: data.len(),
133 expected: data.len(),
134 });
135 }
136 Ok(Self(data))
137 }
138
139 pub fn as_slice(&self) -> &[f32] {
141 &self.0
142 }
143
144 pub fn as_vec(&self) -> &Vec<f32> {
146 &self.0
147 }
148
149 pub fn into_inner(self) -> Vec<f32> {
151 self.0
152 }
153
154 pub fn len(&self) -> usize {
158 self.0.len()
159 }
160
161 pub fn is_empty(&self) -> bool {
163 self.0.is_empty()
164 }
165}
166
167#[derive(Debug, Clone, Copy)]
169pub enum ExecutionProvider {
170 CUDA { device_id: i32 },
172 TensorRT { device_id: i32 },
174 CPU,
176}
177
178impl std::fmt::Display for ExecutionProvider {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 match self {
188 ExecutionProvider::CUDA { device_id } => write!(f, "CUDA (device {})", device_id),
189 ExecutionProvider::TensorRT { device_id } => {
190 write!(f, "TensorRT (device {})", device_id)
191 }
192 ExecutionProvider::CPU => write!(f, "CPU"),
193 }
194 }
195}
196
197pub struct Embedder {
214 session: Mutex<Option<Session>>,
220 tokenizer: OnceCell<tokenizers::Tokenizer>,
222 model_paths: OnceCell<(PathBuf, PathBuf)>,
224 provider: ExecutionProvider,
225 max_length: usize,
226 query_cache: Mutex<LruCache<String, Embedding>>,
228 detected_dim: std::sync::OnceLock<usize>,
230 model_config: ModelConfig,
232}
233
234const DEFAULT_QUERY_CACHE_SIZE: usize = 32;
236
237impl Embedder {
238 pub fn new(model_config: ModelConfig) -> Result<Self, EmbedderError> {
249 Self::new_with_provider(model_config, select_provider())
250 }
251
252 pub fn new_cpu(model_config: ModelConfig) -> Result<Self, EmbedderError> {
257 Self::new_with_provider(model_config, ExecutionProvider::CPU)
258 }
259
260 fn new_with_provider(
262 model_config: ModelConfig,
263 provider: ExecutionProvider,
264 ) -> Result<Self, EmbedderError> {
265 let max_length = model_config.max_seq_length;
266
267 let query_cache = Mutex::new(LruCache::new(
268 NonZeroUsize::new(DEFAULT_QUERY_CACHE_SIZE)
269 .expect("DEFAULT_QUERY_CACHE_SIZE is non-zero"),
270 ));
271
272 Ok(Self {
273 session: Mutex::new(None),
274 tokenizer: OnceCell::new(),
275 model_paths: OnceCell::new(),
276 provider,
277 max_length,
278 query_cache,
279 detected_dim: std::sync::OnceLock::new(),
280 model_config,
281 })
282 }
283
284 pub fn model_config(&self) -> &ModelConfig {
286 &self.model_config
287 }
288
289 fn model_paths(&self) -> Result<&(PathBuf, PathBuf), EmbedderError> {
291 self.model_paths
292 .get_or_try_init(|| ensure_model(&self.model_config))
293 }
294
295 fn session(&self) -> Result<std::sync::MutexGuard<'_, Option<Session>>, EmbedderError> {
297 let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
298 if guard.is_none() {
299 let _span = tracing::info_span!("embedder_session_init").entered();
300 let (model_path, _) = self.model_paths()?;
301 *guard = Some(create_session(model_path, self.provider)?);
302 tracing::info!("Embedder session initialized");
303 }
304 Ok(guard)
305 }
306
307 fn tokenizer(&self) -> Result<&tokenizers::Tokenizer, EmbedderError> {
309 let (_, tokenizer_path) = self.model_paths()?;
310 self.tokenizer.get_or_try_init(|| {
311 tokenizers::Tokenizer::from_file(tokenizer_path)
312 .map_err(|e| EmbedderError::Tokenizer(e.to_string()))
313 })
314 }
315
316 pub fn token_count(&self, text: &str) -> Result<usize, EmbedderError> {
330 let encoding = self
331 .tokenizer()?
332 .encode(text, false)
333 .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
334 Ok(encoding.get_ids().len())
335 }
336
337 pub fn token_counts_batch(&self, texts: &[&str]) -> Result<Vec<usize>, EmbedderError> {
342 if texts.is_empty() {
343 return Ok(vec![]);
344 }
345 let encodings = self
346 .tokenizer()?
347 .encode_batch(texts.to_vec(), false)
348 .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
349 Ok(encodings.iter().map(|e| e.get_ids().len()).collect())
350 }
351
352 pub fn split_into_windows(
359 &self,
360 text: &str,
361 max_tokens: usize,
362 overlap: usize,
363 ) -> Result<Vec<(String, u32)>, EmbedderError> {
364 if max_tokens == 0 {
365 return Ok(vec![]);
366 }
367
368 if overlap >= max_tokens / 2 {
372 return Err(EmbedderError::Tokenizer(format!(
373 "overlap ({overlap}) must be less than max_tokens/2 ({})",
374 max_tokens / 2
375 )));
376 }
377
378 let tokenizer = self.tokenizer()?;
379 let encoding = tokenizer
380 .encode(text, false)
381 .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
382
383 let ids = encoding.get_ids();
384 if ids.len() <= max_tokens {
385 return Ok(vec![(text.to_string(), 0)]);
386 }
387
388 let mut windows = Vec::new();
389 let step = max_tokens - overlap;
392 let mut start = 0;
393 let mut window_idx = 0u32;
394
395 while start < ids.len() {
396 let end = (start + max_tokens).min(ids.len());
397 let window_ids: Vec<u32> = ids[start..end].to_vec();
398
399 let window_text = tokenizer
401 .decode(&window_ids, true)
402 .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
403
404 windows.push((window_text, window_idx));
405 window_idx += 1;
406
407 if end >= ids.len() {
408 break;
409 }
410 start += step;
411 }
412
413 Ok(windows)
414 }
415
416 pub fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedderError> {
420 let _span = tracing::info_span!("embed_documents", count = texts.len()).entered();
421 let prefix = &self.model_config.doc_prefix;
422 const MAX_BATCH: usize = 64;
423 if texts.len() <= MAX_BATCH {
424 let prefixed: Vec<String> = texts.iter().map(|t| format!("{}{}", prefix, t)).collect();
425 return self.embed_batch(&prefixed);
426 }
427 let mut all = Vec::with_capacity(texts.len());
428 for chunk in texts.chunks(MAX_BATCH) {
429 let prefixed: Vec<String> = chunk.iter().map(|t| format!("{}{}", prefix, t)).collect();
430 all.extend(self.embed_batch(&prefixed)?);
431 }
432 Ok(all)
433 }
434
435 const MAX_QUERY_BYTES: usize = 32 * 1024;
446
447 pub fn embed_query(&self, text: &str) -> Result<Embedding, EmbedderError> {
448 let _span = tracing::info_span!("embed_query").entered();
449 let text = text.trim();
450 if text.is_empty() {
451 return Err(EmbedderError::EmptyQuery);
452 }
453 let text = if text.len() > Self::MAX_QUERY_BYTES {
455 tracing::warn!(
456 len = text.len(),
457 max = Self::MAX_QUERY_BYTES,
458 "Query text truncated before embedding"
459 );
460 let mut end = Self::MAX_QUERY_BYTES;
462 while !text.is_char_boundary(end) && end > 0 {
463 end -= 1;
464 }
465 &text[..end]
466 } else {
467 text
468 };
469
470 {
472 let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
473 tracing::warn!("Query cache lock poisoned (prior panic), recovering");
474 poisoned.into_inner()
475 });
476 if let Some(cached) = cache.get(text) {
477 tracing::trace!(query = text, "Embedding cache hit");
478 return Ok(cached.clone());
479 }
480 tracing::trace!(query = text, "Embedding cache miss");
481 }
482
483 let prefixed = format!("{}{}", self.model_config.query_prefix, text);
485 let results = self.embed_batch(&[prefixed])?;
486 let base_embedding = results.into_iter().next().ok_or_else(|| {
487 EmbedderError::InferenceFailed("embed_batch returned empty result".to_string())
488 })?;
489
490 let embedding = base_embedding;
491
492 {
494 let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
495 tracing::warn!("Query cache lock poisoned (prior panic), recovering");
496 poisoned.into_inner()
497 });
498 cache.put(text.to_string(), embedding.clone());
499 tracing::trace!(query = text, cache_len = cache.len(), "Embedding cached");
500 }
501
502 Ok(embedding)
503 }
504
505 pub fn provider(&self) -> ExecutionProvider {
507 self.provider
508 }
509
510 pub fn clear_session(&self) {
519 let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
520 *guard = None;
521 let mut cache = self.query_cache.lock().unwrap_or_else(|p| p.into_inner());
524 cache.clear();
525 tracing::info!("Embedder session and query cache cleared");
526 }
527
528 pub fn warm(&self) -> Result<(), EmbedderError> {
530 let _ = self.embed_query("warmup")?;
531 Ok(())
532 }
533
534 pub fn embedding_dim(&self) -> usize {
537 let dim = *self.detected_dim.get().unwrap_or(&self.model_config.dim);
538 if dim == 0 {
539 EMBEDDING_DIM
540 } else {
541 dim
542 }
543 }
544
545 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, EmbedderError> {
561 use ort::value::Tensor;
562
563 let _span = tracing::info_span!("embed_batch", count = texts.len()).entered();
564
565 if texts.is_empty() {
566 return Ok(vec![]);
567 }
568
569 let encodings = {
573 let _tokenize = tracing::debug_span!("tokenize").entered();
574 self.tokenizer()?
575 .encode_batch(texts.to_vec(), true)
576 .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?
577 };
578
579 let input_ids: Vec<Vec<i64>> = encodings
581 .iter()
582 .map(|e| e.get_ids().iter().map(|&id| id as i64).collect())
583 .collect();
584 let attention_mask: Vec<Vec<i64>> = encodings
585 .iter()
586 .map(|e| e.get_attention_mask().iter().map(|&m| m as i64).collect())
587 .collect();
588
589 let max_len = input_ids
591 .iter()
592 .map(|v| v.len())
593 .max()
594 .unwrap_or(0)
595 .min(self.max_length);
596
597 let input_ids_arr = pad_2d_i64(&input_ids, max_len, 0);
599 let attention_mask_arr = pad_2d_i64(&attention_mask, max_len, 0);
600 let token_type_ids_arr = Array2::<i64>::zeros((texts.len(), max_len));
602
603 let input_ids_tensor = Tensor::from_array(input_ids_arr).map_err(ort_err)?;
605 let attention_mask_tensor = Tensor::from_array(attention_mask_arr).map_err(ort_err)?;
606 let token_type_ids_tensor = Tensor::from_array(token_type_ids_arr).map_err(ort_err)?;
607
608 let mut guard = self.session()?;
610 let session = guard
611 .as_mut()
612 .expect("session() guarantees initialized after Ok return");
613 let _inference = tracing::debug_span!("inference", max_len).entered();
614 let outputs = session
615 .run(ort::inputs![
616 "input_ids" => input_ids_tensor,
617 "attention_mask" => attention_mask_tensor,
618 "token_type_ids" => token_type_ids_tensor,
619 ])
620 .map_err(ort_err)?;
621
622 let output = outputs.get("last_hidden_state").ok_or_else(|| {
624 EmbedderError::InferenceFailed(format!(
625 "ONNX model has no 'last_hidden_state' output. Available: {:?}",
626 outputs.keys().collect::<Vec<_>>()
627 ))
628 })?;
629 let (shape, data) = output.try_extract_tensor::<f32>().map_err(ort_err)?;
630
631 let batch_size = texts.len();
633 let seq_len = max_len;
634 if shape.len() != 3 {
635 return Err(EmbedderError::InferenceFailed(format!(
636 "Unexpected tensor shape: expected 3 dimensions [batch, seq, dim], got {} dimensions",
637 shape.len()
638 )));
639 }
640 let embedding_dim = shape[2] as usize;
641 match self.detected_dim.get() {
643 Some(&expected) if expected != embedding_dim => {
644 return Err(EmbedderError::InferenceFailed(format!(
645 "Embedding dimension changed: expected {expected}, got {embedding_dim}"
646 )));
647 }
648 None => {
649 let _ = self.detected_dim.set(embedding_dim);
650 tracing::info!(
651 dim = embedding_dim,
652 "Detected embedding dimension from model"
653 );
654 }
655 _ => {} }
657 if shape[0] as usize != batch_size {
658 return Err(EmbedderError::InferenceFailed(format!(
659 "Tensor batch size mismatch: expected {}, got {}",
660 batch_size, shape[0]
661 )));
662 }
663 let hidden = Array3::from_shape_vec((batch_size, seq_len, embedding_dim), data.to_vec())
665 .map_err(|e| EmbedderError::InferenceFailed(format!("tensor reshape failed: {e}")))?;
666
667 let mask_2d = Array2::from_shape_fn((batch_size, seq_len), |(i, j)| {
669 attention_mask[i].get(j).copied().unwrap_or(0) as f32
670 });
671 let mask_3d = mask_2d.clone().insert_axis(Axis(2));
672
673 let masked = &hidden * &mask_3d;
675 let summed = masked.sum_axis(Axis(1)); let counts = mask_2d.sum_axis(Axis(1)).insert_axis(Axis(1)); let results = (0..batch_size)
679 .map(|i| {
680 let count = counts[[i, 0]];
681 let row = summed.row(i);
682 let pooled: Vec<f32> = if count > 0.0 {
683 row.iter().map(|v| v / count).collect()
684 } else {
685 tracing::warn!(batch_idx = i, "Zero attention mask — producing zero vector");
686 vec![0.0f32; embedding_dim]
687 };
688 Embedding::new(normalize_l2(pooled))
689 })
690 .collect();
691
692 Ok(results)
693 }
694}
695
696fn ensure_model(config: &ModelConfig) -> Result<(PathBuf, PathBuf), EmbedderError> {
698 if let Ok(dir) = std::env::var("CQS_ONNX_DIR") {
701 let dir = dunce::canonicalize(PathBuf::from(&dir)).unwrap_or_else(|_| PathBuf::from(dir));
702 let model_path = dir.join(&config.onnx_path);
703 let tokenizer_path = dir.join(&config.tokenizer_path);
704 if model_path.exists() && tokenizer_path.exists() {
705 tracing::info!(dir = %dir.display(), "Using local ONNX model directory");
706 return Ok((model_path, tokenizer_path));
707 }
708 let flat_model = dir.join("model.onnx");
710 let flat_tok = dir.join("tokenizer.json");
711 if flat_model.exists() && flat_tok.exists() {
712 tracing::info!(dir = %dir.display(), "Using local ONNX model directory (flat)");
713 return Ok((flat_model, flat_tok));
714 }
715 tracing::warn!(dir = %dir.display(), "CQS_ONNX_DIR set but model files not found, falling back to HF download");
716 }
717
718 use hf_hub::api::sync::Api;
719
720 let api = Api::new().map_err(|e| EmbedderError::HfHub(e.to_string()))?;
721 let repo = api.model(config.repo.clone());
722
723 let model_path = repo
724 .get(&config.onnx_path)
725 .map_err(|e| EmbedderError::HfHub(e.to_string()))?;
726 let tokenizer_path = repo
727 .get(&config.tokenizer_path)
728 .map_err(|e| EmbedderError::HfHub(e.to_string()))?;
729
730 if !MODEL_BLAKE3.is_empty() || !TOKENIZER_BLAKE3.is_empty() {
732 let marker = model_path
733 .parent()
734 .unwrap_or(Path::new("."))
735 .join(".cqs_verified");
736 let expected_marker = format!("{}\n{}", MODEL_BLAKE3, TOKENIZER_BLAKE3);
737 let already_verified = std::fs::read_to_string(&marker)
738 .map(|s| s == expected_marker)
739 .unwrap_or(false);
740
741 if !already_verified {
742 if !MODEL_BLAKE3.is_empty() {
743 verify_checksum(&model_path, MODEL_BLAKE3)?;
744 }
745 if !TOKENIZER_BLAKE3.is_empty() {
746 verify_checksum(&tokenizer_path, TOKENIZER_BLAKE3)?;
747 }
748 let _ = std::fs::write(&marker, &expected_marker);
750 }
751 }
752
753 Ok((model_path, tokenizer_path))
754}
755
756fn verify_checksum(path: &Path, expected: &str) -> Result<(), EmbedderError> {
758 let mut file =
759 std::fs::File::open(path).map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
760 let mut hasher = blake3::Hasher::new();
761 std::io::copy(&mut file, &mut hasher)
762 .map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
763 let actual = hasher.finalize().to_hex().to_string();
764
765 if actual != expected {
766 return Err(EmbedderError::ChecksumMismatch {
767 path: path.display().to_string(),
768 expected: expected.to_string(),
769 actual,
770 });
771 }
772 Ok(())
773}
774
775pub(crate) fn pad_2d_i64(inputs: &[Vec<i64>], max_len: usize, pad_value: i64) -> Array2<i64> {
777 let batch_size = inputs.len();
778 let mut arr = Array2::from_elem((batch_size, max_len), pad_value);
779 for (i, seq) in inputs.iter().enumerate() {
780 for (j, &val) in seq.iter().take(max_len).enumerate() {
781 arr[[i, j]] = val;
782 }
783 }
784 arr
785}
786
787fn normalize_l2(mut v: Vec<f32>) -> Vec<f32> {
789 let norm_sq: f32 = v.iter().fold(0.0, |acc, &x| acc + x * x);
790 if norm_sq > 0.0 {
791 let inv_norm = 1.0 / norm_sq.sqrt();
792 v.iter_mut().for_each(|x| *x *= inv_norm);
793 }
794 v
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
804 fn test_embedding_new() {
805 let data = vec![0.5; EMBEDDING_DIM];
806 let emb = Embedding::new(data.clone());
807 assert_eq!(emb.as_slice(), &data);
808 }
809
810 #[test]
811 fn test_embedding_len() {
812 let emb = Embedding::new(vec![1.0; EMBEDDING_DIM]);
813 assert_eq!(emb.len(), EMBEDDING_DIM);
814 }
815
816 #[test]
817 fn test_embedding_is_empty() {
818 let empty = Embedding::new(vec![]);
819 assert!(empty.is_empty());
820
821 let non_empty = Embedding::new(vec![1.0; EMBEDDING_DIM]);
822 assert!(!non_empty.is_empty());
823 }
824
825 #[test]
826 fn test_embedding_into_inner() {
827 let data = vec![1.0; EMBEDDING_DIM];
828 let emb = Embedding::new(data.clone());
829 assert_eq!(emb.into_inner(), data);
830 }
831
832 #[test]
833 fn test_embedding_as_vec() {
834 let data = vec![1.0; EMBEDDING_DIM];
835 let emb = Embedding::new(data.clone());
836 assert_eq!(emb.as_vec(), &data);
837 }
838
839 #[test]
842 fn tc33_try_new_empty_vec_errors() {
843 let result = Embedding::try_new(vec![]);
844 assert!(result.is_err());
845 let err = result.unwrap_err();
846 assert_eq!(err.actual, 0);
847 assert_eq!(err.expected, 1);
848 }
849
850 #[test]
851 fn tc33_try_new_nan_errors() {
852 let result = Embedding::try_new(vec![1.0, f32::NAN, 3.0]);
853 assert!(result.is_err(), "NaN should be rejected by try_new");
854 }
855
856 #[test]
857 fn tc33_try_new_inf_errors() {
858 let result = Embedding::try_new(vec![1.0, f32::INFINITY, 3.0]);
859 assert!(result.is_err(), "Infinity should be rejected by try_new");
860
861 let result = Embedding::try_new(vec![f32::NEG_INFINITY]);
862 assert!(result.is_err(), "Negative infinity should be rejected");
863 }
864
865 #[test]
866 fn tc33_try_new_valid_ok() {
867 let data = vec![0.1, 0.2, 0.3, 0.4, 0.5];
868 let result = Embedding::try_new(data.clone());
869 assert!(result.is_ok());
870 assert_eq!(result.unwrap().as_slice(), &data);
871 }
872
873 #[test]
876 fn test_normalize_l2_unit_vector() {
877 let v = normalize_l2(vec![1.0, 0.0, 0.0]);
878 assert!((v[0] - 1.0).abs() < 1e-6);
879 assert!((v[1] - 0.0).abs() < 1e-6);
880 assert!((v[2] - 0.0).abs() < 1e-6);
881 }
882
883 #[test]
884 fn test_normalize_l2_produces_unit_vector() {
885 let v = normalize_l2(vec![3.0, 4.0]);
886 assert!((v[0] - 0.6).abs() < 1e-6);
888 assert!((v[1] - 0.8).abs() < 1e-6);
889
890 let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
892 assert!((magnitude - 1.0).abs() < 1e-6);
893 }
894
895 #[test]
896 fn test_normalize_l2_zero_vector() {
897 let v = normalize_l2(vec![0.0, 0.0, 0.0]);
899 assert_eq!(v, vec![0.0, 0.0, 0.0]);
900 }
901
902 #[test]
903 fn test_normalize_l2_empty_vector() {
904 let v = normalize_l2(vec![]);
905 assert!(v.is_empty());
906 }
907
908 #[test]
911 fn test_execution_provider_display() {
912 assert_eq!(format!("{}", ExecutionProvider::CPU), "CPU");
913 assert_eq!(
914 format!("{}", ExecutionProvider::CUDA { device_id: 0 }),
915 "CUDA (device 0)"
916 );
917 assert_eq!(
918 format!("{}", ExecutionProvider::TensorRT { device_id: 1 }),
919 "TensorRT (device 1)"
920 );
921 }
922
923 #[test]
926 fn test_model_dimensions() {
927 assert_eq!(EMBEDDING_DIM, 1024);
928 }
929
930 #[test]
933 fn test_pad_2d_i64_basic() {
934 let inputs = vec![vec![1, 2, 3], vec![4, 5]];
935 let result = pad_2d_i64(&inputs, 4, 0);
936 assert_eq!(result.shape(), &[2, 4]);
937 assert_eq!(result[[0, 0]], 1);
938 assert_eq!(result[[0, 1]], 2);
939 assert_eq!(result[[0, 2]], 3);
940 assert_eq!(result[[0, 3]], 0); assert_eq!(result[[1, 0]], 4);
942 assert_eq!(result[[1, 1]], 5);
943 assert_eq!(result[[1, 2]], 0); assert_eq!(result[[1, 3]], 0); }
946
947 #[test]
948 fn test_pad_2d_i64_truncates() {
949 let inputs = vec![vec![1, 2, 3, 4, 5]];
950 let result = pad_2d_i64(&inputs, 3, 0);
951 assert_eq!(result.shape(), &[1, 3]);
952 assert_eq!(result[[0, 0]], 1);
953 assert_eq!(result[[0, 1]], 2);
954 assert_eq!(result[[0, 2]], 3);
955 }
957
958 #[test]
959 fn test_pad_2d_i64_empty_input() {
960 let inputs: Vec<Vec<i64>> = vec![];
961 let result = pad_2d_i64(&inputs, 5, 0);
962 assert_eq!(result.shape(), &[0, 5]);
963 }
964
965 #[test]
966 fn test_pad_2d_i64_custom_pad_value() {
967 let inputs = vec![vec![1]];
968 let result = pad_2d_i64(&inputs, 3, -1);
969 assert_eq!(result[[0, 0]], 1);
970 assert_eq!(result[[0, 1]], -1);
971 assert_eq!(result[[0, 2]], -1);
972 }
973
974 #[test]
977 fn test_embedder_error_display() {
978 let err = EmbedderError::EmptyQuery;
979 assert_eq!(format!("{}", err), "Query cannot be empty");
980
981 let err = EmbedderError::ModelNotFound("model.onnx".to_string());
982 assert!(format!("{}", err).contains("model.onnx"));
983
984 let err = EmbedderError::Tokenizer("invalid token".to_string());
985 assert!(format!("{}", err).contains("invalid token"));
986
987 let err = EmbedderError::ChecksumMismatch {
988 path: "/path/to/file".to_string(),
989 expected: "abc123".to_string(),
990 actual: "def456".to_string(),
991 };
992 assert!(format!("{}", err).contains("abc123"));
993 assert!(format!("{}", err).contains("def456"));
994 }
995
996 #[test]
997 fn test_embedder_error_from_ort() {
998 let err: EmbedderError = EmbedderError::InferenceFailed("test error".to_string());
1001 assert!(matches!(err, EmbedderError::InferenceFailed(_)));
1002 }
1003
1004 mod proptests {
1007 use super::*;
1008 use proptest::prelude::*;
1009
1010 proptest! {
1011 #[test]
1013 fn prop_normalize_l2_unit_or_zero(v in prop::collection::vec(-1e6f32..1e6f32, 1..100)) {
1014 let normalized = normalize_l2(v.clone());
1015
1016 let magnitude: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
1018
1019 let input_is_zero = v.iter().all(|&x| x == 0.0);
1021 if input_is_zero {
1022 prop_assert!(magnitude < 1e-6, "Zero input should give zero output");
1023 } else {
1024 prop_assert!(
1025 (magnitude - 1.0).abs() < 1e-4,
1026 "Non-zero input should give unit vector, got magnitude {}",
1027 magnitude
1028 );
1029 }
1030 }
1031
1032 #[test]
1034 fn prop_normalize_l2_preserves_direction(v in prop::collection::vec(1.0f32..100.0, 1..50)) {
1035 let normalized = normalize_l2(v.clone());
1036
1037 let dot: f32 = v.iter().zip(normalized.iter()).map(|(a, b)| a * b).sum();
1039 prop_assert!(dot > 0.0, "Direction should be preserved");
1040 }
1041
1042 #[test]
1044 fn prop_embedding_length_preserved(use_model_dim in proptest::bool::ANY) {
1045 let _ = use_model_dim; let emb = Embedding::new(vec![0.5; EMBEDDING_DIM]);
1047 prop_assert_eq!(emb.len(), EMBEDDING_DIM);
1048 prop_assert_eq!(emb.as_slice().len(), EMBEDDING_DIM);
1049 prop_assert_eq!(emb.as_vec().len(), EMBEDDING_DIM);
1050 }
1051 }
1052 }
1053
1054 #[test]
1057 #[ignore] fn test_clear_session_and_reinit() {
1059 let embedder = Embedder::new(ModelConfig::e5_base()).unwrap();
1060 let _ = embedder.embed_query("test");
1062 embedder.clear_session();
1064 let result = embedder.embed_query("test again");
1065 assert!(result.is_ok());
1066 }
1067
1068 #[test]
1069 fn test_clear_session_idempotent() {
1070 let embedder = Embedder::new_cpu(ModelConfig::e5_base()).unwrap();
1071 embedder.clear_session(); embedder.clear_session(); }
1074
1075 mod integration {
1078 use super::*;
1079
1080 #[test]
1081 #[ignore] fn test_token_count_empty() {
1083 let embedder =
1084 Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1085 let count = embedder.token_count("").expect("token_count failed");
1086 assert_eq!(count, 0);
1087 }
1088
1089 #[test]
1090 #[ignore]
1091 fn test_token_count_simple() {
1092 let embedder =
1093 Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1094 let count = embedder
1095 .token_count("hello world")
1096 .expect("token_count failed");
1097 assert!(
1099 (2..=4).contains(&count),
1100 "Expected 2-4 tokens, got {}",
1101 count
1102 );
1103 }
1104
1105 #[test]
1106 #[ignore]
1107 fn test_token_count_code() {
1108 let embedder =
1109 Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1110 let code = "fn main() { println!(\"Hello\"); }";
1111 let count = embedder.token_count(code).expect("token_count failed");
1112 assert!(count > 5, "Expected >5 tokens for code, got {}", count);
1114 }
1115
1116 #[test]
1117 #[ignore]
1118 fn test_token_count_unicode() {
1119 let embedder =
1120 Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1121 let text = "\u{3053}\u{3093}\u{306b}\u{3061}\u{306f}\u{4e16}\u{754c}"; let count = embedder.token_count(text).expect("token_count failed");
1123 assert!(count > 0, "Expected >0 tokens for unicode, got {}", count);
1125 }
1126 }
1127
1128 mod ensure_model_tests {
1131 use super::*;
1132 use std::sync::Mutex;
1133
1134 static ONNX_DIR_MUTEX: Mutex<()> = Mutex::new(());
1136
1137 fn test_model_config() -> ModelConfig {
1138 ModelConfig {
1139 name: "test".to_string(),
1140 repo: "test/model".to_string(),
1141 onnx_path: "onnx/model.onnx".to_string(),
1142 tokenizer_path: "tokenizer.json".to_string(),
1143 dim: 768,
1144 max_seq_length: 512,
1145 query_prefix: String::new(),
1146 doc_prefix: String::new(),
1147 }
1148 }
1149
1150 #[test]
1151 fn cqs_onnx_dir_structured_layout() {
1152 let _lock = ONNX_DIR_MUTEX.lock().unwrap();
1153 let dir = tempfile::TempDir::new().unwrap();
1154 let onnx_dir = dir.path().join("onnx");
1155 std::fs::create_dir_all(&onnx_dir).unwrap();
1156 std::fs::write(onnx_dir.join("model.onnx"), b"fake").unwrap();
1157 std::fs::write(dir.path().join("tokenizer.json"), b"fake").unwrap();
1158
1159 std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
1160 let result = ensure_model(&test_model_config());
1161 std::env::remove_var("CQS_ONNX_DIR");
1162
1163 let (model, tok) = result.unwrap();
1164 assert!(
1165 model.to_string_lossy().ends_with("model.onnx"),
1166 "Expected model path ending in model.onnx, got {:?}",
1167 model
1168 );
1169 assert!(
1170 tok.to_string_lossy().ends_with("tokenizer.json"),
1171 "Expected tokenizer path ending in tokenizer.json, got {:?}",
1172 tok
1173 );
1174 }
1175
1176 #[test]
1177 fn cqs_onnx_dir_flat_layout() {
1178 let _lock = ONNX_DIR_MUTEX.lock().unwrap();
1179 let dir = tempfile::TempDir::new().unwrap();
1180 std::fs::write(dir.path().join("model.onnx"), b"fake").unwrap();
1181 std::fs::write(dir.path().join("tokenizer.json"), b"fake").unwrap();
1182
1183 std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
1184 let result = ensure_model(&test_model_config());
1185 std::env::remove_var("CQS_ONNX_DIR");
1186
1187 let (model, tok) = result.unwrap();
1188 assert!(
1189 model.to_string_lossy().ends_with("model.onnx"),
1190 "Expected model path ending in model.onnx, got {:?}",
1191 model
1192 );
1193 assert!(
1194 tok.to_string_lossy().ends_with("tokenizer.json"),
1195 "Expected tokenizer path ending in tokenizer.json, got {:?}",
1196 tok
1197 );
1198 }
1199
1200 #[test]
1201 fn cqs_onnx_dir_missing_files_falls_through() {
1202 let _lock = ONNX_DIR_MUTEX.lock().unwrap();
1203 let dir = tempfile::TempDir::new().unwrap();
1204 std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
1207 let result = ensure_model(&test_model_config());
1208 std::env::remove_var("CQS_ONNX_DIR");
1209
1210 assert!(
1213 result.is_err() || !result.as_ref().unwrap().0.starts_with(dir.path()),
1214 "Should not return paths from empty CQS_ONNX_DIR"
1215 );
1216 }
1217 }
1218}