Skip to main content

model2vec_rs/
model.rs

1use anyhow::{anyhow, Context, Result};
2use half::f16;
3#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
4use hf_hub::api::sync::{Api, ApiRepo};
5use ndarray::{Array2, ArrayView2, CowArray, Ix2};
6use safetensors::{tensor::Dtype, SafeTensors};
7use serde_json::Value;
8use std::borrow::Cow;
9#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
10use std::env;
11use std::{
12    fs,
13    path::{Path, PathBuf},
14};
15use tokenizers::Tokenizer;
16
17/// Static embedding model for Model2Vec
18#[derive(Debug, Clone)]
19pub struct StaticModel {
20    tokenizer: Tokenizer,
21    embeddings: CowArray<'static, f32, Ix2>,
22    weights: Option<Cow<'static, [f32]>>,
23    token_mapping: Option<Cow<'static, [usize]>>,
24    normalize: bool,
25    median_token_length: usize,
26    unk_token_id: Option<usize>,
27}
28
29#[derive(Debug, Clone)]
30struct ModelFiles {
31    tokenizer: PathBuf,
32    model: PathBuf,
33    config: PathBuf,
34}
35
36fn match_local_layout(config_base: &Path, model_base: &Path, config_file: &str) -> Option<ModelFiles> {
37    let config = config_base.join(config_file);
38    let tokenizer = model_base.join("tokenizer.json");
39    let model = model_base.join("model.safetensors");
40    (config.exists() && tokenizer.exists() && model.exists()).then_some(ModelFiles {
41        tokenizer,
42        model,
43        config,
44    })
45}
46
47fn decode_token_mapping(dtype: Dtype, raw: &[u8]) -> Result<Vec<usize>> {
48    let mapping = match dtype {
49        Dtype::I64 => raw
50            .chunks_exact(8)
51            .map(|b| i64::from_le_bytes(b.try_into().unwrap()) as usize)
52            .collect(),
53        Dtype::I32 => raw
54            .chunks_exact(4)
55            .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
56            .collect(),
57        other => return Err(anyhow!("unsupported mapping dtype: {:?}", other)),
58    };
59
60    Ok(mapping)
61}
62
63#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
64fn is_not_found(e: &hf_hub::api::sync::ApiError) -> bool {
65    use hf_hub::api::sync::ApiError;
66
67    matches!(e, ApiError::RequestError(e) if matches!(e.as_ref(), ureq::Error::Status(404, _)))
68}
69
70#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
71fn match_hub_layout(
72    repo: &ApiRepo,
73    config_prefix: &str,
74    model_prefix: &str,
75    config_file: &str,
76) -> Result<Option<ModelFiles>> {
77    let fetch = |path: String| -> Result<Option<PathBuf>> {
78        match repo.get(&path) {
79            Ok(p) => Ok(Some(p)),
80            Err(e) if is_not_found(&e) => Ok(None),
81            Err(e) => Err(e.into()),
82        }
83    };
84    let Some(config) = fetch(format!("{config_prefix}{config_file}"))? else {
85        return Ok(None);
86    };
87    let Some(tokenizer) = fetch(format!("{model_prefix}tokenizer.json"))? else {
88        return Ok(None);
89    };
90    let Some(model) = fetch(format!("{model_prefix}model.safetensors"))? else {
91        return Ok(None);
92    };
93    Ok(Some(ModelFiles {
94        tokenizer,
95        model,
96        config,
97    }))
98}
99
100fn resolve_local_model_files(folder: &Path) -> Option<ModelFiles> {
101    match_local_layout(folder, folder, "config.json")
102        .or_else(|| match_local_layout(folder, folder, "config_sentence_transformers.json"))
103        .or_else(|| {
104            match_local_layout(
105                folder,
106                &folder.join("0_StaticEmbedding"),
107                "config_sentence_transformers.json",
108            )
109        })
110        .or_else(|| {
111            folder
112                .parent()
113                .and_then(|p| match_local_layout(p, folder, "config_sentence_transformers.json"))
114        })
115}
116
117#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
118fn resolve_hub_model_files(repo: &ApiRepo, prefix: &str) -> Result<ModelFiles> {
119    let sub_prefix = format!("{prefix}0_StaticEmbedding/");
120    let trimmed = prefix.trim_end_matches('/');
121    let parent = match Path::new(trimmed).parent() {
122        Some(path) if !path.as_os_str().is_empty() => format!("{}/", path.display()),
123        _ => String::new(),
124    };
125
126    if let Some(f) = match_hub_layout(repo, prefix, prefix, "config.json")? {
127        return Ok(f);
128    }
129    if let Some(f) = match_hub_layout(repo, prefix, prefix, "config_sentence_transformers.json")? {
130        return Ok(f);
131    }
132    if let Some(f) = match_hub_layout(repo, prefix, &sub_prefix, "config_sentence_transformers.json")? {
133        return Ok(f);
134    }
135    match_hub_layout(repo, &parent, prefix, "config_sentence_transformers.json")?
136        .ok_or_else(|| anyhow!("no valid model layout found in '{prefix}'"))
137}
138
139impl StaticModel {
140    /// Load a Model2Vec model directly from in-memory bytes.
141    ///
142    /// This path is useful for runtimes that fetch model assets as bytes
143    /// rather than reading them from a local filesystem.
144    pub fn from_bytes<T, M, C>(
145        tokenizer_bytes: T,
146        model_bytes: M,
147        config_bytes: C,
148        normalize: Option<bool>,
149    ) -> Result<Self>
150    where
151        T: AsRef<[u8]>,
152        M: AsRef<[u8]>,
153        C: AsRef<[u8]>,
154    {
155        let tokenizer = Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
156
157        // Read normalize default from config.json
158        let cfg: Value = serde_json::from_slice(config_bytes.as_ref()).context("failed to parse config.json")?;
159        let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
160        let normalize = normalize.unwrap_or(cfg_norm);
161
162        // Load the safetensors
163        let safet = SafeTensors::deserialize(model_bytes.as_ref()).context("failed to parse safetensors")?;
164        let tensor = safet
165            .tensor("embeddings")
166            .or_else(|_| safet.tensor("0"))
167            .or_else(|_| safet.tensor("embedding.weight"))
168            .context("embeddings tensor not found")?;
169
170        let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2-D")?;
171        let raw = tensor.data();
172        let floats: Vec<f32> = match tensor.dtype() {
173            Dtype::F32 => raw
174                .chunks_exact(4)
175                .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
176                .collect(),
177            Dtype::F16 => raw
178                .chunks_exact(2)
179                .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
180                .collect(),
181            Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
182            other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
183        };
184
185        let weights = match safet.tensor("weights") {
186            Ok(t) => {
187                let raw = t.data();
188                let v: Vec<f32> = match t.dtype() {
189                    Dtype::F64 => raw
190                        .chunks_exact(8)
191                        .map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
192                        .collect(),
193                    Dtype::F32 => raw
194                        .chunks_exact(4)
195                        .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
196                        .collect(),
197                    Dtype::F16 => raw
198                        .chunks_exact(2)
199                        .map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
200                        .collect(),
201                    other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
202                };
203                Some(v)
204            }
205            Err(_) => None,
206        };
207
208        let token_mapping = match safet.tensor("mapping") {
209            Ok(t) => Some(decode_token_mapping(t.dtype(), t.data())?),
210            Err(_) => None,
211        };
212
213        Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
214    }
215
216    /// Load a Model2Vec model from a local folder or the HuggingFace Hub.
217    ///
218    /// # Arguments
219    /// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
220    /// * `token` - Optional HuggingFace token for authenticated downloads.
221    /// * `normalize` - Optional flag to normalize embeddings (default from the resolved config file).
222    /// * `subfolder` - Optional subfolder within the repo or path to look for model files.
223    pub fn from_pretrained<P: AsRef<Path>>(
224        repo_or_path: P,
225        token: Option<&str>,
226        normalize: Option<bool>,
227        subfolder: Option<&str>,
228    ) -> Result<Self> {
229        let files = resolve_model_files(repo_or_path, token, subfolder)?;
230        let tokenizer_bytes = fs::read(&files.tokenizer).context("failed to read tokenizer.json")?;
231        let model_bytes = fs::read(&files.model).context("failed to read model.safetensors")?;
232        let config_bytes = fs::read(&files.config).context("failed to read config.json")?;
233        Self::from_bytes(tokenizer_bytes, model_bytes, config_bytes, normalize)
234    }
235
236    /// Construct from owned data.
237    ///
238    /// # Arguments
239    /// * `tokenizer` - Pre-deserialized tokenizer
240    /// * `embeddings` - Owned f32 embedding data
241    /// * `rows` - Number of vocabulary entries
242    /// * `cols` - Embedding dimension
243    /// * `normalize` - Whether to L2-normalize output embeddings
244    /// * `weights` - Optional per-token weights for quantized models
245    /// * `token_mapping` - Optional token ID mapping for quantized models
246    pub fn from_owned(
247        tokenizer: Tokenizer,
248        embeddings: Vec<f32>,
249        rows: usize,
250        cols: usize,
251        normalize: bool,
252        weights: Option<Vec<f32>>,
253        token_mapping: Option<Vec<usize>>,
254    ) -> Result<Self> {
255        if embeddings.len() != rows * cols {
256            return Err(anyhow!(
257                "embeddings length {} != rows {} * cols {}",
258                embeddings.len(),
259                rows,
260                cols
261            ));
262        }
263        let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
264        let embeddings =
265            Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?;
266        Ok(Self {
267            tokenizer,
268            embeddings: CowArray::from(embeddings),
269            weights: weights.map(Cow::Owned),
270            token_mapping: token_mapping.map(Cow::Owned),
271            normalize,
272            median_token_length,
273            unk_token_id,
274        })
275    }
276
277    /// Construct from static slices (zero-copy for embedded binary data).
278    ///
279    /// # Arguments
280    /// * `tokenizer` - Pre-deserialized tokenizer
281    /// * `embeddings` - Static f32 embedding data (borrowed, no copy)
282    /// * `rows` - Number of vocabulary entries
283    /// * `cols` - Embedding dimension
284    /// * `normalize` - Whether to L2-normalize output embeddings
285    /// * `weights` - Optional static per-token weights for quantized models
286    /// * `token_mapping` - Optional static token ID mapping for quantized models
287    #[allow(dead_code)] // Public API for external crates
288    pub fn from_borrowed(
289        tokenizer: Tokenizer,
290        embeddings: &'static [f32],
291        rows: usize,
292        cols: usize,
293        normalize: bool,
294        weights: Option<&'static [f32]>,
295        token_mapping: Option<&'static [usize]>,
296    ) -> Result<Self> {
297        if embeddings.len() != rows * cols {
298            return Err(anyhow!(
299                "embeddings length {} != rows {} * cols {}",
300                embeddings.len(),
301                rows,
302                cols
303            ));
304        }
305        let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
306        let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?;
307        Ok(Self {
308            tokenizer,
309            embeddings: CowArray::from(embeddings),
310            weights: weights.map(Cow::Borrowed),
311            token_mapping: token_mapping.map(Cow::Borrowed),
312            normalize,
313            median_token_length,
314            unk_token_id,
315        })
316    }
317
318    /// Compute median token length and unk_token_id from tokenizer.
319    fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
320        let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
321        lens.sort_unstable();
322        let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
323
324        let spec: Value = serde_json::to_value(tokenizer).context("failed to serialize tokenizer")?;
325        let unk_token = spec
326            .get("model")
327            .and_then(|m| m.get("unk_token"))
328            .and_then(Value::as_str);
329        let unk_token_id = if let Some(tok) = unk_token {
330            let id = tokenizer
331                .token_to_id(tok)
332                .ok_or_else(|| anyhow!("unk_token '{tok}' not found in vocabulary"))?;
333            Some(id as usize)
334        } else {
335            None
336        };
337
338        Ok((median_token_length, unk_token_id))
339    }
340
341    /// Char-level truncation to max_tokens * median_token_length
342    fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
343        s.char_indices()
344            .nth(max_tokens.saturating_mul(median_len))
345            .map_or(s, |(byte_idx, _)| &s[..byte_idx])
346    }
347
348    /// Encode texts into embeddings.
349    ///
350    /// # Arguments
351    /// * `sentences` - the list of sentences to encode.
352    /// * `max_length` - max tokens per text.
353    /// * `batch_size` - number of texts per batch.
354    pub fn encode_with_args(
355        &self,
356        sentences: &[String],
357        max_length: Option<usize>,
358        batch_size: usize,
359    ) -> Vec<Vec<f32>> {
360        let mut embeddings = Vec::with_capacity(sentences.len());
361        for batch in sentences.chunks(batch_size) {
362            let truncated: Vec<&str> = batch
363                .iter()
364                .map(|text| {
365                    max_length
366                        .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
367                        .unwrap_or(text.as_str())
368                })
369                .collect();
370            let encodings = self
371                .tokenizer
372                .encode_batch_fast::<String>(truncated.into_iter().map(Into::into).collect(), false)
373                .expect("tokenization failed");
374            for encoding in encodings {
375                let mut token_ids = encoding.get_ids().to_vec();
376                if let Some(unk_id) = self.unk_token_id {
377                    token_ids.retain(|&id| id as usize != unk_id);
378                }
379                if let Some(max_tok) = max_length {
380                    token_ids.truncate(max_tok);
381                }
382                embeddings.push(self.pool_ids(token_ids));
383            }
384        }
385        embeddings
386    }
387
388    /// Default encode: `max_length=512`, `batch_size=1024`
389    pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
390        self.encode_with_args(sentences, Some(512), 1024)
391    }
392
393    /// Encode a single sentence into a vector.
394    pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
395        self.encode(&[sentence.to_string()])
396            .into_iter()
397            .next()
398            .unwrap_or_default()
399    }
400
401    /// Mean-pool a token-ID list into a single vector.
402    fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
403        let dim = self.embeddings.ncols();
404        let mut sum = vec![0.0_f32; dim];
405        let mut cnt = 0usize;
406        for &id in &ids {
407            let tok = id as usize;
408            let row_idx = self
409                .token_mapping
410                .as_ref()
411                .and_then(|m| m.get(tok))
412                .copied()
413                .unwrap_or(tok);
414            let scale = self.weights.as_ref().and_then(|w| w.get(tok)).copied().unwrap_or(1.0);
415            let row = self.embeddings.row(row_idx);
416            for (s, &v) in sum.iter_mut().zip(row.iter()) {
417                *s += v * scale;
418            }
419            cnt += 1;
420        }
421        let denom = cnt.max(1) as f32;
422        for x in &mut sum {
423            *x /= denom;
424        }
425        if self.normalize {
426            let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
427            for x in &mut sum {
428                *x /= norm;
429            }
430        }
431        sum
432    }
433}
434
435fn resolve_model_files<P: AsRef<Path>>(
436    repo_or_path: P,
437    token: Option<&str>,
438    subfolder: Option<&str>,
439) -> Result<ModelFiles> {
440    #[cfg(any(not(feature = "hf-hub"), feature = "local-only"))]
441    let _ = token;
442
443    let base = repo_or_path.as_ref();
444    if base.exists() {
445        let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
446        return resolve_local_model_files(&folder).ok_or_else(|| {
447            anyhow!(
448                "no valid model layout found in {folder:?}. \
449                 Tried: model2vec (config.json), sentence-transformers \
450                 (config_sentence_transformers.json), and 0_StaticEmbedding subfolder."
451            )
452        });
453    }
454
455    #[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
456    {
457        download_model_files(repo_or_path.as_ref().to_string_lossy().as_ref(), token, subfolder)
458    }
459    #[cfg(feature = "local-only")]
460    {
461        Err(anyhow!(
462            "remote model downloads are disabled by the `local-only` feature; pass a local model directory instead"
463        ))
464    }
465    #[cfg(all(not(feature = "hf-hub"), not(feature = "local-only")))]
466    {
467        Err(anyhow!(
468            "remote model downloads require the `hf-hub` feature; pass a local model directory instead"
469        ))
470    }
471}
472
473#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
474fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&str>) -> Result<ModelFiles> {
475    let previous = token.and_then(|_| env::var_os("HF_HUB_TOKEN"));
476    if let Some(tok) = token {
477        env::set_var("HF_HUB_TOKEN", tok);
478    }
479
480    let result = (|| {
481        let api = Api::new().context("hf-hub API init failed")?;
482        let repo = api.model(repo_id.to_owned());
483        let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default();
484        resolve_hub_model_files(&repo, &prefix)
485            .with_context(|| format!("could not load '{repo_id}' from HuggingFace Hub"))
486    })();
487
488    if token.is_some() {
489        if let Some(value) = previous {
490            env::set_var("HF_HUB_TOKEN", value);
491        } else {
492            env::remove_var("HF_HUB_TOKEN");
493        }
494    }
495
496    result
497}
498
499#[cfg(test)]
500mod tests {
501    use super::decode_token_mapping;
502    use safetensors::tensor::Dtype;
503
504    #[test]
505    fn decode_token_mapping_supports_i32_and_i64() {
506        let i32_raw = [1i32, 2, 3]
507            .into_iter()
508            .flat_map(|value| value.to_le_bytes())
509            .collect::<Vec<_>>();
510        let i64_raw = [4i64, 5, 6]
511            .into_iter()
512            .flat_map(|value| value.to_le_bytes())
513            .collect::<Vec<_>>();
514
515        assert_eq!(decode_token_mapping(Dtype::I32, &i32_raw).unwrap(), vec![1, 2, 3]);
516        assert_eq!(decode_token_mapping(Dtype::I64, &i64_raw).unwrap(), vec![4, 5, 6]);
517    }
518
519    #[test]
520    fn decode_token_mapping_rejects_unsupported_dtype() {
521        let err = decode_token_mapping(Dtype::F32, &[0, 0, 0, 0]).unwrap_err();
522        assert!(err.to_string().contains("unsupported mapping dtype"));
523    }
524}