Skip to main content

bones_search/semantic/
model.rs

1use anyhow::{Context, Result, anyhow, bail};
2use sha2::{Digest, Sha256};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6#[cfg(feature = "semantic-ort")]
7use ort::{session::Session, value::Tensor};
8#[cfg(feature = "semantic-ort")]
9use std::io::Write;
10#[cfg(feature = "semantic-ort")]
11use std::sync::Mutex;
12#[cfg(feature = "semantic-ort")]
13use std::sync::atomic::{AtomicBool, Ordering};
14#[cfg(feature = "semantic-ort")]
15use std::time::Duration;
16#[cfg(feature = "semantic-ort")]
17use tokenizers::Tokenizer;
18
19const MODEL_FILENAME: &str = "minilm-l6-v2-int8.onnx";
20#[cfg(feature = "semantic-ort")]
21const TOKENIZER_FILENAME: &str = "minilm-l6-v2-tokenizer.json";
22#[cfg(feature = "semantic-ort")]
23const MAX_TOKENS: usize = 256;
24#[cfg(feature = "semantic-ort")]
25const MODEL_DOWNLOAD_URL_ENV: &str = "BONES_SEMANTIC_MODEL_URL";
26#[cfg(feature = "semantic-ort")]
27const TOKENIZER_DOWNLOAD_URL_ENV: &str = "BONES_SEMANTIC_TOKENIZER_URL";
28#[cfg(feature = "semantic-ort")]
29const AUTO_DOWNLOAD_ENV: &str = "BONES_SEMANTIC_AUTO_DOWNLOAD";
30#[cfg(feature = "semantic-ort")]
31const MODEL_DOWNLOAD_URL_DEFAULT: &str =
32    "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
33#[cfg(feature = "semantic-ort")]
34const TOKENIZER_DOWNLOAD_URL_DEFAULT: &str =
35    "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
36#[cfg(feature = "semantic-ort")]
37const DOWNLOAD_CONNECT_TIMEOUT_SECS: u64 = 2;
38#[cfg(feature = "semantic-ort")]
39const DOWNLOAD_READ_TIMEOUT_SECS: u64 = 30;
40
41#[cfg(feature = "semantic-ort")]
42static MODEL_DOWNLOAD_ATTEMPTED: AtomicBool = AtomicBool::new(false);
43#[cfg(feature = "semantic-ort")]
44static TOKENIZER_DOWNLOAD_ATTEMPTED: AtomicBool = AtomicBool::new(false);
45
46#[cfg(feature = "bundled-model")]
47const BUNDLED_MODEL_BYTES: &[u8] = include_bytes!(concat!(
48    env!("CARGO_MANIFEST_DIR"),
49    "/models/minilm-l6-v2-int8.onnx"
50));
51
52use super::hash_embed::HashEmbedBackend;
53#[cfg(feature = "semantic-model2vec")]
54use super::model2vec::Model2VecBackend;
55
56/// Wrapper around an embedding model backend.
57///
58/// Supports multiple backends selected at compile time via feature flags.
59/// When multiple backends are available, ORT is preferred (higher quality);
60/// model2vec is used as a fallback.
61pub struct SemanticModel {
62    inner: BackendInner,
63}
64
65#[allow(clippy::large_enum_variant)]
66enum BackendInner {
67    #[cfg(feature = "semantic-ort")]
68    Ort {
69        session: Mutex<Session>,
70        tokenizer: Tokenizer,
71    },
72    #[cfg(feature = "semantic-model2vec")]
73    Model2Vec(Model2VecBackend),
74    /// Zero-dependency hash-based embeddings.  Always available as an
75    /// ultimate fallback when no ML backend is compiled in.
76    Hash(HashEmbedBackend),
77}
78
79#[cfg(feature = "semantic-ort")]
80struct EncodedText {
81    input_ids: Vec<i64>,
82    attention_mask: Vec<i64>,
83}
84
85#[cfg(feature = "semantic-ort")]
86enum InputSource {
87    InputIds,
88    AttentionMask,
89    TokenTypeIds,
90}
91
92impl SemanticModel {
93    /// Load the model from the OS cache directory.
94    ///
95    /// Tries backends in priority order: ORT (highest quality), then model2vec
96    /// (no ONNX dependency, Windows-friendly).
97    ///
98    /// # Errors
99    ///
100    /// Returns an error if no backend is available or loading fails.
101    pub fn load() -> Result<Self> {
102        // Try ORT first (higher quality embeddings).
103        #[cfg(feature = "semantic-ort")]
104        {
105            match Self::load_ort() {
106                Ok(model) => return Ok(model),
107                Err(err) => {
108                    tracing::debug!("ORT backend unavailable, trying next: {err:#}");
109                }
110            }
111        }
112
113        // Try model2vec (no ONNX runtime needed).
114        #[cfg(feature = "semantic-model2vec")]
115        {
116            match Model2VecBackend::load() {
117                Ok(backend) => {
118                    return Ok(Self {
119                        inner: BackendInner::Model2Vec(backend),
120                    });
121                }
122                Err(err) => {
123                    tracing::debug!("model2vec backend unavailable: {err:#}");
124                }
125            }
126        }
127
128        // Hash embedder is always available as a zero-dependency fallback.
129        tracing::debug!("using hash embedder (no ML backend available)");
130        Ok(Self {
131            inner: BackendInner::Hash(HashEmbedBackend::new()),
132        })
133    }
134
135    #[cfg(feature = "semantic-ort")]
136    fn load_ort() -> Result<Self> {
137        let path = Self::ort_model_cache_path()?;
138        Self::ensure_model_cached(&path)?;
139
140        let tokenizer_path = Self::tokenizer_cache_path()?;
141        Self::ensure_tokenizer_cached(&tokenizer_path)?;
142
143        let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
144            anyhow!(
145                "failed to load semantic tokenizer from {}: {e}",
146                tokenizer_path.display()
147            )
148        })?;
149
150        let session = Session::builder()
151            .context("failed to create ONNX Runtime session builder")?
152            .commit_from_file(&path)
153            .with_context(|| format!("failed to load semantic model from {}", path.display()))?;
154
155        Ok(Self {
156            inner: BackendInner::Ort {
157                session: Mutex::new(session),
158                tokenizer,
159            },
160        })
161    }
162
163    /// Return the OS-appropriate cache path for the ORT model file.
164    ///
165    /// Uses `dirs::cache_dir() / bones / models / minilm-l6-v2-int8.onnx`.
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if the OS cache directory cannot be determined.
170    pub fn model_cache_path() -> Result<PathBuf> {
171        Ok(Self::model_cache_root()?.join(MODEL_FILENAME))
172    }
173
174    #[cfg(feature = "semantic-ort")]
175    fn ort_model_cache_path() -> Result<PathBuf> {
176        Ok(Self::model_cache_root()?.join(MODEL_FILENAME))
177    }
178
179    #[cfg(feature = "semantic-ort")]
180    fn tokenizer_cache_path() -> Result<PathBuf> {
181        Ok(Self::model_cache_root()?.join(TOKENIZER_FILENAME))
182    }
183
184    /// Return the OS-appropriate cache root for model files.
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if the OS cache directory cannot be determined.
189    pub fn model_cache_root() -> Result<PathBuf> {
190        let mut path = dirs::cache_dir().context("unable to determine OS cache directory")?;
191        path.push("bones");
192        path.push("models");
193        Ok(path)
194    }
195
196    /// Check if cached model matches expected SHA256.
197    #[must_use]
198    pub fn is_cached_valid(path: &Path) -> bool {
199        let expected_sha256 = expected_model_sha256();
200        if expected_sha256.is_none() {
201            return path.is_file();
202        }
203
204        let Ok(contents) = fs::read(path) else {
205            return false;
206        };
207
208        expected_sha256.is_some_and(|sha256| sha256_hex(&contents) == sha256)
209    }
210
211    /// Extract bundled model bytes to cache directory.
212    ///
213    /// # Errors
214    ///
215    /// Returns an error if no bundled bytes are available, or if the file
216    /// system operations fail, or if the extracted model fails SHA256 verification.
217    pub fn extract_to_cache(path: &Path) -> Result<()> {
218        let bundled = bundled_model_bytes().ok_or_else(|| {
219            anyhow!(
220                "semantic model bytes are not bundled; enable `bundled-model` with a packaged ONNX file"
221            )
222        })?;
223
224        let parent = path.parent().with_context(|| {
225            format!(
226                "model cache path '{}' has no parent directory",
227                path.display()
228            )
229        })?;
230        fs::create_dir_all(parent).with_context(|| {
231            format!(
232                "failed to create semantic model cache directory {}",
233                parent.display()
234            )
235        })?;
236
237        let temp_path = parent.join(format!("{MODEL_FILENAME}.tmp"));
238        fs::write(&temp_path, bundled)
239            .with_context(|| format!("failed to write bundled model to {}", temp_path.display()))?;
240
241        if path.exists() {
242            fs::remove_file(path)
243                .with_context(|| format!("failed to replace existing model {}", path.display()))?;
244        }
245
246        fs::rename(&temp_path, path).with_context(|| {
247            format!(
248                "failed to move extracted model from {} to {}",
249                temp_path.display(),
250                path.display()
251            )
252        })?;
253
254        if !Self::is_cached_valid(path) {
255            bail!(
256                "extracted semantic model at {} failed SHA256 verification",
257                path.display()
258            );
259        }
260
261        Ok(())
262    }
263
264    #[cfg(any(feature = "semantic-ort", feature = "bundled-model"))]
265    fn ensure_model_cached(path: &Path) -> Result<()> {
266        if Self::is_cached_valid(path) {
267            return Ok(());
268        }
269
270        if bundled_model_bytes().is_some() {
271            Self::extract_to_cache(path)?;
272            return Ok(());
273        }
274
275        #[cfg(feature = "semantic-ort")]
276        {
277            if !auto_download_enabled() {
278                bail!(
279                    "semantic model not found at {}. Automatic download is disabled via {AUTO_DOWNLOAD_ENV}=0",
280                    path.display()
281                );
282            }
283
284            if MODEL_DOWNLOAD_ATTEMPTED.swap(true, Ordering::SeqCst) {
285                bail!(
286                    "semantic model not found at {} and auto-download was already attempted in this process",
287                    path.display()
288                );
289            }
290
291            download_to_path(&model_download_url(), path, "semantic model")
292                .with_context(|| format!("failed to fetch semantic model to {}", path.display()))?;
293
294            if !Self::is_cached_valid(path) {
295                bail!(
296                    "downloaded semantic model at {} failed validation",
297                    path.display()
298                );
299            }
300
301            Ok(())
302        }
303
304        #[cfg(not(feature = "semantic-ort"))]
305        {
306            bail!(
307                "semantic model not found at {}; enable `bundled-model` or place `{MODEL_FILENAME}` in the cache path",
308                path.display()
309            );
310        }
311    }
312
313    #[cfg(feature = "semantic-ort")]
314    fn ensure_tokenizer_cached(path: &Path) -> Result<()> {
315        if path.is_file() {
316            return Ok(());
317        }
318
319        if !auto_download_enabled() {
320            bail!(
321                "semantic tokenizer not found at {}. Automatic download is disabled via {AUTO_DOWNLOAD_ENV}=0",
322                path.display()
323            );
324        }
325
326        if TOKENIZER_DOWNLOAD_ATTEMPTED.swap(true, Ordering::SeqCst) {
327            bail!(
328                "semantic tokenizer not found at {} and auto-download was already attempted in this process",
329                path.display()
330            );
331        }
332
333        download_to_path(&tokenizer_download_url(), path, "semantic tokenizer")
334            .with_context(|| format!("failed to fetch semantic tokenizer to {}", path.display()))?;
335
336        if !path.is_file() {
337            bail!("semantic tokenizer download completed but file was not created");
338        }
339
340        Ok(())
341    }
342
343    /// The dimensionality of embedding vectors this model produces.
344    #[must_use]
345    #[allow(clippy::missing_const_for_fn)]
346    pub fn dimensions(&self) -> usize {
347        match &self.inner {
348            #[cfg(feature = "semantic-ort")]
349            BackendInner::Ort { .. } => 384, // MiniLM-L6-v2
350            #[cfg(feature = "semantic-model2vec")]
351            BackendInner::Model2Vec(m) => m.dimensions(),
352            BackendInner::Hash(h) => h.dimensions(),
353        }
354    }
355
356    /// A stable identifier for the active backend, used to detect backend
357    /// switches that require re-embedding stored vectors.
358    #[must_use]
359    #[allow(clippy::missing_const_for_fn)]
360    pub fn backend_id(&self) -> &'static str {
361        match &self.inner {
362            #[cfg(feature = "semantic-ort")]
363            BackendInner::Ort { .. } => "ort-minilm-384",
364            #[cfg(feature = "semantic-model2vec")]
365            BackendInner::Model2Vec(_) => "model2vec-potion-8m",
366            BackendInner::Hash(_) => "hash-ngram-256",
367        }
368    }
369
370    /// Run inference for a single text input.
371    ///
372    /// # Errors
373    ///
374    /// Returns an error if the runtime is unavailable or inference fails.
375    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
376        match &self.inner {
377            #[cfg(feature = "semantic-ort")]
378            BackendInner::Ort { .. } => {
379                let encoded = self.encode_text(text)?;
380                let mut out = self.run_model_batch(&[encoded])?;
381                out.pop()
382                    .ok_or_else(|| anyhow!("semantic model returned no embedding"))
383            }
384            #[cfg(feature = "semantic-model2vec")]
385            BackendInner::Model2Vec(m) => m.embed(text),
386            BackendInner::Hash(h) => h.embed(text),
387        }
388    }
389
390    /// Batch inference for efficiency.
391    ///
392    /// # Errors
393    ///
394    /// Returns an error if the runtime is unavailable or batch inference fails.
395    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
396        match &self.inner {
397            #[cfg(feature = "semantic-ort")]
398            BackendInner::Ort { .. } => {
399                let encoded: Vec<EncodedText> = texts
400                    .iter()
401                    .map(|text| self.encode_text(text))
402                    .collect::<Result<Vec<_>>>()?;
403                self.run_model_batch(&encoded)
404            }
405            #[cfg(feature = "semantic-model2vec")]
406            BackendInner::Model2Vec(m) => m.embed_batch(texts),
407            BackendInner::Hash(h) => h.embed_batch(texts),
408        }
409    }
410
411    #[cfg(feature = "semantic-ort")]
412    fn ort_tokenizer(&self) -> &Tokenizer {
413        match &self.inner {
414            BackendInner::Ort { tokenizer, .. } => tokenizer,
415            #[cfg(feature = "semantic-model2vec")]
416            BackendInner::Model2Vec(_) | BackendInner::Hash(_) => {
417                unreachable!("encode_text called on non-ORT backend")
418            }
419            #[cfg(not(feature = "semantic-model2vec"))]
420            BackendInner::Hash(_) => unreachable!("encode_text called on non-ORT backend"),
421        }
422    }
423
424    #[cfg(feature = "semantic-ort")]
425    fn ort_session(&self) -> &Mutex<Session> {
426        match &self.inner {
427            BackendInner::Ort { session, .. } => session,
428            #[cfg(feature = "semantic-model2vec")]
429            BackendInner::Model2Vec(_) | BackendInner::Hash(_) => {
430                unreachable!("run_model_batch called on non-ORT backend")
431            }
432            #[cfg(not(feature = "semantic-model2vec"))]
433            BackendInner::Hash(_) => unreachable!("run_model_batch called on non-ORT backend"),
434        }
435    }
436
437    #[cfg(feature = "semantic-ort")]
438    fn encode_text(&self, text: &str) -> Result<EncodedText> {
439        let encoding = self
440            .ort_tokenizer()
441            .encode(text, true)
442            .map_err(|e| anyhow!("failed to tokenize semantic query: {e}"))?;
443
444        let ids = encoding.get_ids();
445        if ids.is_empty() {
446            bail!("semantic tokenizer produced zero tokens");
447        }
448
449        let attention = encoding.get_attention_mask();
450        let keep = ids.len().min(MAX_TOKENS);
451
452        let mut input_ids = Vec::with_capacity(keep);
453        let mut attention_mask = Vec::with_capacity(keep);
454        for (idx, id) in ids.iter().enumerate().take(keep) {
455            input_ids.push(i64::from(*id));
456            attention_mask.push(i64::from(*attention.get(idx).unwrap_or(&1_u32)));
457        }
458        if attention_mask.iter().all(|v| *v == 0) {
459            attention_mask.fill(1);
460        }
461
462        Ok(EncodedText {
463            input_ids,
464            attention_mask,
465        })
466    }
467
468    #[cfg(feature = "semantic-ort")]
469    #[allow(clippy::significant_drop_tightening, clippy::cast_precision_loss)]
470    fn run_model_batch(&self, encoded: &[EncodedText]) -> Result<Vec<Vec<f32>>> {
471        if encoded.is_empty() {
472            return Ok(Vec::new());
473        }
474
475        let batch = encoded.len();
476        let seq_len = encoded.iter().map(|e| e.input_ids.len()).max().unwrap_or(0);
477        if seq_len == 0 {
478            bail!("semantic batch has no tokens");
479        }
480
481        let mut flat_ids = vec![0_i64; batch * seq_len];
482        let mut flat_attention = vec![0_i64; batch * seq_len];
483        for (row_idx, row) in encoded.iter().enumerate() {
484            let row_base = row_idx * seq_len;
485            flat_ids[row_base..(row.input_ids.len() + row_base)].copy_from_slice(&row.input_ids);
486            flat_attention[row_base..(row.attention_mask.len() + row_base)]
487                .copy_from_slice(&row.attention_mask);
488        }
489        let flat_token_types = vec![0_i64; batch * seq_len];
490
491        let mut session = self
492            .ort_session()
493            .lock()
494            .map_err(|_| anyhow!("semantic model session mutex poisoned"))?;
495
496        let model_inputs = session.inputs();
497        let mut inputs: Vec<(String, Tensor<i64>)> = Vec::with_capacity(model_inputs.len());
498        for (index, input) in model_inputs.iter().enumerate() {
499            let input_name = input.name();
500            let source = input_source(index, input_name);
501            let data = match source {
502                InputSource::InputIds => flat_ids.clone(),
503                InputSource::AttentionMask => flat_attention.clone(),
504                InputSource::TokenTypeIds => flat_token_types.clone(),
505            };
506            let tensor = Tensor::<i64>::from_array(([batch, seq_len], data.into_boxed_slice()))
507                .with_context(|| format!("failed to build ONNX input tensor '{input_name}'"))?;
508            inputs.push((input_name.to_string(), tensor));
509        }
510
511        let outputs = session
512            .run(inputs)
513            .context("failed to run ONNX semantic inference")?;
514
515        if outputs.len() == 0 {
516            bail!("semantic model returned no outputs");
517        }
518
519        let output = outputs
520            .get("sentence_embedding")
521            .or_else(|| outputs.get("last_hidden_state"))
522            .or_else(|| outputs.get("token_embeddings"))
523            .unwrap_or(&outputs[0]);
524
525        let (shape, data) = output.try_extract_tensor::<f32>().with_context(
526            || "semantic model output tensor is not f32; expected sentence embedding tensor",
527        )?;
528
529        decode_embeddings(shape, data, &flat_attention, batch, seq_len)
530    }
531}
532
533#[cfg(feature = "semantic-ort")]
534fn input_source(index: usize, input_name: &str) -> InputSource {
535    let name = input_name.to_ascii_lowercase();
536    if name.contains("attention") {
537        return InputSource::AttentionMask;
538    }
539    if name.contains("token_type") || name.contains("segment") {
540        return InputSource::TokenTypeIds;
541    }
542    if name.contains("input_ids") || (name.contains("input") && name.contains("id")) {
543        return InputSource::InputIds;
544    }
545
546    match index {
547        0 => InputSource::InputIds,
548        1 => InputSource::AttentionMask,
549        _ => InputSource::TokenTypeIds,
550    }
551}
552
553#[cfg(feature = "semantic-ort")]
554#[allow(clippy::cast_precision_loss)]
555fn decode_embeddings(
556    shape: &[i64],
557    data: &[f32],
558    flat_attention: &[i64],
559    batch: usize,
560    seq_len: usize,
561) -> Result<Vec<Vec<f32>>> {
562    match shape.len() {
563        // [batch, hidden]
564        2 => {
565            let out_batch = usize::try_from(shape[0]).unwrap_or(0);
566            let hidden = usize::try_from(shape[1]).unwrap_or(0);
567            if out_batch == 0 || hidden == 0 {
568                bail!("invalid sentence embedding output shape {shape:?}");
569            }
570            if out_batch != batch {
571                bail!("semantic output batch mismatch: expected {batch}, got {out_batch}");
572            }
573
574            let mut out = Vec::with_capacity(out_batch);
575            for row in 0..out_batch {
576                let start = row * hidden;
577                let end = start + hidden;
578                let mut emb = data[start..end].to_vec();
579                normalize_l2(&mut emb);
580                out.push(emb);
581            }
582            Ok(out)
583        }
584
585        // [batch, tokens, hidden] -> mean pool with attention mask.
586        3 => {
587            let out_batch = usize::try_from(shape[0]).unwrap_or(0);
588            let out_tokens = usize::try_from(shape[1]).unwrap_or(0);
589            let hidden = usize::try_from(shape[2]).unwrap_or(0);
590            if out_batch == 0 || out_tokens == 0 || hidden == 0 {
591                bail!("invalid token embedding output shape {shape:?}");
592            }
593            if out_batch != batch {
594                bail!("semantic output batch mismatch: expected {batch}, got {out_batch}");
595            }
596
597            let mut out = Vec::with_capacity(out_batch);
598            for b in 0..out_batch {
599                let mut emb = vec![0.0_f32; hidden];
600                let mut weight_sum = 0.0_f32;
601
602                for t in 0..out_tokens {
603                    let mask_weight = if t < seq_len {
604                        flat_attention[b * seq_len + t] as f32
605                    } else {
606                        0.0
607                    };
608                    if mask_weight <= 0.0 {
609                        continue;
610                    }
611
612                    let token_base = (b * out_tokens + t) * hidden;
613                    for h in 0..hidden {
614                        emb[h] += data[token_base + h] * mask_weight;
615                    }
616                    weight_sum += mask_weight;
617                }
618
619                if weight_sum > 0.0 {
620                    for value in &mut emb {
621                        *value /= weight_sum;
622                    }
623                }
624                normalize_l2(&mut emb);
625                out.push(emb);
626            }
627            Ok(out)
628        }
629
630        // [hidden] (single-row fallback)
631        1 => {
632            if batch != 1 {
633                bail!("rank-1 semantic output only supported for single-row batch");
634            }
635            let hidden = usize::try_from(shape[0]).unwrap_or(0);
636            if hidden == 0 {
637                bail!("invalid rank-1 semantic output shape {shape:?}");
638            }
639            let mut emb = data[0..hidden].to_vec();
640            normalize_l2(&mut emb);
641            Ok(vec![emb])
642        }
643
644        rank => bail!("unsupported semantic output rank {rank}: shape {shape:?}"),
645    }
646}
647
648#[cfg(feature = "semantic-ort")]
649fn normalize_l2(values: &mut [f32]) {
650    let mut norm_sq = 0.0_f32;
651    for value in values.iter() {
652        norm_sq += value * value;
653    }
654
655    if norm_sq == 0.0 {
656        return;
657    }
658
659    let norm = norm_sq.sqrt();
660    for value in values {
661        *value /= norm;
662    }
663}
664
665/// Check if semantic search is currently available.
666#[must_use]
667pub fn is_semantic_available() -> bool {
668    SemanticModel::load().is_ok()
669}
670
671const fn bundled_model_bytes() -> Option<&'static [u8]> {
672    #[cfg(feature = "bundled-model")]
673    {
674        if BUNDLED_MODEL_BYTES.is_empty() {
675            return None;
676        }
677
678        return Some(BUNDLED_MODEL_BYTES);
679    }
680
681    #[cfg(not(feature = "bundled-model"))]
682    {
683        None
684    }
685}
686
687fn expected_model_sha256() -> Option<String> {
688    bundled_model_bytes().map(sha256_hex)
689}
690
691fn sha256_hex(bytes: &[u8]) -> String {
692    let mut hasher = Sha256::new();
693    hasher.update(bytes);
694    format!("{:x}", hasher.finalize())
695}
696
697#[cfg(feature = "semantic-ort")]
698fn auto_download_enabled() -> bool {
699    std::env::var(AUTO_DOWNLOAD_ENV).ok().is_none_or(|raw| {
700        !matches!(
701            raw.trim().to_ascii_lowercase().as_str(),
702            "0" | "false" | "no" | "off"
703        )
704    })
705}
706
707#[cfg(feature = "semantic-ort")]
708fn model_download_url() -> String {
709    std::env::var(MODEL_DOWNLOAD_URL_ENV)
710        .ok()
711        .filter(|value| !value.trim().is_empty())
712        .unwrap_or_else(|| MODEL_DOWNLOAD_URL_DEFAULT.to_string())
713}
714
715#[cfg(feature = "semantic-ort")]
716fn tokenizer_download_url() -> String {
717    std::env::var(TOKENIZER_DOWNLOAD_URL_ENV)
718        .ok()
719        .filter(|value| !value.trim().is_empty())
720        .unwrap_or_else(|| TOKENIZER_DOWNLOAD_URL_DEFAULT.to_string())
721}
722
723#[cfg(feature = "semantic-ort")]
724fn download_to_path(url: &str, path: &Path, artifact_label: &str) -> Result<()> {
725    let parent = path.parent().with_context(|| {
726        format!(
727            "{artifact_label} cache path '{}' has no parent directory",
728            path.display()
729        )
730    })?;
731    fs::create_dir_all(parent).with_context(|| {
732        format!(
733            "failed to create {} cache directory {}",
734            artifact_label,
735            parent.display()
736        )
737    })?;
738
739    let temp_path = parent.join(format!(
740        "{}.download",
741        path.file_name().unwrap_or_default().to_string_lossy()
742    ));
743
744    let agent = ureq::AgentBuilder::new()
745        .timeout_connect(Duration::from_secs(DOWNLOAD_CONNECT_TIMEOUT_SECS))
746        .timeout_read(Duration::from_secs(DOWNLOAD_READ_TIMEOUT_SECS))
747        .build();
748
749    let response = match agent
750        .get(url)
751        .set("User-Agent", "bones-search/semantic-downloader")
752        .call()
753    {
754        Ok(resp) => resp,
755        Err(ureq::Error::Status(code, _)) => {
756            bail!("{artifact_label} download failed: HTTP {code} from {url}")
757        }
758        Err(ureq::Error::Transport(err)) => {
759            bail!("{artifact_label} download failed from {url}: {err}")
760        }
761    };
762
763    {
764        let mut reader = response.into_reader();
765        let mut out = fs::File::create(&temp_path)
766            .with_context(|| format!("failed to create temporary file {}", temp_path.display()))?;
767        std::io::copy(&mut reader, &mut out)
768            .with_context(|| format!("failed to write {artifact_label} download"))?;
769        out.flush()
770            .with_context(|| format!("failed to flush {artifact_label} download"))?;
771    }
772
773    if path.exists() {
774        fs::remove_file(path).with_context(|| {
775            format!(
776                "failed to replace existing {} at {}",
777                artifact_label,
778                path.display()
779            )
780        })?;
781    }
782
783    fs::rename(&temp_path, path).with_context(|| {
784        format!(
785            "failed to move downloaded {} from {} to {}",
786            artifact_label,
787            temp_path.display(),
788            path.display()
789        )
790    })?;
791
792    Ok(())
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use std::path::Path;
799
800    #[test]
801    fn cache_path_uses_expected_suffix() {
802        let path = SemanticModel::model_cache_path().expect("cache path should resolve");
803        let expected = Path::new("bones")
804            .join("models")
805            .join("minilm-l6-v2-int8.onnx");
806        assert!(path.ends_with(expected));
807    }
808
809    #[test]
810    fn sha256_hex_matches_known_vector() {
811        let digest = sha256_hex(b"abc");
812        assert_eq!(
813            digest,
814            "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
815        );
816    }
817
818    #[cfg(not(feature = "bundled-model"))]
819    #[test]
820    fn cached_model_is_accepted_when_not_bundled() {
821        let tmp = tempfile::tempdir().expect("tempdir must be created");
822        let model = tmp.path().join("minilm-l6-v2-int8.onnx");
823        fs::write(&model, b"anything").expect("test file should be writable");
824
825        assert!(SemanticModel::is_cached_valid(&model));
826    }
827
828    #[cfg(not(feature = "bundled-model"))]
829    #[test]
830    fn extract_to_cache_fails_without_bundled_model() {
831        let tmp = tempfile::tempdir().expect("tempdir must be created");
832        let model = tmp.path().join("minilm-l6-v2-int8.onnx");
833
834        let err =
835            SemanticModel::extract_to_cache(&model).expect_err("should fail without bundled model");
836        assert!(err.to_string().contains("not bundled"));
837    }
838
839    #[cfg(not(any(feature = "semantic-ort", feature = "semantic-model2vec")))]
840    #[test]
841    fn hash_embed_always_available_as_fallback() {
842        assert!(is_semantic_available());
843    }
844
845    #[cfg(feature = "semantic-ort")]
846    #[test]
847    fn normalize_l2_produces_unit_norm() {
848        let mut emb = vec![3.0_f32, 4.0_f32, 0.0_f32];
849        normalize_l2(&mut emb);
850        let norm = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
851        assert!((norm - 1.0).abs() < 1e-6);
852    }
853
854    #[cfg(feature = "semantic-ort")]
855    #[test]
856    fn input_source_prefers_named_fields() {
857        assert!(matches!(
858            input_source(5, "attention_mask"),
859            InputSource::AttentionMask
860        ));
861        assert!(matches!(
862            input_source(5, "token_type_ids"),
863            InputSource::TokenTypeIds
864        ));
865        assert!(matches!(
866            input_source(5, "input_ids"),
867            InputSource::InputIds
868        ));
869    }
870}