1use anyhow::{anyhow, Context, Result};
2use half::f16;
3use hf_hub::api::sync::Api;
4use ndarray::Array2;
5use safetensors::{tensor::Dtype, SafeTensors};
6use serde_json::Value;
7use std::{env, fs, path::Path};
8use tokenizers::Tokenizer;
9
10pub struct StaticModel {
12 tokenizer: Tokenizer,
13 embeddings: Array2<f32>,
14 normalize: bool,
15 median_token_length: usize,
16 unk_token_id: Option<usize>,
17}
18
19impl StaticModel {
20 pub fn from_pretrained<P: AsRef<Path>>(
28 repo_or_path: P,
29 token: Option<&str>,
30 normalize: Option<bool>,
31 subfolder: Option<&str>,
32 ) -> Result<Self> {
33 if let Some(tok) = token {
35 env::set_var("HF_HUB_TOKEN", tok);
36 }
37
38 let (tok_path, mdl_path, cfg_path) = {
40 let base = repo_or_path.as_ref();
41 if base.exists() {
42 let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
43 let t = folder.join("tokenizer.json");
44 let m = folder.join("model.safetensors");
45 let c = folder.join("config.json");
46 if !t.exists() || !m.exists() || !c.exists() {
47 return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
48 }
49 (t, m, c)
50 } else {
51 let api = Api::new().context("hf-hub API init failed")?;
52 let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
53 let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
54 let t = repo.get(&format!("{prefix}tokenizer.json"))?;
55 let m = repo.get(&format!("{prefix}model.safetensors"))?;
56 let c = repo.get(&format!("{prefix}config.json"))?;
57 (t, m, c)
58 }
59 };
60
61 let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
63
64 let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
66 lens.sort_unstable();
67 let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
68
69 let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
71 let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
72 let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
73 let normalize = normalize.unwrap_or(cfg_norm);
74
75 let spec_json = tokenizer
77 .to_string(false)
78 .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
79 let spec: Value = serde_json::from_str(&spec_json)?;
80 let unk_token = spec
81 .get("model")
82 .and_then(|m| m.get("unk_token"))
83 .and_then(Value::as_str)
84 .unwrap_or("[UNK]");
85 let unk_token_id = tokenizer
86 .token_to_id(unk_token)
87 .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
88 as usize;
89
90 let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
92 let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
93 let tensor = safet
94 .tensor("embeddings")
95 .or_else(|_| safet.tensor("0"))
96 .context("embeddings tensor not found")?;
97
98 let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2‑D")?;
99 let raw = tensor.data();
100 let dtype = tensor.dtype();
101
102 let floats: Vec<f32> = match dtype {
104 Dtype::F32 => raw
105 .chunks_exact(4)
106 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
107 .collect(),
108 Dtype::F16 => raw
109 .chunks_exact(2)
110 .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
111 .collect(),
112 Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
113 other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
114 };
115 let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;
116
117 Ok(Self {
118 tokenizer,
119 embeddings,
120 normalize,
121 median_token_length,
122 unk_token_id: Some(unk_token_id),
123 })
124 }
125
126 fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
128 let max_chars = max_tokens.saturating_mul(median_len);
129 match s.char_indices().nth(max_chars) {
130 Some((byte_idx, _)) => &s[..byte_idx],
131 None => s,
132 }
133 }
134
135 pub fn encode_with_args(
142 &self,
143 sentences: &[String],
144 max_length: Option<usize>,
145 batch_size: usize,
146 ) -> Vec<Vec<f32>> {
147 let mut embeddings = Vec::with_capacity(sentences.len());
148
149 for batch in sentences.chunks(batch_size) {
151 let truncated: Vec<&str> = batch
153 .iter()
154 .map(|text| {
155 max_length
156 .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
157 .unwrap_or(text.as_str())
158 })
159 .collect();
160
161 let encodings = self
163 .tokenizer
164 .encode_batch_fast::<String>(
165 truncated.into_iter().map(Into::into).collect(),
167 false,
168 )
169 .expect("tokenization failed");
170
171 for encoding in encodings {
173 let mut token_ids = encoding.get_ids().to_vec();
174 if let Some(unk_id) = self.unk_token_id {
176 token_ids.retain(|&id| id as usize != unk_id);
177 }
178 if let Some(max_tok) = max_length {
180 token_ids.truncate(max_tok);
181 }
182 embeddings.push(self.pool_ids(token_ids));
183 }
184 }
185
186 embeddings
187 }
188
189 pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
191 self.encode_with_args(sentences, Some(512), 1024)
192 }
193
194 pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
196 self.encode(&[sentence.to_string()])
197 .into_iter()
198 .next()
199 .unwrap_or_default()
200 }
201
202 fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
204 let mut sum = vec![0.0; self.embeddings.ncols()];
205 for &id in &ids {
206 let row = self.embeddings.row(id as usize);
207 for (i, &v) in row.iter().enumerate() {
208 sum[i] += v;
209 }
210 }
211 let cnt = ids.len().max(1) as f32;
212 sum.iter_mut().for_each(|x| *x /= cnt);
213 if self.normalize {
214 let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
215 sum.iter_mut().for_each(|x| *x /= norm);
216 }
217 sum
218 }
219}