use alloc::string::String;
use alloc::vec::Vec;
use crate::error::{Error, Result};
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Embedding {
pub vector: Vec<f32>,
pub model_id: Option<String>,
}
impl Embedding {
pub fn new(vector: Vec<f32>) -> Result<Self> {
Self::with_model(vector, None)
}
pub fn with_model(vector: Vec<f32>, model_id: Option<String>) -> Result<Self> {
if vector.is_empty() {
return Err(Error::InvalidInput("embedding vector is empty".into()));
}
if vector.iter().any(|x| !x.is_finite()) {
return Err(Error::InvalidInput(
"embedding contains non-finite values".into(),
));
}
Ok(Self { vector, model_id })
}
#[inline]
#[must_use]
pub fn dim(&self) -> usize {
self.vector.len()
}
#[must_use]
pub fn l2_norm(&self) -> f32 {
let sum_sq: f32 = self.vector.iter().map(|x| x * x).sum();
sum_sq.sqrt()
}
pub fn normalize(&mut self) {
let n = self.l2_norm();
if n > 0.0 && n.is_finite() {
for v in &mut self.vector {
*v /= n;
}
}
}
pub fn dot(&self, other: &Embedding) -> Result<f32> {
check_compatible(self, other)?;
let mut acc = 0.0_f32;
for i in 0..self.vector.len() {
acc += self.vector[i] * other.vector[i];
}
Ok(acc)
}
}
pub(super) fn check_compatible(a: &Embedding, b: &Embedding) -> Result<()> {
if let (Some(am), Some(bm)) = (&a.model_id, &b.model_id) {
if am != bm {
return Err(Error::ModelMismatch {
a: am.clone(),
b: bm.clone(),
});
}
}
if a.vector.len() != b.vector.len() {
return Err(Error::DimensionMismatch {
a: a.vector.len(),
b: b.vector.len(),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_vector_rejected() {
let r = Embedding::new(Vec::new());
assert!(matches!(r, Err(Error::InvalidInput(_))));
}
#[test]
fn nan_rejected() {
let r = Embedding::new(vec![1.0, f32::NAN, 0.0]);
assert!(matches!(r, Err(Error::InvalidInput(_))));
}
#[test]
fn infinity_rejected() {
let r = Embedding::new(vec![1.0, f32::INFINITY, 0.0]);
assert!(matches!(r, Err(Error::InvalidInput(_))));
}
#[test]
fn dim_matches_length() {
let e = Embedding::new(vec![0.1, 0.2, 0.3, 0.4]).unwrap();
assert_eq!(e.dim(), 4);
}
#[test]
fn l2_norm_matches_pythag() {
let e = Embedding::new(vec![3.0, 4.0]).unwrap();
assert!((e.l2_norm() - 5.0).abs() < 1e-6);
}
#[test]
fn normalize_sets_unit_length() {
let mut e = Embedding::new(vec![3.0, 4.0]).unwrap();
e.normalize();
assert!((e.l2_norm() - 1.0).abs() < 1e-6);
}
#[test]
fn normalize_zero_vector_is_safe() {
let mut e = Embedding::with_model(vec![0.0, 0.0, 0.0], Some("zero".into())).unwrap();
e.normalize();
assert!(e.vector.iter().all(|v| *v == 0.0));
}
#[test]
fn dot_with_matching_models() {
let a = Embedding::with_model(vec![1.0, 0.0], Some("m1".into())).unwrap();
let b = Embedding::with_model(vec![1.0, 0.0], Some("m1".into())).unwrap();
assert!((a.dot(&b).unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn dot_rejects_model_mismatch() {
let a = Embedding::with_model(vec![1.0; 4], Some("a".into())).unwrap();
let b = Embedding::with_model(vec![1.0; 4], Some("b".into())).unwrap();
assert!(matches!(a.dot(&b), Err(Error::ModelMismatch { .. })));
}
#[test]
fn dot_rejects_dim_mismatch() {
let a = Embedding::new(vec![1.0; 3]).unwrap();
let b = Embedding::new(vec![1.0; 4]).unwrap();
assert!(matches!(a.dot(&b), Err(Error::DimensionMismatch { .. })));
}
}