1use anyhow::{Context, Result, anyhow, bail};
2use sha2::{Digest, Sha256};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6#[cfg(feature = "semantic-ort")]
7use ort::{session::Session, value::Tensor};
8#[cfg(feature = "semantic-ort")]
9use std::io::Write;
10#[cfg(feature = "semantic-ort")]
11use std::sync::Mutex;
12#[cfg(feature = "semantic-ort")]
13use std::sync::atomic::{AtomicBool, Ordering};
14#[cfg(feature = "semantic-ort")]
15use std::time::Duration;
16#[cfg(feature = "semantic-ort")]
17use tokenizers::Tokenizer;
18
19const MODEL_FILENAME: &str = "minilm-l6-v2-int8.onnx";
20#[cfg(feature = "semantic-ort")]
21const TOKENIZER_FILENAME: &str = "minilm-l6-v2-tokenizer.json";
22#[cfg(feature = "semantic-ort")]
23const MAX_TOKENS: usize = 256;
24#[cfg(feature = "semantic-ort")]
25const MODEL_DOWNLOAD_URL_ENV: &str = "BONES_SEMANTIC_MODEL_URL";
26#[cfg(feature = "semantic-ort")]
27const TOKENIZER_DOWNLOAD_URL_ENV: &str = "BONES_SEMANTIC_TOKENIZER_URL";
28#[cfg(feature = "semantic-ort")]
29const AUTO_DOWNLOAD_ENV: &str = "BONES_SEMANTIC_AUTO_DOWNLOAD";
30#[cfg(feature = "semantic-ort")]
31const MODEL_DOWNLOAD_URL_DEFAULT: &str =
32 "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
33#[cfg(feature = "semantic-ort")]
34const TOKENIZER_DOWNLOAD_URL_DEFAULT: &str =
35 "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
36#[cfg(feature = "semantic-ort")]
37const DOWNLOAD_CONNECT_TIMEOUT_SECS: u64 = 2;
38#[cfg(feature = "semantic-ort")]
39const DOWNLOAD_READ_TIMEOUT_SECS: u64 = 30;
40
41#[cfg(feature = "semantic-ort")]
42static MODEL_DOWNLOAD_ATTEMPTED: AtomicBool = AtomicBool::new(false);
43#[cfg(feature = "semantic-ort")]
44static TOKENIZER_DOWNLOAD_ATTEMPTED: AtomicBool = AtomicBool::new(false);
45
46#[cfg(feature = "bundled-model")]
47const BUNDLED_MODEL_BYTES: &[u8] = include_bytes!(concat!(
48 env!("CARGO_MANIFEST_DIR"),
49 "/models/minilm-l6-v2-int8.onnx"
50));
51
52use super::hash_embed::HashEmbedBackend;
53#[cfg(feature = "semantic-model2vec")]
54use super::model2vec::Model2VecBackend;
55
56pub struct SemanticModel {
62 inner: BackendInner,
63}
64
65#[allow(clippy::large_enum_variant)]
66enum BackendInner {
67 #[cfg(feature = "semantic-ort")]
68 Ort {
69 session: Mutex<Session>,
70 tokenizer: Tokenizer,
71 },
72 #[cfg(feature = "semantic-model2vec")]
73 Model2Vec(Model2VecBackend),
74 Hash(HashEmbedBackend),
77}
78
79#[cfg(feature = "semantic-ort")]
80struct EncodedText {
81 input_ids: Vec<i64>,
82 attention_mask: Vec<i64>,
83}
84
85#[cfg(feature = "semantic-ort")]
86enum InputSource {
87 InputIds,
88 AttentionMask,
89 TokenTypeIds,
90}
91
92impl SemanticModel {
93 pub fn load() -> Result<Self> {
102 #[cfg(feature = "semantic-ort")]
104 {
105 match Self::load_ort() {
106 Ok(model) => return Ok(model),
107 Err(err) => {
108 tracing::debug!("ORT backend unavailable, trying next: {err:#}");
109 }
110 }
111 }
112
113 #[cfg(feature = "semantic-model2vec")]
115 {
116 match Model2VecBackend::load() {
117 Ok(backend) => {
118 return Ok(Self {
119 inner: BackendInner::Model2Vec(backend),
120 });
121 }
122 Err(err) => {
123 tracing::debug!("model2vec backend unavailable: {err:#}");
124 }
125 }
126 }
127
128 tracing::debug!("using hash embedder (no ML backend available)");
130 Ok(Self {
131 inner: BackendInner::Hash(HashEmbedBackend::new()),
132 })
133 }
134
135 #[cfg(feature = "semantic-ort")]
136 fn load_ort() -> Result<Self> {
137 let path = Self::ort_model_cache_path()?;
138 Self::ensure_model_cached(&path)?;
139
140 let tokenizer_path = Self::tokenizer_cache_path()?;
141 Self::ensure_tokenizer_cached(&tokenizer_path)?;
142
143 let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
144 anyhow!(
145 "failed to load semantic tokenizer from {}: {e}",
146 tokenizer_path.display()
147 )
148 })?;
149
150 let session = Session::builder()
151 .context("failed to create ONNX Runtime session builder")?
152 .commit_from_file(&path)
153 .with_context(|| format!("failed to load semantic model from {}", path.display()))?;
154
155 Ok(Self {
156 inner: BackendInner::Ort {
157 session: Mutex::new(session),
158 tokenizer,
159 },
160 })
161 }
162
163 pub fn model_cache_path() -> Result<PathBuf> {
171 Ok(Self::model_cache_root()?.join(MODEL_FILENAME))
172 }
173
174 #[cfg(feature = "semantic-ort")]
175 fn ort_model_cache_path() -> Result<PathBuf> {
176 Ok(Self::model_cache_root()?.join(MODEL_FILENAME))
177 }
178
179 #[cfg(feature = "semantic-ort")]
180 fn tokenizer_cache_path() -> Result<PathBuf> {
181 Ok(Self::model_cache_root()?.join(TOKENIZER_FILENAME))
182 }
183
184 pub fn model_cache_root() -> Result<PathBuf> {
190 let mut path = dirs::cache_dir().context("unable to determine OS cache directory")?;
191 path.push("bones");
192 path.push("models");
193 Ok(path)
194 }
195
196 #[must_use]
198 pub fn is_cached_valid(path: &Path) -> bool {
199 let expected_sha256 = expected_model_sha256();
200 if expected_sha256.is_none() {
201 return path.is_file();
202 }
203
204 let Ok(contents) = fs::read(path) else {
205 return false;
206 };
207
208 expected_sha256.is_some_and(|sha256| sha256_hex(&contents) == sha256)
209 }
210
211 pub fn extract_to_cache(path: &Path) -> Result<()> {
218 let bundled = bundled_model_bytes().ok_or_else(|| {
219 anyhow!(
220 "semantic model bytes are not bundled; enable `bundled-model` with a packaged ONNX file"
221 )
222 })?;
223
224 let parent = path.parent().with_context(|| {
225 format!(
226 "model cache path '{}' has no parent directory",
227 path.display()
228 )
229 })?;
230 fs::create_dir_all(parent).with_context(|| {
231 format!(
232 "failed to create semantic model cache directory {}",
233 parent.display()
234 )
235 })?;
236
237 let temp_path = parent.join(format!("{MODEL_FILENAME}.tmp"));
238 fs::write(&temp_path, bundled)
239 .with_context(|| format!("failed to write bundled model to {}", temp_path.display()))?;
240
241 if path.exists() {
242 fs::remove_file(path)
243 .with_context(|| format!("failed to replace existing model {}", path.display()))?;
244 }
245
246 fs::rename(&temp_path, path).with_context(|| {
247 format!(
248 "failed to move extracted model from {} to {}",
249 temp_path.display(),
250 path.display()
251 )
252 })?;
253
254 if !Self::is_cached_valid(path) {
255 bail!(
256 "extracted semantic model at {} failed SHA256 verification",
257 path.display()
258 );
259 }
260
261 Ok(())
262 }
263
264 #[cfg(any(feature = "semantic-ort", feature = "bundled-model"))]
265 fn ensure_model_cached(path: &Path) -> Result<()> {
266 if Self::is_cached_valid(path) {
267 return Ok(());
268 }
269
270 if bundled_model_bytes().is_some() {
271 Self::extract_to_cache(path)?;
272 return Ok(());
273 }
274
275 #[cfg(feature = "semantic-ort")]
276 {
277 if !auto_download_enabled() {
278 bail!(
279 "semantic model not found at {}. Automatic download is disabled via {AUTO_DOWNLOAD_ENV}=0",
280 path.display()
281 );
282 }
283
284 if MODEL_DOWNLOAD_ATTEMPTED.swap(true, Ordering::SeqCst) {
285 bail!(
286 "semantic model not found at {} and auto-download was already attempted in this process",
287 path.display()
288 );
289 }
290
291 download_to_path(&model_download_url(), path, "semantic model")
292 .with_context(|| format!("failed to fetch semantic model to {}", path.display()))?;
293
294 if !Self::is_cached_valid(path) {
295 bail!(
296 "downloaded semantic model at {} failed validation",
297 path.display()
298 );
299 }
300
301 Ok(())
302 }
303
304 #[cfg(not(feature = "semantic-ort"))]
305 {
306 bail!(
307 "semantic model not found at {}; enable `bundled-model` or place `{MODEL_FILENAME}` in the cache path",
308 path.display()
309 );
310 }
311 }
312
313 #[cfg(feature = "semantic-ort")]
314 fn ensure_tokenizer_cached(path: &Path) -> Result<()> {
315 if path.is_file() {
316 return Ok(());
317 }
318
319 if !auto_download_enabled() {
320 bail!(
321 "semantic tokenizer not found at {}. Automatic download is disabled via {AUTO_DOWNLOAD_ENV}=0",
322 path.display()
323 );
324 }
325
326 if TOKENIZER_DOWNLOAD_ATTEMPTED.swap(true, Ordering::SeqCst) {
327 bail!(
328 "semantic tokenizer not found at {} and auto-download was already attempted in this process",
329 path.display()
330 );
331 }
332
333 download_to_path(&tokenizer_download_url(), path, "semantic tokenizer")
334 .with_context(|| format!("failed to fetch semantic tokenizer to {}", path.display()))?;
335
336 if !path.is_file() {
337 bail!("semantic tokenizer download completed but file was not created");
338 }
339
340 Ok(())
341 }
342
343 #[must_use]
345 #[allow(clippy::missing_const_for_fn)]
346 pub fn dimensions(&self) -> usize {
347 match &self.inner {
348 #[cfg(feature = "semantic-ort")]
349 BackendInner::Ort { .. } => 384, #[cfg(feature = "semantic-model2vec")]
351 BackendInner::Model2Vec(m) => m.dimensions(),
352 BackendInner::Hash(h) => h.dimensions(),
353 }
354 }
355
356 #[must_use]
359 #[allow(clippy::missing_const_for_fn)]
360 pub fn backend_id(&self) -> &'static str {
361 match &self.inner {
362 #[cfg(feature = "semantic-ort")]
363 BackendInner::Ort { .. } => "ort-minilm-384",
364 #[cfg(feature = "semantic-model2vec")]
365 BackendInner::Model2Vec(_) => "model2vec-potion-8m",
366 BackendInner::Hash(_) => "hash-ngram-256",
367 }
368 }
369
370 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
376 match &self.inner {
377 #[cfg(feature = "semantic-ort")]
378 BackendInner::Ort { .. } => {
379 let encoded = self.encode_text(text)?;
380 let mut out = self.run_model_batch(&[encoded])?;
381 out.pop()
382 .ok_or_else(|| anyhow!("semantic model returned no embedding"))
383 }
384 #[cfg(feature = "semantic-model2vec")]
385 BackendInner::Model2Vec(m) => m.embed(text),
386 BackendInner::Hash(h) => h.embed(text),
387 }
388 }
389
390 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
396 match &self.inner {
397 #[cfg(feature = "semantic-ort")]
398 BackendInner::Ort { .. } => {
399 let encoded: Vec<EncodedText> = texts
400 .iter()
401 .map(|text| self.encode_text(text))
402 .collect::<Result<Vec<_>>>()?;
403 self.run_model_batch(&encoded)
404 }
405 #[cfg(feature = "semantic-model2vec")]
406 BackendInner::Model2Vec(m) => m.embed_batch(texts),
407 BackendInner::Hash(h) => h.embed_batch(texts),
408 }
409 }
410
411 #[cfg(feature = "semantic-ort")]
412 fn ort_tokenizer(&self) -> &Tokenizer {
413 match &self.inner {
414 BackendInner::Ort { tokenizer, .. } => tokenizer,
415 #[cfg(feature = "semantic-model2vec")]
416 BackendInner::Model2Vec(_) | BackendInner::Hash(_) => {
417 unreachable!("encode_text called on non-ORT backend")
418 }
419 #[cfg(not(feature = "semantic-model2vec"))]
420 BackendInner::Hash(_) => unreachable!("encode_text called on non-ORT backend"),
421 }
422 }
423
424 #[cfg(feature = "semantic-ort")]
425 fn ort_session(&self) -> &Mutex<Session> {
426 match &self.inner {
427 BackendInner::Ort { session, .. } => session,
428 #[cfg(feature = "semantic-model2vec")]
429 BackendInner::Model2Vec(_) | BackendInner::Hash(_) => {
430 unreachable!("run_model_batch called on non-ORT backend")
431 }
432 #[cfg(not(feature = "semantic-model2vec"))]
433 BackendInner::Hash(_) => unreachable!("run_model_batch called on non-ORT backend"),
434 }
435 }
436
437 #[cfg(feature = "semantic-ort")]
438 fn encode_text(&self, text: &str) -> Result<EncodedText> {
439 let encoding = self
440 .ort_tokenizer()
441 .encode(text, true)
442 .map_err(|e| anyhow!("failed to tokenize semantic query: {e}"))?;
443
444 let ids = encoding.get_ids();
445 if ids.is_empty() {
446 bail!("semantic tokenizer produced zero tokens");
447 }
448
449 let attention = encoding.get_attention_mask();
450 let keep = ids.len().min(MAX_TOKENS);
451
452 let mut input_ids = Vec::with_capacity(keep);
453 let mut attention_mask = Vec::with_capacity(keep);
454 for (idx, id) in ids.iter().enumerate().take(keep) {
455 input_ids.push(i64::from(*id));
456 attention_mask.push(i64::from(*attention.get(idx).unwrap_or(&1_u32)));
457 }
458 if attention_mask.iter().all(|v| *v == 0) {
459 attention_mask.fill(1);
460 }
461
462 Ok(EncodedText {
463 input_ids,
464 attention_mask,
465 })
466 }
467
468 #[cfg(feature = "semantic-ort")]
469 #[allow(clippy::significant_drop_tightening, clippy::cast_precision_loss)]
470 fn run_model_batch(&self, encoded: &[EncodedText]) -> Result<Vec<Vec<f32>>> {
471 if encoded.is_empty() {
472 return Ok(Vec::new());
473 }
474
475 let batch = encoded.len();
476 let seq_len = encoded.iter().map(|e| e.input_ids.len()).max().unwrap_or(0);
477 if seq_len == 0 {
478 bail!("semantic batch has no tokens");
479 }
480
481 let mut flat_ids = vec![0_i64; batch * seq_len];
482 let mut flat_attention = vec![0_i64; batch * seq_len];
483 for (row_idx, row) in encoded.iter().enumerate() {
484 let row_base = row_idx * seq_len;
485 flat_ids[row_base..(row.input_ids.len() + row_base)].copy_from_slice(&row.input_ids);
486 flat_attention[row_base..(row.attention_mask.len() + row_base)]
487 .copy_from_slice(&row.attention_mask);
488 }
489 let flat_token_types = vec![0_i64; batch * seq_len];
490
491 let mut session = self
492 .ort_session()
493 .lock()
494 .map_err(|_| anyhow!("semantic model session mutex poisoned"))?;
495
496 let model_inputs = session.inputs();
497 let mut inputs: Vec<(String, Tensor<i64>)> = Vec::with_capacity(model_inputs.len());
498 for (index, input) in model_inputs.iter().enumerate() {
499 let input_name = input.name();
500 let source = input_source(index, input_name);
501 let data = match source {
502 InputSource::InputIds => flat_ids.clone(),
503 InputSource::AttentionMask => flat_attention.clone(),
504 InputSource::TokenTypeIds => flat_token_types.clone(),
505 };
506 let tensor = Tensor::<i64>::from_array(([batch, seq_len], data.into_boxed_slice()))
507 .with_context(|| format!("failed to build ONNX input tensor '{input_name}'"))?;
508 inputs.push((input_name.to_string(), tensor));
509 }
510
511 let outputs = session
512 .run(inputs)
513 .context("failed to run ONNX semantic inference")?;
514
515 if outputs.len() == 0 {
516 bail!("semantic model returned no outputs");
517 }
518
519 let output = outputs
520 .get("sentence_embedding")
521 .or_else(|| outputs.get("last_hidden_state"))
522 .or_else(|| outputs.get("token_embeddings"))
523 .unwrap_or(&outputs[0]);
524
525 let (shape, data) = output.try_extract_tensor::<f32>().with_context(
526 || "semantic model output tensor is not f32; expected sentence embedding tensor",
527 )?;
528
529 decode_embeddings(shape, data, &flat_attention, batch, seq_len)
530 }
531}
532
533#[cfg(feature = "semantic-ort")]
534fn input_source(index: usize, input_name: &str) -> InputSource {
535 let name = input_name.to_ascii_lowercase();
536 if name.contains("attention") {
537 return InputSource::AttentionMask;
538 }
539 if name.contains("token_type") || name.contains("segment") {
540 return InputSource::TokenTypeIds;
541 }
542 if name.contains("input_ids") || (name.contains("input") && name.contains("id")) {
543 return InputSource::InputIds;
544 }
545
546 match index {
547 0 => InputSource::InputIds,
548 1 => InputSource::AttentionMask,
549 _ => InputSource::TokenTypeIds,
550 }
551}
552
553#[cfg(feature = "semantic-ort")]
554#[allow(clippy::cast_precision_loss)]
555fn decode_embeddings(
556 shape: &[i64],
557 data: &[f32],
558 flat_attention: &[i64],
559 batch: usize,
560 seq_len: usize,
561) -> Result<Vec<Vec<f32>>> {
562 match shape.len() {
563 2 => {
565 let out_batch = usize::try_from(shape[0]).unwrap_or(0);
566 let hidden = usize::try_from(shape[1]).unwrap_or(0);
567 if out_batch == 0 || hidden == 0 {
568 bail!("invalid sentence embedding output shape {shape:?}");
569 }
570 if out_batch != batch {
571 bail!("semantic output batch mismatch: expected {batch}, got {out_batch}");
572 }
573
574 let mut out = Vec::with_capacity(out_batch);
575 for row in 0..out_batch {
576 let start = row * hidden;
577 let end = start + hidden;
578 let mut emb = data[start..end].to_vec();
579 normalize_l2(&mut emb);
580 out.push(emb);
581 }
582 Ok(out)
583 }
584
585 3 => {
587 let out_batch = usize::try_from(shape[0]).unwrap_or(0);
588 let out_tokens = usize::try_from(shape[1]).unwrap_or(0);
589 let hidden = usize::try_from(shape[2]).unwrap_or(0);
590 if out_batch == 0 || out_tokens == 0 || hidden == 0 {
591 bail!("invalid token embedding output shape {shape:?}");
592 }
593 if out_batch != batch {
594 bail!("semantic output batch mismatch: expected {batch}, got {out_batch}");
595 }
596
597 let mut out = Vec::with_capacity(out_batch);
598 for b in 0..out_batch {
599 let mut emb = vec![0.0_f32; hidden];
600 let mut weight_sum = 0.0_f32;
601
602 for t in 0..out_tokens {
603 let mask_weight = if t < seq_len {
604 flat_attention[b * seq_len + t] as f32
605 } else {
606 0.0
607 };
608 if mask_weight <= 0.0 {
609 continue;
610 }
611
612 let token_base = (b * out_tokens + t) * hidden;
613 for h in 0..hidden {
614 emb[h] += data[token_base + h] * mask_weight;
615 }
616 weight_sum += mask_weight;
617 }
618
619 if weight_sum > 0.0 {
620 for value in &mut emb {
621 *value /= weight_sum;
622 }
623 }
624 normalize_l2(&mut emb);
625 out.push(emb);
626 }
627 Ok(out)
628 }
629
630 1 => {
632 if batch != 1 {
633 bail!("rank-1 semantic output only supported for single-row batch");
634 }
635 let hidden = usize::try_from(shape[0]).unwrap_or(0);
636 if hidden == 0 {
637 bail!("invalid rank-1 semantic output shape {shape:?}");
638 }
639 let mut emb = data[0..hidden].to_vec();
640 normalize_l2(&mut emb);
641 Ok(vec![emb])
642 }
643
644 rank => bail!("unsupported semantic output rank {rank}: shape {shape:?}"),
645 }
646}
647
648#[cfg(feature = "semantic-ort")]
649fn normalize_l2(values: &mut [f32]) {
650 let mut norm_sq = 0.0_f32;
651 for value in values.iter() {
652 norm_sq += value * value;
653 }
654
655 if norm_sq == 0.0 {
656 return;
657 }
658
659 let norm = norm_sq.sqrt();
660 for value in values {
661 *value /= norm;
662 }
663}
664
665#[must_use]
667pub fn is_semantic_available() -> bool {
668 SemanticModel::load().is_ok()
669}
670
671const fn bundled_model_bytes() -> Option<&'static [u8]> {
672 #[cfg(feature = "bundled-model")]
673 {
674 if BUNDLED_MODEL_BYTES.is_empty() {
675 return None;
676 }
677
678 return Some(BUNDLED_MODEL_BYTES);
679 }
680
681 #[cfg(not(feature = "bundled-model"))]
682 {
683 None
684 }
685}
686
687fn expected_model_sha256() -> Option<String> {
688 bundled_model_bytes().map(sha256_hex)
689}
690
691fn sha256_hex(bytes: &[u8]) -> String {
692 let mut hasher = Sha256::new();
693 hasher.update(bytes);
694 format!("{:x}", hasher.finalize())
695}
696
697#[cfg(feature = "semantic-ort")]
698fn auto_download_enabled() -> bool {
699 std::env::var(AUTO_DOWNLOAD_ENV).ok().is_none_or(|raw| {
700 !matches!(
701 raw.trim().to_ascii_lowercase().as_str(),
702 "0" | "false" | "no" | "off"
703 )
704 })
705}
706
707#[cfg(feature = "semantic-ort")]
708fn model_download_url() -> String {
709 std::env::var(MODEL_DOWNLOAD_URL_ENV)
710 .ok()
711 .filter(|value| !value.trim().is_empty())
712 .unwrap_or_else(|| MODEL_DOWNLOAD_URL_DEFAULT.to_string())
713}
714
715#[cfg(feature = "semantic-ort")]
716fn tokenizer_download_url() -> String {
717 std::env::var(TOKENIZER_DOWNLOAD_URL_ENV)
718 .ok()
719 .filter(|value| !value.trim().is_empty())
720 .unwrap_or_else(|| TOKENIZER_DOWNLOAD_URL_DEFAULT.to_string())
721}
722
723#[cfg(feature = "semantic-ort")]
724fn download_to_path(url: &str, path: &Path, artifact_label: &str) -> Result<()> {
725 let parent = path.parent().with_context(|| {
726 format!(
727 "{artifact_label} cache path '{}' has no parent directory",
728 path.display()
729 )
730 })?;
731 fs::create_dir_all(parent).with_context(|| {
732 format!(
733 "failed to create {} cache directory {}",
734 artifact_label,
735 parent.display()
736 )
737 })?;
738
739 let temp_path = parent.join(format!(
740 "{}.download",
741 path.file_name().unwrap_or_default().to_string_lossy()
742 ));
743
744 let agent = ureq::AgentBuilder::new()
745 .timeout_connect(Duration::from_secs(DOWNLOAD_CONNECT_TIMEOUT_SECS))
746 .timeout_read(Duration::from_secs(DOWNLOAD_READ_TIMEOUT_SECS))
747 .build();
748
749 let response = match agent
750 .get(url)
751 .set("User-Agent", "bones-search/semantic-downloader")
752 .call()
753 {
754 Ok(resp) => resp,
755 Err(ureq::Error::Status(code, _)) => {
756 bail!("{artifact_label} download failed: HTTP {code} from {url}")
757 }
758 Err(ureq::Error::Transport(err)) => {
759 bail!("{artifact_label} download failed from {url}: {err}")
760 }
761 };
762
763 {
764 let mut reader = response.into_reader();
765 let mut out = fs::File::create(&temp_path)
766 .with_context(|| format!("failed to create temporary file {}", temp_path.display()))?;
767 std::io::copy(&mut reader, &mut out)
768 .with_context(|| format!("failed to write {artifact_label} download"))?;
769 out.flush()
770 .with_context(|| format!("failed to flush {artifact_label} download"))?;
771 }
772
773 if path.exists() {
774 fs::remove_file(path).with_context(|| {
775 format!(
776 "failed to replace existing {} at {}",
777 artifact_label,
778 path.display()
779 )
780 })?;
781 }
782
783 fs::rename(&temp_path, path).with_context(|| {
784 format!(
785 "failed to move downloaded {} from {} to {}",
786 artifact_label,
787 temp_path.display(),
788 path.display()
789 )
790 })?;
791
792 Ok(())
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798 use std::path::Path;
799
800 #[test]
801 fn cache_path_uses_expected_suffix() {
802 let path = SemanticModel::model_cache_path().expect("cache path should resolve");
803 let expected = Path::new("bones")
804 .join("models")
805 .join("minilm-l6-v2-int8.onnx");
806 assert!(path.ends_with(expected));
807 }
808
809 #[test]
810 fn sha256_hex_matches_known_vector() {
811 let digest = sha256_hex(b"abc");
812 assert_eq!(
813 digest,
814 "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
815 );
816 }
817
818 #[cfg(not(feature = "bundled-model"))]
819 #[test]
820 fn cached_model_is_accepted_when_not_bundled() {
821 let tmp = tempfile::tempdir().expect("tempdir must be created");
822 let model = tmp.path().join("minilm-l6-v2-int8.onnx");
823 fs::write(&model, b"anything").expect("test file should be writable");
824
825 assert!(SemanticModel::is_cached_valid(&model));
826 }
827
828 #[cfg(not(feature = "bundled-model"))]
829 #[test]
830 fn extract_to_cache_fails_without_bundled_model() {
831 let tmp = tempfile::tempdir().expect("tempdir must be created");
832 let model = tmp.path().join("minilm-l6-v2-int8.onnx");
833
834 let err =
835 SemanticModel::extract_to_cache(&model).expect_err("should fail without bundled model");
836 assert!(err.to_string().contains("not bundled"));
837 }
838
839 #[cfg(not(any(feature = "semantic-ort", feature = "semantic-model2vec")))]
840 #[test]
841 fn hash_embed_always_available_as_fallback() {
842 assert!(is_semantic_available());
843 }
844
845 #[cfg(feature = "semantic-ort")]
846 #[test]
847 fn normalize_l2_produces_unit_norm() {
848 let mut emb = vec![3.0_f32, 4.0_f32, 0.0_f32];
849 normalize_l2(&mut emb);
850 let norm = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
851 assert!((norm - 1.0).abs() < 1e-6);
852 }
853
854 #[cfg(feature = "semantic-ort")]
855 #[test]
856 fn input_source_prefers_named_fields() {
857 assert!(matches!(
858 input_source(5, "attention_mask"),
859 InputSource::AttentionMask
860 ));
861 assert!(matches!(
862 input_source(5, "token_type_ids"),
863 InputSource::TokenTypeIds
864 ));
865 assert!(matches!(
866 input_source(5, "input_ids"),
867 InputSource::InputIds
868 ));
869 }
870}