1use anyhow::{anyhow, Context, Result};
2use half::f16;
3#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
4use hf_hub::api::sync::{Api, ApiRepo};
5use ndarray::{Array2, ArrayView2, CowArray, Ix2};
6use safetensors::{tensor::Dtype, SafeTensors};
7use serde_json::Value;
8use std::borrow::Cow;
9#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
10use std::env;
11use std::{
12 fs,
13 path::{Path, PathBuf},
14};
15use tokenizers::Tokenizer;
16
17#[derive(Debug, Clone)]
19pub struct StaticModel {
20 tokenizer: Tokenizer,
21 embeddings: CowArray<'static, f32, Ix2>,
22 weights: Option<Cow<'static, [f32]>>,
23 token_mapping: Option<Cow<'static, [usize]>>,
24 normalize: bool,
25 median_token_length: usize,
26 unk_token_id: Option<usize>,
27}
28
29#[derive(Debug, Clone)]
30struct ModelFiles {
31 tokenizer: PathBuf,
32 model: PathBuf,
33 config: PathBuf,
34}
35
36fn match_local_layout(config_base: &Path, model_base: &Path, config_file: &str) -> Option<ModelFiles> {
37 let config = config_base.join(config_file);
38 let tokenizer = model_base.join("tokenizer.json");
39 let model = model_base.join("model.safetensors");
40 (config.exists() && tokenizer.exists() && model.exists()).then_some(ModelFiles {
41 tokenizer,
42 model,
43 config,
44 })
45}
46
47fn decode_token_mapping(dtype: Dtype, raw: &[u8]) -> Result<Vec<usize>> {
48 let mapping = match dtype {
49 Dtype::I64 => raw
50 .chunks_exact(8)
51 .map(|b| i64::from_le_bytes(b.try_into().unwrap()) as usize)
52 .collect(),
53 Dtype::I32 => raw
54 .chunks_exact(4)
55 .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
56 .collect(),
57 other => return Err(anyhow!("unsupported mapping dtype: {:?}", other)),
58 };
59
60 Ok(mapping)
61}
62
63#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
64fn is_not_found(e: &hf_hub::api::sync::ApiError) -> bool {
65 use hf_hub::api::sync::ApiError;
66
67 matches!(e, ApiError::RequestError(e) if matches!(e.as_ref(), ureq::Error::Status(404, _)))
68}
69
70#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
71fn match_hub_layout(
72 repo: &ApiRepo,
73 config_prefix: &str,
74 model_prefix: &str,
75 config_file: &str,
76) -> Result<Option<ModelFiles>> {
77 let fetch = |path: String| -> Result<Option<PathBuf>> {
78 match repo.get(&path) {
79 Ok(p) => Ok(Some(p)),
80 Err(e) if is_not_found(&e) => Ok(None),
81 Err(e) => Err(e.into()),
82 }
83 };
84 let Some(config) = fetch(format!("{config_prefix}{config_file}"))? else {
85 return Ok(None);
86 };
87 let Some(tokenizer) = fetch(format!("{model_prefix}tokenizer.json"))? else {
88 return Ok(None);
89 };
90 let Some(model) = fetch(format!("{model_prefix}model.safetensors"))? else {
91 return Ok(None);
92 };
93 Ok(Some(ModelFiles {
94 tokenizer,
95 model,
96 config,
97 }))
98}
99
100fn resolve_local_model_files(folder: &Path) -> Option<ModelFiles> {
101 match_local_layout(folder, folder, "config.json")
102 .or_else(|| match_local_layout(folder, folder, "config_sentence_transformers.json"))
103 .or_else(|| {
104 match_local_layout(
105 folder,
106 &folder.join("0_StaticEmbedding"),
107 "config_sentence_transformers.json",
108 )
109 })
110 .or_else(|| {
111 folder
112 .parent()
113 .and_then(|p| match_local_layout(p, folder, "config_sentence_transformers.json"))
114 })
115}
116
117#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
118fn resolve_hub_model_files(repo: &ApiRepo, prefix: &str) -> Result<ModelFiles> {
119 let sub_prefix = format!("{prefix}0_StaticEmbedding/");
120 let trimmed = prefix.trim_end_matches('/');
121 let parent = match Path::new(trimmed).parent() {
122 Some(path) if !path.as_os_str().is_empty() => format!("{}/", path.display()),
123 _ => String::new(),
124 };
125
126 if let Some(f) = match_hub_layout(repo, prefix, prefix, "config.json")? {
127 return Ok(f);
128 }
129 if let Some(f) = match_hub_layout(repo, prefix, prefix, "config_sentence_transformers.json")? {
130 return Ok(f);
131 }
132 if let Some(f) = match_hub_layout(repo, prefix, &sub_prefix, "config_sentence_transformers.json")? {
133 return Ok(f);
134 }
135 match_hub_layout(repo, &parent, prefix, "config_sentence_transformers.json")?
136 .ok_or_else(|| anyhow!("no valid model layout found in '{prefix}'"))
137}
138
139impl StaticModel {
140 pub fn from_bytes<T, M, C>(
145 tokenizer_bytes: T,
146 model_bytes: M,
147 config_bytes: C,
148 normalize: Option<bool>,
149 ) -> Result<Self>
150 where
151 T: AsRef<[u8]>,
152 M: AsRef<[u8]>,
153 C: AsRef<[u8]>,
154 {
155 let tokenizer = Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
156
157 let cfg: Value = serde_json::from_slice(config_bytes.as_ref()).context("failed to parse config.json")?;
159 let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
160 let normalize = normalize.unwrap_or(cfg_norm);
161
162 let safet = SafeTensors::deserialize(model_bytes.as_ref()).context("failed to parse safetensors")?;
164 let tensor = safet
165 .tensor("embeddings")
166 .or_else(|_| safet.tensor("0"))
167 .or_else(|_| safet.tensor("embedding.weight"))
168 .context("embeddings tensor not found")?;
169
170 let [rows, cols]: [usize; 2] = tensor.shape().try_into().context("embedding tensor is not 2-D")?;
171 let raw = tensor.data();
172 let floats: Vec<f32> = match tensor.dtype() {
173 Dtype::F32 => raw
174 .chunks_exact(4)
175 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
176 .collect(),
177 Dtype::F16 => raw
178 .chunks_exact(2)
179 .map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
180 .collect(),
181 Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
182 other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
183 };
184
185 let weights = match safet.tensor("weights") {
186 Ok(t) => {
187 let raw = t.data();
188 let v: Vec<f32> = match t.dtype() {
189 Dtype::F64 => raw
190 .chunks_exact(8)
191 .map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
192 .collect(),
193 Dtype::F32 => raw
194 .chunks_exact(4)
195 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
196 .collect(),
197 Dtype::F16 => raw
198 .chunks_exact(2)
199 .map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
200 .collect(),
201 other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
202 };
203 Some(v)
204 }
205 Err(_) => None,
206 };
207
208 let token_mapping = match safet.tensor("mapping") {
209 Ok(t) => Some(decode_token_mapping(t.dtype(), t.data())?),
210 Err(_) => None,
211 };
212
213 Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
214 }
215
216 pub fn from_pretrained<P: AsRef<Path>>(
224 repo_or_path: P,
225 token: Option<&str>,
226 normalize: Option<bool>,
227 subfolder: Option<&str>,
228 ) -> Result<Self> {
229 let files = resolve_model_files(repo_or_path, token, subfolder)?;
230 let tokenizer_bytes = fs::read(&files.tokenizer).context("failed to read tokenizer.json")?;
231 let model_bytes = fs::read(&files.model).context("failed to read model.safetensors")?;
232 let config_bytes = fs::read(&files.config).context("failed to read config.json")?;
233 Self::from_bytes(tokenizer_bytes, model_bytes, config_bytes, normalize)
234 }
235
236 pub fn from_owned(
247 tokenizer: Tokenizer,
248 embeddings: Vec<f32>,
249 rows: usize,
250 cols: usize,
251 normalize: bool,
252 weights: Option<Vec<f32>>,
253 token_mapping: Option<Vec<usize>>,
254 ) -> Result<Self> {
255 if embeddings.len() != rows * cols {
256 return Err(anyhow!(
257 "embeddings length {} != rows {} * cols {}",
258 embeddings.len(),
259 rows,
260 cols
261 ));
262 }
263 let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
264 let embeddings =
265 Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?;
266 Ok(Self {
267 tokenizer,
268 embeddings: CowArray::from(embeddings),
269 weights: weights.map(Cow::Owned),
270 token_mapping: token_mapping.map(Cow::Owned),
271 normalize,
272 median_token_length,
273 unk_token_id,
274 })
275 }
276
277 #[allow(dead_code)] pub fn from_borrowed(
289 tokenizer: Tokenizer,
290 embeddings: &'static [f32],
291 rows: usize,
292 cols: usize,
293 normalize: bool,
294 weights: Option<&'static [f32]>,
295 token_mapping: Option<&'static [usize]>,
296 ) -> Result<Self> {
297 if embeddings.len() != rows * cols {
298 return Err(anyhow!(
299 "embeddings length {} != rows {} * cols {}",
300 embeddings.len(),
301 rows,
302 cols
303 ));
304 }
305 let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
306 let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?;
307 Ok(Self {
308 tokenizer,
309 embeddings: CowArray::from(embeddings),
310 weights: weights.map(Cow::Borrowed),
311 token_mapping: token_mapping.map(Cow::Borrowed),
312 normalize,
313 median_token_length,
314 unk_token_id,
315 })
316 }
317
318 fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
320 let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
321 lens.sort_unstable();
322 let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
323
324 let spec: Value = serde_json::to_value(tokenizer).context("failed to serialize tokenizer")?;
325 let unk_token = spec
326 .get("model")
327 .and_then(|m| m.get("unk_token"))
328 .and_then(Value::as_str);
329 let unk_token_id = if let Some(tok) = unk_token {
330 let id = tokenizer
331 .token_to_id(tok)
332 .ok_or_else(|| anyhow!("unk_token '{tok}' not found in vocabulary"))?;
333 Some(id as usize)
334 } else {
335 None
336 };
337
338 Ok((median_token_length, unk_token_id))
339 }
340
341 fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
343 s.char_indices()
344 .nth(max_tokens.saturating_mul(median_len))
345 .map_or(s, |(byte_idx, _)| &s[..byte_idx])
346 }
347
348 pub fn encode_with_args(
355 &self,
356 sentences: &[String],
357 max_length: Option<usize>,
358 batch_size: usize,
359 ) -> Vec<Vec<f32>> {
360 let mut embeddings = Vec::with_capacity(sentences.len());
361 for batch in sentences.chunks(batch_size) {
362 let truncated: Vec<&str> = batch
363 .iter()
364 .map(|text| {
365 max_length
366 .map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
367 .unwrap_or(text.as_str())
368 })
369 .collect();
370 let encodings = self
371 .tokenizer
372 .encode_batch_fast::<String>(truncated.into_iter().map(Into::into).collect(), false)
373 .expect("tokenization failed");
374 for encoding in encodings {
375 let mut token_ids = encoding.get_ids().to_vec();
376 if let Some(unk_id) = self.unk_token_id {
377 token_ids.retain(|&id| id as usize != unk_id);
378 }
379 if let Some(max_tok) = max_length {
380 token_ids.truncate(max_tok);
381 }
382 embeddings.push(self.pool_ids(token_ids));
383 }
384 }
385 embeddings
386 }
387
388 pub fn encode(&self, sentences: &[String]) -> Vec<Vec<f32>> {
390 self.encode_with_args(sentences, Some(512), 1024)
391 }
392
393 pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
395 self.encode(&[sentence.to_string()])
396 .into_iter()
397 .next()
398 .unwrap_or_default()
399 }
400
401 fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
403 let dim = self.embeddings.ncols();
404 let mut sum = vec![0.0_f32; dim];
405 let mut cnt = 0usize;
406 for &id in &ids {
407 let tok = id as usize;
408 let row_idx = self
409 .token_mapping
410 .as_ref()
411 .and_then(|m| m.get(tok))
412 .copied()
413 .unwrap_or(tok);
414 let scale = self.weights.as_ref().and_then(|w| w.get(tok)).copied().unwrap_or(1.0);
415 let row = self.embeddings.row(row_idx);
416 for (s, &v) in sum.iter_mut().zip(row.iter()) {
417 *s += v * scale;
418 }
419 cnt += 1;
420 }
421 let denom = cnt.max(1) as f32;
422 for x in &mut sum {
423 *x /= denom;
424 }
425 if self.normalize {
426 let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
427 for x in &mut sum {
428 *x /= norm;
429 }
430 }
431 sum
432 }
433}
434
435fn resolve_model_files<P: AsRef<Path>>(
436 repo_or_path: P,
437 token: Option<&str>,
438 subfolder: Option<&str>,
439) -> Result<ModelFiles> {
440 #[cfg(any(not(feature = "hf-hub"), feature = "local-only"))]
441 let _ = token;
442
443 let base = repo_or_path.as_ref();
444 if base.exists() {
445 let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
446 return resolve_local_model_files(&folder).ok_or_else(|| {
447 anyhow!(
448 "no valid model layout found in {folder:?}. \
449 Tried: model2vec (config.json), sentence-transformers \
450 (config_sentence_transformers.json), and 0_StaticEmbedding subfolder."
451 )
452 });
453 }
454
455 #[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
456 {
457 download_model_files(repo_or_path.as_ref().to_string_lossy().as_ref(), token, subfolder)
458 }
459 #[cfg(feature = "local-only")]
460 {
461 Err(anyhow!(
462 "remote model downloads are disabled by the `local-only` feature; pass a local model directory instead"
463 ))
464 }
465 #[cfg(all(not(feature = "hf-hub"), not(feature = "local-only")))]
466 {
467 Err(anyhow!(
468 "remote model downloads require the `hf-hub` feature; pass a local model directory instead"
469 ))
470 }
471}
472
473#[cfg(all(feature = "hf-hub", not(feature = "local-only")))]
474fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&str>) -> Result<ModelFiles> {
475 let previous = token.and_then(|_| env::var_os("HF_HUB_TOKEN"));
476 if let Some(tok) = token {
477 env::set_var("HF_HUB_TOKEN", tok);
478 }
479
480 let result = (|| {
481 let api = Api::new().context("hf-hub API init failed")?;
482 let repo = api.model(repo_id.to_owned());
483 let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default();
484 resolve_hub_model_files(&repo, &prefix)
485 .with_context(|| format!("could not load '{repo_id}' from HuggingFace Hub"))
486 })();
487
488 if token.is_some() {
489 if let Some(value) = previous {
490 env::set_var("HF_HUB_TOKEN", value);
491 } else {
492 env::remove_var("HF_HUB_TOKEN");
493 }
494 }
495
496 result
497}
498
499#[cfg(test)]
500mod tests {
501 use super::decode_token_mapping;
502 use safetensors::tensor::Dtype;
503
504 #[test]
505 fn decode_token_mapping_supports_i32_and_i64() {
506 let i32_raw = [1i32, 2, 3]
507 .into_iter()
508 .flat_map(|value| value.to_le_bytes())
509 .collect::<Vec<_>>();
510 let i64_raw = [4i64, 5, 6]
511 .into_iter()
512 .flat_map(|value| value.to_le_bytes())
513 .collect::<Vec<_>>();
514
515 assert_eq!(decode_token_mapping(Dtype::I32, &i32_raw).unwrap(), vec![1, 2, 3]);
516 assert_eq!(decode_token_mapping(Dtype::I64, &i64_raw).unwrap(), vec![4, 5, 6]);
517 }
518
519 #[test]
520 fn decode_token_mapping_rejects_unsupported_dtype() {
521 let err = decode_token_mapping(Dtype::F32, &[0, 0, 0, 0]).unwrap_err();
522 assert!(err.to_string().contains("unsupported mapping dtype"));
523 }
524}