1use anyhow::{anyhow, Context, Result};
2use half::f16;
3#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
4use hf_hub::api::sync::{Api, ApiRepo};
5use ndarray::{Array2, ArrayView2, CowArray, Ix2};
6use safetensors::{tensor::Dtype, SafeTensors};
7use serde_json::Value;
8use std::borrow::Cow;
9#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
10use std::env;
11use std::{
12 fs,
13 path::{Path, PathBuf},
14};
15use tokenizers::Tokenizer;
16
17#[derive(Debug, Clone)]
19pub struct StaticModel {
20 tokenizer: Tokenizer,
21 embeddings: CowArray<'static, f32, Ix2>,
22 weights: Option<Cow<'static, [f32]>>,
23 token_mapping: Option<Cow<'static, [usize]>>,
24 normalize: bool,
25 median_token_length: usize,
26 unk_token_id: Option<usize>,
27}
28
29#[derive(Debug, Clone)]
30struct ModelFiles {
31 tokenizer: PathBuf,
32 model: PathBuf,
33 config: PathBuf,
34}
35
36fn match_local_layout(config_base: &Path, model_base: &Path, config_file: &str) -> Option<ModelFiles> {
37 let config = config_base.join(config_file);
38 let tokenizer = model_base.join("tokenizer.json");
39 let model = model_base.join("model.safetensors");
40 (config.exists() && tokenizer.exists() && model.exists()).then_some(ModelFiles {
41 tokenizer,
42 model,
43 config,
44 })
45}
46
47#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
48fn is_not_found(e: &hf_hub::api::sync::ApiError) -> bool {
49 use hf_hub::api::sync::ApiError;
50
51 matches!(e, ApiError::RequestError(e) if matches!(e.as_ref(), ureq::Error::Status(404, _)))
52}
53
54#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
55fn match_hub_layout(
56 repo: &ApiRepo,
57 config_prefix: &str,
58 model_prefix: &str,
59 config_file: &str,
60) -> Result<Option<ModelFiles>> {
61 let fetch = |path: String| -> Result<Option<PathBuf>> {
62 match repo.get(&path) {
63 Ok(p) => Ok(Some(p)),
64 Err(e) if is_not_found(&e) => Ok(None),
65 Err(e) => Err(e.into()),
66 }
67 };
68 let Some(config) = fetch(format!("{config_prefix}{config_file}"))? else {
69 return Ok(None);
70 };
71 let Some(tokenizer) = fetch(format!("{model_prefix}tokenizer.json"))? else {
72 return Ok(None);
73 };
74 let Some(model) = fetch(format!("{model_prefix}model.safetensors"))? else {
75 return Ok(None);
76 };
77 Ok(Some(ModelFiles {
78 tokenizer,
79 model,
80 config,
81 }))
82}
83
84fn resolve_local_model_files(folder: &Path) -> Option<ModelFiles> {
85 match_local_layout(folder, folder, "config.json")
86 .or_else(|| match_local_layout(folder, folder, "config_sentence_transformers.json"))
87 .or_else(|| {
88 match_local_layout(
89 folder,
90 &folder.join("0_StaticEmbedding"),
91 "config_sentence_transformers.json",
92 )
93 })
94 .or_else(|| {
95 folder
96 .parent()
97 .and_then(|p| match_local_layout(p, folder, "config_sentence_transformers.json"))
98 })
99}
100
101#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
102fn resolve_hub_model_files(repo: &ApiRepo, prefix: &str) -> Result<ModelFiles> {
103 let sub_prefix = format!("{prefix}0_StaticEmbedding/");
104 let trimmed = prefix.trim_end_matches('/');
105 let parent = match Path::new(trimmed).parent() {
106 Some(path) if !path.as_os_str().is_empty() => format!("{}/", path.display()),
107 _ => String::new(),
108 };
109
110 if let Some(f) = match_hub_layout(repo, prefix, prefix, "config.json")? {
111 return Ok(f);
112 }
113 if let Some(f) = match_hub_layout(repo, prefix, prefix, "config_sentence_transformers.json")? {
114 return Ok(f);
115 }
116 if let Some(f) = match_hub_layout(repo, prefix, &sub_prefix, "config_sentence_transformers.json")? {
117 return Ok(f);
118 }
119 match_hub_layout(repo, &parent, prefix, "config_sentence_transformers.json")?
120 .ok_or_else(|| anyhow!("no valid model layout found in '{prefix}'"))
121}
122
123impl StaticModel {
124 pub fn from_bytes<T, M, C>(
129 tokenizer_bytes: T,
130 model_bytes: M,
131 config_bytes: C,
132 normalize: Option<bool>,
133 ) -> Result<Self>
134 where
135 T: AsRef<[u8]>,
136 M: AsRef<[u8]>,
137 C: AsRef<[u8]>,
138 {
139 let tokenizer = Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
140
141 let cfg: Value = serde_json::from_slice(config_bytes.as_ref()).context("failed to parse config.json")?;
143 let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
144 let normalize = normalize.unwrap_or(cfg_norm);
145
146 let safet = SafeTensors::deserialize(model_bytes.as_ref()).context("failed to parse safetensors")?;
148 let tensor = safet
149 .tensor("embeddings")
150 .or_else(|_| safet.tensor("0"))
151 .or_else(|_| safet.tensor("embedding.weight"))
152 .context("embeddings tensor not found")?;
153
154 let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2-D")?;
155 let raw = tensor.data();
156 let floats: Vec<f32> = match tensor.dtype() {
157 Dtype::F32 => raw
158 .chunks_exact(4)
159 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
160 .collect(),
161 Dtype::F16 => raw
162 .chunks_exact(2)
163 .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
164 .collect(),
165 Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
166 other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
167 };
168
169 let weights = match safet.tensor("weights") {
170 Ok(t) => {
171 let raw = t.data();
172 let v: Vec<f32> = match t.dtype() {
173 Dtype::F64 => raw
174 .chunks_exact(8)
175 .map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
176 .collect(),
177 Dtype::F32 => raw
178 .chunks_exact(4)
179 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
180 .collect(),
181 Dtype::F16 => raw
182 .chunks_exact(2)
183 .map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
184 .collect(),
185 other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
186 };
187 Some(v)
188 }
189 Err(_) => None,
190 };
191
192 let token_mapping = match safet.tensor("mapping") {
193 Ok(t) => {
194 let raw = t.data();
195 let v: Vec<usize> = raw
196 .chunks_exact(4)
197 .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
198 .collect();
199 Some(v)
200 }
201 Err(_) => None,
202 };
203
204 Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
205 }
206
207 pub fn from_pretrained<P: AsRef<Path>>(
215 repo_or_path: P,
216 token: Option<&str>,
217 normalize: Option<bool>,
218 subfolder: Option<&str>,
219 ) -> Result<Self> {
220 let files = resolve_model_files(repo_or_path, token, subfolder)?;
221 let tokenizer_bytes = fs::read(&files.tokenizer).context("failed to read tokenizer.json")?;
222 let model_bytes = fs::read(&files.model).context("failed to read model.safetensors")?;
223 let config_bytes = fs::read(&files.config).context("failed to read config.json")?;
224 Self::from_bytes(tokenizer_bytes, model_bytes, config_bytes, normalize)
225 }
226
227 pub fn from_owned(
238 tokenizer: Tokenizer,
239 embeddings: Vec<f32>,
240 rows: usize,
241 cols: usize,
242 normalize: bool,
243 weights: Option<Vec<f32>>,
244 token_mapping: Option<Vec<usize>>,
245 ) -> Result<Self> {
246 if embeddings.len() != rows * cols {
247 return Err(anyhow!(
248 "embeddings length {} != rows {} * cols {}",
249 embeddings.len(),
250 rows,
251 cols
252 ));
253 }
254 let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
255 let embeddings =
256 Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?;
257 Ok(Self {
258 tokenizer,
259 embeddings: CowArray::from(embeddings),
260 weights: weights.map(Cow::Owned),
261 token_mapping: token_mapping.map(Cow::Owned),
262 normalize,
263 median_token_length,
264 unk_token_id,
265 })
266 }
267
268 #[allow(dead_code)] pub fn from_borrowed(
280 tokenizer: Tokenizer,
281 embeddings: &'static [f32],
282 rows: usize,
283 cols: usize,
284 normalize: bool,
285 weights: Option<&'static [f32]>,
286 token_mapping: Option<&'static [usize]>,
287 ) -> Result<Self> {
288 if embeddings.len() != rows * cols {
289 return Err(anyhow!(
290 "embeddings length {} != rows {} * cols {}",
291 embeddings.len(),
292 rows,
293 cols
294 ));
295 }
296 let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
297 let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?;
298 Ok(Self {
299 tokenizer,
300 embeddings: CowArray::from(embeddings),
301 weights: weights.map(Cow::Borrowed),
302 token_mapping: token_mapping.map(Cow::Borrowed),
303 normalize,
304 median_token_length,
305 unk_token_id,
306 })
307 }
308
309 fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
311 let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
312 lens.sort_unstable();
313 let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
314
315 let spec: Value = serde_json::to_value(tokenizer).context("failed to serialize tokenizer")?;
316 let unk_token = spec
317 .get("model")
318 .and_then(|m| m.get("unk_token"))
319 .and_then(Value::as_str);
320 let unk_token_id = if let Some(tok) = unk_token {
321 let id = tokenizer
322 .token_to_id(tok)
323 .ok_or_else(|| anyhow!("unk_token '{tok}' not found in vocabulary"))?;
324 Some(id as usize)
325 } else {
326 None
327 };
328
329 Ok((median_token_length, unk_token_id))
330 }
331
332 fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
334 s.char_indices()
335 .nth(max_tokens.saturating_mul(median_len))
336 .map_or(s, |(byte_idx, _)| &s[..byte_idx])
337 }
338
339 pub fn encode_with_args(
346 &self,
347 sentences: &[String],
348 max_length: Option<usize>,
349 batch_size: usize,
350 ) -> Vec<Vec<f32>> {
351 let mut embeddings = Vec::with_capacity(sentences.len());
352 for batch in sentences.chunks(batch_size) {
353 let truncated: Vec<&str> = batch
354 .iter()
355 .map(|text| {
356 max_length
357 .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
358 .unwrap_or(text.as_str())
359 })
360 .collect();
361 let encodings = self
362 .tokenizer
363 .encode_batch_fast::<String>(truncated.into_iter().map(Into::into).collect(), false)
364 .expect("tokenization failed");
365 for encoding in encodings {
366 let mut token_ids = encoding.get_ids().to_vec();
367 if let Some(unk_id) = self.unk_token_id {
368 token_ids.retain(|&id| id as usize != unk_id);
369 }
370 if let Some(max_tok) = max_length {
371 token_ids.truncate(max_tok);
372 }
373 embeddings.push(self.pool_ids(token_ids));
374 }
375 }
376 embeddings
377 }
378
379 pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
381 self.encode_with_args(sentences, Some(512), 1024)
382 }
383
384 pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
386 self.encode(&[sentence.to_string()])
387 .into_iter()
388 .next()
389 .unwrap_or_default()
390 }
391
392 fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
394 let dim = self.embeddings.ncols();
395 let mut sum = vec![0.0_f32; dim];
396 let mut cnt = 0usize;
397 for &id in &ids {
398 let tok = id as usize;
399 let row_idx = self
400 .token_mapping
401 .as_ref()
402 .and_then(|m| m.get(tok))
403 .copied()
404 .unwrap_or(tok);
405 let scale = self.weights.as_ref().and_then(|w| w.get(tok)).copied().unwrap_or(1.0);
406 let row = self.embeddings.row(row_idx);
407 for (s, &v) in sum.iter_mut().zip(row.iter()) {
408 *s += v * scale;
409 }
410 cnt += 1;
411 }
412 let denom = cnt.max(1) as f32;
413 for x in &mut sum {
414 *x /= denom;
415 }
416 if self.normalize {
417 let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
418 for x in &mut sum {
419 *x /= norm;
420 }
421 }
422 sum
423 }
424}
425
426fn resolve_model_files<P: AsRef<Path>>(
427 repo_or_path: P,
428 token: Option<&str>,
429 subfolder: Option<&str>,
430) -> Result<ModelFiles> {
431 #[cfg(any(not(feature = "hf-hub"), feature = "local-only"))]
432 let _ = token;
433
434 let base = repo_or_path.as_ref();
435 if base.exists() {
436 let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
437 return resolve_local_model_files(&folder).ok_or_else(|| {
438 anyhow!(
439 "no valid model layout found in {folder:?}. \
440 Tried: model2vec (config.json), sentence-transformers \
441 (config_sentence_transformers.json), and 0_StaticEmbedding subfolder."
442 )
443 });
444 }
445
446 #[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
447 {
448 download_model_files(repo_or_path.as_ref().to_string_lossy().as_ref(), token, subfolder)
449 }
450 #[cfg(feature = "local-only")]
451 {
452 Err(anyhow!(
453 "remote model downloads are disabled by the `local-only` feature; pass a local model directory instead"
454 ))
455 }
456 #[cfg(all(not(feature = "hf-hub"), not(feature = "local-only")))]
457 {
458 Err(anyhow!(
459 "remote model downloads require the `hf-hub` feature; pass a local model directory instead"
460 ))
461 }
462}
463
464#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
465fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&str>) -> Result<ModelFiles> {
466 let previous = token.and_then(|_| env::var_os("HF_HUB_TOKEN"));
467 if let Some(tok) = token {
468 env::set_var("HF_HUB_TOKEN", tok);
469 }
470
471 let result = (|| {
472 let api = Api::new().context("hf-hub API init failed")?;
473 let repo = api.model(repo_id.to_owned());
474 let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default();
475 resolve_hub_model_files(&repo, &prefix)
476 .with_context(|| format!("could not load '{repo_id}' from HuggingFace Hub"))
477 })();
478
479 if token.is_some() {
480 if let Some(value) = previous {
481 env::set_var("HF_HUB_TOKEN", value);
482 } else {
483 env::remove_var("HF_HUB_TOKEN");
484 }
485 }
486
487 result
488}