Skip to main content

clawft_kernel/
embedding_onnx.rs

1//! ONNX, sentence-transformer, and AST-aware embedding backends (K3c-G2).
2//!
3//! These are alternative [`EmbeddingProvider`] implementations that complement
4//! the existing [`LlmEmbeddingProvider`] and [`MockEmbeddingProvider`].
5//!
6//! - [`OnnxEmbeddingProvider`] -- local model inference (all-MiniLM-L6-v2, 384-d).
7//! - [`SentenceTransformerProvider`] -- documentation-optimised paragraph embedder.
8//! - [`AstEmbeddingProvider`] -- hybrid structural + semantic embedder for Rust code.
9
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12#[cfg(feature = "onnx-embeddings")]
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19use crate::embedding::{EmbeddingError, EmbeddingProvider, MockEmbeddingProvider};
20
21// ---------------------------------------------------------------------------
22// WordPiece tokenizer for BERT / all-MiniLM-L6-v2
23// ---------------------------------------------------------------------------
24
25/// BERT-compatible WordPiece tokenizer.
26///
27/// Loads a `vocab.txt` file (one token per line, ID = line number) and performs:
28/// 1. Lowercasing and Unicode NFD accent stripping
29/// 2. Whitespace + punctuation pre-tokenization
30/// 3. Greedy longest-match WordPiece splitting (subwords prefixed with `##`)
31/// 4. [CLS] / [SEP] framing and [PAD] / truncation to `max_length`
32///
33/// When no vocab file is available, [`WordPieceTokenizer::encode`] returns
34/// `None` so callers can fall back to hash-based tokenization.
35pub struct WordPieceTokenizer {
36    /// Token string -> vocab ID.
37    vocab: HashMap<String, i64>,
38    /// Maximum sequence length including [CLS] and [SEP].
39    max_length: usize,
40    /// Maximum characters per word before treating as [UNK].
41    max_word_chars: usize,
42}
43
44/// Special token IDs for the BERT uncased vocabulary.
45const CLS_ID: i64 = 101;
46const SEP_ID: i64 = 102;
47const UNK_ID: i64 = 100;
48const PAD_ID: i64 = 0;
49
50impl WordPieceTokenizer {
51    /// Try to load a `vocab.txt` from the given path.
52    ///
53    /// Returns `None` if the file does not exist or cannot be read.
54    pub fn load(vocab_path: &Path) -> Option<Self> {
55        let content = std::fs::read_to_string(vocab_path).ok()?;
56        let line_count = content.lines().count();
57        let mut vocab = HashMap::with_capacity(line_count);
58        for (id, line) in content.lines().enumerate() {
59            vocab.insert(line.to_string(), id as i64);
60        }
61        if vocab.len() < 1000 {
62            // Suspiciously small — probably not a real BERT vocab.
63            tracing::warn!(
64                "vocab.txt at {} has only {} entries, expected ~30k",
65                vocab_path.display(),
66                vocab.len()
67            );
68            return None;
69        }
70        tracing::info!(
71            "WordPiece vocab loaded: {} tokens from {}",
72            vocab.len(),
73            vocab_path.display()
74        );
75        Some(Self {
76            vocab,
77            max_length: 128,
78            max_word_chars: 100,
79        })
80    }
81
82    /// Create a tokenizer with a custom max sequence length.
83    pub fn with_max_length(mut self, max_length: usize) -> Self {
84        self.max_length = max_length;
85        self
86    }
87
88    /// Encode text into (input_ids, attention_mask, token_type_ids).
89    ///
90    /// All three vectors have length `self.max_length`, padded or truncated
91    /// as needed. Returns `None` if the vocab is empty.
92    pub fn encode(&self, text: &str) -> Option<(Vec<i64>, Vec<i64>, Vec<i64>)> {
93        if self.vocab.is_empty() {
94            return None;
95        }
96
97        let mut token_ids: Vec<i64> = Vec::with_capacity(self.max_length);
98        token_ids.push(CLS_ID);
99
100        // Pre-tokenize: lowercase, split on whitespace and punctuation.
101        let words = self.pre_tokenize(text);
102
103        for word in &words {
104            if token_ids.len() >= self.max_length - 1 {
105                break; // Reserve space for [SEP].
106            }
107            let sub_ids = self.wordpiece_split(word);
108            for id in sub_ids {
109                if token_ids.len() >= self.max_length - 1 {
110                    break;
111                }
112                token_ids.push(id);
113            }
114        }
115
116        token_ids.push(SEP_ID);
117
118        let seq_len = token_ids.len();
119        let mut attention_mask = vec![1i64; seq_len];
120        let mut token_type_ids = vec![0i64; seq_len];
121
122        // Pad to max_length.
123        while token_ids.len() < self.max_length {
124            token_ids.push(PAD_ID);
125            attention_mask.push(0);
126            token_type_ids.push(0);
127        }
128
129        Some((token_ids, attention_mask, token_type_ids))
130    }
131
132    /// Pre-tokenize: lowercase, split on whitespace and punctuation.
133    fn pre_tokenize(&self, text: &str) -> Vec<String> {
134        let lower = text.to_lowercase();
135        let mut words = Vec::new();
136        let mut current = String::new();
137
138        for ch in lower.chars() {
139            if ch.is_whitespace() {
140                if !current.is_empty() {
141                    words.push(std::mem::take(&mut current));
142                }
143            } else if ch.is_ascii_punctuation() || is_cjk_char(ch) {
144                // Punctuation and CJK chars become individual tokens.
145                if !current.is_empty() {
146                    words.push(std::mem::take(&mut current));
147                }
148                words.push(ch.to_string());
149            } else if is_accent_char(ch) {
150                // Strip combining marks (basic accent removal).
151                continue;
152            } else if ch.is_control() {
153                continue;
154            } else {
155                current.push(ch);
156            }
157        }
158        if !current.is_empty() {
159            words.push(current);
160        }
161        words
162    }
163
164    /// WordPiece greedy longest-match splitting for a single pre-token.
165    fn wordpiece_split(&self, word: &str) -> Vec<i64> {
166        if word.len() > self.max_word_chars {
167            return vec![UNK_ID];
168        }
169
170        let chars: Vec<char> = word.chars().collect();
171        let mut ids = Vec::new();
172        let mut start = 0;
173
174        while start < chars.len() {
175            let mut end = chars.len();
176            let mut found = false;
177
178            while start < end {
179                let substr: String = if start == 0 {
180                    chars[start..end].iter().collect()
181                } else {
182                    format!("##{}", chars[start..end].iter().collect::<String>())
183                };
184
185                if self.vocab.contains_key(&substr) {
186                    ids.push(self.vocab[&substr]);
187                    found = true;
188                    start = end;
189                    break;
190                }
191                end -= 1;
192            }
193
194            if !found {
195                ids.push(UNK_ID);
196                start += 1;
197            }
198        }
199
200        ids
201    }
202}
203
204/// Check if a character is in the CJK Unified Ideographs range.
205fn is_cjk_char(ch: char) -> bool {
206    let cp = ch as u32;
207    matches!(cp,
208        0x4E00..=0x9FFF
209        | 0x3400..=0x4DBF
210        | 0x20000..=0x2A6DF
211        | 0x2A700..=0x2B73F
212        | 0x2B740..=0x2B81F
213        | 0x2B820..=0x2CEAF
214        | 0xF900..=0xFAFF
215        | 0x2F800..=0x2FA1F
216    )
217}
218
219/// Check if a character is a Unicode combining mark (accent).
220fn is_accent_char(ch: char) -> bool {
221    let cp = ch as u32;
222    matches!(cp, 0x0300..=0x036F | 0x1AB0..=0x1AFF | 0x1DC0..=0x1DFF | 0xFE20..=0xFE2F)
223}
224
225/// Search paths for the vocab.txt file alongside an ONNX model.
226///
227/// Looks for `vocab.txt` in the same directory as the model, and in a
228/// sibling directory named after the model (e.g., `all-MiniLM-L6-v2/vocab.txt`).
229#[cfg_attr(not(feature = "onnx-embeddings"), allow(dead_code))]
230fn vocab_search_paths(model_path: &Path) -> Vec<PathBuf> {
231    let mut paths = Vec::new();
232
233    if let Some(parent) = model_path.parent() {
234        // Same directory as the model.
235        paths.push(parent.join("vocab.txt"));
236
237        // Sibling directory named after the model.
238        if let Some(stem) = model_path.file_stem() {
239            paths.push(parent.join(stem).join("vocab.txt"));
240        }
241    }
242
243    // Also check standard WeftOS model paths.
244    let model_dir_name = "all-MiniLM-L6-v2";
245    paths.push(PathBuf::from(format!(".weftos/models/{model_dir_name}/vocab.txt")));
246    if let Ok(home) = std::env::var("HOME") {
247        paths.push(PathBuf::from(format!("{home}/.weftos/models/{model_dir_name}/vocab.txt")));
248    }
249    if let Ok(env_dir) = std::env::var("WEFTOS_VOCAB_PATH") {
250        paths.push(PathBuf::from(env_dir));
251    }
252
253    paths
254}
255
256// ---------------------------------------------------------------------------
257// Shared tokenisation helpers
258// ---------------------------------------------------------------------------
259
260/// Simple whitespace tokeniser that lowercases and strips non-alphanumeric chars.
261fn simple_tokenize(text: &str, max_tokens: usize) -> Vec<String> {
262    text.to_lowercase()
263        .split_whitespace()
264        .take(max_tokens)
265        .map(|s| s.chars().filter(|c| c.is_alphanumeric()).collect::<String>())
266        .filter(|s| !s.is_empty())
267        .collect()
268}
269
270/// Convert a token sequence into a fixed-size embedding via position-weighted
271/// SHA-256 hashing.  Produces consistent, deterministic vectors per token set.
272fn tokens_to_embedding(tokens: &[String], dims: usize) -> Vec<f32> {
273    let mut embedding = vec![0.0f32; dims];
274
275    for (i, token) in tokens.iter().enumerate() {
276        let mut hasher = Sha256::new();
277        hasher.update(token.as_bytes());
278        hasher.update((i as u32).to_le_bytes());
279        let hash = hasher.finalize();
280
281        // Scatter hash bytes across embedding dimensions.
282        for (j, &byte) in hash.iter().enumerate() {
283            let dim = (j + i * 32) % dims;
284            let val = (byte as f32 / 128.0) - 1.0; // [-1, 1]
285            embedding[dim] += val / (tokens.len() as f32).sqrt();
286        }
287    }
288
289    l2_normalize(&mut embedding);
290    embedding
291}
292
293/// In-place L2 normalisation.
294fn l2_normalize(vec: &mut [f32]) {
295    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
296    if norm > 0.0 {
297        vec.iter_mut().for_each(|x| *x /= norm);
298    }
299}
300
301/// Cosine similarity between two equal-length vectors.
302#[cfg(test)]
303fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
304    debug_assert_eq!(a.len(), b.len());
305    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
306    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
307    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
308    if norm_a == 0.0 || norm_b == 0.0 {
309        return 0.0;
310    }
311    dot / (norm_a * norm_b)
312}
313
314// =========================================================================
315// Backend 1: OnnxEmbeddingProvider
316// =========================================================================
317
318/// ONNX Runtime embedding provider.
319///
320/// Uses all-MiniLM-L6-v2 (384 dimensions) for semantic text embeddings.
321/// When the ONNX runtime is not available (no `onnx-embeddings` feature or
322/// missing model file), falls back to a position-aware token-hashing approach
323/// that is architecturally compatible with real inference.
324pub struct OnnxEmbeddingProvider {
325    /// Path to the ONNX model file.
326    model_path: PathBuf,
327    /// Output dimensions.
328    dimensions: usize,
329    /// Model name for identification.
330    model_name: String,
331    /// Whether the ONNX runtime session is available.
332    runtime_available: bool,
333    /// Max input tokens.
334    max_tokens: usize,
335    /// Fallback provider used when runtime is not available.
336    #[allow(dead_code)]
337    fallback: MockEmbeddingProvider,
338    /// WordPiece tokenizer loaded from vocab.txt (if available).
339    /// When `None`, ONNX inference falls back to hash-based token IDs.
340    #[cfg(feature = "onnx-embeddings")]
341    tokenizer: Option<WordPieceTokenizer>,
342    /// ONNX runtime session (only present when `onnx-embeddings` feature is active
343    /// and model was loaded successfully).
344    #[cfg(feature = "onnx-embeddings")]
345    session: Option<Arc<ort::Session>>,
346}
347
348impl OnnxEmbeddingProvider {
349    /// Default output dimensionality (all-MiniLM-L6-v2).
350    pub const DEFAULT_DIMS: usize = 384;
351    /// Default maximum input tokens for code snippets.
352    pub const DEFAULT_MAX_TOKENS: usize = 128;
353    /// Model identifier.
354    pub const MODEL_NAME: &'static str = "all-MiniLM-L6-v2";
355
356    /// Create a new ONNX provider pointing at the given model path.
357    ///
358    /// If the model file does not exist or the `onnx-embeddings` feature is
359    /// disabled, the provider transparently falls back to token-hashing.
360    pub fn new(model_path: impl Into<PathBuf>) -> Self {
361        let model_path = model_path.into();
362        #[cfg(feature = "onnx-embeddings")]
363        let session = Self::try_load_session(&model_path);
364        #[cfg(feature = "onnx-embeddings")]
365        let runtime_available = session.is_some();
366        #[cfg(not(feature = "onnx-embeddings"))]
367        let runtime_available = false;
368        #[cfg(feature = "onnx-embeddings")]
369        let tokenizer = Self::try_load_tokenizer(&model_path, Self::DEFAULT_MAX_TOKENS);
370
371        Self {
372            model_name: if runtime_available {
373                Self::MODEL_NAME.to_string()
374            } else {
375                format!("{}-hash-fallback", Self::MODEL_NAME)
376            },
377            model_path,
378            dimensions: Self::DEFAULT_DIMS,
379            runtime_available,
380            max_tokens: Self::DEFAULT_MAX_TOKENS,
381            fallback: MockEmbeddingProvider::new(Self::DEFAULT_DIMS),
382            #[cfg(feature = "onnx-embeddings")]
383            tokenizer,
384            #[cfg(feature = "onnx-embeddings")]
385            session,
386        }
387    }
388
389    /// Create a provider with custom dimensions and max tokens.
390    pub fn with_config(
391        model_path: impl Into<PathBuf>,
392        dimensions: usize,
393        max_tokens: usize,
394    ) -> Self {
395        let model_path = model_path.into();
396        #[cfg(feature = "onnx-embeddings")]
397        let session = Self::try_load_session(&model_path);
398        #[cfg(feature = "onnx-embeddings")]
399        let runtime_available = session.is_some();
400        #[cfg(not(feature = "onnx-embeddings"))]
401        let runtime_available = false;
402        #[cfg(feature = "onnx-embeddings")]
403        let tokenizer = Self::try_load_tokenizer(&model_path, max_tokens);
404
405        Self {
406            model_name: if runtime_available {
407                Self::MODEL_NAME.to_string()
408            } else {
409                format!("{}-hash-fallback", Self::MODEL_NAME)
410            },
411            model_path,
412            dimensions,
413            runtime_available,
414            max_tokens,
415            fallback: MockEmbeddingProvider::new(dimensions),
416            #[cfg(feature = "onnx-embeddings")]
417            tokenizer,
418            #[cfg(feature = "onnx-embeddings")]
419            session,
420        }
421    }
422
423    /// Attempt to load a WordPiece tokenizer from vocab.txt near the model.
424    #[cfg(feature = "onnx-embeddings")]
425    fn try_load_tokenizer(model_path: &Path, max_tokens: usize) -> Option<WordPieceTokenizer> {
426        for path in vocab_search_paths(model_path) {
427            if path.exists() {
428                if let Some(tok) = WordPieceTokenizer::load(&path) {
429                    // max_tokens here refers to the token count limit; for
430                    // WordPiece the max_length (including [CLS]/[SEP]) is
431                    // max_tokens + 2, capped at 512 for BERT models.
432                    let max_len = (max_tokens + 2).min(512);
433                    return Some(tok.with_max_length(max_len));
434                }
435            }
436        }
437        tracing::debug!(
438            "No vocab.txt found for WordPiece tokenizer near {}; \
439             ONNX inference will use hash-based token IDs (degraded quality)",
440            model_path.display()
441        );
442        None
443    }
444
445    /// Attempt to load an ONNX runtime session from the model path.
446    #[cfg(feature = "onnx-embeddings")]
447    fn try_load_session(model_path: &PathBuf) -> Option<Arc<ort::Session>> {
448        if !model_path.exists() {
449            tracing::debug!("ONNX model not found at {}, using hash fallback", model_path.display());
450            return None;
451        }
452        match ort::Session::builder()
453            .and_then(|builder| builder.commit_from_file(model_path))
454        {
455            Ok(session) => {
456                tracing::info!("ONNX session loaded from {}", model_path.display());
457                Some(Arc::new(session))
458            }
459            Err(e) => {
460                tracing::warn!("Failed to load ONNX session: {e}, using hash fallback");
461                None
462            }
463        }
464    }
465
466    /// Whether the real ONNX runtime is active (vs. fallback).
467    pub fn is_runtime_available(&self) -> bool {
468        self.runtime_available
469    }
470
471    /// Path to the configured model file.
472    pub fn model_path(&self) -> &PathBuf {
473        &self.model_path
474    }
475
476    /// Maximum input token count.
477    pub fn max_tokens(&self) -> usize {
478        self.max_tokens
479    }
480
481    /// Embed using the token-hashing fallback.
482    fn hash_embed(&self, text: &str) -> Vec<f32> {
483        let tokens = simple_tokenize(text, self.max_tokens);
484        if tokens.is_empty() {
485            // Return zero vector for empty input.
486            return vec![0.0f32; self.dimensions];
487        }
488        tokens_to_embedding(&tokens, self.dimensions)
489    }
490
491    /// Run real ONNX inference on the input text.
492    ///
493    /// Uses the WordPiece tokenizer (if vocab.txt was loaded) to produce
494    /// correct token IDs for the all-MiniLM-L6-v2 model. Falls back to
495    /// hash-based token IDs when no vocab is available (degraded quality
496    /// but still structurally valid).
497    ///
498    /// Builds input tensors (input_ids, attention_mask, token_type_ids),
499    /// runs the model, and mean-pools the last hidden state (masked by
500    /// attention_mask) to produce a fixed-size embedding.
501    #[cfg(feature = "onnx-embeddings")]
502    fn onnx_embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
503        use ndarray::Array2;
504
505        let session = self.session.as_ref().ok_or_else(|| {
506            EmbeddingError::BackendError("ONNX session not loaded".to_string())
507        })?;
508
509        // Tokenize using WordPiece if available, otherwise fall back to hashing.
510        let (input_ids, attention_mask, token_type_ids) = if let Some(ref tokenizer) = self.tokenizer {
511            tokenizer.encode(text).ok_or_else(|| {
512                EmbeddingError::BackendError("WordPiece tokenizer returned None".to_string())
513            })?
514        } else {
515            // Legacy hash-based fallback (produces structurally valid but
516            // semantically meaningless token IDs).
517            tracing::warn_once!(
518                "ONNX inference without WordPiece vocab — embeddings will not be semantic"
519            );
520            let tokens = simple_tokenize(text, self.max_tokens);
521            let seq_len = tokens.len().max(1) + 2; // +2 for [CLS] and [SEP]
522
523            let mut ids = vec![CLS_ID];
524            for token in &tokens {
525                let mut hasher = Sha256::new();
526                hasher.update(token.as_bytes());
527                let hash = hasher.finalize();
528                let id = 1000
529                    + (u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]]) % 29000)
530                        as i64;
531                ids.push(id);
532            }
533            ids.push(SEP_ID);
534
535            let mask = vec![1i64; seq_len];
536            let types = vec![0i64; seq_len];
537            (ids, mask, types)
538        };
539
540        let seq_len = input_ids.len();
541
542        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
543            .map_err(|e| EmbeddingError::BackendError(format!("shape error: {e}")))?;
544        let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask.clone())
545            .map_err(|e| EmbeddingError::BackendError(format!("shape error: {e}")))?;
546        let token_type_ids_arr = Array2::from_shape_vec((1, seq_len), token_type_ids)
547            .map_err(|e| EmbeddingError::BackendError(format!("shape error: {e}")))?;
548
549        let inputs = ort::inputs![
550            "input_ids" => input_ids_arr,
551            "attention_mask" => attention_mask_arr,
552            "token_type_ids" => token_type_ids_arr,
553        ].map_err(|e| EmbeddingError::BackendError(format!("input error: {e}")))?;
554
555        let outputs = session.run(inputs)
556            .map_err(|e| EmbeddingError::BackendError(format!("inference error: {e}")))?;
557
558        // Extract the last_hidden_state output and mean-pool across the sequence.
559        // Output shape: (1, seq_len, hidden_dim)
560        let output_tensor = outputs.get("last_hidden_state")
561            .or_else(|| outputs.iter().next().map(|(_, v)| v))
562            .ok_or_else(|| EmbeddingError::BackendError("no output tensor".to_string()))?;
563
564        let tensor = output_tensor
565            .try_extract_tensor::<f32>()
566            .map_err(|e| EmbeddingError::BackendError(format!("extract error: {e}")))?;
567
568        let shape = tensor.shape();
569        if shape.len() < 2 {
570            return Err(EmbeddingError::BackendError(
571                format!("unexpected output shape: {shape:?}"),
572            ));
573        }
574        let hidden_dim = *shape.last().unwrap();
575        let seq = shape[1];
576
577        // Attention-masked mean pooling: only average over non-padding tokens.
578        let mut embedding = vec![0.0f32; hidden_dim];
579        let data = tensor.as_slice().ok_or_else(|| {
580            EmbeddingError::BackendError("tensor not contiguous".to_string())
581        })?;
582
583        let mut active_count: f32 = 0.0;
584        for s in 0..seq {
585            let mask_val = if s < attention_mask.len() {
586                attention_mask[s] as f32
587            } else {
588                0.0
589            };
590            if mask_val > 0.0 {
591                for d in 0..hidden_dim {
592                    embedding[d] += data[s * hidden_dim + d];
593                }
594                active_count += 1.0;
595            }
596        }
597        if active_count > 0.0 {
598            for val in &mut embedding {
599                *val /= active_count;
600            }
601        }
602
603        // L2 normalize.
604        l2_normalize(&mut embedding);
605
606        // Truncate or pad to expected dimensions.
607        embedding.truncate(self.dimensions);
608        while embedding.len() < self.dimensions {
609            embedding.push(0.0);
610        }
611
612        Ok(embedding)
613    }
614}
615
616#[async_trait]
617impl EmbeddingProvider for OnnxEmbeddingProvider {
618    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
619        #[cfg(feature = "onnx-embeddings")]
620        if self.runtime_available {
621            return self.onnx_embed(text);
622        }
623        Ok(self.hash_embed(text))
624    }
625
626    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
627        let mut results = Vec::with_capacity(texts.len());
628        for text in texts {
629            results.push(self.embed(text).await?);
630        }
631        Ok(results)
632    }
633
634    fn dimensions(&self) -> usize {
635        self.dimensions
636    }
637
638    fn model_name(&self) -> &str {
639        &self.model_name
640    }
641}
642
643// =========================================================================
644// Backend 2: SentenceTransformerProvider
645// =========================================================================
646
647/// Sentence-transformer embedding provider for documentation.
648///
649/// Optimised for natural language paragraphs rather than code.  Pre-processes
650/// markdown, splits into sentences, embeds each, and averages (mean pooling).
651pub struct SentenceTransformerProvider {
652    /// Base ONNX provider (reuses the same infrastructure).
653    base: OnnxEmbeddingProvider,
654    /// Max token length (longer for docs than code).
655    max_tokens: usize,
656    /// Whether sentence splitting is enabled.
657    split_sentences: bool,
658}
659
660impl SentenceTransformerProvider {
661    /// Default max tokens for documentation (longer context than code).
662    pub const DEFAULT_MAX_TOKENS: usize = 512;
663
664    /// Create a new sentence-transformer provider.
665    pub fn new(model_path: impl Into<PathBuf>) -> Self {
666        Self {
667            base: OnnxEmbeddingProvider::with_config(
668                model_path,
669                OnnxEmbeddingProvider::DEFAULT_DIMS,
670                Self::DEFAULT_MAX_TOKENS,
671            ),
672            max_tokens: Self::DEFAULT_MAX_TOKENS,
673            split_sentences: true,
674        }
675    }
676
677    /// Create with custom max tokens and optional sentence splitting.
678    pub fn with_config(
679        model_path: impl Into<PathBuf>,
680        max_tokens: usize,
681        split_sentences: bool,
682    ) -> Self {
683        Self {
684            base: OnnxEmbeddingProvider::with_config(
685                model_path,
686                OnnxEmbeddingProvider::DEFAULT_DIMS,
687                max_tokens,
688            ),
689            max_tokens,
690            split_sentences,
691        }
692    }
693
694    /// Whether sentence splitting is enabled.
695    pub fn split_sentences(&self) -> bool {
696        self.split_sentences
697    }
698
699    /// Max token length.
700    pub fn max_tokens(&self) -> usize {
701        self.max_tokens
702    }
703
704    /// Embed a single sentence/paragraph through the base provider.
705    fn embed_text(&self, text: &str) -> Vec<f32> {
706        self.base.hash_embed(text)
707    }
708}
709
710/// Pre-process markdown text by stripping structural elements.
711pub fn preprocess_markdown(text: &str) -> String {
712    text.lines()
713        .filter(|l| !l.starts_with('#'))      // skip headers
714        .filter(|l| !l.starts_with("```"))     // skip code fences
715        .filter(|l| !l.starts_with('|'))       // skip tables
716        .filter(|l| !l.starts_with("- ["))     // skip checklists
717        .map(|l| l.trim())
718        .filter(|l| !l.is_empty())
719        .collect::<Vec<_>>()
720        .join(" ")
721}
722
723/// Simple sentence splitting on ". " and newlines.
724pub fn split_sentences(text: &str) -> Vec<&str> {
725    text.split(". ")
726        .flat_map(|s| s.split('\n'))
727        .map(|s| s.trim())
728        .filter(|s| s.len() > 10) // skip tiny fragments
729        .collect()
730}
731
732#[async_trait]
733impl EmbeddingProvider for SentenceTransformerProvider {
734    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
735        let cleaned = preprocess_markdown(text);
736
737        if !self.split_sentences {
738            return Ok(self.embed_text(&cleaned));
739        }
740
741        let sentences = split_sentences(&cleaned);
742        if sentences.is_empty() {
743            // Fall through to full-text embedding.
744            return Ok(self.embed_text(&cleaned));
745        }
746
747        // Mean pooling across sentence embeddings.
748        let dims = self.base.dimensions;
749        let mut summed = vec![0.0f32; dims];
750        let count = sentences.len() as f32;
751
752        for sentence in &sentences {
753            let vec = self.embed_text(sentence);
754            for (i, val) in vec.iter().enumerate() {
755                summed[i] += val;
756            }
757        }
758
759        // Average and re-normalise.
760        summed.iter_mut().for_each(|x| *x /= count);
761        l2_normalize(&mut summed);
762
763        Ok(summed)
764    }
765
766    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
767        let mut results = Vec::with_capacity(texts.len());
768        for text in texts {
769            results.push(self.embed(text).await?);
770        }
771        Ok(results)
772    }
773
774    fn dimensions(&self) -> usize {
775        self.base.dimensions
776    }
777
778    fn model_name(&self) -> &str {
779        "sentence-transformer"
780    }
781}
782
783// =========================================================================
784// Backend 3: AstEmbeddingProvider
785// =========================================================================
786
787/// Structural features extracted from Rust source code via regex parsing.
788#[derive(Debug, Clone, Default, Serialize, Deserialize)]
789pub struct RustCodeFeatures {
790    /// Function/type/impl signature.
791    pub signature: Option<String>,
792    /// Return type.
793    pub return_type: Option<String>,
794    /// Parameter types.
795    pub param_types: Vec<String>,
796    /// Visibility (pub, pub(crate), private).
797    pub visibility: String,
798    /// Whether it is async.
799    pub is_async: bool,
800    /// Whether it is generic.
801    pub is_generic: bool,
802    /// Trait bounds.
803    pub trait_bounds: Vec<String>,
804    /// Attributes (#[test], #[cfg(...)], etc.)
805    pub attributes: Vec<String>,
806    /// Item kind (fn, struct, enum, impl, trait, mod).
807    pub item_kind: String,
808}
809
810/// Extract structural features from a Rust code snippet using simple regex-like
811/// parsing.  Does not depend on tree-sitter so the kernel stays lightweight.
812pub fn extract_rust_features(code: &str) -> RustCodeFeatures {
813    let mut features = RustCodeFeatures::default();
814
815    // -- Item kind --------------------------------------------------------
816    // Order matters: check trait/struct/enum before fn, because trait and
817    // impl bodies often contain `fn` keywords.
818    if code.contains("pub trait ") || code.contains("trait ") {
819        features.item_kind = "trait".into();
820    } else if code.contains("pub struct ") || code.contains("struct ") {
821        features.item_kind = "struct".into();
822    } else if code.contains("pub enum ") || code.contains("enum ") {
823        features.item_kind = "enum".into();
824    } else if code.contains("pub fn ") || code.contains("fn ") {
825        features.item_kind = "fn".into();
826    } else if code.contains("impl ") {
827        features.item_kind = "impl".into();
828    } else if code.contains("pub mod ") || code.contains("mod ") {
829        features.item_kind = "mod".into();
830    }
831
832    // -- Visibility -------------------------------------------------------
833    features.visibility = if code.contains("pub(crate)") {
834        "pub(crate)".into()
835    } else if code.contains("pub(super)") {
836        "pub(super)".into()
837    } else if code.contains("pub ") {
838        "pub".into()
839    } else {
840        "private".into()
841    };
842
843    // -- Async ------------------------------------------------------------
844    features.is_async = code.contains("async fn");
845
846    // -- Generics ---------------------------------------------------------
847    features.is_generic = code.contains('<') && code.contains('>');
848
849    // -- Signature (first fn/struct/enum/trait line) -----------------------
850    for line in code.lines() {
851        let trimmed = line.trim();
852        if trimmed.contains("fn ")
853            || trimmed.starts_with("pub struct ")
854            || trimmed.starts_with("struct ")
855            || trimmed.starts_with("pub enum ")
856            || trimmed.starts_with("enum ")
857            || trimmed.starts_with("pub trait ")
858            || trimmed.starts_with("trait ")
859        {
860            // Take up to '{' or end of line.
861            let sig = if let Some(brace) = trimmed.find('{') {
862                trimmed[..brace].trim()
863            } else {
864                trimmed.trim_end_matches(';').trim()
865            };
866            features.signature = Some(sig.to_string());
867            break;
868        }
869    }
870
871    // -- Return type (after -> before { or ;) -----------------------------
872    if let Some(arrow) = code.find("->") {
873        let after = &code[arrow + 2..];
874        if let Some(brace) = after.find('{') {
875            features.return_type = Some(after[..brace].trim().to_string());
876        } else if let Some(semi) = after.find(';') {
877            features.return_type = Some(after[..semi].trim().to_string());
878        }
879    }
880
881    // -- Parameter types (inside parentheses of fn) -----------------------
882    if features.item_kind == "fn"
883        && let Some(open) = code.find('(')
884        && let Some(close) = code.find(')')
885        && close > open
886    {
887        let params = &code[open + 1..close];
888        for param in params.split(',') {
889            let param = param.trim();
890            if param == "&self" || param == "&mut self" || param == "self" {
891                features.param_types.push(param.to_string());
892            } else if let Some(colon) = param.find(':') {
893                let ty = param[colon + 1..].trim().to_string();
894                if !ty.is_empty() {
895                    features.param_types.push(ty);
896                }
897            }
898        }
899    }
900
901    // -- Trait bounds (where clause) --------------------------------------
902    if let Some(where_idx) = code.find("where") {
903        let after = &code[where_idx + 5..];
904        let end = after.find('{').unwrap_or(after.len());
905        let clause = &after[..end];
906        for bound in clause.split(',') {
907            let bound = bound.trim();
908            if !bound.is_empty() {
909                features.trait_bounds.push(bound.to_string());
910            }
911        }
912    }
913
914    // -- Attributes -------------------------------------------------------
915    for line in code.lines() {
916        let trimmed = line.trim();
917        if trimmed.starts_with("#[") {
918            features.attributes.push(trimmed.to_string());
919        }
920    }
921
922    features
923}
924
925/// AST-aware embedding provider for Rust source code.
926///
927/// Combines structural features (signature, types, visibility) with semantic
928/// text embeddings for hybrid code understanding.
929pub struct AstEmbeddingProvider {
930    /// Base text embedding provider.
931    text_provider: OnnxEmbeddingProvider,
932    /// Dimensions allocated to structural features.
933    structural_dims: usize,
934    /// Total output dimensions.
935    total_dims: usize,
936    /// Weight for structural vs. text features [0.0, 1.0].
937    structural_weight: f32,
938}
939
940impl AstEmbeddingProvider {
941    /// Default total output dimensionality.
942    pub const DEFAULT_TOTAL_DIMS: usize = 256;
943    /// Default structural feature dimensions.
944    pub const DEFAULT_STRUCTURAL_DIMS: usize = 64;
945    /// Default weight for structural features.
946    pub const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.3;
947
948    /// Create a new AST-aware provider with default configuration.
949    pub fn new(model_path: impl Into<PathBuf>) -> Self {
950        Self {
951            text_provider: OnnxEmbeddingProvider::with_config(
952                model_path,
953                Self::DEFAULT_TOTAL_DIMS - Self::DEFAULT_STRUCTURAL_DIMS,
954                OnnxEmbeddingProvider::DEFAULT_MAX_TOKENS,
955            ),
956            structural_dims: Self::DEFAULT_STRUCTURAL_DIMS,
957            total_dims: Self::DEFAULT_TOTAL_DIMS,
958            structural_weight: Self::DEFAULT_STRUCTURAL_WEIGHT,
959        }
960    }
961
962    /// Create with custom configuration.
963    pub fn with_config(
964        model_path: impl Into<PathBuf>,
965        total_dims: usize,
966        structural_dims: usize,
967        structural_weight: f32,
968    ) -> Self {
969        assert!(
970            structural_dims < total_dims,
971            "structural_dims must be less than total_dims"
972        );
973        let text_dims = total_dims - structural_dims;
974        Self {
975            text_provider: OnnxEmbeddingProvider::with_config(
976                model_path,
977                text_dims,
978                OnnxEmbeddingProvider::DEFAULT_MAX_TOKENS,
979            ),
980            structural_dims,
981            total_dims,
982            structural_weight: structural_weight.clamp(0.0, 1.0),
983        }
984    }
985
986    /// Total output dimensions.
987    pub fn total_dims(&self) -> usize {
988        self.total_dims
989    }
990
991    /// Weight applied to structural features.
992    pub fn structural_weight(&self) -> f32 {
993        self.structural_weight
994    }
995
996    /// Encode [`RustCodeFeatures`] into a fixed-size structural vector.
997    fn encode_structural(&self, features: &RustCodeFeatures) -> Vec<f32> {
998        let dims = self.structural_dims;
999        let mut vec = vec![0.0f32; dims];
1000
1001        // Hash each feature category into different regions of the vector.
1002        let mut write_hash = |label: &str, offset: usize, slots: usize| {
1003            let mut hasher = Sha256::new();
1004            hasher.update(label.as_bytes());
1005            let hash = hasher.finalize();
1006            for (j, &byte) in hash.iter().enumerate().take(slots.min(32)) {
1007                let dim = (offset + j) % dims;
1008                vec[dim] += (byte as f32 / 128.0) - 1.0;
1009            }
1010        };
1011
1012        // Item kind (fn, struct, enum, ...).
1013        write_hash(&format!("kind:{}", features.item_kind), 0, 8);
1014
1015        // Visibility.
1016        write_hash(&format!("vis:{}", features.visibility), 8, 6);
1017
1018        // Async flag.
1019        if features.is_async {
1020            write_hash("async:true", 14, 4);
1021        }
1022
1023        // Generic flag.
1024        if features.is_generic {
1025            write_hash("generic:true", 18, 4);
1026        }
1027
1028        // Return type.
1029        if let Some(ref rt) = features.return_type {
1030            write_hash(&format!("ret:{rt}"), 22, 8);
1031        }
1032
1033        // Parameter types.
1034        for (i, pt) in features.param_types.iter().enumerate() {
1035            write_hash(&format!("param{i}:{pt}"), 30 + i * 6, 6);
1036        }
1037
1038        // Attributes.
1039        for (i, attr) in features.attributes.iter().enumerate() {
1040            write_hash(&format!("attr{i}:{attr}"), 48 + i * 4, 4);
1041        }
1042
1043        l2_normalize(&mut vec);
1044        vec
1045    }
1046
1047    /// Produce the hybrid embedding for a Rust code snippet.
1048    fn hybrid_embed(&self, code: &str) -> Vec<f32> {
1049        let features = extract_rust_features(code);
1050        let structural = self.encode_structural(&features);
1051        let text = self.text_provider.hash_embed(code);
1052
1053        let w_s = self.structural_weight;
1054        let w_t = 1.0 - w_s;
1055
1056        // Concatenate weighted structural + text vectors.
1057        let mut combined = Vec::with_capacity(self.total_dims);
1058        for val in &structural {
1059            combined.push(val * w_s);
1060        }
1061        for val in &text {
1062            combined.push(val * w_t);
1063        }
1064
1065        // Ensure exact dimensionality.
1066        combined.truncate(self.total_dims);
1067        while combined.len() < self.total_dims {
1068            combined.push(0.0);
1069        }
1070
1071        l2_normalize(&mut combined);
1072        combined
1073    }
1074}
1075
1076#[async_trait]
1077impl EmbeddingProvider for AstEmbeddingProvider {
1078    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
1079        Ok(self.hybrid_embed(text))
1080    }
1081
1082    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
1083        Ok(texts.iter().map(|t| self.hybrid_embed(t)).collect())
1084    }
1085
1086    fn dimensions(&self) -> usize {
1087        self.total_dims
1088    }
1089
1090    fn model_name(&self) -> &str {
1091        "ast-aware-hybrid"
1092    }
1093}
1094
1095// =========================================================================
1096// Tests
1097// =========================================================================
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102
1103    // -- Helper -----------------------------------------------------------
1104
1105    fn vec_magnitude(v: &[f32]) -> f32 {
1106        v.iter().map(|x| x * x).sum::<f32>().sqrt()
1107    }
1108
1109    // =====================================================================
1110    // OnnxEmbeddingProvider tests
1111    // =====================================================================
1112
1113    #[test]
1114    fn onnx_construction_default() {
1115        let p = OnnxEmbeddingProvider::new("/nonexistent/model.onnx");
1116        assert_eq!(p.dimensions(), 384);
1117        assert!(!p.is_runtime_available());
1118        assert!(p.model_name().contains("fallback"));
1119    }
1120
1121    #[test]
1122    fn onnx_construction_custom() {
1123        let p = OnnxEmbeddingProvider::with_config("/tmp/model.onnx", 128, 64);
1124        assert_eq!(p.dimensions(), 128);
1125        assert_eq!(p.max_tokens(), 64);
1126    }
1127
1128    #[tokio::test]
1129    async fn onnx_embed_returns_correct_dimensions() {
1130        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1131        let vec = p.embed("hello world").await.unwrap();
1132        assert_eq!(vec.len(), 384);
1133    }
1134
1135    #[tokio::test]
1136    async fn onnx_embed_deterministic() {
1137        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1138        let v1 = p.embed("test input").await.unwrap();
1139        let v2 = p.embed("test input").await.unwrap();
1140        assert_eq!(v1, v2);
1141    }
1142
1143    #[tokio::test]
1144    async fn onnx_embed_different_inputs_differ() {
1145        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1146        let v1 = p.embed("alpha").await.unwrap();
1147        let v2 = p.embed("beta").await.unwrap();
1148        assert_ne!(v1, v2);
1149    }
1150
1151    #[tokio::test]
1152    async fn onnx_embed_l2_normalized() {
1153        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1154        let vec = p.embed("normalisation check").await.unwrap();
1155        let mag = vec_magnitude(&vec);
1156        assert!((mag - 1.0).abs() < 0.01, "magnitude = {mag}, expected ~1.0");
1157    }
1158
1159    #[tokio::test]
1160    async fn onnx_embed_batch() {
1161        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1162        let results = p.embed_batch(&["a", "b", "c"]).await.unwrap();
1163        assert_eq!(results.len(), 3);
1164        for v in &results {
1165            assert_eq!(v.len(), 384);
1166        }
1167    }
1168
1169    #[tokio::test]
1170    async fn onnx_similar_inputs_high_cosine() {
1171        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1172        let v1 = p.embed("the quick brown fox").await.unwrap();
1173        let v2 = p.embed("the quick brown dog").await.unwrap();
1174        let sim = cosine_similarity(&v1, &v2);
1175        assert!(sim > 0.5, "similar inputs cosine = {sim}, expected > 0.5");
1176    }
1177
1178    #[tokio::test]
1179    async fn onnx_empty_input_returns_zero_vector() {
1180        let p = OnnxEmbeddingProvider::new("/tmp/model.onnx");
1181        let vec = p.embed("").await.unwrap();
1182        assert_eq!(vec.len(), 384);
1183        assert!(vec.iter().all(|x| *x == 0.0));
1184    }
1185
1186    // =====================================================================
1187    // SentenceTransformerProvider tests
1188    // =====================================================================
1189
1190    #[test]
1191    fn sentence_construction() {
1192        let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1193        assert_eq!(p.dimensions(), 384);
1194        assert_eq!(p.max_tokens(), 512);
1195        assert!(p.split_sentences());
1196        assert_eq!(p.model_name(), "sentence-transformer");
1197    }
1198
1199    #[tokio::test]
1200    async fn sentence_embed_returns_correct_dimensions() {
1201        let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1202        let vec = p.embed("This is a test paragraph with enough words.").await.unwrap();
1203        assert_eq!(vec.len(), 384);
1204    }
1205
1206    #[tokio::test]
1207    async fn sentence_embed_l2_normalized() {
1208        let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1209        let vec = p.embed("Testing normalisation of sentence embeddings here.").await.unwrap();
1210        let mag = vec_magnitude(&vec);
1211        assert!((mag - 1.0).abs() < 0.01, "magnitude = {mag}, expected ~1.0");
1212    }
1213
1214    #[tokio::test]
1215    async fn sentence_embed_batch() {
1216        let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1217        let results = p
1218            .embed_batch(&[
1219                "First paragraph with a decent amount of words in it.",
1220                "Second paragraph also has a reasonable length for testing.",
1221            ])
1222            .await
1223            .unwrap();
1224        assert_eq!(results.len(), 2);
1225        for v in &results {
1226            assert_eq!(v.len(), 384);
1227        }
1228    }
1229
1230    #[tokio::test]
1231    async fn sentence_similar_inputs_positive_cosine() {
1232        let p = SentenceTransformerProvider::new("/tmp/model.onnx");
1233        let v1 = p.embed("The kernel boots up the system and runs all the services correctly.").await.unwrap();
1234        let v2 = p.embed("The kernel boots up the system and runs all the services properly.").await.unwrap();
1235        let v3 = p.embed("Quantum chromodynamics explains the strong interaction between quarks.").await.unwrap();
1236        let sim_similar = cosine_similarity(&v1, &v2);
1237        let sim_different = cosine_similarity(&v1, &v3);
1238        assert!(
1239            sim_similar > sim_different,
1240            "similar ({sim_similar}) should be closer than different ({sim_different})"
1241        );
1242    }
1243
1244    #[test]
1245    fn preprocess_markdown_strips_headers() {
1246        let md = "# Title\nSome text.\n## Subtitle\nMore text.";
1247        let result = preprocess_markdown(md);
1248        assert!(!result.contains("Title"));
1249        assert!(!result.contains("Subtitle"));
1250        assert!(result.contains("Some text."));
1251        assert!(result.contains("More text."));
1252    }
1253
1254    #[test]
1255    fn preprocess_markdown_strips_code_fences() {
1256        let md = "Before.\n```rust\nlet x = 1;\n```\nAfter.";
1257        let result = preprocess_markdown(md);
1258        assert!(!result.contains("```"));
1259        // Code line itself is kept (only fence markers are stripped).
1260        assert!(result.contains("Before."));
1261        assert!(result.contains("After."));
1262    }
1263
1264    #[test]
1265    fn preprocess_markdown_strips_tables() {
1266        let md = "Intro.\n| Col1 | Col2 |\n|------|------|\n| A | B |\nOutro.";
1267        let result = preprocess_markdown(md);
1268        assert!(!result.contains("Col1"));
1269        assert!(result.contains("Intro."));
1270        assert!(result.contains("Outro."));
1271    }
1272
1273    #[test]
1274    fn preprocess_markdown_strips_checklists() {
1275        let md = "Text here.\n- [x] Done item\n- [ ] Todo item\nMore text.";
1276        let result = preprocess_markdown(md);
1277        assert!(!result.contains("Done item"));
1278        assert!(result.contains("Text here."));
1279    }
1280
1281    #[test]
1282    fn split_sentences_basic() {
1283        let text = "First sentence here. Second sentence here. Third.";
1284        let sentences = split_sentences(text);
1285        // "Third." is only 6 chars, below the 10-char minimum.
1286        assert_eq!(sentences.len(), 2);
1287        assert!(sentences[0].contains("First"));
1288        assert!(sentences[1].contains("Second"));
1289    }
1290
1291    #[test]
1292    fn split_sentences_newlines() {
1293        let text = "Line one is long enough.\nLine two is also long enough.";
1294        let sentences = split_sentences(text);
1295        assert_eq!(sentences.len(), 2);
1296    }
1297
1298    // =====================================================================
1299    // AstEmbeddingProvider tests
1300    // =====================================================================
1301
1302    #[test]
1303    fn ast_construction_default() {
1304        let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1305        assert_eq!(p.dimensions(), 256);
1306        assert_eq!(p.total_dims(), 256);
1307        assert!((p.structural_weight() - 0.3).abs() < 0.001);
1308        assert_eq!(p.model_name(), "ast-aware-hybrid");
1309    }
1310
1311    #[tokio::test]
1312    async fn ast_embed_returns_correct_dimensions() {
1313        let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1314        let vec = p.embed("pub fn hello() -> String { }").await.unwrap();
1315        assert_eq!(vec.len(), 256);
1316    }
1317
1318    #[tokio::test]
1319    async fn ast_embed_l2_normalized() {
1320        let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1321        let vec = p.embed("pub fn hello() -> String { }").await.unwrap();
1322        let mag = vec_magnitude(&vec);
1323        assert!((mag - 1.0).abs() < 0.01, "magnitude = {mag}, expected ~1.0");
1324    }
1325
1326    #[tokio::test]
1327    async fn ast_embed_batch() {
1328        let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1329        let results = p
1330            .embed_batch(&["fn a() {}", "fn b() {}", "struct C {}"])
1331            .await
1332            .unwrap();
1333        assert_eq!(results.len(), 3);
1334        for v in &results {
1335            assert_eq!(v.len(), 256);
1336        }
1337    }
1338
1339    #[tokio::test]
1340    async fn ast_embed_different_inputs_differ() {
1341        let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1342        let v1 = p.embed("pub fn alpha() -> u32 {}").await.unwrap();
1343        let v2 = p.embed("struct Beta { x: f64 }").await.unwrap();
1344        assert_ne!(v1, v2);
1345    }
1346
1347    #[tokio::test]
1348    async fn ast_structural_similarity_same_signature() {
1349        // Two functions with same signature but different names should be
1350        // closer than two items with different signatures.
1351        let p = AstEmbeddingProvider::new("/tmp/model.onnx");
1352        let v_foo = p
1353            .embed("pub async fn foo(&self, x: u32) -> Result<(), Error> {}")
1354            .await
1355            .unwrap();
1356        let v_bar = p
1357            .embed("pub async fn bar(&self, x: u32) -> Result<(), Error> {}")
1358            .await
1359            .unwrap();
1360        let v_struct = p.embed("pub struct Baz { count: usize }").await.unwrap();
1361
1362        let sim_fns = cosine_similarity(&v_foo, &v_bar);
1363        let sim_fn_struct = cosine_similarity(&v_foo, &v_struct);
1364        assert!(
1365            sim_fns > sim_fn_struct,
1366            "same-signature fns ({sim_fns}) should be more similar than fn-vs-struct ({sim_fn_struct})"
1367        );
1368    }
1369
1370    // =====================================================================
1371    // extract_rust_features tests
1372    // =====================================================================
1373
1374    #[test]
1375    fn rust_features_pub_async_fn() {
1376        let code = r#"
1377#[test]
1378pub async fn process_batch(&self, items: Vec<Item>) -> Result<(), Error> {
1379    // body
1380}
1381"#;
1382        let f = extract_rust_features(code);
1383        assert_eq!(f.item_kind, "fn");
1384        assert_eq!(f.visibility, "pub");
1385        assert!(f.is_async);
1386        assert!(f.is_generic);
1387        assert_eq!(f.return_type.as_deref(), Some("Result<(), Error>"));
1388        assert!(f.attributes.contains(&"#[test]".to_string()));
1389        assert!(f.param_types.contains(&"&self".to_string()));
1390        assert!(f.param_types.iter().any(|p| p.contains("Vec<Item>")));
1391    }
1392
1393    #[test]
1394    fn rust_features_struct() {
1395        let code = "pub struct Config { pub name: String, pub value: u64 }";
1396        let f = extract_rust_features(code);
1397        assert_eq!(f.item_kind, "struct");
1398        assert_eq!(f.visibility, "pub");
1399        assert!(!f.is_async);
1400        assert!(!f.is_generic); // no < > in this struct
1401        assert!(f.return_type.is_none());
1402    }
1403
1404    #[test]
1405    fn rust_features_private_fn() {
1406        let code = "fn helper(x: &str) -> bool { true }";
1407        let f = extract_rust_features(code);
1408        assert_eq!(f.item_kind, "fn");
1409        assert_eq!(f.visibility, "private");
1410        assert!(!f.is_async);
1411        assert_eq!(f.return_type.as_deref(), Some("bool"));
1412        assert!(f.param_types.iter().any(|p| p.contains("&str")));
1413    }
1414
1415    #[test]
1416    fn rust_features_enum() {
1417        let code = "pub enum Status { Active, Inactive, Pending }";
1418        let f = extract_rust_features(code);
1419        assert_eq!(f.item_kind, "enum");
1420        assert_eq!(f.visibility, "pub");
1421    }
1422
1423    #[test]
1424    fn rust_features_trait() {
1425        let code = "pub trait Displayable { fn display(&self) -> String; }";
1426        let f = extract_rust_features(code);
1427        assert_eq!(f.item_kind, "trait");
1428        assert_eq!(f.visibility, "pub");
1429    }
1430
1431    #[test]
1432    fn rust_features_impl_block() {
1433        let code = "impl MyStruct { fn new() -> Self { Self {} } }";
1434        let f = extract_rust_features(code);
1435        // "fn" is detected before "impl" because code.contains("fn ")
1436        assert_eq!(f.item_kind, "fn");
1437    }
1438
1439    #[test]
1440    fn rust_features_where_clause() {
1441        let code = "pub fn serialize<T>(val: T) -> String where T: Serialize + Debug { }";
1442        let f = extract_rust_features(code);
1443        assert!(f.is_generic);
1444        assert!(!f.trait_bounds.is_empty());
1445        assert!(f.trait_bounds.iter().any(|b| b.contains("Serialize")));
1446    }
1447
1448    #[test]
1449    fn rust_features_pub_crate() {
1450        let code = "pub(crate) fn internal_helper() {}";
1451        let f = extract_rust_features(code);
1452        assert_eq!(f.visibility, "pub(crate)");
1453    }
1454
1455    #[test]
1456    fn rust_features_multiple_attributes() {
1457        let code = "#[cfg(test)]\n#[allow(dead_code)]\nfn test_fn() {}";
1458        let f = extract_rust_features(code);
1459        assert_eq!(f.attributes.len(), 2);
1460        assert!(f.attributes.contains(&"#[cfg(test)]".to_string()));
1461        assert!(f.attributes.contains(&"#[allow(dead_code)]".to_string()));
1462    }
1463
1464    // =====================================================================
1465    // Tokenisation helper tests
1466    // =====================================================================
1467
1468    #[test]
1469    fn simple_tokenize_basic() {
1470        let tokens = simple_tokenize("Hello World! Foo-bar", 10);
1471        assert_eq!(tokens, vec!["hello", "world", "foobar"]);
1472    }
1473
1474    #[test]
1475    fn simple_tokenize_max_tokens() {
1476        let tokens = simple_tokenize("a b c d e f", 3);
1477        assert_eq!(tokens.len(), 3);
1478    }
1479
1480    #[test]
1481    fn simple_tokenize_empty() {
1482        let tokens = simple_tokenize("", 10);
1483        assert!(tokens.is_empty());
1484    }
1485
1486    #[test]
1487    fn tokens_to_embedding_deterministic() {
1488        let tokens: Vec<String> = vec!["hello".into(), "world".into()];
1489        let v1 = tokens_to_embedding(&tokens, 64);
1490        let v2 = tokens_to_embedding(&tokens, 64);
1491        assert_eq!(v1, v2);
1492    }
1493
1494    #[test]
1495    fn tokens_to_embedding_normalized() {
1496        let tokens: Vec<String> = vec!["test".into()];
1497        let v = tokens_to_embedding(&tokens, 128);
1498        let mag = vec_magnitude(&v);
1499        assert!((mag - 1.0).abs() < 0.01);
1500    }
1501
1502    // =====================================================================
1503    // WordPiece tokenizer tests
1504    // =====================================================================
1505
1506    /// Create a small test vocab file for WordPiece tests.
1507    /// Returns the path to the written file.
1508    fn make_test_vocab() -> PathBuf {
1509        use std::fmt::Write as FmtWrite;
1510        let mut content = String::new();
1511        // Build a minimal BERT-style vocab (needs >1000 entries).
1512        // IDs 0-99: [unused0]..[unused99]
1513        for i in 0..100 {
1514            writeln!(content, "[unused{}]", i).unwrap();
1515        }
1516        writeln!(content, "[UNK]").unwrap();   // ID 100
1517        writeln!(content, "[CLS]").unwrap();   // ID 101
1518        writeln!(content, "[SEP]").unwrap();   // ID 102
1519        writeln!(content, "[MASK]").unwrap();  // ID 103
1520        for i in 104..1000 {
1521            writeln!(content, "[unused{}]", i).unwrap();
1522        }
1523        // ID 1000+: real tokens
1524        let words = [
1525            "the", "a", "is", "of", "and", "to", "in", "for", "that", "it",
1526            "hello", "world", "test", "input", "embedding", "model", "token",
1527            "##s", "##ing", "##ed", "##er", "##tion", "##ly", "##ize",
1528            ".", ",", "!", "?",
1529            "quick", "brown", "fox", "dog", "cat", "rust", "code",
1530            "function", "struct", "pub", "async", "fn",
1531        ];
1532        for w in &words {
1533            writeln!(content, "{}", w).unwrap();
1534        }
1535        // Pad to >1000 entries total.
1536        for i in 0..100 {
1537            writeln!(content, "extra{}", i).unwrap();
1538        }
1539
1540        let path = PathBuf::from(format!(
1541            "/tmp/clawft_test_vocab_{}.txt",
1542            std::process::id()
1543        ));
1544        std::fs::write(&path, &content).expect("failed to write test vocab");
1545        path
1546    }
1547
1548    #[test]
1549    fn wordpiece_load_valid_vocab() {
1550        let f = make_test_vocab();
1551        let tok = WordPieceTokenizer::load(&f);
1552        assert!(tok.is_some(), "should load a vocab with >1000 entries");
1553    }
1554
1555    #[test]
1556    fn wordpiece_load_missing_file() {
1557        let tok = WordPieceTokenizer::load(Path::new("/nonexistent/vocab.txt"));
1558        assert!(tok.is_none());
1559    }
1560
1561    #[test]
1562    fn wordpiece_encode_produces_cls_sep() {
1563        let f = make_test_vocab();
1564        let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(32);
1565        let (ids, mask, types) = tok.encode("hello world").unwrap();
1566        assert_eq!(ids.len(), 32, "should be padded to max_length");
1567        assert_eq!(ids[0], CLS_ID, "first token must be [CLS]");
1568        // Find [SEP] -- it should be after the content tokens.
1569        let sep_pos = ids.iter().position(|&x| x == SEP_ID);
1570        assert!(sep_pos.is_some(), "must contain [SEP]");
1571        let sep_pos = sep_pos.unwrap();
1572        assert!(sep_pos >= 2, "[SEP] should come after at least one content token");
1573        // Attention mask: 1s up to and including [SEP], then 0s.
1574        assert_eq!(mask[0], 1);
1575        assert_eq!(mask[sep_pos], 1);
1576        if sep_pos + 1 < 32 {
1577            assert_eq!(mask[sep_pos + 1], 0, "padding should have mask=0");
1578        }
1579        // Token type IDs should all be 0 for single-sentence input.
1580        assert!(types.iter().all(|&t| t == 0));
1581    }
1582
1583    #[test]
1584    fn wordpiece_encode_known_tokens() {
1585        let f = make_test_vocab();
1586        let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(16);
1587        let (ids, _, _) = tok.encode("hello").unwrap();
1588        // "hello" is in our test vocab -- should NOT be [UNK].
1589        let content_ids: Vec<i64> = ids[1..].iter()
1590            .take_while(|&&x| x != SEP_ID)
1591            .cloned()
1592            .collect();
1593        assert!(!content_ids.is_empty(), "should tokenize 'hello' to at least one token");
1594        assert!(
1595            content_ids.iter().any(|&id| id != UNK_ID),
1596            "known word 'hello' should not be all [UNK]"
1597        );
1598    }
1599
1600    #[test]
1601    fn wordpiece_encode_unknown_token_uses_unk() {
1602        let f = make_test_vocab();
1603        let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(16);
1604        let (ids, _, _) = tok.encode("xyzzyplugh").unwrap();
1605        // "xyzzyplugh" is not in our vocab, so it should produce [UNK].
1606        let content_ids: Vec<i64> = ids[1..].iter()
1607            .take_while(|&&x| x != SEP_ID)
1608            .cloned()
1609            .collect();
1610        assert!(
1611            content_ids.contains(&UNK_ID),
1612            "unknown word should produce [UNK] token"
1613        );
1614    }
1615
1616    #[test]
1617    fn wordpiece_encode_truncates_long_input() {
1618        let f = make_test_vocab();
1619        let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(8);
1620        // Input much longer than max_length=8.
1621        let long_input = "the quick brown fox hello world test input embedding model";
1622        let (ids, mask, _) = tok.encode(long_input).unwrap();
1623        assert_eq!(ids.len(), 8, "output must be exactly max_length");
1624        assert_eq!(mask.len(), 8);
1625        assert_eq!(ids[0], CLS_ID);
1626        // [SEP] must be present.
1627        assert!(ids.contains(&SEP_ID));
1628    }
1629
1630    #[test]
1631    fn wordpiece_encode_empty_input() {
1632        let f = make_test_vocab();
1633        let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(16);
1634        let (ids, mask, _) = tok.encode("").unwrap();
1635        assert_eq!(ids[0], CLS_ID);
1636        assert_eq!(ids[1], SEP_ID);
1637        // Rest should be padding.
1638        assert!(ids[2..].iter().all(|&x| x == PAD_ID));
1639        assert_eq!(mask[0], 1);
1640        assert_eq!(mask[1], 1);
1641        assert!(mask[2..].iter().all(|&x| x == 0));
1642    }
1643
1644    #[test]
1645    fn wordpiece_pre_tokenize_punctuation() {
1646        let f = make_test_vocab();
1647        let tok = WordPieceTokenizer::load(&f).unwrap();
1648        let words = tok.pre_tokenize("Hello, World!");
1649        // Should split into: ["hello", ",", "world", "!"]
1650        assert!(words.contains(&",".to_string()));
1651        assert!(words.contains(&"!".to_string()));
1652        assert!(words.contains(&"hello".to_string()));
1653        assert!(words.contains(&"world".to_string()));
1654    }
1655
1656    #[test]
1657    fn wordpiece_subword_splitting() {
1658        let f = make_test_vocab();
1659        let tok = WordPieceTokenizer::load(&f).unwrap();
1660        // "tokens" should split into "token" + "##s" since both are in vocab.
1661        let ids = tok.wordpiece_split("tokens");
1662        // If "token" and "##s" are in the vocab, we should get 2 IDs (neither UNK).
1663        assert!(
1664            ids.len() >= 1,
1665            "should produce at least one subword token"
1666        );
1667    }
1668
1669    #[test]
1670    fn wordpiece_deterministic() {
1671        let f = make_test_vocab();
1672        let tok = WordPieceTokenizer::load(&f).unwrap().with_max_length(32);
1673        let (ids1, _, _) = tok.encode("the quick brown fox").unwrap();
1674        let (ids2, _, _) = tok.encode("the quick brown fox").unwrap();
1675        assert_eq!(ids1, ids2, "encoding must be deterministic");
1676    }
1677
1678    #[test]
1679    fn vocab_search_paths_finds_sibling() {
1680        let paths = vocab_search_paths(Path::new("/models/all-MiniLM-L6-v2.onnx"));
1681        assert!(paths.iter().any(|p| p.ends_with("vocab.txt")));
1682        assert!(
1683            paths.iter().any(|p| p.to_string_lossy().contains("all-MiniLM-L6-v2/vocab.txt")),
1684            "should check sibling directory: {:?}",
1685            paths
1686        );
1687    }
1688}