use crate::error::{Error, Result};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct Vector {
data: Box<[f32]>,
}
impl Vector {
pub fn new(data: Vec<f32>) -> Result<Self> {
Self::from_box(data.into_boxed_slice())
}
pub fn from_slice(data: &[f32]) -> Result<Self> {
Self::from_box(Box::from(data))
}
fn from_box(data: Box<[f32]>) -> Result<Self> {
if data.is_empty() {
return Err(Error::invalid_vector("vector is empty"));
}
if !data.iter().all(|v| v.is_finite()) {
return Err(Error::invalid_vector("vector contains a non-finite value"));
}
Ok(Self { data })
}
#[inline]
#[must_use]
pub fn dim(&self) -> usize {
self.data.len()
}
#[inline]
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.data
}
#[inline]
#[must_use]
pub fn into_inner(self) -> Box<[f32]> {
self.data
}
#[inline]
#[must_use]
pub fn norm_squared(&self) -> f32 {
self.data.iter().map(|x| x * x).sum()
}
#[inline]
#[must_use]
pub fn norm(&self) -> f32 {
self.norm_squared().sqrt()
}
}
impl AsRef<[f32]> for Vector {
#[inline]
fn as_ref(&self) -> &[f32] {
&self.data
}
}
impl TryFrom<Vec<f32>> for Vector {
type Error = Error;
#[inline]
fn try_from(value: Vec<f32>) -> Result<Self> {
Self::new(value)
}
}
impl<'a> TryFrom<&'a [f32]> for Vector {
type Error = Error;
#[inline]
fn try_from(value: &'a [f32]) -> Result<Self> {
Self::from_slice(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DistanceMetric {
L2,
Cosine,
Dot,
}
impl DistanceMetric {
pub fn distance(self, a: &Vector, b: &Vector) -> Result<f32> {
if a.dim() != b.dim() {
return Err(Error::DimensionMismatch {
left: a.dim(),
right: b.dim(),
});
}
let value = match self {
Self::L2 => l2_distance(a.as_slice(), b.as_slice()),
Self::Cosine => cosine_distance(a.as_slice(), b.as_slice()),
Self::Dot => -dot_product(a.as_slice(), b.as_slice()),
};
Ok(value)
}
}
#[inline]
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let sum_sq: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum();
sum_sq.sqrt()
}
#[inline]
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for (&x, &y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
f32::NAN
} else {
1.0 - dot / denom
}
}
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_rejects_empty_vector() {
let err = Vector::new(Vec::new()).unwrap_err();
assert!(matches!(err, Error::InvalidVector { .. }));
}
#[test]
fn new_rejects_nan_component() {
let err = Vector::new(vec![1.0, f32::NAN, 2.0]).unwrap_err();
assert!(matches!(err, Error::InvalidVector { .. }));
}
#[test]
fn new_rejects_positive_infinity() {
let err = Vector::new(vec![f32::INFINITY]).unwrap_err();
assert!(matches!(err, Error::InvalidVector { .. }));
}
#[test]
fn new_rejects_negative_infinity() {
let err = Vector::new(vec![f32::NEG_INFINITY]).unwrap_err();
assert!(matches!(err, Error::InvalidVector { .. }));
}
#[test]
fn from_slice_copies_input() {
let data = [1.0, 2.0, 3.0];
let v = Vector::from_slice(&data).unwrap();
assert_eq!(v.as_slice(), &data);
}
#[test]
fn try_from_vec_works() {
let v: Vector = vec![1.0_f32, 2.0].try_into().unwrap();
assert_eq!(v.dim(), 2);
}
#[test]
fn try_from_slice_works() {
let data: &[f32] = &[1.0, 2.0];
let v: Vector = data.try_into().unwrap();
assert_eq!(v.dim(), 2);
}
#[test]
fn dim_matches_input_length() {
let v = Vector::new(vec![0.0; 768]).unwrap();
assert_eq!(v.dim(), 768);
}
#[test]
fn norm_squared_equals_dot_with_self() {
let v = Vector::new(vec![3.0, 4.0]).unwrap();
assert_eq!(v.norm_squared(), 25.0);
}
#[test]
fn norm_of_3_4_is_5() {
let v = Vector::new(vec![3.0, 4.0]).unwrap();
assert!((v.norm() - 5.0).abs() < 1e-6);
}
#[test]
fn as_slice_and_as_ref_agree() {
let v = Vector::new(vec![1.0, 2.0]).unwrap();
assert_eq!(v.as_slice(), v.as_ref());
}
#[test]
fn into_inner_returns_box() {
let v = Vector::new(vec![1.0, 2.0]).unwrap();
let owned = v.into_inner();
assert_eq!(&*owned, &[1.0, 2.0]);
}
#[test]
fn l2_distance_identical_vectors_is_zero() {
let a = Vector::new(vec![1.0, 2.0, 3.0]).unwrap();
let b = Vector::new(vec![1.0, 2.0, 3.0]).unwrap();
assert_eq!(DistanceMetric::L2.distance(&a, &b).unwrap(), 0.0);
}
#[test]
fn l2_distance_3_4_5_triple() {
let a = Vector::new(vec![0.0, 0.0]).unwrap();
let b = Vector::new(vec![3.0, 4.0]).unwrap();
let d = DistanceMetric::L2.distance(&a, &b).unwrap();
assert!((d - 5.0).abs() < 1e-6);
}
#[test]
fn cosine_distance_identical_unit_vectors_is_zero() {
let a = Vector::new(vec![1.0, 0.0]).unwrap();
let b = Vector::new(vec![1.0, 0.0]).unwrap();
let d = DistanceMetric::Cosine.distance(&a, &b).unwrap();
assert!(d.abs() < 1e-6);
}
#[test]
fn cosine_distance_orthogonal_unit_vectors_is_one() {
let a = Vector::new(vec![1.0, 0.0]).unwrap();
let b = Vector::new(vec![0.0, 1.0]).unwrap();
let d = DistanceMetric::Cosine.distance(&a, &b).unwrap();
assert!((d - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_distance_opposite_unit_vectors_is_two() {
let a = Vector::new(vec![1.0, 0.0]).unwrap();
let b = Vector::new(vec![-1.0, 0.0]).unwrap();
let d = DistanceMetric::Cosine.distance(&a, &b).unwrap();
assert!((d - 2.0).abs() < 1e-6);
}
#[test]
fn dot_distance_is_negated_inner_product() {
let a = Vector::new(vec![1.0, 2.0, 3.0]).unwrap();
let b = Vector::new(vec![4.0, 5.0, 6.0]).unwrap();
let d = DistanceMetric::Dot.distance(&a, &b).unwrap();
assert!((d + 32.0).abs() < 1e-5);
}
#[test]
fn distance_rejects_dimension_mismatch() {
let a = Vector::new(vec![1.0, 2.0]).unwrap();
let b = Vector::new(vec![1.0, 2.0, 3.0]).unwrap();
for metric in [
DistanceMetric::L2,
DistanceMetric::Cosine,
DistanceMetric::Dot,
] {
let err = metric.distance(&a, &b).unwrap_err();
assert!(matches!(
err,
Error::DimensionMismatch { left: 2, right: 3 }
));
}
}
#[test]
fn vector_clone_is_independent() {
let a = Vector::new(vec![1.0, 2.0, 3.0]).unwrap();
let b = a.clone();
assert_eq!(a.as_slice(), b.as_slice());
assert_eq!(a, b);
}
}