model2vec_rs/
model.rs

1use anyhow::{anyhow, Context, Result};
2use half::f16;
3use hf_hub::api::sync::Api;
4use ndarray::Array2;
5use safetensors::{tensor::Dtype, SafeTensors};
6use serde_json::Value;
7use std::{env, fs, path::Path};
8use tokenizers::Tokenizer;
9
10/// Static embedding model for Model2Vec
11pub struct StaticModel {
12    tokenizer: Tokenizer,
13    embeddings: Array2<f32>,
14    normalize: bool,
15    median_token_length: usize,
16    unk_token_id: Option<usize>,
17}
18
19impl StaticModel {
20    /// Load a Model2Vec model from a local folder or the HuggingFace Hub.
21    ///
22    /// # Arguments
23    /// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
24    /// * `token` - Optional HuggingFace token for authenticated downloads.
25    /// * `normalize` - Optional flag to normalize embeddings (default from config.json).
26    /// * `subfolder` - Optional subfolder within the repo or path to look for model files.
27    pub fn from_pretrained<P: AsRef<Path>>(
28        repo_or_path: P,
29        token: Option<&str>,
30        normalize: Option<bool>,
31        subfolder: Option<&str>,
32    ) -> Result<Self> {
33        // If provided, set HF token for authenticated downloads
34        if let Some(tok) = token {
35            env::set_var("HF_HUB_TOKEN", tok);
36        }
37
38        // Locate tokenizer.json, model.safetensors, config.json
39        let (tok_path, mdl_path, cfg_path) = {
40            let base = repo_or_path.as_ref();
41            if base.exists() {
42                let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
43                let t = folder.join("tokenizer.json");
44                let m = folder.join("model.safetensors");
45                let c = folder.join("config.json");
46                if !t.exists() || !m.exists() || !c.exists() {
47                    return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
48                }
49                (t, m, c)
50            } else {
51                let api = Api::new().context("hf-hub API init failed")?;
52                let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
53                let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
54                let t = repo.get(&format!("{prefix}tokenizer.json"))?;
55                let m = repo.get(&format!("{prefix}model.safetensors"))?;
56                let c = repo.get(&format!("{prefix}config.json"))?;
57                (t, m, c)
58            }
59        };
60
61        // Load the tokenizer
62        let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
63
64        // Median-token-length hack for pre-truncation
65        let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
66        lens.sort_unstable();
67        let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
68
69        // Read normalize default from config.json
70        let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
71        let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
72        let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
73        let normalize = normalize.unwrap_or(cfg_norm);
74
75        // Serialize the tokenizer to JSON, then parse it and get the unk_token
76        let spec_json = tokenizer
77            .to_string(false)
78            .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
79        let spec: Value = serde_json::from_str(&spec_json)?;
80        let unk_token = spec
81            .get("model")
82            .and_then(|m| m.get("unk_token"))
83            .and_then(Value::as_str)
84            .unwrap_or("[UNK]");
85        let unk_token_id = tokenizer
86            .token_to_id(unk_token)
87            .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
88            as usize;
89
90        // Load the safetensors
91        let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
92        let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
93        let tensor = safet
94            .tensor("embeddings")
95            .or_else(|_| safet.tensor("0"))
96            .context("embeddings tensor not found")?;
97
98        let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2‑D")?;
99        let raw = tensor.data();
100        let dtype = tensor.dtype();
101
102        // Decode into f32
103        let floats: Vec<f32> = match dtype {
104            Dtype::F32 => raw
105                .chunks_exact(4)
106                .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
107                .collect(),
108            Dtype::F16 => raw
109                .chunks_exact(2)
110                .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
111                .collect(),
112            Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
113            other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
114        };
115        let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;
116
117        Ok(Self {
118            tokenizer,
119            embeddings,
120            normalize,
121            median_token_length,
122            unk_token_id: Some(unk_token_id),
123        })
124    }
125
126    /// Char-level truncation to max_tokens * median_token_length
127    fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
128        let max_chars = max_tokens.saturating_mul(median_len);
129        match s.char_indices().nth(max_chars) {
130            Some((byte_idx, _)) => &s[..byte_idx],
131            None => s,
132        }
133    }
134
135    /// Encode texts into embeddings.
136    ///
137    /// # Arguments
138    /// * `sentences` - the list of sentences to encode.
139    /// * `max_length` - max tokens per text.
140    /// * `batch_size` - number of texts per batch.
141    pub fn encode_with_args(
142        &self,
143        sentences: &[String],
144        max_length: Option<usize>,
145        batch_size: usize,
146    ) -> Vec<Vec<f32>> {
147        let mut embeddings = Vec::with_capacity(sentences.len());
148
149        // Process in batches
150        for batch in sentences.chunks(batch_size) {
151            // Truncate each sentence to max_length * median_token_length chars
152            let truncated: Vec<&str> = batch
153                .iter()
154                .map(|text| {
155                    max_length
156                        .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
157                        .unwrap_or(text.as_str())
158                })
159                .collect();
160
161            // Tokenize the batch
162            let encodings = self
163                .tokenizer
164                .encode_batch_fast::<String>(
165                    // Into<EncodeInput>
166                    truncated.into_iter().map(Into::into).collect(),
167                    /* add_special_tokens = */ false,
168                )
169                .expect("tokenization failed");
170
171            // Pool each token-ID list into a single mean vector
172            for encoding in encodings {
173                let mut token_ids = encoding.get_ids().to_vec();
174                // Remove unk tokens if specified
175                if let Some(unk_id) = self.unk_token_id {
176                    token_ids.retain(|&id| id as usize != unk_id);
177                }
178                // Truncate to max_length if specified
179                if let Some(max_tok) = max_length {
180                    token_ids.truncate(max_tok);
181                }
182                embeddings.push(self.pool_ids(token_ids));
183            }
184        }
185
186        embeddings
187    }
188
189    /// Default encode: `max_length=512`, `batch_size=1024`
190    pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
191        self.encode_with_args(sentences, Some(512), 1024)
192    }
193
194    // / Encode a single sentence into a vector
195    pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
196        self.encode(&[sentence.to_string()])
197            .into_iter()
198            .next()
199            .unwrap_or_default()
200    }
201
202    /// Mean-pool a single token-ID list into a vector
203    fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
204        let mut sum = vec![0.0; self.embeddings.ncols()];
205        for &id in &ids {
206            let row = self.embeddings.row(id as usize);
207            for (i, &v) in row.iter().enumerate() {
208                sum[i] += v;
209            }
210        }
211        let cnt = ids.len().max(1) as f32;
212        sum.iter_mut().for_each(|x| *x /= cnt);
213        if self.normalize {
214            let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
215            sum.iter_mut().for_each(|x| *x /= norm);
216        }
217        sum
218    }
219}