use crate::vector::vector::VectorType;
pub fn distance_l2(a: &VectorType, b: &VectorType) -> Result<f64, String> {
match (a, b) {
(VectorType::Float32(a_vec), VectorType::Float32(b_vec)) => {
if a_vec.len() != b_vec.len() {
return Err("Vector length mismatch".to_string());
}
let sum: f64 = a_vec.iter()
.zip(b_vec.iter())
.map(|(x, y)| (*x as f64 - *y as f64).powi(2))
.sum();
Ok(sum.sqrt())
}
(VectorType::Int8(a_vec), VectorType::Int8(b_vec)) => {
if a_vec.len() != b_vec.len() {
return Err("Vector length mismatch".to_string());
}
let sum: f64 = a_vec.iter()
.zip(b_vec.iter())
.map(|(&x, &y)| (x as f64 - y as f64).powi(2))
.sum();
Ok(sum.sqrt())
}
_ => Err("Cannot calculate L2 distance between bitvectors".to_string()),
}
}
pub fn distance_cosine(a: &VectorType, b: &VectorType) -> Result<f64, String> {
match (a, b) {
(VectorType::Float32(a_vec), VectorType::Float32(b_vec)) => {
if a_vec.len() != b_vec.len() {
return Err("Vector length mismatch".to_string());
}
let dot: f64 = a_vec.iter().zip(b_vec.iter()).map(|(x, y)| (*x as f64) * (*y as f64)).sum();
let norm_a: f64 = a_vec.iter().map(|x| (*x as f64) * (*x as f64)).sum::<f64>().sqrt();
let norm_b: f64 = b_vec.iter().map(|x| (*x as f64) * (*x as f64)).sum::<f64>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return Err("Cannot calculate cosine distance with zero vector".to_string());
}
let similarity = dot / (norm_a * norm_b);
Ok(1.0 - similarity)
}
(VectorType::Int8(a_vec), VectorType::Int8(b_vec)) => {
if a_vec.len() != b_vec.len() {
return Err("Vector length mismatch".to_string());
}
let dot: i64 = a_vec.iter().zip(b_vec.iter())
.map(|(&x, &y)| (x as i64) * (y as i64))
.sum();
let norm_a: f64 = (a_vec.iter().map(|&x| (x as i64).pow(2)).sum::<i64>() as f64).sqrt();
let norm_b: f64 = (b_vec.iter().map(|&x| (x as i64).pow(2)).sum::<i64>() as f64).sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return Err("Cannot calculate cosine distance with zero vector".to_string());
}
let similarity = dot as f64 / (norm_a * norm_b);
Ok(1.0 - similarity)
}
_ => Err("Cannot calculate cosine distance between bitvectors".to_string()),
}
}
pub fn distance_hamming(a: &VectorType, b: &VectorType) -> Result<u32, String> {
match (a, b) {
(VectorType::Bit(a_vec), VectorType::Bit(b_vec)) => {
if a_vec.len() != b_vec.len() {
return Err("Vector length mismatch".to_string());
}
let mut distance = 0u32;
for (a_byte, b_byte) in a_vec.iter().zip(b_vec.iter()) {
let xor = a_byte ^ b_byte;
distance += xor.count_ones();
}
Ok(distance)
}
_ => Err("Cannot calculate hamming distance between non-bit vectors".to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector::vector::VectorType;
#[test]
fn test_l2_distance() {
let a = VectorType::Float32(vec![1.0, 1.0]);
let b = VectorType::Float32(vec![2.0, 2.0]);
let d = distance_l2(&a, &b).unwrap();
assert!((d - 1.41421356237).abs() < 0.001);
}
#[test]
fn test_cosine_distance() {
let a = VectorType::Float32(vec![1.0, 0.0]);
let b = VectorType::Float32(vec![1.0, 0.0]);
let d = distance_cosine(&a, &b).unwrap();
assert!(d < 0.001);
}
#[test]
fn test_cosine_distance_opposite() {
let a = VectorType::Float32(vec![1.0, 0.0]);
let b = VectorType::Float32(vec![-1.0, 0.0]);
let d = distance_cosine(&a, &b).unwrap();
assert!((d - 2.0).abs() < 0.001);
}
#[test]
fn test_hamming_distance() {
let a = VectorType::Bit(vec![0b11110000]);
let b = VectorType::Bit(vec![0b00001111]);
let d = distance_hamming(&a, &b).unwrap();
assert_eq!(d, 8);
}
#[test]
fn test_hamming_distance_same() {
let a = VectorType::Bit(vec![0b11111111]);
let b = VectorType::Bit(vec![0b11111111]);
let d = distance_hamming(&a, &b).unwrap();
assert_eq!(d, 0);
}
#[test]
fn test_int8_l2() {
let a = VectorType::Int8(vec![1, 2, 3]);
let b = VectorType::Int8(vec![4, 5, 6]);
let d = distance_l2(&a, &b).unwrap();
assert!((d - 5.196152422706632).abs() < 0.001);
}
}