Skip to main content

bones_search/semantic/
model.rs

1use anyhow::{Context, Result, anyhow, bail};
2use sha2::{Digest, Sha256};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6#[cfg(feature = "semantic-ort")]
7use ort::{session::Session, value::Tensor};
8#[cfg(feature = "semantic-ort")]
9use std::io::Write;
10#[cfg(feature = "semantic-ort")]
11use std::sync::Mutex;
12#[cfg(feature = "semantic-ort")]
13use std::sync::atomic::{AtomicBool, Ordering};
14#[cfg(feature = "semantic-ort")]
15use std::time::Duration;
16#[cfg(feature = "semantic-ort")]
17use tokenizers::Tokenizer;
18
19const MODEL_FILENAME: &str = "minilm-l6-v2-int8.onnx";
20#[cfg(feature = "semantic-ort")]
21const TOKENIZER_FILENAME: &str = "minilm-l6-v2-tokenizer.json";
22#[cfg(feature = "semantic-ort")]
23const MAX_TOKENS: usize = 256;
24#[cfg(feature = "semantic-ort")]
25const MODEL_DOWNLOAD_URL_ENV: &str = "BONES_SEMANTIC_MODEL_URL";
26#[cfg(feature = "semantic-ort")]
27const TOKENIZER_DOWNLOAD_URL_ENV: &str = "BONES_SEMANTIC_TOKENIZER_URL";
28#[cfg(feature = "semantic-ort")]
29const AUTO_DOWNLOAD_ENV: &str = "BONES_SEMANTIC_AUTO_DOWNLOAD";
30#[cfg(feature = "semantic-ort")]
31const MODEL_DOWNLOAD_URL_DEFAULT: &str =
32    "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
33#[cfg(feature = "semantic-ort")]
34const TOKENIZER_DOWNLOAD_URL_DEFAULT: &str =
35    "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
36#[cfg(feature = "semantic-ort")]
37const DOWNLOAD_CONNECT_TIMEOUT_SECS: u64 = 2;
38#[cfg(feature = "semantic-ort")]
39const DOWNLOAD_READ_TIMEOUT_SECS: u64 = 30;
40
41#[cfg(feature = "semantic-ort")]
42static MODEL_DOWNLOAD_ATTEMPTED: AtomicBool = AtomicBool::new(false);
43#[cfg(feature = "semantic-ort")]
44static TOKENIZER_DOWNLOAD_ATTEMPTED: AtomicBool = AtomicBool::new(false);
45
46#[cfg(feature = "bundled-model")]
47const BUNDLED_MODEL_BYTES: &[u8] = include_bytes!(concat!(
48    env!("CARGO_MANIFEST_DIR"),
49    "/models/minilm-l6-v2-int8.onnx"
50));
51
52/// Wrapper around an ONNX Runtime session for the embedding model.
53pub struct SemanticModel {
54    #[cfg(feature = "semantic-ort")]
55    session: Mutex<Session>,
56    #[cfg(feature = "semantic-ort")]
57    tokenizer: Tokenizer,
58}
59
60#[cfg(feature = "semantic-ort")]
61struct EncodedText {
62    input_ids: Vec<i64>,
63    attention_mask: Vec<i64>,
64}
65
66#[cfg(feature = "semantic-ort")]
67enum InputSource {
68    InputIds,
69    AttentionMask,
70    TokenTypeIds,
71}
72
73impl SemanticModel {
74    /// Load the model from the OS cache directory.
75    ///
76    /// If the cache file is missing or has a mismatched SHA256, this extracts
77    /// bundled model bytes into the cache directory before creating an ORT
78    /// session. When no bundled bytes are available, this attempts a one-time
79    /// download of model/tokenizer assets from stable URLs.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if the model cannot be loaded or the cache is invalid.
84    pub fn load() -> Result<Self> {
85        let path = Self::model_cache_path()?;
86        Self::ensure_model_cached(&path)?;
87
88        #[cfg(feature = "semantic-ort")]
89        {
90            let tokenizer_path = Self::tokenizer_cache_path()?;
91            Self::ensure_tokenizer_cached(&tokenizer_path)?;
92
93            let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
94                anyhow!(
95                    "failed to load semantic tokenizer from {}: {e}",
96                    tokenizer_path.display()
97                )
98            })?;
99
100            let session = Session::builder()
101                .context("failed to create ONNX Runtime session builder")?
102                .commit_from_file(&path)
103                .with_context(|| {
104                    format!("failed to load semantic model from {}", path.display())
105                })?;
106
107            Ok(Self {
108                session: Mutex::new(session),
109                tokenizer,
110            })
111        }
112
113        #[cfg(not(feature = "semantic-ort"))]
114        {
115            let _ = path;
116            bail!("semantic runtime unavailable: compile bones-search with `semantic-ort`");
117        }
118    }
119
120    /// Return the OS-appropriate cache path for the model file.
121    ///
122    /// Uses `dirs::cache_dir() / bones / models / minilm-l6-v2-int8.onnx`.
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if the OS cache directory cannot be determined.
127    pub fn model_cache_path() -> Result<PathBuf> {
128        Ok(Self::model_cache_root()?.join(MODEL_FILENAME))
129    }
130
131    #[cfg(feature = "semantic-ort")]
132    fn tokenizer_cache_path() -> Result<PathBuf> {
133        Ok(Self::model_cache_root()?.join(TOKENIZER_FILENAME))
134    }
135
136    fn model_cache_root() -> Result<PathBuf> {
137        let mut path = dirs::cache_dir().context("unable to determine OS cache directory")?;
138        path.push("bones");
139        path.push("models");
140        Ok(path)
141    }
142
143    /// Check if cached model matches expected SHA256.
144    #[must_use]
145    pub fn is_cached_valid(path: &Path) -> bool {
146        let expected_sha256 = expected_model_sha256();
147        if expected_sha256.is_none() {
148            return path.is_file();
149        }
150
151        let Ok(contents) = fs::read(path) else {
152            return false;
153        };
154
155        expected_sha256.is_some_and(|sha256| sha256_hex(&contents) == sha256)
156    }
157
158    /// Extract bundled model bytes to cache directory.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if no bundled bytes are available, or if the file
163    /// system operations fail, or if the extracted model fails SHA256 verification.
164    pub fn extract_to_cache(path: &Path) -> Result<()> {
165        let bundled = bundled_model_bytes().ok_or_else(|| {
166            anyhow!(
167                "semantic model bytes are not bundled; enable `bundled-model` with a packaged ONNX file"
168            )
169        })?;
170
171        let parent = path.parent().with_context(|| {
172            format!(
173                "model cache path '{}' has no parent directory",
174                path.display()
175            )
176        })?;
177        fs::create_dir_all(parent).with_context(|| {
178            format!(
179                "failed to create semantic model cache directory {}",
180                parent.display()
181            )
182        })?;
183
184        let temp_path = parent.join(format!("{MODEL_FILENAME}.tmp"));
185        fs::write(&temp_path, bundled)
186            .with_context(|| format!("failed to write bundled model to {}", temp_path.display()))?;
187
188        if path.exists() {
189            fs::remove_file(path)
190                .with_context(|| format!("failed to replace existing model {}", path.display()))?;
191        }
192
193        fs::rename(&temp_path, path).with_context(|| {
194            format!(
195                "failed to move extracted model from {} to {}",
196                temp_path.display(),
197                path.display()
198            )
199        })?;
200
201        if !Self::is_cached_valid(path) {
202            bail!(
203                "extracted semantic model at {} failed SHA256 verification",
204                path.display()
205            );
206        }
207
208        Ok(())
209    }
210
211    fn ensure_model_cached(path: &Path) -> Result<()> {
212        if Self::is_cached_valid(path) {
213            return Ok(());
214        }
215
216        if bundled_model_bytes().is_some() {
217            Self::extract_to_cache(path)?;
218            return Ok(());
219        }
220
221        #[cfg(feature = "semantic-ort")]
222        {
223            if !auto_download_enabled() {
224                bail!(
225                    "semantic model not found at {}. Automatic download is disabled via {AUTO_DOWNLOAD_ENV}=0",
226                    path.display()
227                );
228            }
229
230            if MODEL_DOWNLOAD_ATTEMPTED.swap(true, Ordering::SeqCst) {
231                bail!(
232                    "semantic model not found at {} and auto-download was already attempted in this process",
233                    path.display()
234                );
235            }
236
237            download_to_path(&model_download_url(), path, "semantic model")
238                .with_context(|| format!("failed to fetch semantic model to {}", path.display()))?;
239
240            if !Self::is_cached_valid(path) {
241                bail!(
242                    "downloaded semantic model at {} failed validation",
243                    path.display()
244                );
245            }
246
247            Ok(())
248        }
249
250        #[cfg(not(feature = "semantic-ort"))]
251        {
252            bail!(
253                "semantic model not found at {}; enable `bundled-model` or place `{MODEL_FILENAME}` in the cache path",
254                path.display()
255            );
256        }
257    }
258
259    #[cfg(feature = "semantic-ort")]
260    fn ensure_tokenizer_cached(path: &Path) -> Result<()> {
261        if path.is_file() {
262            return Ok(());
263        }
264
265        if !auto_download_enabled() {
266            bail!(
267                "semantic tokenizer not found at {}. Automatic download is disabled via {AUTO_DOWNLOAD_ENV}=0",
268                path.display()
269            );
270        }
271
272        if TOKENIZER_DOWNLOAD_ATTEMPTED.swap(true, Ordering::SeqCst) {
273            bail!(
274                "semantic tokenizer not found at {} and auto-download was already attempted in this process",
275                path.display()
276            );
277        }
278
279        download_to_path(&tokenizer_download_url(), path, "semantic tokenizer")
280            .with_context(|| format!("failed to fetch semantic tokenizer to {}", path.display()))?;
281
282        if !path.is_file() {
283            bail!("semantic tokenizer download completed but file was not created");
284        }
285
286        Ok(())
287    }
288
289    /// Run inference for a single text input.
290    ///
291    /// # Errors
292    ///
293    /// Returns an error if the runtime is unavailable or inference fails.
294    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
295        #[cfg(feature = "semantic-ort")]
296        {
297            let encoded = self.encode_text(text)?;
298            let mut out = self.run_model_batch(&[encoded])?;
299            out.pop()
300                .ok_or_else(|| anyhow!("semantic model returned no embedding"))
301        }
302
303        #[cfg(not(feature = "semantic-ort"))]
304        {
305            let _ = self;
306            let _ = text;
307            bail!("semantic runtime unavailable: compile bones-search with `semantic-ort`");
308        }
309    }
310
311    /// Batch inference for efficiency.
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if the runtime is unavailable or batch inference fails.
316    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
317        #[cfg(feature = "semantic-ort")]
318        {
319            let encoded: Vec<EncodedText> = texts
320                .iter()
321                .map(|text| self.encode_text(text))
322                .collect::<Result<Vec<_>>>()?;
323            self.run_model_batch(&encoded)
324        }
325
326        #[cfg(not(feature = "semantic-ort"))]
327        {
328            let _ = self;
329            let _ = texts;
330            bail!("semantic runtime unavailable: compile bones-search with `semantic-ort`");
331        }
332    }
333
334    #[cfg(feature = "semantic-ort")]
335    fn encode_text(&self, text: &str) -> Result<EncodedText> {
336        let encoding = self
337            .tokenizer
338            .encode(text, true)
339            .map_err(|e| anyhow!("failed to tokenize semantic query: {e}"))?;
340
341        let ids = encoding.get_ids();
342        if ids.is_empty() {
343            bail!("semantic tokenizer produced zero tokens");
344        }
345
346        let attention = encoding.get_attention_mask();
347        let keep = ids.len().min(MAX_TOKENS);
348
349        let mut input_ids = Vec::with_capacity(keep);
350        let mut attention_mask = Vec::with_capacity(keep);
351        for (idx, id) in ids.iter().enumerate().take(keep) {
352            input_ids.push(i64::from(*id));
353            attention_mask.push(i64::from(*attention.get(idx).unwrap_or(&1_u32)));
354        }
355        if attention_mask.iter().all(|v| *v == 0) {
356            attention_mask.fill(1);
357        }
358
359        Ok(EncodedText {
360            input_ids,
361            attention_mask,
362        })
363    }
364
365    #[cfg(feature = "semantic-ort")]
366    #[allow(clippy::significant_drop_tightening, clippy::cast_precision_loss)]
367    fn run_model_batch(&self, encoded: &[EncodedText]) -> Result<Vec<Vec<f32>>> {
368        if encoded.is_empty() {
369            return Ok(Vec::new());
370        }
371
372        let batch = encoded.len();
373        let seq_len = encoded.iter().map(|e| e.input_ids.len()).max().unwrap_or(0);
374        if seq_len == 0 {
375            bail!("semantic batch has no tokens");
376        }
377
378        let mut flat_ids = vec![0_i64; batch * seq_len];
379        let mut flat_attention = vec![0_i64; batch * seq_len];
380        for (row_idx, row) in encoded.iter().enumerate() {
381            let row_base = row_idx * seq_len;
382            flat_ids[row_base..(row.input_ids.len() + row_base)].copy_from_slice(&row.input_ids);
383            flat_attention[row_base..(row.attention_mask.len() + row_base)]
384                .copy_from_slice(&row.attention_mask);
385        }
386        let flat_token_types = vec![0_i64; batch * seq_len];
387
388        let mut session = self
389            .session
390            .lock()
391            .map_err(|_| anyhow!("semantic model session mutex poisoned"))?;
392
393        let model_inputs = session.inputs();
394        let mut inputs: Vec<(String, Tensor<i64>)> = Vec::with_capacity(model_inputs.len());
395        for (index, input) in model_inputs.iter().enumerate() {
396            let input_name = input.name();
397            let source = input_source(index, input_name);
398            let data = match source {
399                InputSource::InputIds => flat_ids.clone(),
400                InputSource::AttentionMask => flat_attention.clone(),
401                InputSource::TokenTypeIds => flat_token_types.clone(),
402            };
403            let tensor = Tensor::<i64>::from_array(([batch, seq_len], data.into_boxed_slice()))
404                .with_context(|| format!("failed to build ONNX input tensor '{input_name}'"))?;
405            inputs.push((input_name.to_string(), tensor));
406        }
407
408        let outputs = session
409            .run(inputs)
410            .context("failed to run ONNX semantic inference")?;
411
412        if outputs.len() == 0 {
413            bail!("semantic model returned no outputs");
414        }
415
416        let output = outputs
417            .get("sentence_embedding")
418            .or_else(|| outputs.get("last_hidden_state"))
419            .or_else(|| outputs.get("token_embeddings"))
420            .unwrap_or(&outputs[0]);
421
422        let (shape, data) = output.try_extract_tensor::<f32>().with_context(
423            || "semantic model output tensor is not f32; expected sentence embedding tensor",
424        )?;
425
426        decode_embeddings(shape, data, &flat_attention, batch, seq_len)
427    }
428}
429
430#[cfg(feature = "semantic-ort")]
431fn input_source(index: usize, input_name: &str) -> InputSource {
432    let name = input_name.to_ascii_lowercase();
433    if name.contains("attention") {
434        return InputSource::AttentionMask;
435    }
436    if name.contains("token_type") || name.contains("segment") {
437        return InputSource::TokenTypeIds;
438    }
439    if name.contains("input_ids") || (name.contains("input") && name.contains("id")) {
440        return InputSource::InputIds;
441    }
442
443    match index {
444        0 => InputSource::InputIds,
445        1 => InputSource::AttentionMask,
446        _ => InputSource::TokenTypeIds,
447    }
448}
449
450#[cfg(feature = "semantic-ort")]
451#[allow(clippy::cast_precision_loss)]
452fn decode_embeddings(
453    shape: &[i64],
454    data: &[f32],
455    flat_attention: &[i64],
456    batch: usize,
457    seq_len: usize,
458) -> Result<Vec<Vec<f32>>> {
459    match shape.len() {
460        // [batch, hidden]
461        2 => {
462            let out_batch = usize::try_from(shape[0]).unwrap_or(0);
463            let hidden = usize::try_from(shape[1]).unwrap_or(0);
464            if out_batch == 0 || hidden == 0 {
465                bail!("invalid sentence embedding output shape {shape:?}");
466            }
467            if out_batch != batch {
468                bail!("semantic output batch mismatch: expected {batch}, got {out_batch}");
469            }
470
471            let mut out = Vec::with_capacity(out_batch);
472            for row in 0..out_batch {
473                let start = row * hidden;
474                let end = start + hidden;
475                let mut emb = data[start..end].to_vec();
476                normalize_l2(&mut emb);
477                out.push(emb);
478            }
479            Ok(out)
480        }
481
482        // [batch, tokens, hidden] -> mean pool with attention mask.
483        3 => {
484            let out_batch = usize::try_from(shape[0]).unwrap_or(0);
485            let out_tokens = usize::try_from(shape[1]).unwrap_or(0);
486            let hidden = usize::try_from(shape[2]).unwrap_or(0);
487            if out_batch == 0 || out_tokens == 0 || hidden == 0 {
488                bail!("invalid token embedding output shape {shape:?}");
489            }
490            if out_batch != batch {
491                bail!("semantic output batch mismatch: expected {batch}, got {out_batch}");
492            }
493
494            let mut out = Vec::with_capacity(out_batch);
495            for b in 0..out_batch {
496                let mut emb = vec![0.0_f32; hidden];
497                let mut weight_sum = 0.0_f32;
498
499                for t in 0..out_tokens {
500                    let mask_weight = if t < seq_len {
501                        flat_attention[b * seq_len + t] as f32
502                    } else {
503                        0.0
504                    };
505                    if mask_weight <= 0.0 {
506                        continue;
507                    }
508
509                    let token_base = (b * out_tokens + t) * hidden;
510                    for h in 0..hidden {
511                        emb[h] += data[token_base + h] * mask_weight;
512                    }
513                    weight_sum += mask_weight;
514                }
515
516                if weight_sum > 0.0 {
517                    for value in &mut emb {
518                        *value /= weight_sum;
519                    }
520                }
521                normalize_l2(&mut emb);
522                out.push(emb);
523            }
524            Ok(out)
525        }
526
527        // [hidden] (single-row fallback)
528        1 => {
529            if batch != 1 {
530                bail!("rank-1 semantic output only supported for single-row batch");
531            }
532            let hidden = usize::try_from(shape[0]).unwrap_or(0);
533            if hidden == 0 {
534                bail!("invalid rank-1 semantic output shape {shape:?}");
535            }
536            let mut emb = data[0..hidden].to_vec();
537            normalize_l2(&mut emb);
538            Ok(vec![emb])
539        }
540
541        rank => bail!("unsupported semantic output rank {rank}: shape {shape:?}"),
542    }
543}
544
545#[cfg(feature = "semantic-ort")]
546fn normalize_l2(values: &mut [f32]) {
547    let mut norm_sq = 0.0_f32;
548    for value in values.iter() {
549        norm_sq += value * value;
550    }
551
552    if norm_sq == 0.0 {
553        return;
554    }
555
556    let norm = norm_sq.sqrt();
557    for value in values {
558        *value /= norm;
559    }
560}
561
562/// Check if semantic search is currently available.
563#[must_use]
564pub fn is_semantic_available() -> bool {
565    SemanticModel::load().is_ok()
566}
567
568const fn bundled_model_bytes() -> Option<&'static [u8]> {
569    #[cfg(feature = "bundled-model")]
570    {
571        if BUNDLED_MODEL_BYTES.is_empty() {
572            return None;
573        }
574
575        return Some(BUNDLED_MODEL_BYTES);
576    }
577
578    #[cfg(not(feature = "bundled-model"))]
579    {
580        None
581    }
582}
583
584fn expected_model_sha256() -> Option<String> {
585    bundled_model_bytes().map(sha256_hex)
586}
587
588fn sha256_hex(bytes: &[u8]) -> String {
589    let mut hasher = Sha256::new();
590    hasher.update(bytes);
591    format!("{:x}", hasher.finalize())
592}
593
594#[cfg(feature = "semantic-ort")]
595fn auto_download_enabled() -> bool {
596    std::env::var(AUTO_DOWNLOAD_ENV).ok().is_none_or(|raw| {
597        !matches!(
598            raw.trim().to_ascii_lowercase().as_str(),
599            "0" | "false" | "no" | "off"
600        )
601    })
602}
603
604#[cfg(feature = "semantic-ort")]
605fn model_download_url() -> String {
606    std::env::var(MODEL_DOWNLOAD_URL_ENV)
607        .ok()
608        .filter(|value| !value.trim().is_empty())
609        .unwrap_or_else(|| MODEL_DOWNLOAD_URL_DEFAULT.to_string())
610}
611
612#[cfg(feature = "semantic-ort")]
613fn tokenizer_download_url() -> String {
614    std::env::var(TOKENIZER_DOWNLOAD_URL_ENV)
615        .ok()
616        .filter(|value| !value.trim().is_empty())
617        .unwrap_or_else(|| TOKENIZER_DOWNLOAD_URL_DEFAULT.to_string())
618}
619
620#[cfg(feature = "semantic-ort")]
621fn download_to_path(url: &str, path: &Path, artifact_label: &str) -> Result<()> {
622    let parent = path.parent().with_context(|| {
623        format!(
624            "{artifact_label} cache path '{}' has no parent directory",
625            path.display()
626        )
627    })?;
628    fs::create_dir_all(parent).with_context(|| {
629        format!(
630            "failed to create {} cache directory {}",
631            artifact_label,
632            parent.display()
633        )
634    })?;
635
636    let temp_path = parent.join(format!(
637        "{}.download",
638        path.file_name().unwrap_or_default().to_string_lossy()
639    ));
640
641    let agent = ureq::AgentBuilder::new()
642        .timeout_connect(Duration::from_secs(DOWNLOAD_CONNECT_TIMEOUT_SECS))
643        .timeout_read(Duration::from_secs(DOWNLOAD_READ_TIMEOUT_SECS))
644        .build();
645
646    let response = match agent
647        .get(url)
648        .set("User-Agent", "bones-search/semantic-downloader")
649        .call()
650    {
651        Ok(resp) => resp,
652        Err(ureq::Error::Status(code, _)) => {
653            bail!("{artifact_label} download failed: HTTP {code} from {url}")
654        }
655        Err(ureq::Error::Transport(err)) => {
656            bail!("{artifact_label} download failed from {url}: {err}")
657        }
658    };
659
660    {
661        let mut reader = response.into_reader();
662        let mut out = fs::File::create(&temp_path)
663            .with_context(|| format!("failed to create temporary file {}", temp_path.display()))?;
664        std::io::copy(&mut reader, &mut out)
665            .with_context(|| format!("failed to write {artifact_label} download"))?;
666        out.flush()
667            .with_context(|| format!("failed to flush {artifact_label} download"))?;
668    }
669
670    if path.exists() {
671        fs::remove_file(path).with_context(|| {
672            format!(
673                "failed to replace existing {} at {}",
674                artifact_label,
675                path.display()
676            )
677        })?;
678    }
679
680    fs::rename(&temp_path, path).with_context(|| {
681        format!(
682            "failed to move downloaded {} from {} to {}",
683            artifact_label,
684            temp_path.display(),
685            path.display()
686        )
687    })?;
688
689    Ok(())
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695    use std::path::Path;
696
697    #[test]
698    fn cache_path_uses_expected_suffix() {
699        let path = SemanticModel::model_cache_path().expect("cache path should resolve");
700        let expected = Path::new("bones")
701            .join("models")
702            .join("minilm-l6-v2-int8.onnx");
703        assert!(path.ends_with(expected));
704    }
705
706    #[test]
707    fn sha256_hex_matches_known_vector() {
708        let digest = sha256_hex(b"abc");
709        assert_eq!(
710            digest,
711            "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
712        );
713    }
714
715    #[cfg(not(feature = "bundled-model"))]
716    #[test]
717    fn cached_model_is_accepted_when_not_bundled() {
718        let tmp = tempfile::tempdir().expect("tempdir must be created");
719        let model = tmp.path().join("minilm-l6-v2-int8.onnx");
720        fs::write(&model, b"anything").expect("test file should be writable");
721
722        assert!(SemanticModel::is_cached_valid(&model));
723    }
724
725    #[cfg(not(feature = "bundled-model"))]
726    #[test]
727    fn extract_to_cache_fails_without_bundled_model() {
728        let tmp = tempfile::tempdir().expect("tempdir must be created");
729        let model = tmp.path().join("minilm-l6-v2-int8.onnx");
730
731        let err =
732            SemanticModel::extract_to_cache(&model).expect_err("should fail without bundled model");
733        assert!(err.to_string().contains("not bundled"));
734    }
735
736    #[cfg(not(feature = "semantic-ort"))]
737    #[test]
738    fn semantic_is_reported_unavailable_without_runtime_feature() {
739        assert!(!is_semantic_available());
740    }
741
742    #[cfg(feature = "semantic-ort")]
743    #[test]
744    fn normalize_l2_produces_unit_norm() {
745        let mut emb = vec![3.0_f32, 4.0_f32, 0.0_f32];
746        normalize_l2(&mut emb);
747        let norm = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
748        assert!((norm - 1.0).abs() < 1e-6);
749    }
750
751    #[cfg(feature = "semantic-ort")]
752    #[test]
753    fn input_source_prefers_named_fields() {
754        assert!(matches!(
755            input_source(5, "attention_mask"),
756            InputSource::AttentionMask
757        ));
758        assert!(matches!(
759            input_source(5, "token_type_ids"),
760            InputSource::TokenTypeIds
761        ));
762        assert!(matches!(
763            input_source(5, "input_ids"),
764            InputSource::InputIds
765        ));
766    }
767}