1use anyhow::{anyhow, Context, Result};
2use half::f16;
3use hf_hub::api::sync::Api;
4use ndarray::Array2;
5use safetensors::{tensor::Dtype, SafeTensors};
6use serde_json::Value;
7use std::{env, fs, path::Path};
8use tokenizers::Tokenizer;
9
10#[derive(Debug, Clone)]
12pub struct StaticModel {
13 tokenizer: Tokenizer,
14 embeddings: Array2<f32>,
15 weights: Option<Vec<f32>>,
16 token_mapping: Option<Vec<usize>>,
17 normalize: bool,
18 median_token_length: usize,
19 unk_token_id: Option<usize>,
20}
21
22impl StaticModel {
23 pub fn from_pretrained<P: AsRef<Path>>(
31 repo_or_path: P,
32 token: Option<&str>,
33 normalize: Option<bool>,
34 subfolder: Option<&str>,
35 ) -> Result<Self> {
36 if let Some(tok) = token {
38 env::set_var("HF_HUB_TOKEN", tok);
39 }
40
41 let (tok_path, mdl_path, cfg_path) = {
43 let base = repo_or_path.as_ref();
44 if base.exists() {
45 let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
46 let t = folder.join("tokenizer.json");
47 let m = folder.join("model.safetensors");
48 let c = folder.join("config.json");
49 if !t.exists() || !m.exists() || !c.exists() {
50 return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
51 }
52 (t, m, c)
53 } else {
54 let api = Api::new().context("hf-hub API init failed")?;
55 let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
56 let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
57 let t = repo.get(&format!("{prefix}tokenizer.json"))?;
58 let m = repo.get(&format!("{prefix}model.safetensors"))?;
59 let c = repo.get(&format!("{prefix}config.json"))?;
60 (t, m, c)
61 }
62 };
63
64 let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
66
67 let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
69 lens.sort_unstable();
70 let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
71
72 let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
74 let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
75 let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
76 let normalize = normalize.unwrap_or(cfg_norm);
77
78 let spec_json = tokenizer
80 .to_string(false)
81 .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
82 let spec: Value = serde_json::from_str(&spec_json)?;
83 let unk_token = spec
84 .get("model")
85 .and_then(|m| m.get("unk_token"))
86 .and_then(Value::as_str)
87 .unwrap_or("[UNK]");
88 let unk_token_id = tokenizer
89 .token_to_id(unk_token)
90 .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
91 as usize;
92
93 let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
95 let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
96 let tensor = safet
97 .tensor("embeddings")
98 .or_else(|_| safet.tensor("0"))
99 .context("embeddings tensor not found")?;
100
101 let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2‑D")?;
102 let raw = tensor.data();
103 let dtype = tensor.dtype();
104
105 let floats: Vec<f32> = match dtype {
107 Dtype::F32 => raw
108 .chunks_exact(4)
109 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
110 .collect(),
111 Dtype::F16 => raw
112 .chunks_exact(2)
113 .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
114 .collect(),
115 Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
116 other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
117 };
118 let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;
119
120 let weights = match safet.tensor("weights") {
122 Ok(t) => {
123 let raw = t.data();
124 let v: Vec<f32> = match t.dtype() {
125 Dtype::F64 => raw
126 .chunks_exact(8)
127 .map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
128 .collect(),
129 Dtype::F32 => raw
130 .chunks_exact(4)
131 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
132 .collect(),
133 Dtype::F16 => raw
134 .chunks_exact(2)
135 .map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
136 .collect(),
137 other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
138 };
139 Some(v)
140 }
141 Err(_) => None,
142 };
143
144 let token_mapping = match safet.tensor("mapping") {
146 Ok(t) => {
147 let raw = t.data();
148 let v: Vec<usize> = raw
149 .chunks_exact(4)
150 .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
151 .collect();
152 Some(v)
153 }
154 Err(_) => None,
155 };
156
157 Ok(Self {
158 tokenizer,
159 embeddings,
160 weights,
161 token_mapping,
162 normalize,
163 median_token_length,
164 unk_token_id: Some(unk_token_id),
165 })
166 }
167
168 fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
170 let max_chars = max_tokens.saturating_mul(median_len);
171 match s.char_indices().nth(max_chars) {
172 Some((byte_idx, _)) => &s[..byte_idx],
173 None => s,
174 }
175 }
176
177 pub fn encode_with_args(
184 &self,
185 sentences: &[String],
186 max_length: Option<usize>,
187 batch_size: usize,
188 ) -> Vec<Vec<f32>> {
189 let mut embeddings = Vec::with_capacity(sentences.len());
190
191 for batch in sentences.chunks(batch_size) {
193 let truncated: Vec<&str> = batch
195 .iter()
196 .map(|text| {
197 max_length
198 .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
199 .unwrap_or(text.as_str())
200 })
201 .collect();
202
203 let encodings = self
205 .tokenizer
206 .encode_batch_fast::<String>(
207 truncated.into_iter().map(Into::into).collect(),
209 false,
210 )
211 .expect("tokenization failed");
212
213 for encoding in encodings {
215 let mut token_ids = encoding.get_ids().to_vec();
216 if let Some(unk_id) = self.unk_token_id {
218 token_ids.retain(|&id| id as usize != unk_id);
219 }
220 if let Some(max_tok) = max_length {
222 token_ids.truncate(max_tok);
223 }
224 embeddings.push(self.pool_ids(token_ids));
225 }
226 }
227
228 embeddings
229 }
230
231 pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
233 self.encode_with_args(sentences, Some(512), 1024)
234 }
235
236 pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
238 self.encode(&[sentence.to_string()])
239 .into_iter()
240 .next()
241 .unwrap_or_default()
242 }
243
244 fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
246 let dim = self.embeddings.ncols();
247 let mut sum = vec![0.0; dim];
248 let mut cnt = 0usize;
249
250 for &id in &ids {
251 let tok = id as usize;
252
253 let row_idx = if let Some(m) = &self.token_mapping {
255 *m.get(tok).unwrap_or(&tok)
256 } else {
257 tok
258 };
259
260 let scale = if let Some(w) = &self.weights {
262 *w.get(tok).unwrap_or(&1.0)
263 } else {
264 1.0
265 };
266
267 let row = self.embeddings.row(row_idx);
268 for (i, &v) in row.iter().enumerate() {
269 sum[i] += v * scale;
270 }
271 cnt += 1;
272 }
273
274 let denom = (cnt.max(1)) as f32;
276 for x in &mut sum {
277 *x /= denom;
278 }
279
280 if self.normalize {
282 let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
283 for x in &mut sum {
284 *x /= norm;
285 }
286 }
287 sum
288 }
289}