use ndarray::Array1;
#[derive(Default)]
pub enum Metric {
#[default]
Cosine,
Euclidean,
Manhattan,
Jaccard,
}
impl Metric {
pub(crate) fn distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
match self {
Metric::Euclidean => self.euclidean_distance(emb1, emb2),
Metric::Cosine => self.cosine_distance(emb1, emb2),
Metric::Manhattan => self.manhattan_distance(emb1, emb2),
Metric::Jaccard => self.jaccard_distance(emb1, emb2),
}
}
fn euclidean_distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
let diff = emb1 - emb2;
diff.dot(&diff).sqrt()
}
fn manhattan_distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
emb1.iter()
.zip(emb2.iter())
.map(|(a, b)| (a - b).abs())
.sum()
}
fn jaccard_distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
let intersection = emb1.iter().zip(emb2.iter()).filter(|(a, b)| a == b).count();
let union = emb1.len() + emb2.len() - intersection;
1.0 - (intersection as f32 / union as f32)
}
fn cosine_distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
let dot_product = emb1.dot(emb2);
let norm1 = emb1.dot(emb1).sqrt();
let norm2 = emb2.dot(emb2).sqrt();
1.0 - dot_product / (norm1 * norm2)
}
}