use crate::error::{Error, Result};
use super::embedding::{Embedding, check_compatible};
pub trait EmbeddingProvider: Send + Sync {
type Input: ?Sized;
fn embed(&self, input: &Self::Input) -> Result<Embedding>;
fn model_id(&self) -> &str;
fn dimension(&self) -> usize;
}
pub fn semantic_similarity(a: &Embedding, b: &Embedding) -> Result<f32> {
check_compatible(a, b)?;
let mut dot = 0.0_f32;
let mut norm_a_sq = 0.0_f32;
let mut norm_b_sq = 0.0_f32;
for i in 0..a.vector.len() {
let av = a.vector[i];
let bv = b.vector[i];
dot += av * bv;
norm_a_sq += av * av;
norm_b_sq += bv * bv;
}
let na = norm_a_sq.sqrt();
let nb = norm_b_sq.sqrt();
if na == 0.0 || nb == 0.0 {
return Err(Error::InvalidInput(
"cannot compute cosine for a zero-norm embedding".into(),
));
}
Ok(dot / (na * nb))
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
fn emb(v: alloc::vec::Vec<f32>) -> Embedding {
Embedding::new(v).unwrap()
}
#[test]
fn identical_vectors_score_one() {
let a = emb(alloc::vec![1.0, 0.0, 0.0]);
let b = emb(alloc::vec![1.0, 0.0, 0.0]);
assert!((semantic_similarity(&a, &b).unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn orthogonal_vectors_score_zero() {
let a = emb(alloc::vec![1.0, 0.0]);
let b = emb(alloc::vec![0.0, 1.0]);
assert!(semantic_similarity(&a, &b).unwrap().abs() < 1e-6);
}
#[test]
fn opposite_vectors_score_minus_one() {
let a = emb(alloc::vec![1.0, 0.0]);
let b = emb(alloc::vec![-1.0, 0.0]);
assert!((semantic_similarity(&a, &b).unwrap() + 1.0).abs() < 1e-6);
}
#[test]
fn rejects_dim_mismatch() {
let a = emb(alloc::vec![1.0, 0.0, 0.0]);
let b = emb(alloc::vec![1.0, 0.0]);
assert!(matches!(
semantic_similarity(&a, &b),
Err(Error::DimensionMismatch { .. })
));
}
#[test]
fn rejects_model_mismatch() {
let a = Embedding::with_model(alloc::vec![1.0; 4], Some("ma".into())).unwrap();
let b = Embedding::with_model(alloc::vec![1.0; 4], Some("mb".into())).unwrap();
assert!(matches!(
semantic_similarity(&a, &b),
Err(Error::ModelMismatch { .. })
));
}
#[test]
fn allows_one_sided_model_id() {
let a = Embedding::new(alloc::vec![1.0; 4]).unwrap();
let b = Embedding::with_model(alloc::vec![1.0; 4], Some("mb".to_string())).unwrap();
let s = semantic_similarity(&a, &b).unwrap();
assert!((s - 1.0).abs() < 1e-6);
}
#[test]
fn large_dim_round_trip() {
let dim = 768;
let v: alloc::vec::Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
let e = Embedding::new(v).unwrap();
let s = semantic_similarity(&e, &e).unwrap();
assert!((s - 1.0).abs() < 1e-5);
}
}