Skip to main content

nexus_memory_core/
math.rs

1/// Compute cosine similarity between two vectors.
2///
3/// Returns a value clamped to [0.0, 1.0]. This is intentional for embedding
4/// similarity where negative values indicate orthogonal/anti-correlated vectors
5/// and are treated as zero similarity for ranking purposes.
6/// For the true mathematical range [-1.0, 1.0], use `cosine_similarity_raw`.
7pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
8    if a.len() != b.len() || a.is_empty() {
9        return 0.0;
10    }
11    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
12    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
13    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
14    if norm_a == 0.0 || norm_b == 0.0 {
15        return 0.0;
16    }
17    (dot / (norm_a * norm_b)).clamp(0.0, 1.0)
18}
19
20/// Compute raw cosine similarity returning the true mathematical range [-1.0, 1.0].
21pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
22    if a.len() != b.len() || a.is_empty() {
23        return 0.0;
24    }
25    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
26    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
27    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
28    if norm_a == 0.0 || norm_b == 0.0 {
29        return 0.0;
30    }
31    dot / (norm_a * norm_b)
32}
33
34#[cfg(test)]
35mod tests {
36    use super::*;
37
38    #[test]
39    fn test_cosine_similarity_identical() {
40        let v = [1.0, 2.0, 3.0];
41        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
42    }
43
44    #[test]
45    fn test_cosine_similarity_orthogonal() {
46        let a = [1.0, 0.0];
47        let b = [0.0, 1.0];
48        assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
49    }
50
51    #[test]
52    fn test_cosine_similarity_empty() {
53        assert_eq!(cosine_similarity(&[], &[]), 0.0);
54    }
55
56    #[test]
57    fn test_cosine_similarity_mismatched_lengths() {
58        assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
59    }
60}