model2vec_rs/
model.rs

1use anyhow::{anyhow, Context, Result};
2use half::f16;
3use hf_hub::api::sync::Api;
4use ndarray::Array2;
5use safetensors::{tensor::Dtype, SafeTensors};
6use serde_json::Value;
7use std::{env, fs, path::Path};
8use tokenizers::Tokenizer;
9
10/// Static embedding model for Model2Vec
11#[derive(Debug, Clone)]
12pub struct StaticModel {
13    tokenizer: Tokenizer,
14    embeddings: Array2<f32>,
15    weights: Option<Vec<f32>>,
16    token_mapping: Option<Vec<usize>>,
17    normalize: bool,
18    median_token_length: usize,
19    unk_token_id: Option<usize>,
20}
21
22impl StaticModel {
23    /// Load a Model2Vec model from a local folder or the HuggingFace Hub.
24    ///
25    /// # Arguments
26    /// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
27    /// * `token` - Optional HuggingFace token for authenticated downloads.
28    /// * `normalize` - Optional flag to normalize embeddings (default from config.json).
29    /// * `subfolder` - Optional subfolder within the repo or path to look for model files.
30    pub fn from_pretrained<P: AsRef<Path>>(
31        repo_or_path: P,
32        token: Option<&str>,
33        normalize: Option<bool>,
34        subfolder: Option<&str>,
35    ) -> Result<Self> {
36        // If provided, set HF token for authenticated downloads
37        if let Some(tok) = token {
38            env::set_var("HF_HUB_TOKEN", tok);
39        }
40
41        // Locate tokenizer.json, model.safetensors, config.json
42        let (tok_path, mdl_path, cfg_path) = {
43            let base = repo_or_path.as_ref();
44            if base.exists() {
45                let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
46                let t = folder.join("tokenizer.json");
47                let m = folder.join("model.safetensors");
48                let c = folder.join("config.json");
49                if !t.exists() || !m.exists() || !c.exists() {
50                    return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
51                }
52                (t, m, c)
53            } else {
54                let api = Api::new().context("hf-hub API init failed")?;
55                let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
56                let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
57                let t = repo.get(&format!("{prefix}tokenizer.json"))?;
58                let m = repo.get(&format!("{prefix}model.safetensors"))?;
59                let c = repo.get(&format!("{prefix}config.json"))?;
60                (t, m, c)
61            }
62        };
63
64        // Load the tokenizer
65        let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
66
67        // Median-token-length hack for pre-truncation
68        let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
69        lens.sort_unstable();
70        let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
71
72        // Read normalize default from config.json
73        let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
74        let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
75        let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
76        let normalize = normalize.unwrap_or(cfg_norm);
77
78        // Serialize the tokenizer to JSON, then parse it and get the unk_token
79        let spec_json = tokenizer
80            .to_string(false)
81            .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
82        let spec: Value = serde_json::from_str(&spec_json)?;
83        let unk_token = spec
84            .get("model")
85            .and_then(|m| m.get("unk_token"))
86            .and_then(Value::as_str)
87            .unwrap_or("[UNK]");
88        let unk_token_id = tokenizer
89            .token_to_id(unk_token)
90            .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
91            as usize;
92
93        // Load the safetensors
94        let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
95        let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
96        let tensor = safet
97            .tensor("embeddings")
98            .or_else(|_| safet.tensor("0"))
99            .context("embeddings tensor not found")?;
100
101        let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2‑D")?;
102        let raw = tensor.data();
103        let dtype = tensor.dtype();
104
105        // Decode into f32
106        let floats: Vec<f32> = match dtype {
107            Dtype::F32 => raw
108                .chunks_exact(4)
109                .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
110                .collect(),
111            Dtype::F16 => raw
112                .chunks_exact(2)
113                .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
114                .collect(),
115            Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
116            other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
117        };
118        let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;
119
120        // Load optional weights for vocabulary quantization
121        let weights = match safet.tensor("weights") {
122            Ok(t) => {
123                let raw = t.data();
124                let v: Vec<f32> = match t.dtype() {
125                    Dtype::F64 => raw
126                        .chunks_exact(8)
127                        .map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
128                        .collect(),
129                    Dtype::F32 => raw
130                        .chunks_exact(4)
131                        .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
132                        .collect(),
133                    Dtype::F16 => raw
134                        .chunks_exact(2)
135                        .map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
136                        .collect(),
137                    other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
138                };
139                Some(v)
140            }
141            Err(_) => None,
142        };
143
144        // Load optional token mapping for vocabulary quantization
145        let token_mapping = match safet.tensor("mapping") {
146            Ok(t) => {
147                let raw = t.data();
148                let v: Vec<usize> = raw
149                    .chunks_exact(4)
150                    .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
151                    .collect();
152                Some(v)
153            }
154            Err(_) => None,
155        };
156
157        Ok(Self {
158            tokenizer,
159            embeddings,
160            weights,
161            token_mapping,
162            normalize,
163            median_token_length,
164            unk_token_id: Some(unk_token_id),
165        })
166    }
167
168    /// Char-level truncation to max_tokens * median_token_length
169    fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
170        let max_chars = max_tokens.saturating_mul(median_len);
171        match s.char_indices().nth(max_chars) {
172            Some((byte_idx, _)) => &s[..byte_idx],
173            None => s,
174        }
175    }
176
177    /// Encode texts into embeddings.
178    ///
179    /// # Arguments
180    /// * `sentences` - the list of sentences to encode.
181    /// * `max_length` - max tokens per text.
182    /// * `batch_size` - number of texts per batch.
183    pub fn encode_with_args(
184        &self,
185        sentences: &[String],
186        max_length: Option<usize>,
187        batch_size: usize,
188    ) -> Vec<Vec<f32>> {
189        let mut embeddings = Vec::with_capacity(sentences.len());
190
191        // Process in batches
192        for batch in sentences.chunks(batch_size) {
193            // Truncate each sentence to max_length * median_token_length chars
194            let truncated: Vec<&str> = batch
195                .iter()
196                .map(|text| {
197                    max_length
198                        .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
199                        .unwrap_or(text.as_str())
200                })
201                .collect();
202
203            // Tokenize the batch
204            let encodings = self
205                .tokenizer
206                .encode_batch_fast::<String>(
207                    // Into<EncodeInput>
208                    truncated.into_iter().map(Into::into).collect(),
209                    /* add_special_tokens = */ false,
210                )
211                .expect("tokenization failed");
212
213            // Pool each token-ID list into a single mean vector
214            for encoding in encodings {
215                let mut token_ids = encoding.get_ids().to_vec();
216                // Remove unk tokens if specified
217                if let Some(unk_id) = self.unk_token_id {
218                    token_ids.retain(|&id| id as usize != unk_id);
219                }
220                // Truncate to max_length if specified
221                if let Some(max_tok) = max_length {
222                    token_ids.truncate(max_tok);
223                }
224                embeddings.push(self.pool_ids(token_ids));
225            }
226        }
227
228        embeddings
229    }
230
231    /// Default encode: `max_length=512`, `batch_size=1024`
232    pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
233        self.encode_with_args(sentences, Some(512), 1024)
234    }
235
236    // / Encode a single sentence into a vector
237    pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
238        self.encode(&[sentence.to_string()])
239            .into_iter()
240            .next()
241            .unwrap_or_default()
242    }
243
244    /// Mean-pool a single token-ID list into a vector
245    fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
246        let dim = self.embeddings.ncols();
247        let mut sum = vec![0.0; dim];
248        let mut cnt = 0usize;
249
250        for &id in &ids {
251            let tok = id as usize;
252
253            // Remap: row = token_mapping[id] or id
254            let row_idx = if let Some(m) = &self.token_mapping {
255                *m.get(tok).unwrap_or(&tok)
256            } else {
257                tok
258            };
259
260            // Scale by per-token weight if present
261            let scale = if let Some(w) = &self.weights {
262                *w.get(tok).unwrap_or(&1.0)
263            } else {
264                1.0
265            };
266
267            let row = self.embeddings.row(row_idx);
268            for (i, &v) in row.iter().enumerate() {
269                sum[i] += v * scale;
270            }
271            cnt += 1;
272        }
273
274        // Mean pool the embeddings
275        let denom = (cnt.max(1)) as f32;
276        for x in &mut sum {
277            *x /= denom;
278        }
279
280        // Normalize the embeddings if required
281        if self.normalize {
282            let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
283            for x in &mut sum {
284                *x /= norm;
285            }
286        }
287        sum
288    }
289}