Skip to main content

nodedb_types/
vector_distance.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared scalar distance metric implementations.
4//!
5//! Pure scalar functions that the compiler auto-vectorizes. Used by both
6//! `nodedb` (with optional SIMD dispatch) and `nodedb-lite` (scalar only).
7
8/// Distance metric selection.
9#[derive(
10    Debug,
11    Clone,
12    Copy,
13    PartialEq,
14    Eq,
15    serde::Serialize,
16    serde::Deserialize,
17    zerompk::ToMessagePack,
18    zerompk::FromMessagePack,
19)]
20#[non_exhaustive]
21pub enum DistanceMetric {
22    /// Euclidean (L2) squared distance.
23    L2 = 0,
24    /// Cosine distance (1 - cosine_similarity).
25    Cosine = 1,
26    /// Negative inner product (for max-inner-product search via min-heap).
27    InnerProduct = 2,
28    /// Manhattan (L1) distance: sum of absolute differences.
29    Manhattan = 3,
30    /// Chebyshev (L-infinity) distance: max absolute difference.
31    Chebyshev = 4,
32    /// Hamming distance for binary-like vectors (threshold > 0.5).
33    Hamming = 5,
34    /// Jaccard distance for binary-like vectors (threshold > 0.5).
35    Jaccard = 6,
36    /// Pearson distance: 1 - Pearson correlation coefficient.
37    Pearson = 7,
38}
39
40/// Euclidean (L2) squared distance.
41#[inline]
42pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
43    debug_assert_eq!(a.len(), b.len());
44    let mut sum = 0.0f32;
45    for i in 0..a.len() {
46        let d = a[i] - b[i];
47        sum += d * d;
48    }
49    sum
50}
51
52/// Cosine distance: 1.0 - cosine_similarity(a, b).
53///
54/// Returns 0.0 for identical directions, 2.0 for opposite directions.
55#[inline]
56pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
57    debug_assert_eq!(a.len(), b.len());
58    let mut dot = 0.0f32;
59    let mut norm_a = 0.0f32;
60    let mut norm_b = 0.0f32;
61
62    for i in 0..a.len() {
63        dot += a[i] * b[i];
64        norm_a += a[i] * a[i];
65        norm_b += b[i] * b[i];
66    }
67
68    let denom = (norm_a * norm_b).sqrt();
69    if denom < f32::EPSILON {
70        return 1.0;
71    }
72    (1.0 - (dot / denom)).max(0.0)
73}
74
75/// Negative inner product (for max-inner-product search via min-heap).
76#[inline]
77pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
78    debug_assert_eq!(a.len(), b.len());
79    let mut dot = 0.0f32;
80    for i in 0..a.len() {
81        dot += a[i] * b[i];
82    }
83    -dot
84}
85
86/// Manhattan (L1) distance: sum of absolute differences.
87#[inline]
88pub fn manhattan(a: &[f32], b: &[f32]) -> f32 {
89    debug_assert_eq!(a.len(), b.len());
90    let mut sum = 0.0f32;
91    for i in 0..a.len() {
92        sum += (a[i] - b[i]).abs();
93    }
94    sum
95}
96
97/// Chebyshev (L-infinity) distance: max absolute difference.
98#[inline]
99pub fn chebyshev(a: &[f32], b: &[f32]) -> f32 {
100    debug_assert_eq!(a.len(), b.len());
101    let mut max = 0.0f32;
102    for i in 0..a.len() {
103        let d = (a[i] - b[i]).abs();
104        if d > max {
105            max = d;
106        }
107    }
108    max
109}
110
111/// Hamming distance for f32 vectors (values > 0.5 treated as 1).
112#[inline]
113pub fn hamming_f32(a: &[f32], b: &[f32]) -> f32 {
114    debug_assert_eq!(a.len(), b.len());
115    let mut count = 0u32;
116    for i in 0..a.len() {
117        let ba = a[i] > 0.5;
118        let bb = b[i] > 0.5;
119        if ba != bb {
120            count += 1;
121        }
122    }
123    count as f32
124}
125
126/// Jaccard distance for f32 vectors (values > 0.5 treated as set membership).
127///
128/// Returns 1 - |intersection|/|union|. If both are zero-sets, returns 0.0.
129#[inline]
130pub fn jaccard(a: &[f32], b: &[f32]) -> f32 {
131    debug_assert_eq!(a.len(), b.len());
132    let mut intersection = 0u32;
133    let mut union = 0u32;
134    for i in 0..a.len() {
135        let ba = a[i] > 0.5;
136        let bb = b[i] > 0.5;
137        if ba || bb {
138            union += 1;
139        }
140        if ba && bb {
141            intersection += 1;
142        }
143    }
144    if union == 0 {
145        0.0
146    } else {
147        1.0 - (intersection as f32 / union as f32)
148    }
149}
150
151/// Pearson distance: 1 - Pearson correlation coefficient.
152///
153/// Returns 0.0 for perfectly correlated, 1.0 for uncorrelated, ~2.0 for
154/// perfectly anti-correlated.
155#[inline]
156pub fn pearson(a: &[f32], b: &[f32]) -> f32 {
157    debug_assert_eq!(a.len(), b.len());
158    let n = a.len() as f32;
159    if n < 2.0 {
160        return 1.0;
161    }
162    let mut sum_a = 0.0f32;
163    let mut sum_b = 0.0f32;
164    for i in 0..a.len() {
165        sum_a += a[i];
166        sum_b += b[i];
167    }
168    let mean_a = sum_a / n;
169    let mean_b = sum_b / n;
170
171    let mut cov = 0.0f32;
172    let mut var_a = 0.0f32;
173    let mut var_b = 0.0f32;
174    for i in 0..a.len() {
175        let da = a[i] - mean_a;
176        let db = b[i] - mean_b;
177        cov += da * db;
178        var_a += da * da;
179        var_b += db * db;
180    }
181    let denom = (var_a * var_b).sqrt();
182    if denom < f32::EPSILON {
183        return 1.0;
184    }
185    (1.0 - cov / denom).max(0.0)
186}
187
188/// Compute distance using the specified metric (scalar dispatch).
189#[inline]
190pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
191    match metric {
192        DistanceMetric::L2 => l2_squared(a, b),
193        DistanceMetric::Cosine => cosine_distance(a, b),
194        DistanceMetric::InnerProduct => neg_inner_product(a, b),
195        DistanceMetric::Manhattan => manhattan(a, b),
196        DistanceMetric::Chebyshev => chebyshev(a, b),
197        DistanceMetric::Hamming => hamming_f32(a, b),
198        DistanceMetric::Jaccard => jaccard(a, b),
199        DistanceMetric::Pearson => pearson(a, b),
200    }
201}