pub trait Embedder: Send + Sync {
fn embed(&self, texts: &[&str]) -> Vec<Vec<f32>>;
fn id(&self) -> &str {
"anonymous"
}
}
pub struct BoxedEmbedder<F>
where
F: Fn(&[&str]) -> Vec<Vec<f32>> + Send + Sync,
{
f: F,
name: String,
}
impl<F> BoxedEmbedder<F>
where
F: Fn(&[&str]) -> Vec<Vec<f32>> + Send + Sync,
{
pub fn new(f: F) -> Self {
Self {
f,
name: "boxed".to_string(),
}
}
pub fn named(f: F, name: impl Into<String>) -> Self {
Self {
f,
name: name.into(),
}
}
}
impl<F> Embedder for BoxedEmbedder<F>
where
F: Fn(&[&str]) -> Vec<Vec<f32>> + Send + Sync,
{
fn embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
(self.f)(texts)
}
fn id(&self) -> &str {
&self.name
}
}
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot: f32 = 0.0;
let mut na: f32 = 0.0;
let mut nb: f32 = 0.0;
for i in 0..a.len() {
let x = a[i];
let y = b[i];
dot += x * y;
na += x * x;
nb += y * y;
}
let na = na.sqrt();
let nb = nb.sqrt();
if na < 1e-12 && nb < 1e-12 {
return 1.0;
}
if na < 1e-12 || nb < 1e-12 {
return 0.0;
}
(dot / (na * nb)).clamp(-1.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_vectors_is_one() {
let v = [1.0_f32, 2.0, 3.0];
assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_vectors_is_zero() {
let a = [1.0_f32, 0.0];
let b = [0.0_f32, 1.0];
assert!(cosine(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_both_zero_returns_one() {
let z = [0.0_f32; 4];
assert!((cosine(&z, &z) - 1.0).abs() < 1e-9);
}
#[test]
fn cosine_one_zero_returns_zero() {
let a = [0.0_f32; 4];
let b = [1.0_f32, 2.0, 3.0, 4.0];
assert_eq!(cosine(&a, &b), 0.0);
}
#[test]
fn cosine_dim_mismatch_returns_zero() {
let a = [1.0_f32, 2.0];
let b = [1.0_f32, 2.0, 3.0];
assert_eq!(cosine(&a, &b), 0.0);
}
#[test]
fn boxed_embedder_round_trip() {
let emb = BoxedEmbedder::named(
|texts: &[&str]| texts.iter().map(|t| vec![t.len() as f32, 1.0]).collect(),
"len-embed",
);
let v = emb.embed(&["abc", "abcdef"]);
assert_eq!(v.len(), 2);
assert_eq!(v[0], vec![3.0, 1.0]);
assert_eq!(v[1], vec![6.0, 1.0]);
assert_eq!(emb.id(), "len-embed");
}
}