use std::fs::{self, File};
use std::iter::zip;
use std::path::Path;
use anyhow::{Context, Result, anyhow, bail};
use ndarray::Array2;
use half::f16;
use serde_json::Value;
use safetensors::{SafeTensors, tensor::Dtype};
use tokenizers::Tokenizer;
pub struct Model2Vec {
tokenizer: Tokenizer,
embeddings: Array2<f32>,
normalize: bool,
median_token_length: usize,
unk_token_id: Option<usize>,
}
impl Model2Vec {
pub fn from_pretrained<P>(
base_path: P,
normalize: Option<bool>,
subdir: Option<&str>,
) -> Result<Self>
where
P: AsRef<Path>
{
let base_path = base_path.as_ref();
let directory = subdir.map(|s| base_path.join(s)).unwrap_or_else(|| base_path.to_path_buf());
let tok_path = directory.join("tokenizer.json");
let mdl_path = directory.join("model.safetensors");
let cfg_path = directory.join("config.json");
let tokenizer = Tokenizer::from_file(&tok_path)
.map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
let mut lens: Vec<usize> = tokenizer
.get_vocab(false)
.keys()
.map(String::len)
.collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
let cfg_file = File::open(&cfg_path).context("failed to read config.json")?;
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);
let spec: Value = serde_json::from_reader(
File::open(&tok_path).context("can't open tokenizer JSON")?
).context(
"failed to parse tokenizer JSON spec"
)?;
let unk_token_id = if let Some(unk_id) = spec.pointer("/model/unk_id").and_then(Value::as_u64) {
Some(unk_id as usize) } else if let Some(unk_token) = spec.pointer("/model/unk_token").and_then(Value::as_str) {
let unk_id = tokenizer.token_to_id(unk_token).ok_or_else(|| anyhow!(
"tokenizer JSON declared unk_token=\"{unk_token}\", which is not in the vocab"
))?;
Some(unk_id as usize)
} else {
None
};
let model_bytes = fs::read(&mdl_path).context("failed to read safetensors file")?;
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors data")?;
let tensor = safet
.tensor("embeddings")
.or_else(|_| safet.tensor("0"))
.context("no 'embeddings' tensor found")?;
let &[rows, cols] = tensor.shape().try_into().context("embedding tensor is not a 2D matrix")?;
let raw = tensor.data();
let dtype = tensor.dtype();
let floats: Vec<f32> = match dtype {
Dtype::F32 => {
raw.chunks_exact(4)
.map(|bs| f32::from_le_bytes(bs.try_into().unwrap()))
.collect()
}
Dtype::F16 => {
raw.chunks_exact(2)
.map(|bs| f16::from_le_bytes(bs.try_into().unwrap()).to_f32())
.collect()
}
Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
other => bail!("unsupported scalar dtype in tensor: {:?}", other),
};
let embeddings = Array2::from_shape_vec((rows, cols), floats)
.context("failed to build embeddings array")?;
Ok(Self {
tokenizer,
embeddings,
normalize,
median_token_length,
unk_token_id,
})
}
fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
let max_chars = max_tokens.saturating_mul(median_len);
match s.char_indices().nth(max_chars) {
Some((byte_idx, _)) => &s[..byte_idx],
None => s,
}
}
pub fn encode_with_args<A, S>(
&self,
sentences: A,
max_length: Option<usize>,
batch_size: usize,
) -> Result<Array2<f32>>
where
A: AsRef<[S]>,
S: AsRef<str>,
{
let sentences = sentences.as_ref();
let mut embeddings = vec![0.0; sentences.len() * self.embedding_dim()];
let batch_iter = zip(
sentences.chunks(batch_size),
embeddings.chunks_mut(batch_size * self.embedding_dim()),
);
for (batch, emb_batch) in batch_iter {
let truncated: Vec<&str> = if let Some(max_tok) = max_length {
batch
.iter()
.map(|text| {
Self::truncate_str(text.as_ref(), max_tok, self.median_token_length)
})
.collect()
} else {
batch.iter().map(S::as_ref).collect()
};
let means = emb_batch.chunks_exact_mut(self.embedding_dim());
assert_eq!(batch.len(), means.len());
assert_eq!(truncated.len(), means.len());
let encodings = self
.tokenizer
.encode_batch_fast(
truncated,
false,
)
.map_err(|e| anyhow!("tokenization failed: {e}"))?;
assert_eq!(encodings.len(), means.len());
for (encoding, out_mean) in zip(encodings, means) {
let token_ids = encoding
.get_ids()
.iter()
.copied()
.map(|id| id as usize)
.filter(|&id| {
self.unk_token_id != Some(id)
})
.take(max_length.unwrap_or(usize::MAX));
self.pool_ids(token_ids, out_mean);
}
}
Array2::from_shape_vec(
(sentences.len(), self.embedding_dim()),
embeddings,
).context(
"embedding shape (input/output count or dimensionality) mismatch"
)
}
pub fn encode<A, S>(&self, sentences: A) -> Result<Array2<f32>>
where
A: AsRef<[S]>,
S: AsRef<str>,
{
self.encode_with_args(sentences, Some(512), 1024)
}
fn pool_ids(&self, ids: impl IntoIterator<Item = usize>, mean: &mut [f32]) {
assert_eq!(mean.len(), self.embedding_dim());
mean.fill(0.0);
let mut cnt = 0;
for id in ids {
let row = self.embeddings.row(id);
for (x, &v) in zip(&mut *mean, row) {
*x += v;
}
cnt += 1;
}
let denominator = if self.normalize {
mean.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12)
} else {
cnt.max(1) as f32
};
for x in mean {
*x /= denominator;
}
}
pub fn embedding_dim(&self) -> usize {
self.embeddings.ncols()
}
}