Skip to main content

model2vec_rs/
model.rs

1use anyhow::{anyhow, Context, Result};
2use half::f16;
3#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
4use hf_hub::api::sync::{Api, ApiRepo};
5use ndarray::{Array2, ArrayView2, CowArray, Ix2};
6use safetensors::{tensor::Dtype, SafeTensors};
7use serde_json::Value;
8use std::borrow::Cow;
9#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
10use std::env;
11use std::{
12    fs,
13    path::{Path, PathBuf},
14};
15use tokenizers::Tokenizer;
16
17/// Static embedding model for Model2Vec
18#[derive(Debug, Clone)]
19pub struct StaticModel {
20    tokenizer: Tokenizer,
21    embeddings: CowArray<'static, f32, Ix2>,
22    weights: Option<Cow<'static, [f32]>>,
23    token_mapping: Option<Cow<'static, [usize]>>,
24    normalize: bool,
25    median_token_length: usize,
26    unk_token_id: Option<usize>,
27}
28
29#[derive(Debug, Clone)]
30struct ModelFiles {
31    tokenizer: PathBuf,
32    model: PathBuf,
33    config: PathBuf,
34}
35
36fn match_local_layout(config_base: &Path, model_base: &Path, config_file: &str) -> Option<ModelFiles> {
37    let config = config_base.join(config_file);
38    let tokenizer = model_base.join("tokenizer.json");
39    let model = model_base.join("model.safetensors");
40    (config.exists() && tokenizer.exists() && model.exists()).then_some(ModelFiles {
41        tokenizer,
42        model,
43        config,
44    })
45}
46
47#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
48fn is_not_found(e: &hf_hub::api::sync::ApiError) -> bool {
49    use hf_hub::api::sync::ApiError;
50
51    matches!(e, ApiError::RequestError(e) if matches!(e.as_ref(), ureq::Error::Status(404, _)))
52}
53
54#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
55fn match_hub_layout(
56    repo: &ApiRepo,
57    config_prefix: &str,
58    model_prefix: &str,
59    config_file: &str,
60) -> Result<Option<ModelFiles>> {
61    let fetch = |path: String| -> Result<Option<PathBuf>> {
62        match repo.get(&path) {
63            Ok(p) => Ok(Some(p)),
64            Err(e) if is_not_found(&e) => Ok(None),
65            Err(e) => Err(e.into()),
66        }
67    };
68    let Some(config) = fetch(format!("{config_prefix}{config_file}"))? else {
69        return Ok(None);
70    };
71    let Some(tokenizer) = fetch(format!("{model_prefix}tokenizer.json"))? else {
72        return Ok(None);
73    };
74    let Some(model) = fetch(format!("{model_prefix}model.safetensors"))? else {
75        return Ok(None);
76    };
77    Ok(Some(ModelFiles {
78        tokenizer,
79        model,
80        config,
81    }))
82}
83
84fn resolve_local_model_files(folder: &Path) -> Option<ModelFiles> {
85    match_local_layout(folder, folder, "config.json")
86        .or_else(|| match_local_layout(folder, folder, "config_sentence_transformers.json"))
87        .or_else(|| {
88            match_local_layout(
89                folder,
90                &folder.join("0_StaticEmbedding"),
91                "config_sentence_transformers.json",
92            )
93        })
94        .or_else(|| {
95            folder
96                .parent()
97                .and_then(|p| match_local_layout(p, folder, "config_sentence_transformers.json"))
98        })
99}
100
101#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
102fn resolve_hub_model_files(repo: &ApiRepo, prefix: &str) -> Result<ModelFiles> {
103    let sub_prefix = format!("{prefix}0_StaticEmbedding/");
104    let trimmed = prefix.trim_end_matches('/');
105    let parent = match Path::new(trimmed).parent() {
106        Some(path) if !path.as_os_str().is_empty() => format!("{}/", path.display()),
107        _ => String::new(),
108    };
109
110    if let Some(f) = match_hub_layout(repo, prefix, prefix, "config.json")? {
111        return Ok(f);
112    }
113    if let Some(f) = match_hub_layout(repo, prefix, prefix, "config_sentence_transformers.json")? {
114        return Ok(f);
115    }
116    if let Some(f) = match_hub_layout(repo, prefix, &sub_prefix, "config_sentence_transformers.json")? {
117        return Ok(f);
118    }
119    match_hub_layout(repo, &parent, prefix, "config_sentence_transformers.json")?
120        .ok_or_else(|| anyhow!("no valid model layout found in '{prefix}'"))
121}
122
123impl StaticModel {
124    /// Load a Model2Vec model directly from in-memory bytes.
125    ///
126    /// This path is useful for runtimes that fetch model assets as bytes
127    /// rather than reading them from a local filesystem.
128    pub fn from_bytes<T, M, C>(
129        tokenizer_bytes: T,
130        model_bytes: M,
131        config_bytes: C,
132        normalize: Option<bool>,
133    ) -> Result<Self>
134    where
135        T: AsRef<[u8]>,
136        M: AsRef<[u8]>,
137        C: AsRef<[u8]>,
138    {
139        let tokenizer = Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
140
141        // Read normalize default from config.json
142        let cfg: Value = serde_json::from_slice(config_bytes.as_ref()).context("failed to parse config.json")?;
143        let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
144        let normalize = normalize.unwrap_or(cfg_norm);
145
146        // Load the safetensors
147        let safet = SafeTensors::deserialize(model_bytes.as_ref()).context("failed to parse safetensors")?;
148        let tensor = safet
149            .tensor("embeddings")
150            .or_else(|_| safet.tensor("0"))
151            .or_else(|_| safet.tensor("embedding.weight"))
152            .context("embeddings tensor not found")?;
153
154        let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2-D")?;
155        let raw = tensor.data();
156        let floats: Vec<f32> = match tensor.dtype() {
157            Dtype::F32 => raw
158                .chunks_exact(4)
159                .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
160                .collect(),
161            Dtype::F16 => raw
162                .chunks_exact(2)
163                .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
164                .collect(),
165            Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
166            other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
167        };
168
169        let weights = match safet.tensor("weights") {
170            Ok(t) => {
171                let raw = t.data();
172                let v: Vec<f32> = match t.dtype() {
173                    Dtype::F64 => raw
174                        .chunks_exact(8)
175                        .map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
176                        .collect(),
177                    Dtype::F32 => raw
178                        .chunks_exact(4)
179                        .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
180                        .collect(),
181                    Dtype::F16 => raw
182                        .chunks_exact(2)
183                        .map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
184                        .collect(),
185                    other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
186                };
187                Some(v)
188            }
189            Err(_) => None,
190        };
191
192        let token_mapping = match safet.tensor("mapping") {
193            Ok(t) => {
194                let raw = t.data();
195                let v: Vec<usize> = raw
196                    .chunks_exact(4)
197                    .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
198                    .collect();
199                Some(v)
200            }
201            Err(_) => None,
202        };
203
204        Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
205    }
206
207    /// Load a Model2Vec model from a local folder or the HuggingFace Hub.
208    ///
209    /// # Arguments
210    /// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
211    /// * `token` - Optional HuggingFace token for authenticated downloads.
212    /// * `normalize` - Optional flag to normalize embeddings (default from the resolved config file).
213    /// * `subfolder` - Optional subfolder within the repo or path to look for model files.
214    pub fn from_pretrained<P: AsRef<Path>>(
215        repo_or_path: P,
216        token: Option<&str>,
217        normalize: Option<bool>,
218        subfolder: Option<&str>,
219    ) -> Result<Self> {
220        let files = resolve_model_files(repo_or_path, token, subfolder)?;
221        let tokenizer_bytes = fs::read(&files.tokenizer).context("failed to read tokenizer.json")?;
222        let model_bytes = fs::read(&files.model).context("failed to read model.safetensors")?;
223        let config_bytes = fs::read(&files.config).context("failed to read config.json")?;
224        Self::from_bytes(tokenizer_bytes, model_bytes, config_bytes, normalize)
225    }
226
227    /// Construct from owned data.
228    ///
229    /// # Arguments
230    /// * `tokenizer` - Pre-deserialized tokenizer
231    /// * `embeddings` - Owned f32 embedding data
232    /// * `rows` - Number of vocabulary entries
233    /// * `cols` - Embedding dimension
234    /// * `normalize` - Whether to L2-normalize output embeddings
235    /// * `weights` - Optional per-token weights for quantized models
236    /// * `token_mapping` - Optional token ID mapping for quantized models
237    pub fn from_owned(
238        tokenizer: Tokenizer,
239        embeddings: Vec<f32>,
240        rows: usize,
241        cols: usize,
242        normalize: bool,
243        weights: Option<Vec<f32>>,
244        token_mapping: Option<Vec<usize>>,
245    ) -> Result<Self> {
246        if embeddings.len() != rows * cols {
247            return Err(anyhow!(
248                "embeddings length {} != rows {} * cols {}",
249                embeddings.len(),
250                rows,
251                cols
252            ));
253        }
254        let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
255        let embeddings =
256            Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?;
257        Ok(Self {
258            tokenizer,
259            embeddings: CowArray::from(embeddings),
260            weights: weights.map(Cow::Owned),
261            token_mapping: token_mapping.map(Cow::Owned),
262            normalize,
263            median_token_length,
264            unk_token_id,
265        })
266    }
267
268    /// Construct from static slices (zero-copy for embedded binary data).
269    ///
270    /// # Arguments
271    /// * `tokenizer` - Pre-deserialized tokenizer
272    /// * `embeddings` - Static f32 embedding data (borrowed, no copy)
273    /// * `rows` - Number of vocabulary entries
274    /// * `cols` - Embedding dimension
275    /// * `normalize` - Whether to L2-normalize output embeddings
276    /// * `weights` - Optional static per-token weights for quantized models
277    /// * `token_mapping` - Optional static token ID mapping for quantized models
278    #[allow(dead_code)] // Public API for external crates
279    pub fn from_borrowed(
280        tokenizer: Tokenizer,
281        embeddings: &'static [f32],
282        rows: usize,
283        cols: usize,
284        normalize: bool,
285        weights: Option<&'static [f32]>,
286        token_mapping: Option<&'static [usize]>,
287    ) -> Result<Self> {
288        if embeddings.len() != rows * cols {
289            return Err(anyhow!(
290                "embeddings length {} != rows {} * cols {}",
291                embeddings.len(),
292                rows,
293                cols
294            ));
295        }
296        let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
297        let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?;
298        Ok(Self {
299            tokenizer,
300            embeddings: CowArray::from(embeddings),
301            weights: weights.map(Cow::Borrowed),
302            token_mapping: token_mapping.map(Cow::Borrowed),
303            normalize,
304            median_token_length,
305            unk_token_id,
306        })
307    }
308
309    /// Compute median token length and unk_token_id from tokenizer.
310    fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
311        let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
312        lens.sort_unstable();
313        let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
314
315        let spec: Value = serde_json::to_value(tokenizer).context("failed to serialize tokenizer")?;
316        let unk_token = spec
317            .get("model")
318            .and_then(|m| m.get("unk_token"))
319            .and_then(Value::as_str);
320        let unk_token_id = if let Some(tok) = unk_token {
321            let id = tokenizer
322                .token_to_id(tok)
323                .ok_or_else(|| anyhow!("unk_token '{tok}' not found in vocabulary"))?;
324            Some(id as usize)
325        } else {
326            None
327        };
328
329        Ok((median_token_length, unk_token_id))
330    }
331
332    /// Char-level truncation to max_tokens * median_token_length
333    fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
334        s.char_indices()
335            .nth(max_tokens.saturating_mul(median_len))
336            .map_or(s, |(byte_idx, _)| &s[..byte_idx])
337    }
338
339    /// Encode texts into embeddings.
340    ///
341    /// # Arguments
342    /// * `sentences` - the list of sentences to encode.
343    /// * `max_length` - max tokens per text.
344    /// * `batch_size` - number of texts per batch.
345    pub fn encode_with_args(
346        &self,
347        sentences: &[String],
348        max_length: Option<usize>,
349        batch_size: usize,
350    ) -> Vec<Vec<f32>> {
351        let mut embeddings = Vec::with_capacity(sentences.len());
352        for batch in sentences.chunks(batch_size) {
353            let truncated: Vec<&str> = batch
354                .iter()
355                .map(|text| {
356                    max_length
357                        .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
358                        .unwrap_or(text.as_str())
359                })
360                .collect();
361            let encodings = self
362                .tokenizer
363                .encode_batch_fast::<String>(truncated.into_iter().map(Into::into).collect(), false)
364                .expect("tokenization failed");
365            for encoding in encodings {
366                let mut token_ids = encoding.get_ids().to_vec();
367                if let Some(unk_id) = self.unk_token_id {
368                    token_ids.retain(|&id| id as usize != unk_id);
369                }
370                if let Some(max_tok) = max_length {
371                    token_ids.truncate(max_tok);
372                }
373                embeddings.push(self.pool_ids(token_ids));
374            }
375        }
376        embeddings
377    }
378
379    /// Default encode: `max_length=512`, `batch_size=1024`
380    pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
381        self.encode_with_args(sentences, Some(512), 1024)
382    }
383
384    /// Encode a single sentence into a vector.
385    pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
386        self.encode(&[sentence.to_string()])
387            .into_iter()
388            .next()
389            .unwrap_or_default()
390    }
391
392    /// Mean-pool a token-ID list into a single vector.
393    fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
394        let dim = self.embeddings.ncols();
395        let mut sum = vec![0.0_f32; dim];
396        let mut cnt = 0usize;
397        for &id in &ids {
398            let tok = id as usize;
399            let row_idx = self
400                .token_mapping
401                .as_ref()
402                .and_then(|m| m.get(tok))
403                .copied()
404                .unwrap_or(tok);
405            let scale = self.weights.as_ref().and_then(|w| w.get(tok)).copied().unwrap_or(1.0);
406            let row = self.embeddings.row(row_idx);
407            for (s, &v) in sum.iter_mut().zip(row.iter()) {
408                *s += v * scale;
409            }
410            cnt += 1;
411        }
412        let denom = cnt.max(1) as f32;
413        for x in &mut sum {
414            *x /= denom;
415        }
416        if self.normalize {
417            let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
418            for x in &mut sum {
419                *x /= norm;
420            }
421        }
422        sum
423    }
424}
425
426fn resolve_model_files<P: AsRef<Path>>(
427    repo_or_path: P,
428    token: Option<&str>,
429    subfolder: Option<&str>,
430) -> Result<ModelFiles> {
431    #[cfg(any(not(feature = "hf-hub"), feature = "local-only"))]
432    let _ = token;
433
434    let base = repo_or_path.as_ref();
435    if base.exists() {
436        let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
437        return resolve_local_model_files(&folder).ok_or_else(|| {
438            anyhow!(
439                "no valid model layout found in {folder:?}. \
440                 Tried: model2vec (config.json), sentence-transformers \
441                 (config_sentence_transformers.json), and 0_StaticEmbedding subfolder."
442            )
443        });
444    }
445
446    #[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
447    {
448        download_model_files(repo_or_path.as_ref().to_string_lossy().as_ref(), token, subfolder)
449    }
450    #[cfg(feature = "local-only")]
451    {
452        Err(anyhow!(
453            "remote model downloads are disabled by the `local-only` feature; pass a local model directory instead"
454        ))
455    }
456    #[cfg(all(not(feature = "hf-hub"), not(feature = "local-only")))]
457    {
458        Err(anyhow!(
459            "remote model downloads require the `hf-hub` feature; pass a local model directory instead"
460        ))
461    }
462}
463
464#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
465fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&str>) -> Result<ModelFiles> {
466    let previous = token.and_then(|_| env::var_os("HF_HUB_TOKEN"));
467    if let Some(tok) = token {
468        env::set_var("HF_HUB_TOKEN", tok);
469    }
470
471    let result = (|| {
472        let api = Api::new().context("hf-hub API init failed")?;
473        let repo = api.model(repo_id.to_owned());
474        let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default();
475        resolve_hub_model_files(&repo, &prefix)
476            .with_context(|| format!("could not load '{repo_id}' from HuggingFace Hub"))
477    })();
478
479    if token.is_some() {
480        if let Some(value) = previous {
481            env::set_var("HF_HUB_TOKEN", value);
482        } else {
483            env::remove_var("HF_HUB_TOKEN");
484        }
485    }
486
487    result
488}