use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Normalized(());
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Unnormalized(());
#[derive(Debug, Clone)]
pub struct EmbeddingVector<State> {
inner: Vec<f32>,
_state: PhantomData<State>,
}
impl EmbeddingVector<Unnormalized> {
#[must_use]
pub fn new(inner: Vec<f32>) -> Self {
Self {
inner,
_state: PhantomData,
}
}
#[must_use]
pub fn normalize(self) -> EmbeddingVector<Normalized> {
let norm: f32 = self.inner.iter().map(|x| x * x).sum::<f32>().sqrt();
let normalized = if norm < f32::EPSILON {
self.inner
} else {
self.inner.into_iter().map(|x| x / norm).collect()
};
EmbeddingVector {
inner: normalized,
_state: PhantomData,
}
}
}
impl EmbeddingVector<Normalized> {
#[must_use]
pub fn new_normalized(inner: Vec<f32>) -> Self {
Self {
inner,
_state: PhantomData,
}
}
}
impl<State> EmbeddingVector<State> {
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.inner
}
#[must_use]
pub fn into_inner(self) -> Vec<f32> {
self.inner
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl From<Vec<f32>> for EmbeddingVector<Unnormalized> {
fn from(v: Vec<f32>) -> Self {
Self::new(v)
}
}
#[inline]
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
debug_assert_eq!(a.len(), b.len(), "cosine_similarity: length mismatch");
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < f32::EPSILON {
return 0.0;
}
(dot / denom).clamp(-1.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identical_vectors() {
let v = vec![1.0_f32, 2.0, 3.0];
assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn orthogonal_vectors() {
let a = vec![1.0_f32, 0.0];
let b = vec![0.0_f32, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn opposite_vectors() {
let a = vec![1.0_f32, 0.0];
let b = vec![-1.0_f32, 0.0];
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
}
#[test]
fn zero_vector() {
let a = vec![0.0_f32, 0.0];
let b = vec![1.0_f32, 0.0];
assert!(cosine_similarity(&a, &b).abs() <= f32::EPSILON);
}
#[test]
fn different_lengths() {
let a = vec![1.0_f32];
let b = vec![1.0_f32, 0.0];
assert!(cosine_similarity(&a, &b).abs() <= f32::EPSILON);
}
#[test]
fn empty_vectors() {
assert!(cosine_similarity(&[], &[]).abs() <= f32::EPSILON);
}
#[test]
fn parallel_vectors() {
let a = vec![2.0_f32, 0.0];
let b = vec![5.0_f32, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn normalized_vectors() {
let s = 1.0_f32 / 2.0_f32.sqrt();
let a = vec![s, s];
let b = vec![s, s];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn embedding_vector_normalize_produces_unit_vector() {
let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
let normed = raw.normalize();
let sum_sq: f32 = normed.as_slice().iter().map(|x| x * x).sum();
assert!((sum_sq - 1.0).abs() < 1e-6);
}
#[test]
fn embedding_vector_normalize_zero_vector_is_safe() {
let raw = EmbeddingVector::<Unnormalized>::new(vec![0.0_f32, 0.0]);
let normed = raw.normalize();
assert_eq!(normed.as_slice(), &[0.0_f32, 0.0]);
}
#[test]
fn embedding_vector_into_inner_roundtrip() {
let data = vec![1.0_f32, 2.0, 3.0];
let v = EmbeddingVector::<Unnormalized>::new(data.clone());
assert_eq!(v.into_inner(), data);
}
#[test]
fn embedding_vector_len_and_is_empty() {
let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
assert_eq!(v.len(), 2);
assert!(!v.is_empty());
let empty = EmbeddingVector::<Unnormalized>::new(vec![]);
assert!(empty.is_empty());
}
#[test]
fn embedding_vector_new_normalized_trust_caller() {
let v = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
assert_eq!(v.as_slice(), &[0.6_f32, 0.8]);
}
#[test]
fn embedding_vector_from_vec() {
let v: EmbeddingVector<Unnormalized> = vec![1.0_f32, 2.0].into();
assert_eq!(v.as_slice(), &[1.0_f32, 2.0]);
}
}