nodedb_types/
vector_distance.rs1#[derive(
10 Debug,
11 Clone,
12 Copy,
13 PartialEq,
14 Eq,
15 serde::Serialize,
16 serde::Deserialize,
17 zerompk::ToMessagePack,
18 zerompk::FromMessagePack,
19)]
20#[non_exhaustive]
21pub enum DistanceMetric {
22 L2 = 0,
24 Cosine = 1,
26 InnerProduct = 2,
28 Manhattan = 3,
30 Chebyshev = 4,
32 Hamming = 5,
34 Jaccard = 6,
36 Pearson = 7,
38}
39
40#[inline]
42pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
43 debug_assert_eq!(a.len(), b.len());
44 let mut sum = 0.0f32;
45 for i in 0..a.len() {
46 let d = a[i] - b[i];
47 sum += d * d;
48 }
49 sum
50}
51
52#[inline]
56pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
57 debug_assert_eq!(a.len(), b.len());
58 let mut dot = 0.0f32;
59 let mut norm_a = 0.0f32;
60 let mut norm_b = 0.0f32;
61
62 for i in 0..a.len() {
63 dot += a[i] * b[i];
64 norm_a += a[i] * a[i];
65 norm_b += b[i] * b[i];
66 }
67
68 let denom = (norm_a * norm_b).sqrt();
69 if denom < f32::EPSILON {
70 return 1.0;
71 }
72 (1.0 - (dot / denom)).max(0.0)
73}
74
75#[inline]
77pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
78 debug_assert_eq!(a.len(), b.len());
79 let mut dot = 0.0f32;
80 for i in 0..a.len() {
81 dot += a[i] * b[i];
82 }
83 -dot
84}
85
86#[inline]
88pub fn manhattan(a: &[f32], b: &[f32]) -> f32 {
89 debug_assert_eq!(a.len(), b.len());
90 let mut sum = 0.0f32;
91 for i in 0..a.len() {
92 sum += (a[i] - b[i]).abs();
93 }
94 sum
95}
96
97#[inline]
99pub fn chebyshev(a: &[f32], b: &[f32]) -> f32 {
100 debug_assert_eq!(a.len(), b.len());
101 let mut max = 0.0f32;
102 for i in 0..a.len() {
103 let d = (a[i] - b[i]).abs();
104 if d > max {
105 max = d;
106 }
107 }
108 max
109}
110
111#[inline]
113pub fn hamming_f32(a: &[f32], b: &[f32]) -> f32 {
114 debug_assert_eq!(a.len(), b.len());
115 let mut count = 0u32;
116 for i in 0..a.len() {
117 let ba = a[i] > 0.5;
118 let bb = b[i] > 0.5;
119 if ba != bb {
120 count += 1;
121 }
122 }
123 count as f32
124}
125
126#[inline]
130pub fn jaccard(a: &[f32], b: &[f32]) -> f32 {
131 debug_assert_eq!(a.len(), b.len());
132 let mut intersection = 0u32;
133 let mut union = 0u32;
134 for i in 0..a.len() {
135 let ba = a[i] > 0.5;
136 let bb = b[i] > 0.5;
137 if ba || bb {
138 union += 1;
139 }
140 if ba && bb {
141 intersection += 1;
142 }
143 }
144 if union == 0 {
145 0.0
146 } else {
147 1.0 - (intersection as f32 / union as f32)
148 }
149}
150
151#[inline]
156pub fn pearson(a: &[f32], b: &[f32]) -> f32 {
157 debug_assert_eq!(a.len(), b.len());
158 let n = a.len() as f32;
159 if n < 2.0 {
160 return 1.0;
161 }
162 let mut sum_a = 0.0f32;
163 let mut sum_b = 0.0f32;
164 for i in 0..a.len() {
165 sum_a += a[i];
166 sum_b += b[i];
167 }
168 let mean_a = sum_a / n;
169 let mean_b = sum_b / n;
170
171 let mut cov = 0.0f32;
172 let mut var_a = 0.0f32;
173 let mut var_b = 0.0f32;
174 for i in 0..a.len() {
175 let da = a[i] - mean_a;
176 let db = b[i] - mean_b;
177 cov += da * db;
178 var_a += da * da;
179 var_b += db * db;
180 }
181 let denom = (var_a * var_b).sqrt();
182 if denom < f32::EPSILON {
183 return 1.0;
184 }
185 (1.0 - cov / denom).max(0.0)
186}
187
188#[inline]
190pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
191 match metric {
192 DistanceMetric::L2 => l2_squared(a, b),
193 DistanceMetric::Cosine => cosine_distance(a, b),
194 DistanceMetric::InnerProduct => neg_inner_product(a, b),
195 DistanceMetric::Manhattan => manhattan(a, b),
196 DistanceMetric::Chebyshev => chebyshev(a, b),
197 DistanceMetric::Hamming => hamming_f32(a, b),
198 DistanceMetric::Jaccard => jaccard(a, b),
199 DistanceMetric::Pearson => pearson(a, b),
200 }
201}