use half::{bf16, f16};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum VectorPrecision {
#[default]
F32,
F16,
BF16,
}
impl VectorPrecision {
#[must_use]
pub const fn bytes_per_element(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 | Self::BF16 => 2,
}
}
#[must_use]
pub const fn memory_size(&self, dimension: usize) -> usize {
self.bytes_per_element() * dimension
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum VectorData {
F32(Vec<f32>),
F16(Vec<f16>),
BF16(Vec<bf16>),
}
impl VectorData {
#[must_use]
pub fn from_f32_slice(data: &[f32], precision: VectorPrecision) -> Self {
match precision {
VectorPrecision::F32 => Self::F32(data.to_vec()),
VectorPrecision::F16 => Self::F16(data.iter().map(|&x| f16::from_f32(x)).collect()),
VectorPrecision::BF16 => Self::BF16(data.iter().map(|&x| bf16::from_f32(x)).collect()),
}
}
#[must_use]
pub fn from_f32_vec(data: Vec<f32>, precision: VectorPrecision) -> Self {
if precision == VectorPrecision::F32 {
Self::F32(data)
} else {
Self::from_f32_slice(&data, precision)
}
}
#[must_use]
pub const fn precision(&self) -> VectorPrecision {
match self {
Self::F32(_) => VectorPrecision::F32,
Self::F16(_) => VectorPrecision::F16,
Self::BF16(_) => VectorPrecision::BF16,
}
}
#[must_use]
pub fn len(&self) -> usize {
match self {
Self::F32(v) => v.len(),
Self::F16(v) => v.len(),
Self::BF16(v) => v.len(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn memory_size(&self) -> usize {
self.precision().memory_size(self.len())
}
#[must_use]
pub fn to_f32_vec(&self) -> Vec<f32> {
match self {
Self::F32(v) => v.clone(),
Self::F16(v) => v.iter().map(|x| x.to_f32()).collect(),
Self::BF16(v) => v.iter().map(|x| x.to_f32()).collect(),
}
}
#[must_use]
pub fn as_f32_slice(&self) -> Option<&[f32]> {
match self {
Self::F32(v) => Some(v.as_slice()),
Self::F16(_) | Self::BF16(_) => None,
}
}
#[must_use]
pub fn convert(&self, target: VectorPrecision) -> Self {
if self.precision() == target {
return self.clone();
}
Self::from_f32_slice(&self.to_f32_vec(), target)
}
}
impl From<Vec<f32>> for VectorData {
fn from(data: Vec<f32>) -> Self {
Self::F32(data)
}
}
impl From<&[f32]> for VectorData {
fn from(data: &[f32]) -> Self {
Self::F32(data.to_vec())
}
}
fn with_f32_simd(a: &VectorData, b: &VectorData, simd_fn: fn(&[f32], &[f32]) -> f32) -> f32 {
match (a, b) {
(VectorData::F32(va), VectorData::F32(vb)) => simd_fn(va, vb),
_ => simd_fn(&a.to_f32_vec(), &b.to_f32_vec()),
}
}
#[must_use]
pub fn dot_product(a: &VectorData, b: &VectorData) -> f32 {
with_f32_simd(a, b, crate::simd_native::dot_product_native)
}
#[must_use]
pub fn cosine_similarity(a: &VectorData, b: &VectorData) -> f32 {
if let (VectorData::F32(va), VectorData::F32(vb)) = (a, b) {
crate::simd_native::cosine_similarity_native(va, vb)
} else {
let dot = dot_product(a, b);
let norm_a = norm_squared(a).sqrt();
let norm_b = norm_squared(b).sqrt();
if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
0.0
} else {
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
}
}
#[must_use]
pub fn euclidean_distance(a: &VectorData, b: &VectorData) -> f32 {
with_f32_simd(a, b, crate::simd_native::euclidean_native)
}
fn norm_squared(v: &VectorData) -> f32 {
if let VectorData::F32(data) = v {
let n = crate::simd_native::norm_native(data);
n * n
} else {
let f32_vec = v.to_f32_vec();
let n = crate::simd_native::norm_native(&f32_vec);
n * n
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_data_f16_roundtrip() {
let data = vec![1.0, 2.0, 3.0];
let v = VectorData::from_f32_slice(&data, VectorPrecision::F16);
let result = v.to_f32_vec();
for (a, b) in data.iter().zip(result.iter()) {
assert!((a - b).abs() < 0.01);
}
}
#[test]
fn test_cosine_similarity_identical() {
let v1 = VectorData::from_f32_slice(&[1.0, 0.0, 0.0], VectorPrecision::F32);
let v2 = VectorData::from_f32_slice(&[1.0, 0.0, 0.0], VectorPrecision::F32);
let sim = cosine_similarity(&v1, &v2);
assert!((sim - 1.0).abs() < 1e-5);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let v1 = VectorData::from_f32_slice(&[1.0, 0.0, 0.0], VectorPrecision::F32);
let v2 = VectorData::from_f32_slice(&[0.0, 1.0, 0.0], VectorPrecision::F32);
let sim = cosine_similarity(&v1, &v2);
assert!(sim.abs() < 1e-5);
}
#[test]
fn test_euclidean_distance_identical() {
let v1 = VectorData::from_f32_slice(&[1.0, 2.0, 3.0], VectorPrecision::F32);
let v2 = VectorData::from_f32_slice(&[1.0, 2.0, 3.0], VectorPrecision::F32);
let dist = euclidean_distance(&v1, &v2);
assert!(dist.abs() < 1e-5);
}
#[test]
fn test_euclidean_distance_345() {
let v1 = VectorData::from_f32_slice(&[0.0, 0.0], VectorPrecision::F32);
let v2 = VectorData::from_f32_slice(&[3.0, 4.0], VectorPrecision::F32);
let dist = euclidean_distance(&v1, &v2);
assert!((dist - 5.0).abs() < 1e-5);
}
#[test]
fn test_norm_squared_f32() {
let v = VectorData::from_f32_slice(&[3.0, 4.0], VectorPrecision::F32);
let norm = norm_squared(&v);
assert!((norm - 25.0).abs() < 1e-5);
}
#[test]
fn test_cosine_similarity_f16_vs_f32() {
let v1 = VectorData::from_f32_slice(&[1.0, 2.0, 3.0], VectorPrecision::F16);
let v2 = VectorData::from_f32_slice(&[1.0, 2.0, 3.0], VectorPrecision::F32);
let sim = cosine_similarity(&v1, &v2);
assert!((sim - 1.0).abs() < 0.01);
}
#[test]
fn test_cosine_similarity_is_clamped_to_unit_interval() {
let v1 = VectorData::from_f32_slice(&[1.0, 1.0, 1.0, 1.0], VectorPrecision::F16);
let v2 = VectorData::from_f32_slice(&[1.0, 1.0, 1.0, 1.0], VectorPrecision::BF16);
let sim = cosine_similarity(&v1, &v2);
assert!(
(-1.0..=1.0).contains(&sim),
"cosine similarity must be clamped to [-1, 1], got {sim}"
);
}
}