use serde::{Deserialize, Serialize, de};
use std::iter::Sum;
use std::ops::{Add, Div, Mul, Neg, Sub};
pub trait Scalar:
Copy
+ Default
+ PartialOrd
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Neg<Output = Self>
+ Sum
+ Serialize
+ de::DeserializeOwned
+ Send
+ Sync
+ 'static
{
fn sqrt(self) -> Self;
fn epsilon() -> Self;
fn one() -> Self;
fn zero() -> Self;
fn total_cmp(&self, other: &Self) -> std::cmp::Ordering;
}
impl Scalar for f32 {
#[inline]
fn sqrt(self) -> Self {
f32::sqrt(self)
}
#[inline]
fn epsilon() -> Self {
f32::EPSILON
}
#[inline]
fn one() -> Self {
1.0
}
#[inline]
fn zero() -> Self {
0.0
}
#[inline]
fn total_cmp(&self, other: &Self) -> std::cmp::Ordering {
f32::total_cmp(self, other)
}
}
impl Scalar for f64 {
#[inline]
fn sqrt(self) -> Self {
f64::sqrt(self)
}
#[inline]
fn epsilon() -> Self {
f64::EPSILON
}
#[inline]
fn one() -> Self {
1.0
}
#[inline]
fn zero() -> Self {
0.0
}
#[inline]
fn total_cmp(&self, other: &Self) -> std::cmp::Ordering {
f64::total_cmp(self, other)
}
}
pub trait DistanceMetric<S: Scalar = f32>:
Clone + Default + Serialize + de::DeserializeOwned + Send + Sync + 'static
{
fn distance(a: &[S], b: &[S]) -> S;
}
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct L2;
impl<S: Scalar> DistanceMetric<S> for L2 {
#[inline]
fn distance(a: &[S], b: &[S]) -> S {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b)
.map(|(x, y)| {
let d = *x - *y;
d * d
})
.sum()
}
}
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct Cosine;
impl<S: Scalar> DistanceMetric<S> for Cosine {
#[inline]
fn distance(a: &[S], b: &[S]) -> S {
debug_assert_eq!(a.len(), b.len());
let (mut dot, mut na, mut nb) = (S::zero(), S::zero(), S::zero());
for (x, y) in a.iter().zip(b) {
dot = dot + *x * *y;
na = na + *x * *x;
nb = nb + *y * *y;
}
let denom = na.sqrt() * nb.sqrt();
if denom < S::epsilon() {
S::one()
} else {
S::one() - dot / denom
}
}
}
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct InnerProduct;
impl<S: Scalar> DistanceMetric<S> for InnerProduct {
#[inline]
fn distance(a: &[S], b: &[S]) -> S {
debug_assert_eq!(a.len(), b.len());
-(a.iter().zip(b).map(|(x, y)| *x * *y).sum::<S>())
}
}