use std::path::Path;
use anyhow::{Context, Result, anyhow};
use ndarray::Array2;
use rayon::prelude::*;
use safetensors::SafeTensors;
use safetensors::tensor::Dtype;
use serde_json::Value;
use tokenizers::Tokenizer;
use wide::f32x8;
pub const DEFAULT_MAX_TOKENS: usize = 512;
const BATCH_SIZE: usize = 1024;
pub struct StaticEmbedModel {
tokenizer: Tokenizer,
embeddings: Array2<f32>,
weights: Option<Vec<f32>>,
token_mapping: Option<Vec<usize>>,
normalize: bool,
median_token_length: usize,
unk_token_id: Option<usize>,
}
impl StaticEmbedModel {
pub fn from_path(path: &Path, normalize_override: Option<bool>) -> Result<Self> {
let tokenizer_path = path.join("tokenizer.json");
let model_path = path.join("model.safetensors");
let config_path = path.join("config.json");
let tokenizer_bytes =
std::fs::read(&tokenizer_path).context("read tokenizer.json failed")?;
let model_bytes = std::fs::read(&model_path).context("read model.safetensors failed")?;
let config_bytes = std::fs::read(&config_path).context("read config.json failed")?;
Self::from_bytes(
&tokenizer_bytes,
&model_bytes,
&config_bytes,
normalize_override,
)
}
#[allow(clippy::too_many_lines)]
pub fn from_bytes(
tokenizer_bytes: &[u8],
model_bytes: &[u8],
config_bytes: &[u8],
normalize_override: Option<bool>,
) -> Result<Self> {
let mut tokenizer = Tokenizer::from_bytes(tokenizer_bytes)
.map_err(|e| anyhow!("tokenizer load failed: {e}"))?;
tokenizer.with_padding(None).with_truncation(None).ok();
let cfg: Value = serde_json::from_slice(config_bytes).context("config.json parse")?;
let cfg_norm = cfg
.get("normalize")
.and_then(Value::as_bool)
.unwrap_or(true);
let normalize = normalize_override.unwrap_or(cfg_norm);
let safet = SafeTensors::deserialize(model_bytes).context("safetensors deserialize")?;
let embed_tensor = safet
.tensor("embeddings")
.or_else(|_| safet.tensor("0"))
.or_else(|_| safet.tensor("embedding.weight"))
.map_err(|_| anyhow!("embeddings tensor not found in safetensors"))?;
let [rows, cols]: [usize; 2] = embed_tensor
.shape()
.try_into()
.map_err(|_| anyhow!("embedding tensor is not 2-D"))?;
let raw = embed_tensor.data();
let floats: Vec<f32> = match embed_tensor.dtype() {
Dtype::F32 => raw
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
Dtype::F16 => raw
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
Dtype::I8 => raw.iter().map(|&b| f32::from(b.cast_signed())).collect(),
other => return Err(anyhow!("unsupported embedding dtype: {other:?}")),
};
let embeddings = Array2::from_shape_vec((rows, cols), floats)
.context("embedding matrix shape mismatch")?;
let weights = safet.tensor("weights").ok().map(|t| {
let raw = t.data();
match t.dtype() {
Dtype::F64 => raw
.chunks_exact(8)
.map(|b| {
#[expect(
clippy::cast_possible_truncation,
reason = "weights are bounded; f32 precision is sufficient downstream"
)]
let v = f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
as f32;
v
})
.collect::<Vec<f32>>(),
Dtype::F32 => raw
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect::<Vec<f32>>(),
Dtype::F16 => raw
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect::<Vec<f32>>(),
_ => Vec::new(),
}
});
let token_mapping = safet.tensor("mapping").ok().map(|t| {
let raw = t.data();
#[expect(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
reason = "mapping values are non-negative row indices well within usize range"
)]
match t.dtype() {
Dtype::I64 => raw
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([
b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
]) as usize
})
.collect::<Vec<usize>>(),
Dtype::I32 => raw
.chunks_exact(4)
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize)
.collect::<Vec<usize>>(),
Dtype::U32 => raw
.chunks_exact(4)
.map(|b| u32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize)
.collect::<Vec<usize>>(),
Dtype::U64 => raw
.chunks_exact(8)
.map(|b| {
u64::from_le_bytes([
b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
]) as usize
})
.collect::<Vec<usize>>(),
_ => Vec::new(),
}
});
let (median_token_length, unk_token_id) = compute_metadata(&tokenizer)?;
Ok(Self {
tokenizer,
embeddings,
weights,
token_mapping,
normalize,
median_token_length,
unk_token_id,
})
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.embeddings.ncols()
}
pub fn encode_query(&self, text: &str) -> Vec<f32> {
let truncated = truncate_chars(text, DEFAULT_MAX_TOKENS, self.median_token_length);
let Ok(encoding) = self.tokenizer.encode_fast(truncated, false) else {
return vec![0.0; self.hidden_dim()];
};
let ids = filter_ids(encoding.get_ids(), self.unk_token_id, DEFAULT_MAX_TOKENS);
self.pool_ids(&ids)
}
pub fn encode_batch(&self, texts: &[&str]) -> Vec<Vec<f32>> {
if texts.is_empty() {
return Vec::new();
}
let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
for chunk in texts.chunks(BATCH_SIZE) {
let truncated: Vec<String> = chunk
.iter()
.map(|t| {
truncate_chars(t, DEFAULT_MAX_TOKENS, self.median_token_length).to_string()
})
.collect();
let Ok(encodings) = self.tokenizer.encode_batch_fast::<String>(truncated, false) else {
out.extend(std::iter::repeat_n(
vec![0.0; self.hidden_dim()],
chunk.len(),
));
continue;
};
let pooled: Vec<Vec<f32>> = encodings
.par_iter()
.map(|enc| {
let ids = filter_ids(enc.get_ids(), self.unk_token_id, DEFAULT_MAX_TOKENS);
self.pool_ids(&ids)
})
.collect();
out.extend(pooled);
}
out
}
fn pool_ids(&self, ids: &[u32]) -> Vec<f32> {
let dim = self.hidden_dim();
let mut sum = vec![0.0_f32; dim];
let mut count: usize = 0;
let embeddings_slice = self
.embeddings
.as_slice()
.expect("embedding matrix is non-contiguous; static_model load invariant violated");
let nrows = self.embeddings.nrows();
for &id in ids {
let tok = id as usize;
let row_idx = self
.token_mapping
.as_deref()
.and_then(|m| m.get(tok).copied())
.unwrap_or(tok);
if row_idx >= nrows {
continue;
}
let row_start = row_idx * dim;
let row = &embeddings_slice[row_start..row_start + dim];
let scale = self
.weights
.as_deref()
.and_then(|w| w.get(tok).copied())
.unwrap_or(1.0);
#[expect(
clippy::float_cmp,
reason = "bit-exact 1.0 check is the intended fast-path gate"
)]
let no_scale = scale == 1.0;
if no_scale {
accumulate_f32x8(&mut sum, row);
} else {
accumulate_scaled_f32x8(&mut sum, row, scale);
}
count += 1;
}
let denom = count.max(1) as f32;
scale_in_place_f32x8(&mut sum, 1.0 / denom);
if self.normalize {
let norm = l2_norm_f32x8(&sum).max(1e-12);
scale_in_place_f32x8(&mut sum, 1.0 / norm);
}
sum
}
}
fn truncate_chars(s: &str, max_tokens: usize, median_len: usize) -> &str {
s.char_indices()
.nth(max_tokens.saturating_mul(median_len))
.map_or(s, |(byte_idx, _)| &s[..byte_idx])
}
fn accumulate_f32x8(acc: &mut [f32], row: &[f32]) {
debug_assert_eq!(acc.len(), row.len(), "pool dim mismatch");
let n = acc.len();
let body = n - (n % 8);
let (acc_body, acc_tail) = acc.split_at_mut(body);
let (row_body, row_tail) = row.split_at(body);
for (a_chunk, r_chunk) in acc_body.chunks_exact_mut(8).zip(row_body.chunks_exact(8)) {
let a = f32x8::from(<[f32; 8]>::try_from(&*a_chunk).unwrap());
let r = f32x8::from(<[f32; 8]>::try_from(r_chunk).unwrap());
a_chunk.copy_from_slice((a + r).as_array());
}
for (a, &r) in acc_tail.iter_mut().zip(row_tail.iter()) {
*a += r;
}
}
fn accumulate_scaled_f32x8(acc: &mut [f32], row: &[f32], scale: f32) {
debug_assert_eq!(acc.len(), row.len(), "pool dim mismatch");
let n = acc.len();
let body = n - (n % 8);
let (acc_body, acc_tail) = acc.split_at_mut(body);
let (row_body, row_tail) = row.split_at(body);
let scale_v = f32x8::splat(scale);
for (a_chunk, r_chunk) in acc_body.chunks_exact_mut(8).zip(row_body.chunks_exact(8)) {
let a = f32x8::from(<[f32; 8]>::try_from(&*a_chunk).unwrap());
let r = f32x8::from(<[f32; 8]>::try_from(r_chunk).unwrap());
a_chunk.copy_from_slice(r.mul_add(scale_v, a).as_array());
}
for (a, &r) in acc_tail.iter_mut().zip(row_tail.iter()) {
*a += r * scale;
}
}
fn scale_in_place_f32x8(v: &mut [f32], factor: f32) {
let n = v.len();
let body = n - (n % 8);
let (body_slice, tail) = v.split_at_mut(body);
let factor_v = f32x8::splat(factor);
for chunk in body_slice.chunks_exact_mut(8) {
let x = f32x8::from(<[f32; 8]>::try_from(&*chunk).unwrap());
chunk.copy_from_slice((x * factor_v).as_array());
}
for x in tail.iter_mut() {
*x *= factor;
}
}
fn l2_norm_f32x8(v: &[f32]) -> f32 {
let n = v.len();
let body = n - (n % 8);
let (body_slice, tail) = v.split_at(body);
let mut acc_v = f32x8::splat(0.0);
for chunk in body_slice.chunks_exact(8) {
let x = f32x8::from(<[f32; 8]>::try_from(chunk).unwrap());
acc_v = x.mul_add(x, acc_v);
}
let mut sum_sq: f32 = acc_v.as_array().iter().sum();
for &x in tail {
sum_sq += x * x;
}
sum_sq.sqrt()
}
fn filter_ids(ids: &[u32], unk_id: Option<usize>, max_tokens: usize) -> Vec<u32> {
let mut out: Vec<u32> = match unk_id {
Some(u) => ids.iter().copied().filter(|&i| i as usize != u).collect(),
None => ids.to_vec(),
};
if out.len() > max_tokens {
out.truncate(max_tokens);
}
out
}
fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
let mut lens: Vec<usize> = tokenizer
.get_vocab(false)
.keys()
.map(std::string::String::len)
.collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
let spec: Value =
serde_json::to_value(tokenizer).context("tokenizer serialize for unk lookup")?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str);
let unk_token_id = match unk_token {
Some(tok) => tokenizer.token_to_id(tok).map(|id| id as usize),
None => None,
};
Ok((median_token_length, unk_token_id))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_ids_empty_input() {
let _ = compute_metadata;
}
}