pub trait Embedder: Send + Sync {
fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
fn model_id(&self) -> &str;
fn dim(&self) -> usize;
fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
let mut v = self.embed(&[text])?;
Ok(v.pop().unwrap_or_default())
}
}
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na.sqrt() * nb.sqrt())
}
pub fn to_blob(v: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(v.len() * 4);
for f in v {
out.extend_from_slice(&f.to_le_bytes());
}
out
}
pub fn from_blob(b: &[u8]) -> Vec<f32> {
b.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
pub fn is_embeddable(text: &str) -> bool {
text.trim().chars().count() >= 12
}
pub struct HashEmbedder {
dim: usize,
}
impl HashEmbedder {
pub fn new(dim: usize) -> Self {
Self { dim: dim.max(1) }
}
fn hash_token(tok: &str) -> u64 {
let mut h: u64 = 0xcbf29ce484222325;
for b in tok.bytes() {
h ^= b as u64;
h = h.wrapping_mul(0x100000001b3);
}
h
}
}
impl Default for HashEmbedder {
fn default() -> Self {
Self::new(64)
}
}
impl Embedder for HashEmbedder {
fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
let mut out = Vec::with_capacity(texts.len());
for t in texts {
let mut v = vec![0.0f32; self.dim];
for tok in t
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
{
let lower = tok.to_lowercase();
let bucket = (Self::hash_token(&lower) as usize) % self.dim;
v[bucket] += 1.0;
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
out.push(v);
}
Ok(out)
}
fn model_id(&self) -> &str {
"hash-v1"
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(feature = "embed")]
pub const DEFAULT_EMBED_MODEL: &str = "minishlab/potion-multilingual-128M";
pub fn default_embedder() -> Box<dyn Embedder> {
if std::env::var("TJ_EMBED").as_deref() == Ok("hash") {
return Box::new(HashEmbedder::default());
}
#[cfg(feature = "embed")]
{
let repo =
std::env::var("TJ_EMBED_MODEL").unwrap_or_else(|_| DEFAULT_EMBED_MODEL.to_string());
match Model2VecEmbedder::load(&repo) {
Ok(m) => return Box::new(m),
Err(e) => {
tracing::warn!("model2vec load failed ({e:#}); using hash embedder fallback");
}
}
}
Box::new(HashEmbedder::default())
}
#[cfg(feature = "embed")]
pub struct Model2VecEmbedder {
model: model2vec_rs::model::StaticModel,
model_id: String,
dim: usize,
}
#[cfg(feature = "embed")]
impl Model2VecEmbedder {
pub fn load(repo: &str) -> anyhow::Result<Self> {
let model = model2vec_rs::model::StaticModel::from_pretrained(
repo,
None, Some(true), None, )?;
let dim = model.encode_single("probe").len();
anyhow::ensure!(
dim > 0,
"model2vec model {repo} produced a zero-dim embedding"
);
Ok(Self {
model,
model_id: format!("model2vec:{repo}"),
dim,
})
}
}
#[cfg(feature = "embed")]
impl Embedder for Model2VecEmbedder {
fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
Ok(self.model.encode(&owned))
}
fn model_id(&self) -> &str {
&self.model_id
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_is_one() {
let v = vec![1.0, 2.0, 3.0];
assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_zero() {
assert_eq!(cosine(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
}
#[test]
fn cosine_mismatch_or_zero_norm_is_zero() {
assert_eq!(cosine(&[1.0, 2.0], &[1.0]), 0.0);
assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
}
#[test]
fn blob_round_trips() {
let v = vec![0.5, -1.25, 3.0, 0.0];
assert_eq!(from_blob(&to_blob(&v)), v);
}
#[test]
fn is_embeddable_skips_short_boilerplate() {
assert!(!is_embeddable(""));
assert!(!is_embeddable("[open]"));
assert!(is_embeddable("Fix the auth bug in middleware"));
}
#[test]
fn hash_embedder_is_deterministic_and_normalised() {
let e = HashEmbedder::new(32);
let a = e.embed_one("payment gateway dedup").unwrap();
let b = e.embed_one("payment gateway dedup").unwrap();
assert_eq!(a, b);
assert_eq!(a.len(), 32);
let norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn hash_embedder_overlap_ranks_above_disjoint() {
let e = HashEmbedder::new(256);
let q = e.embed_one("payment refund duplicate write").unwrap();
let near = e.embed_one("duplicate refund write on payment").unwrap();
let far = e.embed_one("frontend button color tweak").unwrap();
assert!(
cosine(&q, &near) > cosine(&q, &far),
"lexical overlap must score higher: near={} far={}",
cosine(&q, &near),
cosine(&q, &far)
);
}
}