#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Cosine<T> {
_t: PhantomData<T>,
}
impl<T: Number> Default for Cosine<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number> Cosine<T> {
pub fn new() -> Cosine<T> {
Cosine { _t: PhantomData }
}
#[inline]
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different.");
}
x.dot(y).to_f64().unwrap()
}
#[inline]
#[allow(dead_code)]
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
x.iterator(0)
.map(|&a| {
let val = a.to_f64().unwrap();
val * val
})
.sum()
}
#[inline]
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
x.norm2()
}
#[inline]
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
let dot_product = Self::dot_product(x, y);
let magnitude_x = Self::magnitude(x);
let magnitude_y = Self::magnitude(y);
if magnitude_x == 0.0 || magnitude_y == 0.0 {
return f64::MIN;
}
dot_product / (magnitude_x * magnitude_y)
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
let similarity = Cosine::cosine_similarity(x, y);
1.0 - similarity
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_identical_vectors() {
let a = vec![1, 2, 3];
let b = vec![1, 2, 3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 0.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_orthogonal_vectors() {
let a = vec![1, 0];
let b = vec![0, 1];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_opposite_vectors() {
let a = vec![1, 2, 3];
let b = vec![-1, -2, -3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 2.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_general_case() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 1.0, 3.0];
let dist: f64 = Cosine::new().distance(&a, &b);
let expected_dist = 1.0 - (13.0 / 14.0);
assert!((dist - expected_dist).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[should_panic(expected = "Input vector sizes are different.")]
fn cosine_distance_different_sizes() {
let a = vec![1, 2];
let b = vec![1, 2, 3];
let _dist: f64 = Cosine::new().distance(&a, &b);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_zero_vector() {
let a = vec![0, 0, 0];
let b = vec![1, 2, 3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!(dist > 1e300)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_float_precision() {
let a = vec![1.0f32, 2.0, 3.0];
let b = vec![4.0f32, 5.0, 6.0];
let dist: f64 = Cosine::new().distance(&a, &b);
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); let expected_similarity = dot_product / (mag_a * mag_b);
let expected_distance = 1.0 - expected_similarity;
assert!((dist - expected_distance).abs() < 1e-6);
}
}