1use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12#[cfg(feature = "onnx-embeddings")]
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19use crate::embedding::{EmbeddingError, EmbeddingProvider, MockEmbeddingProvider};
20
21pub struct WordPieceTokenizer {
36 vocab: HashMap<String, i64>,
38 max_length: usize,
40 max_word_chars: usize,
42}
43
44const CLS_ID: i64 = 101;
46const SEP_ID: i64 = 102;
47const UNK_ID: i64 = 100;
48const PAD_ID: i64 = 0;
49
50impl WordPieceTokenizer {
51 pub fn load(vocab_path: &Path) -> Option<Self> {
55 let content = std::fs::read_to_string(vocab_path).ok()?;
56 let line_count = content.lines().count();
57 let mut vocab = HashMap::with_capacity(line_count);
58 for (id, line) in content.lines().enumerate() {
59 vocab.insert(line.to_string(), id as i64);
60 }
61 if vocab.len() < 1000 {
62 tracing::warn!(
64 "vocab.txt at {} has only {} entries, expected ~30k",
65 vocab_path.display(),
66 vocab.len()
67 );
68 return None;
69 }
70 tracing::info!(
71 "WordPiece vocab loaded: {} tokens from {}",
72 vocab.len(),
73 vocab_path.display()
74 );
75 Some(Self {
76 vocab,
77 max_length: 128,
78 max_word_chars: 100,
79 })
80 }
81
82 pub fn with_max_length(mut self, max_length: usize) -> Self {
84 self.max_length = max_length;
85 self
86 }
87
88 pub fn encode(&self, text: &str) -> Option<(Vec<i64>, Vec<i64>, Vec<i64>)> {
93 if self.vocab.is_empty() {
94 return None;
95 }
96
97 let mut token_ids: Vec<i64> = Vec::with_capacity(self.max_length);
98 token_ids.push(CLS_ID);
99
100 let words = self.pre_tokenize(text);
102
103 for word in &words {
104 if token_ids.len() >= self.max_length - 1 {
105 break; }
107 let sub_ids = self.wordpiece_split(word);
108 for id in sub_ids {
109 if token_ids.len() >= self.max_length - 1 {
110 break;
111 }
112 token_ids.push(id);
113 }
114 }
115
116 token_ids.push(SEP_ID);
117
118 let seq_len = token_ids.len();
119 let mut attention_mask = vec![1i64; seq_len];
120 let mut token_type_ids = vec![0i64; seq_len];
121
122 while token_ids.len() < self.max_length {
124 token_ids.push(PAD_ID);
125 attention_mask.push(0);
126 token_type_ids.push(0);
127 }
128
129 Some((token_ids, attention_mask, token_type_ids))
130 }
131
132 fn pre_tokenize(&self, text: &str) -> Vec<String> {
134 let lower = text.to_lowercase();
135 let mut words = Vec::new();
136 let mut current = String::new();
137
138 for ch in lower.chars() {
139 if ch.is_whitespace() {
140 if !current.is_empty() {
141 words.push(std::mem::take(&mut current));
142 }
143 } else if ch.is_ascii_punctuation() || is_cjk_char(ch) {
144 if !current.is_empty() {
146 words.push(std::mem::take(&mut current));
147 }
148 words.push(ch.to_string());
149 } else if is_accent_char(ch) {
150 continue;
152 } else if ch.is_control() {
153 continue;
154 } else {
155 current.push(ch);
156 }
157 }
158 if !current.is_empty() {
159 words.push(current);
160 }
161 words
162 }
163
164 fn wordpiece_split(&self, word: &str) -> Vec<i64> {
166 if word.len() > self.max_word_chars {
167 return vec![UNK_ID];
168 }
169
170 let chars: Vec<char> = word.chars().collect();
171 let mut ids = Vec::new();
172 let mut start = 0;
173
174 while start < chars.len() {
175 let mut end = chars.len();
176 let mut found = false;
177
178 while start < end {
179 let substr: String = if start == 0 {
180 chars[start..end].iter().collect()
181 } else {
182 format!("##{}", chars[start..end].iter().collect::<String>())
183 };
184
185 if self.vocab.contains_key(&substr) {
186 ids.push(self.vocab[&substr]);
187 found = true;
188 start = end;
189 break;
190 }
191 end -= 1;
192 }
193
194 if !found {
195 ids.push(UNK_ID);
196 start += 1;
197 }
198 }
199
200 ids
201 }
202}
203
204fn is_cjk_char(ch: char) -> bool {
206 let cp = ch as u32;
207 matches!(cp,
208 0x4E00..=0x9FFF
209 | 0x3400..=0x4DBF
210 | 0x20000..=0x2A6DF
211 | 0x2A700..=0x2B73F
212 | 0x2B740..=0x2B81F
213 | 0x2B820..=0x2CEAF
214 | 0xF900..=0xFAFF
215 | 0x2F800..=0x2FA1F
216 )
217}
218
219fn is_accent_char(ch: char) -> bool {
221 let cp = ch as u32;
222 matches!(cp, 0x0300..=0x036F | 0x1AB0..=0x1AFF | 0x1DC0..=0x1DFF | 0xFE20..=0xFE2F)
223}
224
225#[cfg_attr(not(feature = "onnx-embeddings"), allow(dead_code))]
230fn vocab_search_paths(model_path: &Path) -> Vec<PathBuf> {
231 let mut paths = Vec::new();
232
233 if let Some(parent) = model_path.parent() {
234 paths.push(parent.join("vocab.txt"));
236
237 if let Some(stem) = model_path.file_stem() {
239 paths.push(parent.join(stem).join("vocab.txt"));
240 }
241 }
242
243 let model_dir_name = "all-MiniLM-L6-v2";
245 paths.push(PathBuf::from(format!(".weftos/models/{model_dir_name}/vocab.txt")));
246 if let Ok(home) = std::env::var("HOME") {
247 paths.push(PathBuf::from(format!("{home}/.weftos/models/{model_dir_name}/vocab.txt")));
248 }
249 if let Ok(env_dir) = std::env::var("WEFTOS_VOCAB_PATH") {
250 paths.push(PathBuf::from(env_dir));
251 }
252
253 paths
254}
255
256fn simple_tokenize(text: &str, max_tokens: usize) -> Vec<String> {
262 text.to_lowercase()
263 .split_whitespace()
264 .take(max_tokens)
265 .map(|s| s.chars().filter(|c| c.is_alphanumeric()).collect::<String>())
266 .filter(|s| !s.is_empty())
267 .collect()
268}
269
270fn tokens_to_embedding(tokens: &[String], dims: usize) -> Vec<f32> {
273 let mut embedding = vec![0.0f32; dims];
274
275 for (i, token) in tokens.iter().enumerate() {
276 let mut hasher = Sha256::new();
277 hasher.update(token.as_bytes());
278 hasher.update((i as u32).to_le_bytes());
279 let hash = hasher.finalize();
280
281 for (j, &byte) in hash.iter().enumerate() {
283 let dim = (j + i * 32) % dims;
284 let val = (byte as f32 / 128.0) - 1.0; embedding[dim] += val / (tokens.len() as f32).sqrt();
286 }
287 }
288
289 l2_normalize(&mut embedding);
290 embedding
291}
292
293fn l2_normalize(vec: &mut [f32]) {
295 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
296 if norm > 0.0 {
297 vec.iter_mut().for_each(|x| *x /= norm);
298 }
299}
300
301#[cfg(test)]
303fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
304 debug_assert_eq!(a.len(), b.len());
305 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
306 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
307 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
308 if norm_a == 0.0 || norm_b == 0.0 {
309 return 0.0;
310 }
311 dot / (norm_a * norm_b)
312}
313
314pub struct OnnxEmbeddingProvider {
325 model_path: PathBuf,
327 dimensions: usize,
329 model_name: String,
331 runtime_available: bool,
333 max_tokens: usize,
335 #[allow(dead_code)]
337 fallback: MockEmbeddingProvider,
338 #[cfg(feature = "onnx-embeddings")]
341 tokenizer: Option<WordPieceTokenizer>,
342 #[cfg(feature = "onnx-embeddings")]
345 session: Option<Arc<ort::Session>>,
346}
347
348impl OnnxEmbeddingProvider {
349 pub const DEFAULT_DIMS: usize = 384;
351 pub const DEFAULT_MAX_TOKENS: usize = 128;
353 pub const MODEL_NAME: &'static str = "all-MiniLM-L6-v2";
355
356 pub fn new(model_path: impl Into<PathBuf>) -> Self {
361 let model_path = model_path.into();
362 #[cfg(feature = "onnx-embeddings")]
363 let session = Self::try_load_session(&model_path);
364 #[cfg(feature = "onnx-embeddings")]
365 let runtime_available = session.is_some();
366 #[cfg(not(feature = "onnx-embeddings"))]
367 let runtime_available = false;
368 #[cfg(feature = "onnx-embeddings")]
369 let tokenizer = Self::try_load_tokenizer(&model_path, Self::DEFAULT_MAX_TOKENS);
370
371 Self {
372 model_name: if runtime_available {
373 Self::MODEL_NAME.to_string()
374 } else {
375 format!("{}-hash-fallback", Self::MODEL_NAME)
376 },
377 model_path,
378 dimensions: Self::DEFAULT_DIMS,
379 runtime_available,
380 max_tokens: Self::DEFAULT_MAX_TOKENS,
381 fallback: MockEmbeddingProvider::new(Self::DEFAULT_DIMS),
382 #[cfg(feature = "onnx-embeddings")]
383 tokenizer,
384 #[cfg(feature = "onnx-embeddings")]
385 session,
386 }
387 }
388
389 pub fn with_config(
391 model_path: impl Into<PathBuf>,
392 dimensions: usize,
393 max_tokens: usize,
394 ) -> Self {
395 let model_path = model_path.into();
396 #[cfg(feature = "onnx-embeddings")]
397 let session = Self::try_load_session(&model_path);
398 #[cfg(feature = "onnx-embeddings")]
399 let runtime_available = session.is_some();
400 #[cfg(not(feature = "onnx-embeddings"))]
401 let runtime_available = false;
402 #[cfg(feature = "onnx-embeddings")]
403 let tokenizer = Self::try_load_tokenizer(&model_path, max_tokens);
404
405 Self {
406 model_name: if runtime_available {
407 Self::MODEL_NAME.to_string()
408 } else {
409 format!("{}-hash-fallback", Self::MODEL_NAME)
410 },
411 model_path,
412 dimensions,
413 runtime_available,
414 max_tokens,
415 fallback: MockEmbeddingProvider::new(dimensions),
416 #[cfg(feature = "onnx-embeddings")]
417 tokenizer,
418 #[cfg(feature = "onnx-embeddings")]
419 session,
420 }
421 }
422
423 #[cfg(feature = "onnx-embeddings")]
425 fn try_load_tokenizer(model_path: &Path, max_tokens: usize) -> Option<WordPieceTokenizer> {
426 for path in vocab_search_paths(model_path) {
427 if path.exists() {
428 if let Some(tok) = WordPieceTokenizer::load(&path) {
429 let max_len = (max_tokens + 2).min(512);
433 return Some(tok.with_max_length(max_len));
434 }
435 }
436 }
437 tracing::debug!(
438 "No vocab.txt found for WordPiece tokenizer near {}; \
439 ONNX inference will use hash-based token IDs (degraded quality)",
440 model_path.display()
441 );
442 None
443 }
444
445 #[cfg(feature = "onnx-embeddings")]
447 fn try_load_session(model_path: &PathBuf) -> Option<Arc<ort::Session>> {
448 if !model_path.exists() {
449 tracing::debug!("ONNX model not found at {}, using hash fallback", model_path.display());
450 return None;
451 }
452 match ort::Session::builder()
453 .and_then(|builder| builder.commit_from_file(model_path))
454 {
455 Ok(session) => {
456 tracing::info!("ONNX session loaded from {}", model_path.display());
457 Some(Arc::new(session))
458 }
459 Err(e) => {
460 tracing::warn!("Failed to load ONNX session: {e}, using hash fallback");
461 None
462 }
463 }
464 }
465
466 pub fn is_runtime_available(&self) -> bool {
468 self.runtime_available
469 }
470
471 pub fn model_path(&self) -> &PathBuf {
473 &self.model_path
474 }
475
476 pub fn max_tokens(&self) -> usize {
478 self.max_tokens
479 }
480
481 fn hash_embed(&self, text: &str) -> Vec<f32> {
483 let tokens = simple_tokenize(text, self.max_tokens);
484 if tokens.is_empty() {
485 return vec![0.0f32; self.dimensions];
487 }
488 tokens_to_embedding(&tokens, self.dimensions)
489 }
490
491 #[cfg(feature = "onnx-embeddings")]
502 fn onnx_embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
503 use ndarray::Array2;
504
505 let session = self.session.as_ref().ok_or_else(|| {
506 EmbeddingError::BackendError("ONNX session not loaded".to_string())
507 })?;
508
509 let (input_ids, attention_mask, token_type_ids) = if let Some(ref tokenizer) = self.tokenizer {
511 tokenizer.encode(text).ok_or_else(|| {
512 EmbeddingError::BackendError("WordPiece tokenizer returned None".to_string())
513 })?
514 } else {
515 tracing::warn_once!(
518 "ONNX inference without WordPiece vocab — embeddings will not be semantic"
519 );
520 let tokens = simple_tokenize(text, self.max_tokens);
521 let seq_len = tokens.len().max(1) + 2; let mut ids = vec![CLS_ID];
524 for token in &tokens {
525 let mut hasher = Sha256::new();
526 hasher.update(token.as_bytes());
527 let hash = hasher.finalize();
528 let id = 1000
529 + (u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]]) % 29000)
530 as i64;
531 ids.push(id);
532 }
533 ids.push(SEP_ID);
534
535 let mask = vec![1i64; seq_len];
536 let types = vec![0i64; seq_len];
537 (ids, mask, types)
538 };
539
540 let seq_len = input_ids.len();
541
542 let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
543 .map_err(|e| EmbeddingError::BackendError(format!("shape error: {e}")))?;
544 let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask.clone())
545 .map_err(|e| EmbeddingError::BackendError(format!("shape error: {e}")))?;
546 let token_type_ids_arr = Array2::from_shape_vec((1, seq_len), token_type_ids)
547 .map_err(|e| EmbeddingError::BackendError(format!("shape error: {e}")))?;
548
549 let inputs = ort::inputs![
550 "input_ids" => input_ids_arr,
551 "attention_mask" => attention_mask_arr,
552 "token_type_ids" => token_type_ids_arr,
553 ].map_err(|e| EmbeddingError::BackendError(format!("input error: {e}")))?;
554
555 let outputs = session.run(inputs)
556 .map_err(|e| EmbeddingError::BackendError(format!("inference error: {e}")))?;
557
558 let output_tensor = outputs.get("last_hidden_state")
561 .or_else(|| outputs.iter().next().map(|(_, v)| v))
562 .ok_or_else(|| EmbeddingError::BackendError("no output tensor".to_string()))?;
563
564 let tensor = output_tensor
565 .try_extract_tensor::<f32>()
566 .map_err(|e| EmbeddingError::BackendError(format!("extract error: {e}")))?;
567
568 let shape = tensor.shape();
569 if shape.len() < 2 {
570 return Err(EmbeddingError::BackendError(
571 format!("unexpected output shape: {shape:?}"),
572 ));
573 }
574 let hidden_dim = *shape.last().unwrap();
575 let seq = shape[1];
576
577 let mut embedding = vec![0.0f32; hidden_dim];
579 let data = tensor.as_slice().ok_or_else(|| {
580 EmbeddingError::BackendError("tensor not contiguous".to_string())
581 })?;
582
583 let mut active_count: f32 = 0.0;
584 for s in 0..seq {
585 let mask_val = if s < attention_mask.len() {
586 attention_mask[s] as f32
587 } else {
588 0.0
589 };
590 if mask_val > 0.0 {
591 for d in 0..hidden_dim {
592 embedding[d] += data[s * hidden_dim + d];
593 }
594 active_count += 1.0;
595 }
596 }
597 if active_count > 0.0 {
598 for val in &mut embedding {
599 *val /= active_count;
600 }
601 }
602
603 l2_normalize(&mut embedding);
605
606 embedding.truncate(self.dimensions);
608 while embedding.len() < self.dimensions {
609 embedding.push(0.0);
610 }
611
612 Ok(embedding)
613 }
614}
615
616#[async_trait]
617impl EmbeddingProvider for OnnxEmbeddingProvider {
618 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
619 #[cfg(feature = "onnx-embeddings")]
620 if self.runtime_available {
621 return self.onnx_embed(text);
622 }
623 Ok(self.hash_embed(text))
624 }
625
626 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
627 let mut results = Vec::with_capacity(texts.len());
628 for text in texts {
629 results.push(self.embed(text).await?);
630 }
631 Ok(results)
632 }
633
634 fn dimensions(&self) -> usize {
635 self.dimensions
636 }
637
638 fn model_name(&self) -> &str {
639 &self.model_name
640 }
641}
642
643pub struct SentenceTransformerProvider {
652 base: OnnxEmbeddingProvider,
654 max_tokens: usize,
656 split_sentences: bool,
658}
659
660impl SentenceTransformerProvider {
661 pub const DEFAULT_MAX_TOKENS: usize = 512;
663
664 pub fn new(model_path: impl Into<PathBuf>) -> Self {
666 Self {
667 base: OnnxEmbeddingProvider::with_config(
668 model_path,
669 OnnxEmbeddingProvider::DEFAULT_DIMS,
670 Self::DEFAULT_MAX_TOKENS,
671 ),
672 max_tokens: Self::DEFAULT_MAX_TOKENS,
673 split_sentences: true,
674 }
675 }
676
677 pub fn with_config(
679 model_path: impl Into<PathBuf>,
680 max_tokens: usize,
681 split_sentences: bool,
682 ) -> Self {
683 Self {
684 base: OnnxEmbeddingProvider::with_config(
685 model_path,
686 OnnxEmbeddingProvider::DEFAULT_DIMS,
687 max_tokens,
688 ),
689 max_tokens,
690 split_sentences,
691 }
692 }
693
694 pub fn split_sentences(&self) -> bool {
696 self.split_sentences
697 }
698
699 pub fn max_tokens(&self) -> usize {
701 self.max_tokens
702 }
703
704 fn embed_text(&self, text: &str) -> Vec<f32> {
706 self.base.hash_embed(text)
707 }
708}
709
710pub fn preprocess_markdown(text: &str) -> String {
712 text.lines()
713 .filter(|l| !l.starts_with('#')) .filter(|l| !l.starts_with("```")) .filter(|l| !l.starts_with('|')) .filter(|l| !l.starts_with("- [")) .map(|l| l.trim())
718 .filter(|l| !l.is_empty())
719 .collect::<Vec<_>>()
720 .join(" ")
721}
722
723pub fn split_sentences(text: &str) -> Vec<&str> {
725 text.split(". ")
726 .flat_map(|s| s.split('\n'))
727 .map(|s| s.trim())
728 .filter(|s| s.len() > 10) .collect()
730}
731
732#[async_trait]
733impl EmbeddingProvider for SentenceTransformerProvider {
734 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
735 let cleaned = preprocess_markdown(text);
736
737 if !self.split_sentences {
738 return Ok(self.embed_text(&cleaned));
739 }
740
741 let sentences = split_sentences(&cleaned);
742 if sentences.is_empty() {
743 return Ok(self.embed_text(&cleaned));
745 }
746
747 let dims = self.base.dimensions;
749 let mut summed = vec![0.0f32; dims];
750 let count = sentences.len() as f32;
751
752 for sentence in &sentences {
753 let vec = self.embed_text(sentence);
754 for (i, val) in vec.iter().enumerate() {
755 summed[i] += val;
756 }
757 }
758
759 summed.iter_mut().for_each(|x| *x /= count);
761 l2_normalize(&mut summed);
762
763 Ok(summed)
764 }
765
766 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
767 let mut results = Vec::with_capacity(texts.len());
768 for text in texts {
769 results.push(self.embed(text).await?);
770 }
771 Ok(results)
772 }
773
774 fn dimensions(&self) -> usize {
775 self.base.dimensions
776 }
777
778 fn model_name(&self) -> &str {
779 "sentence-transformer"
780 }
781}
782
783#[derive(Debug, Clone, Default, Serialize, Deserialize)]
789pub struct RustCodeFeatures {
790 pub signature: Option<String>,
792 pub return_type: Option<String>,
794 pub param_types: Vec<String>,
796 pub visibility: String,
798 pub is_async: bool,
800 pub is_generic: bool,
802 pub trait_bounds: Vec<String>,
804 pub attributes: Vec<String>,
806 pub item_kind: String,
808}
809
810pub fn extract_rust_features(code: &str) -> RustCodeFeatures {
813 let mut features = RustCodeFeatures::default();
814
815 if code.contains("pub trait ") || code.contains("trait ") {
819 features.item_kind = "trait".into();
820 } else if code.contains("pub struct ") || code.contains("struct ") {
821 features.item_kind = "struct".into();
822 } else if code.contains("pub enum ") || code.contains("enum ") {
823 features.item_kind = "enum".into();
824 } else if code.contains("pub fn ") || code.contains("fn ") {
825 features.item_kind = "fn".into();
826 } else if code.contains("impl ") {
827 features.item_kind = "impl".into();
828 } else if code.contains("pub mod ") || code.contains("mod ") {
829 features.item_kind = "mod".into();
830 }
831
832 features.visibility = if code.contains("pub(crate)") {
834 "pub(crate)".into()
835 } else if code.contains("pub(super)") {
836 "pub(super)".into()
837 } else if code.contains("pub ") {
838 "pub".into()
839 } else {
840 "private".into()
841 };
842
843 features.is_async = code.contains("async fn");
845
846 features.is_generic = code.contains('<') && code.contains('>');
848
849 for line in code.lines() {
851 let trimmed = line.trim();
852 if trimmed.contains("fn ")
853 || trimmed.starts_with("pub struct ")
854 || trimmed.starts_with("struct ")
855 || trimmed.starts_with("pub enum ")
856 || trimmed.starts_with("enum ")
857 || trimmed.starts_with("pub trait ")
858 || trimmed.starts_with("trait ")
859 {
860 let sig = if let Some(brace) = trimmed.find('{') {
862 trimmed[..brace].trim()
863 } else {
864 trimmed.trim_end_matches(';').trim()
865 };
866 features.signature = Some(sig.to_string());
867 break;
868 }
869 }
870
871 if let Some(arrow) = code.find("->") {
873 let after = &code[arrow + 2..];
874 if let Some(brace) = after.find('{') {
875 features.return_type = Some(after[..brace].trim().to_string());
876 } else if let Some(semi) = after.find(';') {
877 features.return_type = Some(after[..semi].trim().to_string());
878 }
879 }
880
881 if features.item_kind == "fn"
883 && let Some(open) = code.find('(')
884 && let Some(close) = code.find(')')
885 && close > open
886 {
887 let params = &code[open + 1..close];
888 for param in params.split(',') {
889 let param = param.trim();
890 if param == "&self" || param == "&mut self" || param == "self" {
891 features.param_types.push(param.to_string());
892 } else if let Some(colon) = param.find(':') {
893 let ty = param[colon + 1..].trim().to_string();
894 if !ty.is_empty() {
895 features.param_types.push(ty);
896 }
897 }
898 }
899 }
900
901 if let Some(where_idx) = code.find("where") {
903 let after = &code[where_idx + 5..];
904 let end = after.find('{').unwrap_or(after.len());
905 let clause = &after[..end];
906 for bound in clause.split(',') {
907 let bound = bound.trim();
908 if !bound.is_empty() {
909 features.trait_bounds.push(bound.to_string());
910 }
911 }
912 }
913
914 for line in code.lines() {
916 let trimmed = line.trim();
917 if trimmed.starts_with("#[") {
918 features.attributes.push(trimmed.to_string());
919 }
920 }
921
922 features
923}
924
925pub struct AstEmbeddingProvider {
930 text_provider: OnnxEmbeddingProvider,
932 structural_dims: usize,
934 total_dims: usize,
936 structural_weight: f32,
938}
939
940impl AstEmbeddingProvider {
941 pub const DEFAULT_TOTAL_DIMS: usize = 256;
943 pub const DEFAULT_STRUCTURAL_DIMS: usize = 64;
945 pub const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.3;
947
948 pub fn new(model_path: impl Into<PathBuf>) -> Self {
950 Self {
951 text_provider: OnnxEmbeddingProvider::with_config(
952 model_path,
953 Self::DEFAULT_TOTAL_DIMS - Self::DEFAULT_STRUCTURAL_DIMS,
954 OnnxEmbeddingProvider::DEFAULT_MAX_TOKENS,
955 ),
956 structural_dims: Self::DEFAULT_STRUCTURAL_DIMS,
957 total_dims: Self::DEFAULT_TOTAL_DIMS,
958 structural_weight: Self::DEFAULT_STRUCTURAL_WEIGHT,
959 }
960 }
961
962 pub fn with_config(
964 model_path: impl Into<PathBuf>,
965 total_dims: usize,
966 structural_dims: usize,
967 structural_weight: f32,
968 ) -> Self {
969 assert!(
970 structural_dims < total_dims,
971 "structural_dims must be less than total_dims"
972 );
973 let text_dims = total_dims - structural_dims;
974 Self {
975 text_provider: OnnxEmbeddingProvider::with_config(
976 model_path,
977 text_dims,
978 OnnxEmbeddingProvider::DEFAULT_MAX_TOKENS,
979 ),
980 structural_dims,
981 total_dims,
982 structural_weight: structural_weight.clamp(0.0, 1.0),
983 }
984 }
985
986 pub fn total_dims(&self) -> usize {
988 self.total_dims
989 }
990
991 pub fn structural_weight(&self) -> f32 {
993 self.structural_weight
994 }
995
996 fn encode_structural(&self, features: &RustCodeFeatures) -> Vec<f32> {
998 let dims = self.structural_dims;
999 let mut vec = vec![0.0f32; dims];
1000
1001 let mut write_hash = |label: &str, offset: usize, slots: usize| {
1003 let mut hasher = Sha256::new();
1004 hasher.update(label.as_bytes());
1005 let hash = hasher.finalize();
1006 for (j, &byte) in hash.iter().enumerate().take(slots.min(32)) {
1007 let dim = (offset + j) % dims;
1008 vec[dim] += (byte as f32 / 128.0) - 1.0;
1009 }
1010 };
1011
1012 write_hash(&format!("kind:{}", features.item_kind), 0, 8);
1014
1015 write_hash(&format!("vis:{}", features.visibility), 8, 6);
1017
1018 if features.is_async {
1020 write_hash("async:true", 14, 4);
1021 }
1022
1023 if features.is_generic {
1025 write_hash("generic:true", 18, 4);
1026 }
1027
1028 if let Some(ref rt) = features.return_type {
1030 write_hash(&format!("ret:{rt}"), 22, 8);
1031 }
1032
1033 for (i, pt) in features.param_types.iter().enumerate() {
1035 write_hash(&format!("param{i}:{pt}"), 30 + i * 6, 6);
1036 }
1037
1038 for (i, attr) in features.attributes.iter().enumerate() {
1040 write_hash(&format!("attr{i}:{attr}"), 48 + i * 4, 4);
1041 }
1042
1043 l2_normalize(&mut vec);
1044 vec
1045 }
1046
1047 fn hybrid_embed(&self, code: &str) -> Vec<f32> {
1049 let features = extract_rust_features(code);
1050 let structural = self.encode_structural(&features);
1051 let text = self.text_provider.hash_embed(code);
1052
1053 let w_s = self.structural_weight;
1054 let w_t = 1.0 - w_s;
1055
1056 let mut combined = Vec::with_capacity(self.total_dims);
1058 for val in &structural {
1059 combined.push(val * w_s);
1060 }
1061 for val in &text {
1062 combined.push(val * w_t);
1063 }
1064
1065 combined.truncate(self.total_dims);
1067 while combined.len() < self.total_dims {
1068 combined.push(0.0);
1069 }
1070
1071 l2_normalize(&mut combined);
1072 combined
1073 }
1074}
1075
1076#[async_trait]
1077impl EmbeddingProvider for AstEmbeddingProvider {
1078 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
1079 Ok(self.hybrid_embed(text))
1080 }
1081
1082 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
1083 Ok(texts.iter().map(|t| self.hybrid_embed(t)).collect())
1084 }
1085
1086 fn dimensions(&self) -> usize {
1087 self.total_dims
1088 }
1089
1090 fn model_name(&self) -> &str {
1091 "ast-aware-hybrid"
1092 }
1093}
1094
1095#[cfg(test)]
1100mod tests {
1101 use super::*;
1102
1103 fn vec_magnitude(v: &[f32]) -> f32 {
1106 v.iter().map(|x| x * x).sum::<f32>().sqrt()
1107 }
1108
1109 #[test]
1114 fn onnx_construction_default() {
1115 let p = OnnxEmbeddingProvider::new("/nonexistent/model.onnx");
1116 assert_eq!(p.dimensions(), 384);
1117 assert!(!p.is_runtime_available());
1118 assert!(p.model_name().contains("fallback"));
1119 }
1120
1121 #[test]
1122 fn onnx_construction_custom() {
1123 let p = OnnxEmbeddingProvider::with_config("/tmp/model.onnx", 128, 64);
1124 assert_eq!(p.dimensions(), 128);
1125 assert_eq!(p.max_tokens(), 64);
1126 }
1127
1128 #[tokio::test]
1129 async fn onnx_embed_returns_correct_dimensions() {
1130 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1131 let vec = p.embed("hello world").await.unwrap();
1132 assert_eq!(vec.len(), 384);
1133 }
1134
1135 #[tokio::test]
1136 async fn onnx_embed_deterministic() {
1137 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1138 let v1 = p.embed("test input").await.unwrap();
1139 let v2 = p.embed("test input").await.unwrap();
1140 assert_eq!(v1, v2);
1141 }
1142
1143 #[tokio::test]
1144 async fn onnx_embed_different_inputs_differ() {
1145 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1146 let v1 = p.embed("alpha").await.unwrap();
1147 let v2 = p.embed("beta").await.unwrap();
1148 assert_ne!(v1, v2);
1149 }
1150
1151 #[tokio::test]
1152 async fn onnx_embed_l2_normalized() {
1153 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1154 let vec = p.embed("normalisation check").await.unwrap();
1155 let mag = vec_magnitude(&vec);
1156 assert!((mag - 1.0).abs() < 0.01, "magnitude = {mag}, expected ~1.0");
1157 }
1158
1159 #[tokio::test]
1160 async fn onnx_embed_batch() {
1161 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1162 let results = p.embed_batch(&["a", "b", "c"]).await.unwrap();
1163 assert_eq!(results.len(), 3);
1164 for v in &results {
1165 assert_eq!(v.len(), 384);
1166 }
1167 }
1168
1169 #[tokio::test]
1170 async fn onnx_similar_inputs_high_cosine() {
1171 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1172 let v1 = p.embed("the quick brown fox").await.unwrap();
1173 let v2 = p.embed("the quick brown dog").await.unwrap();
1174 let sim = cosine_similarity(&v1, &v2);
1175 assert!(sim > 0.5, "similar inputs cosine = {sim}, expected > 0.5");
1176 }
1177
1178 #[tokio::test]
1179 async fn onnx_empty_input_returns_zero_vector() {
1180 let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1181 let vec = p.embed("").await.unwrap();
1182 assert_eq!(vec.len(), 384);
1183 assert!(vec.iter().all(|x| *x == 0.0));
1184 }
1185
1186 #[test]
1191 fn sentence_construction() {
1192 let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1193 assert_eq!(p.dimensions(), 384);
1194 assert_eq!(p.max_tokens(), 512);
1195 assert!(p.split_sentences());
1196 assert_eq!(p.model_name(), "sentence-transformer");
1197 }
1198
1199 #[tokio::test]
1200 async fn sentence_embed_returns_correct_dimensions() {
1201 let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1202 let vec = p.embed("This is a test paragraph with enough words.").await.unwrap();
1203 assert_eq!(vec.len(), 384);
1204 }
1205
1206 #[tokio::test]
1207 async fn sentence_embed_l2_normalized() {
1208 let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1209 let vec = p.embed("Testing normalisation of sentence embeddings here.").await.unwrap();
1210 let mag = vec_magnitude(&vec);
1211 assert!((mag - 1.0).abs() < 0.01, "magnitude = {mag}, expected ~1.0");
1212 }
1213
1214 #[tokio::test]
1215 async fn sentence_embed_batch() {
1216 let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1217 let results = p
1218 .embed_batch(&[
1219 "First paragraph with a decent amount of words in it.",
1220 "Second paragraph also has a reasonable length for testing.",
1221 ])
1222 .await
1223 .unwrap();
1224 assert_eq!(results.len(), 2);
1225 for v in &results {
1226 assert_eq!(v.len(), 384);
1227 }
1228 }
1229
1230 #[tokio::test]
1231 async fn sentence_similar_inputs_positive_cosine() {
1232 let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1233 let v1 = p.embed("The kernel boots up the system and runs all the services correctly.").await.unwrap();
1234 let v2 = p.embed("The kernel boots up the system and runs all the services properly.").await.unwrap();
1235 let v3 = p.embed("Quantum chromodynamics explains the strong interaction between quarks.").await.unwrap();
1236 let sim_similar = cosine_similarity(&v1, &v2);
1237 let sim_different = cosine_similarity(&v1, &v3);
1238 assert!(
1239 sim_similar > sim_different,
1240 "similar ({sim_similar}) should be closer than different ({sim_different})"
1241 );
1242 }
1243
1244 #[test]
1245 fn preprocess_markdown_strips_headers() {
1246 let md = "# Title\nSome text.\n## Subtitle\nMore text.";
1247 let result = preprocess_markdown(md);
1248 assert!(!result.contains("Title"));
1249 assert!(!result.contains("Subtitle"));
1250 assert!(result.contains("Some text."));
1251 assert!(result.contains("More text."));
1252 }
1253
1254 #[test]
1255 fn preprocess_markdown_strips_code_fences() {
1256 let md = "Before.\n```rust\nlet x = 1;\n```\nAfter.";
1257 let result = preprocess_markdown(md);
1258 assert!(!result.contains("```"));
1259 assert!(result.contains("Before."));
1261 assert!(result.contains("After."));
1262 }
1263
1264 #[test]
1265 fn preprocess_markdown_strips_tables() {
1266 let md = "Intro.\n| Col1 | Col2 |\n|------|------|\n| A | B |\nOutro.";
1267 let result = preprocess_markdown(md);
1268 assert!(!result.contains("Col1"));
1269 assert!(result.contains("Intro."));
1270 assert!(result.contains("Outro."));
1271 }
1272
1273 #[test]
1274 fn preprocess_markdown_strips_checklists() {
1275 let md = "Text here.\n- [x] Done item\n- [ ] Todo item\nMore text.";
1276 let result = preprocess_markdown(md);
1277 assert!(!result.contains("Done item"));
1278 assert!(result.contains("Text here."));
1279 }
1280
1281 #[test]
1282 fn split_sentences_basic() {
1283 let text = "First sentence here. Second sentence here. Third.";
1284 let sentences = split_sentences(text);
1285 assert_eq!(sentences.len(), 2);
1287 assert!(sentences[0].contains("First"));
1288 assert!(sentences[1].contains("Second"));
1289 }
1290
1291 #[test]
1292 fn split_sentences_newlines() {
1293 let text = "Line one is long enough.\nLine two is also long enough.";
1294 let sentences = split_sentences(text);
1295 assert_eq!(sentences.len(), 2);
1296 }
1297
1298 #[test]
1303 fn ast_construction_default() {
1304 let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1305 assert_eq!(p.dimensions(), 256);
1306 assert_eq!(p.total_dims(), 256);
1307 assert!((p.structural_weight() - 0.3).abs() < 0.001);
1308 assert_eq!(p.model_name(), "ast-aware-hybrid");
1309 }
1310
1311 #[tokio::test]
1312 async fn ast_embed_returns_correct_dimensions() {
1313 let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1314 let vec = p.embed("pub fn hello() -> String { }").await.unwrap();
1315 assert_eq!(vec.len(), 256);
1316 }
1317
1318 #[tokio::test]
1319 async fn ast_embed_l2_normalized() {
1320 let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1321 let vec = p.embed("pub fn hello() -> String { }").await.unwrap();
1322 let mag = vec_magnitude(&vec);
1323 assert!((mag - 1.0).abs() < 0.01, "magnitude = {mag}, expected ~1.0");
1324 }
1325
1326 #[tokio::test]
1327 async fn ast_embed_batch() {
1328 let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1329 let results = p
1330 .embed_batch(&["fn a() {}", "fn b() {}", "struct C {}"])
1331 .await
1332 .unwrap();
1333 assert_eq!(results.len(), 3);
1334 for v in &results {
1335 assert_eq!(v.len(), 256);
1336 }
1337 }
1338
1339 #[tokio::test]
1340 async fn ast_embed_different_inputs_differ() {
1341 let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1342 let v1 = p.embed("pub fn alpha() -> u32 {}").await.unwrap();
1343 let v2 = p.embed("struct Beta { x: f64 }").await.unwrap();
1344 assert_ne!(v1, v2);
1345 }
1346
1347 #[tokio::test]
1348 async fn ast_structural_similarity_same_signature() {
1349 let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1352 let v_foo = p
1353 .embed("pub async fn foo(&self, x: u32) -> Result<(), Error> {}")
1354 .await
1355 .unwrap();
1356 let v_bar = p
1357 .embed("pub async fn bar(&self, x: u32) -> Result<(), Error> {}")
1358 .await
1359 .unwrap();
1360 let v_struct = p.embed("pub struct Baz { count: usize }").await.unwrap();
1361
1362 let sim_fns = cosine_similarity(&v_foo, &v_bar);
1363 let sim_fn_struct = cosine_similarity(&v_foo, &v_struct);
1364 assert!(
1365 sim_fns > sim_fn_struct,
1366 "same-signature fns ({sim_fns}) should be more similar than fn-vs-struct ({sim_fn_struct})"
1367 );
1368 }
1369
1370 #[test]
1375 fn rust_features_pub_async_fn() {
1376 let code = r#"
1377#[test]
1378pub async fn process_batch(&self, items: Vec<Item>) -> Result<(), Error> {
1379 // body
1380}
1381"#;
1382 let f = extract_rust_features(code);
1383 assert_eq!(f.item_kind, "fn");
1384 assert_eq!(f.visibility, "pub");
1385 assert!(f.is_async);
1386 assert!(f.is_generic);
1387 assert_eq!(f.return_type.as_deref(), Some("Result<(), Error>"));
1388 assert!(f.attributes.contains(&"#[test]".to_string()));
1389 assert!(f.param_types.contains(&"&self".to_string()));
1390 assert!(f.param_types.iter().any(|p| p.contains("Vec<Item>")));
1391 }
1392
1393 #[test]
1394 fn rust_features_struct() {
1395 let code = "pub struct Config { pub name: String, pub value: u64 }";
1396 let f = extract_rust_features(code);
1397 assert_eq!(f.item_kind, "struct");
1398 assert_eq!(f.visibility, "pub");
1399 assert!(!f.is_async);
1400 assert!(!f.is_generic); assert!(f.return_type.is_none());
1402 }
1403
1404 #[test]
1405 fn rust_features_private_fn() {
1406 let code = "fn helper(x: &str) -> bool { true }";
1407 let f = extract_rust_features(code);
1408 assert_eq!(f.item_kind, "fn");
1409 assert_eq!(f.visibility, "private");
1410 assert!(!f.is_async);
1411 assert_eq!(f.return_type.as_deref(), Some("bool"));
1412 assert!(f.param_types.iter().any(|p| p.contains("&str")));
1413 }
1414
1415 #[test]
1416 fn rust_features_enum() {
1417 let code = "pub enum Status { Active, Inactive, Pending }";
1418 let f = extract_rust_features(code);
1419 assert_eq!(f.item_kind, "enum");
1420 assert_eq!(f.visibility, "pub");
1421 }
1422
1423 #[test]
1424 fn rust_features_trait() {
1425 let code = "pub trait Displayable { fn display(&self) -> String; }";
1426 let f = extract_rust_features(code);
1427 assert_eq!(f.item_kind, "trait");
1428 assert_eq!(f.visibility, "pub");
1429 }
1430
1431 #[test]
1432 fn rust_features_impl_block() {
1433 let code = "impl MyStruct { fn new() -> Self { Self {} } }";
1434 let f = extract_rust_features(code);
1435 assert_eq!(f.item_kind, "fn");
1437 }
1438
1439 #[test]
1440 fn rust_features_where_clause() {
1441 let code = "pub fn serialize<T>(val: T) -> String where T: Serialize + Debug { }";
1442 let f = extract_rust_features(code);
1443 assert!(f.is_generic);
1444 assert!(!f.trait_bounds.is_empty());
1445 assert!(f.trait_bounds.iter().any(|b| b.contains("Serialize")));
1446 }
1447
1448 #[test]
1449 fn rust_features_pub_crate() {
1450 let code = "pub(crate) fn internal_helper() {}";
1451 let f = extract_rust_features(code);
1452 assert_eq!(f.visibility, "pub(crate)");
1453 }
1454
1455 #[test]
1456 fn rust_features_multiple_attributes() {
1457 let code = "#[cfg(test)]\n#[allow(dead_code)]\nfn test_fn() {}";
1458 let f = extract_rust_features(code);
1459 assert_eq!(f.attributes.len(), 2);
1460 assert!(f.attributes.contains(&"#[cfg(test)]".to_string()));
1461 assert!(f.attributes.contains(&"#[allow(dead_code)]".to_string()));
1462 }
1463
1464 #[test]
1469 fn simple_tokenize_basic() {
1470 let tokens = simple_tokenize("Hello World! Foo-bar", 10);
1471 assert_eq!(tokens, vec!["hello", "world", "foobar"]);
1472 }
1473
1474 #[test]
1475 fn simple_tokenize_max_tokens() {
1476 let tokens = simple_tokenize("a b c d e f", 3);
1477 assert_eq!(tokens.len(), 3);
1478 }
1479
1480 #[test]
1481 fn simple_tokenize_empty() {
1482 let tokens = simple_tokenize("", 10);
1483 assert!(tokens.is_empty());
1484 }
1485
1486 #[test]
1487 fn tokens_to_embedding_deterministic() {
1488 let tokens: Vec<String> = vec!["hello".into(), "world".into()];
1489 let v1 = tokens_to_embedding(&tokens, 64);
1490 let v2 = tokens_to_embedding(&tokens, 64);
1491 assert_eq!(v1, v2);
1492 }
1493
1494 #[test]
1495 fn tokens_to_embedding_normalized() {
1496 let tokens: Vec<String> = vec!["test".into()];
1497 let v = tokens_to_embedding(&tokens, 128);
1498 let mag = vec_magnitude(&v);
1499 assert!((mag - 1.0).abs() < 0.01);
1500 }
1501
1502 fn make_test_vocab() -> PathBuf {
1509 use std::fmt::Write as FmtWrite;
1510 let mut content = String::new();
1511 for i in 0..100 {
1514 writeln!(content, "[unused{}]", i).unwrap();
1515 }
1516 writeln!(content, "[UNK]").unwrap(); writeln!(content, "[CLS]").unwrap(); writeln!(content, "[SEP]").unwrap(); writeln!(content, "[MASK]").unwrap(); for i in 104..1000 {
1521 writeln!(content, "[unused{}]", i).unwrap();
1522 }
1523 let words = [
1525 "the", "a", "is", "of", "and", "to", "in", "for", "that", "it",
1526 "hello", "world", "test", "input", "embedding", "model", "token",
1527 "##s", "##ing", "##ed", "##er", "##tion", "##ly", "##ize",
1528 ".", ",", "!", "?",
1529 "quick", "brown", "fox", "dog", "cat", "rust", "code",
1530 "function", "struct", "pub", "async", "fn",
1531 ];
1532 for w in &words {
1533 writeln!(content, "{}", w).unwrap();
1534 }
1535 for i in 0..100 {
1537 writeln!(content, "extra{}", i).unwrap();
1538 }
1539
1540 let path = PathBuf::from(format!(
1541 "/tmp/clawft_test_vocab_{}.txt",
1542 std::process::id()
1543 ));
1544 std::fs::write(&path, &content).expect("failed to write test vocab");
1545 path
1546 }
1547
1548 #[test]
1549 fn wordpiece_load_valid_vocab() {
1550 let f = make_test_vocab();
1551 let tok = WordPieceTokenizer::load(&f);
1552 assert!(tok.is_some(), "should load a vocab with >1000 entries");
1553 }
1554
1555 #[test]
1556 fn wordpiece_load_missing_file() {
1557 let tok = WordPieceTokenizer::load(Path::new("/nonexistent/vocab.txt"));
1558 assert!(tok.is_none());
1559 }
1560
1561 #[test]
1562 fn wordpiece_encode_produces_cls_sep() {
1563 let f = make_test_vocab();
1564 let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(32);
1565 let (ids, mask, types) = tok.encode("hello world").unwrap();
1566 assert_eq!(ids.len(), 32, "should be padded to max_length");
1567 assert_eq!(ids[0], CLS_ID, "first token must be [CLS]");
1568 let sep_pos = ids.iter().position(|&x| x == SEP_ID);
1570 assert!(sep_pos.is_some(), "must contain [SEP]");
1571 let sep_pos = sep_pos.unwrap();
1572 assert!(sep_pos >= 2, "[SEP] should come after at least one content token");
1573 assert_eq!(mask[0], 1);
1575 assert_eq!(mask[sep_pos], 1);
1576 if sep_pos + 1 < 32 {
1577 assert_eq!(mask[sep_pos + 1], 0, "padding should have mask=0");
1578 }
1579 assert!(types.iter().all(|&t| t == 0));
1581 }
1582
1583 #[test]
1584 fn wordpiece_encode_known_tokens() {
1585 let f = make_test_vocab();
1586 let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(16);
1587 let (ids, _, _) = tok.encode("hello").unwrap();
1588 let content_ids: Vec<i64> = ids[1..].iter()
1590 .take_while(|&&x| x != SEP_ID)
1591 .cloned()
1592 .collect();
1593 assert!(!content_ids.is_empty(), "should tokenize 'hello' to at least one token");
1594 assert!(
1595 content_ids.iter().any(|&id| id != UNK_ID),
1596 "known word 'hello' should not be all [UNK]"
1597 );
1598 }
1599
1600 #[test]
1601 fn wordpiece_encode_unknown_token_uses_unk() {
1602 let f = make_test_vocab();
1603 let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(16);
1604 let (ids, _, _) = tok.encode("xyzzyplugh").unwrap();
1605 let content_ids: Vec<i64> = ids[1..].iter()
1607 .take_while(|&&x| x != SEP_ID)
1608 .cloned()
1609 .collect();
1610 assert!(
1611 content_ids.contains(&UNK_ID),
1612 "unknown word should produce [UNK] token"
1613 );
1614 }
1615
1616 #[test]
1617 fn wordpiece_encode_truncates_long_input() {
1618 let f = make_test_vocab();
1619 let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(8);
1620 let long_input = "the quick brown fox hello world test input embedding model";
1622 let (ids, mask, _) = tok.encode(long_input).unwrap();
1623 assert_eq!(ids.len(), 8, "output must be exactly max_length");
1624 assert_eq!(mask.len(), 8);
1625 assert_eq!(ids[0], CLS_ID);
1626 assert!(ids.contains(&SEP_ID));
1628 }
1629
1630 #[test]
1631 fn wordpiece_encode_empty_input() {
1632 let f = make_test_vocab();
1633 let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(16);
1634 let (ids, mask, _) = tok.encode("").unwrap();
1635 assert_eq!(ids[0], CLS_ID);
1636 assert_eq!(ids[1], SEP_ID);
1637 assert!(ids[2..].iter().all(|&x| x == PAD_ID));
1639 assert_eq!(mask[0], 1);
1640 assert_eq!(mask[1], 1);
1641 assert!(mask[2..].iter().all(|&x| x == 0));
1642 }
1643
1644 #[test]
1645 fn wordpiece_pre_tokenize_punctuation() {
1646 let f = make_test_vocab();
1647 let tok = WordPieceTokenizer::load(&f).unwrap();
1648 let words = tok.pre_tokenize("Hello, World!");
1649 assert!(words.contains(&",".to_string()));
1651 assert!(words.contains(&"!".to_string()));
1652 assert!(words.contains(&"hello".to_string()));
1653 assert!(words.contains(&"world".to_string()));
1654 }
1655
1656 #[test]
1657 fn wordpiece_subword_splitting() {
1658 let f = make_test_vocab();
1659 let tok = WordPieceTokenizer::load(&f).unwrap();
1660 let ids = tok.wordpiece_split("tokens");
1662 assert!(
1664 ids.len() >= 1,
1665 "should produce at least one subword token"
1666 );
1667 }
1668
1669 #[test]
1670 fn wordpiece_deterministic() {
1671 let f = make_test_vocab();
1672 let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(32);
1673 let (ids1, _, _) = tok.encode("the quick brown fox").unwrap();
1674 let (ids2, _, _) = tok.encode("the quick brown fox").unwrap();
1675 assert_eq!(ids1, ids2, "encoding must be deterministic");
1676 }
1677
1678 #[test]
1679 fn vocab_search_paths_finds_sibling() {
1680 let paths = vocab_search_paths(Path::new("/models/all-MiniLM-L6-v2.onnx"));
1681 assert!(paths.iter().any(|p| p.ends_with("vocab.txt")));
1682 assert!(
1683 paths.iter().any(|p| p.to_string_lossy().contains("all-MiniLM-L6-v2/vocab.txt")),
1684 "should check sibling directory: {:?}",
1685 paths
1686 );
1687 }
1688}