1use crate::Vector;
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum ExtendedDistanceMetric {
13 Cosine,
15 Euclidean,
16 Manhattan,
17 Chebyshev,
18 Minkowski { p: f32 },
19
20 Hamming,
22 Jaccard,
23 Dice,
24 Pearson,
25 Spearman,
26 Kendall,
27
28 KLDivergence,
30 JensenShannon,
31 Bhattacharyya,
32 Hellinger,
33
34 Levenshtein,
36 DamerauLevenshtein,
37
38 MutualInformation,
40 NormalizedCompressionDistance,
41
42 Mahalanobis,
44 BrayCurtis,
45
46 Custom(u32), }
49
50impl ExtendedDistanceMetric {
51 pub fn distance(&self, a: &Vector, b: &Vector) -> Result<f32> {
53 let a_f32 = a.as_f32();
54 let b_f32 = b.as_f32();
55
56 if a_f32.len() != b_f32.len() {
57 return Err(anyhow::anyhow!(
58 "Vector dimensions must match: {} != {}",
59 a_f32.len(),
60 b_f32.len()
61 ));
62 }
63
64 match self {
65 ExtendedDistanceMetric::Cosine => Self::cosine_distance(&a_f32, &b_f32),
66 ExtendedDistanceMetric::Euclidean => Self::euclidean_distance(&a_f32, &b_f32),
67 ExtendedDistanceMetric::Manhattan => Self::manhattan_distance(&a_f32, &b_f32),
68 ExtendedDistanceMetric::Chebyshev => Self::chebyshev_distance(&a_f32, &b_f32),
69 ExtendedDistanceMetric::Minkowski { p } => Self::minkowski_distance(&a_f32, &b_f32, *p),
70 ExtendedDistanceMetric::Hamming => Self::hamming_distance(&a_f32, &b_f32),
71 ExtendedDistanceMetric::Jaccard => Self::jaccard_distance(&a_f32, &b_f32),
72 ExtendedDistanceMetric::Dice => Self::dice_distance(&a_f32, &b_f32),
73 ExtendedDistanceMetric::Pearson => Self::pearson_distance(&a_f32, &b_f32),
74 ExtendedDistanceMetric::Spearman => Self::spearman_distance(&a_f32, &b_f32),
75 ExtendedDistanceMetric::Kendall => Self::kendall_distance(&a_f32, &b_f32),
76 ExtendedDistanceMetric::KLDivergence => Self::kl_divergence(&a_f32, &b_f32),
77 ExtendedDistanceMetric::JensenShannon => Self::jensen_shannon(&a_f32, &b_f32),
78 ExtendedDistanceMetric::Bhattacharyya => Self::bhattacharyya(&a_f32, &b_f32),
79 ExtendedDistanceMetric::Hellinger => Self::hellinger(&a_f32, &b_f32),
80 ExtendedDistanceMetric::Levenshtein => Self::levenshtein_distance(&a_f32, &b_f32),
81 ExtendedDistanceMetric::DamerauLevenshtein => {
82 Self::damerau_levenshtein_distance(&a_f32, &b_f32)
83 }
84 ExtendedDistanceMetric::MutualInformation => Self::mutual_information(&a_f32, &b_f32),
85 ExtendedDistanceMetric::NormalizedCompressionDistance => Self::ncd(&a_f32, &b_f32),
86 ExtendedDistanceMetric::Mahalanobis => Self::mahalanobis_distance(&a_f32, &b_f32),
87 ExtendedDistanceMetric::BrayCurtis => Self::bray_curtis_distance(&a_f32, &b_f32),
88 ExtendedDistanceMetric::Custom(_id) => {
89 Err(anyhow::anyhow!("Custom metrics not implemented"))
91 }
92 }
93 }
94
95 fn cosine_distance(a: &[f32], b: &[f32]) -> Result<f32> {
98 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
99 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
100 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
101
102 if norm_a == 0.0 || norm_b == 0.0 {
103 return Ok(1.0);
104 }
105
106 Ok(1.0 - (dot / (norm_a * norm_b)))
107 }
108
109 fn euclidean_distance(a: &[f32], b: &[f32]) -> Result<f32> {
110 let dist: f32 = a
111 .iter()
112 .zip(b)
113 .map(|(x, y)| (x - y).powi(2))
114 .sum::<f32>()
115 .sqrt();
116 Ok(dist)
117 }
118
119 fn manhattan_distance(a: &[f32], b: &[f32]) -> Result<f32> {
120 let dist: f32 = a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum();
121 Ok(dist)
122 }
123
124 fn chebyshev_distance(a: &[f32], b: &[f32]) -> Result<f32> {
125 let dist = a
126 .iter()
127 .zip(b)
128 .map(|(x, y)| (x - y).abs())
129 .fold(0.0f32, |max, val| max.max(val));
130 Ok(dist)
131 }
132
133 fn minkowski_distance(a: &[f32], b: &[f32], p: f32) -> Result<f32> {
134 if p <= 0.0 {
135 return Err(anyhow::anyhow!("p must be positive for Minkowski distance"));
136 }
137
138 if p == f32::INFINITY {
139 return Self::chebyshev_distance(a, b);
140 }
141
142 let dist = a
143 .iter()
144 .zip(b)
145 .map(|(x, y)| (x - y).abs().powf(p))
146 .sum::<f32>()
147 .powf(1.0 / p);
148 Ok(dist)
149 }
150
151 fn hamming_distance(a: &[f32], b: &[f32]) -> Result<f32> {
154 let threshold = 0.5; let dist = a
156 .iter()
157 .zip(b)
158 .filter(|(x, y)| {
159 let x_bin = **x > threshold;
160 let y_bin = **y > threshold;
161 x_bin != y_bin
162 })
163 .count();
164 Ok(dist as f32)
165 }
166
167 fn jaccard_distance(a: &[f32], b: &[f32]) -> Result<f32> {
168 let threshold = 0.5;
169 let mut intersection = 0;
170 let mut union = 0;
171
172 for (x, y) in a.iter().zip(b) {
173 let x_bin = *x > threshold;
174 let y_bin = *y > threshold;
175
176 if x_bin || y_bin {
177 union += 1;
178 if x_bin && y_bin {
179 intersection += 1;
180 }
181 }
182 }
183
184 if union == 0 {
185 return Ok(0.0);
186 }
187
188 Ok(1.0 - (intersection as f32 / union as f32))
189 }
190
191 fn dice_distance(a: &[f32], b: &[f32]) -> Result<f32> {
192 let threshold = 0.5;
193 let mut intersection = 0;
194 let mut a_count = 0;
195 let mut b_count = 0;
196
197 for (x, y) in a.iter().zip(b) {
198 let x_bin = *x > threshold;
199 let y_bin = *y > threshold;
200
201 if x_bin {
202 a_count += 1;
203 }
204 if y_bin {
205 b_count += 1;
206 }
207 if x_bin && y_bin {
208 intersection += 1;
209 }
210 }
211
212 let sum = a_count + b_count;
213 if sum == 0 {
214 return Ok(0.0);
215 }
216
217 Ok(1.0 - (2.0 * intersection as f32 / sum as f32))
218 }
219
220 fn pearson_distance(a: &[f32], b: &[f32]) -> Result<f32> {
221 let n = a.len() as f32;
222 let mean_a: f32 = a.iter().sum::<f32>() / n;
223 let mean_b: f32 = b.iter().sum::<f32>() / n;
224
225 let mut numerator = 0.0;
226 let mut sum_sq_a = 0.0;
227 let mut sum_sq_b = 0.0;
228
229 for (x, y) in a.iter().zip(b) {
230 let da = x - mean_a;
231 let db = y - mean_b;
232 numerator += da * db;
233 sum_sq_a += da * da;
234 sum_sq_b += db * db;
235 }
236
237 if sum_sq_a == 0.0 || sum_sq_b == 0.0 {
238 return Ok(1.0);
239 }
240
241 let correlation = numerator / (sum_sq_a.sqrt() * sum_sq_b.sqrt());
242 Ok(1.0 - correlation)
243 }
244
245 fn spearman_distance(a: &[f32], b: &[f32]) -> Result<f32> {
246 let rank_a = Self::rank_vector(a);
248 let rank_b = Self::rank_vector(b);
249
250 Self::pearson_distance(&rank_a, &rank_b)
252 }
253
254 fn kendall_distance(a: &[f32], b: &[f32]) -> Result<f32> {
255 let n = a.len();
256 let mut concordant = 0;
257 let mut discordant = 0;
258
259 for i in 0..n {
260 for j in (i + 1)..n {
261 let sign_a = (a[j] - a[i]).signum();
262 let sign_b = (b[j] - b[i]).signum();
263
264 if sign_a * sign_b > 0.0 {
265 concordant += 1;
266 } else if sign_a * sign_b < 0.0 {
267 discordant += 1;
268 }
269 }
270 }
271
272 let total_pairs = (n * (n - 1)) / 2;
273 if total_pairs == 0 {
274 return Ok(0.0);
275 }
276
277 let tau = (concordant - discordant) as f32 / total_pairs as f32;
278 Ok(1.0 - tau)
279 }
280
281 fn kl_divergence(p: &[f32], q: &[f32]) -> Result<f32> {
284 let epsilon = 1e-10;
285 let mut divergence = 0.0;
286
287 for (pi, qi) in p.iter().zip(q) {
288 let pi_safe = pi.max(epsilon);
289 let qi_safe = qi.max(epsilon);
290 divergence += pi_safe * (pi_safe / qi_safe).ln();
291 }
292
293 Ok(divergence)
294 }
295
296 fn jensen_shannon(p: &[f32], q: &[f32]) -> Result<f32> {
297 let m: Vec<f32> = p.iter().zip(q).map(|(pi, qi)| (pi + qi) / 2.0).collect();
298
299 let kl_pm = Self::kl_divergence(p, &m)?;
300 let kl_qm = Self::kl_divergence(q, &m)?;
301
302 Ok((kl_pm + kl_qm) / 2.0)
303 }
304
305 fn bhattacharyya(p: &[f32], q: &[f32]) -> Result<f32> {
306 let bc: f32 = p.iter().zip(q).map(|(pi, qi)| (pi * qi).sqrt()).sum();
307 Ok(-bc.ln())
308 }
309
310 fn hellinger(p: &[f32], q: &[f32]) -> Result<f32> {
311 let sum: f32 = p
312 .iter()
313 .zip(q)
314 .map(|(pi, qi)| (pi.sqrt() - qi.sqrt()).powi(2))
315 .sum();
316 Ok((sum / 2.0).sqrt())
317 }
318
319 #[allow(clippy::needless_range_loop)]
322 fn levenshtein_distance(a: &[f32], b: &[f32]) -> Result<f32> {
323 let threshold = 0.5;
324 let a_bin: Vec<bool> = a.iter().map(|x| *x > threshold).collect();
325 let b_bin: Vec<bool> = b.iter().map(|x| *x > threshold).collect();
326
327 let m = a_bin.len();
328 let n = b_bin.len();
329
330 if m == 0 {
331 return Ok(n as f32);
332 }
333 if n == 0 {
334 return Ok(m as f32);
335 }
336
337 let mut dp = vec![vec![0; n + 1]; m + 1];
338
339 for i in 0..=m {
340 dp[i][0] = i;
341 }
342 for j in 0..=n {
343 dp[0][j] = j;
344 }
345
346 for i in 1..=m {
347 for j in 1..=n {
348 let cost = if a_bin[i - 1] == b_bin[j - 1] { 0 } else { 1 };
349 dp[i][j] = (dp[i - 1][j] + 1)
350 .min(dp[i][j - 1] + 1)
351 .min(dp[i - 1][j - 1] + cost);
352 }
353 }
354
355 Ok(dp[m][n] as f32)
356 }
357
358 fn damerau_levenshtein_distance(a: &[f32], b: &[f32]) -> Result<f32> {
359 Self::levenshtein_distance(a, b)
362 }
363
364 fn mutual_information(a: &[f32], b: &[f32]) -> Result<f32> {
367 let joint_entropy = Self::calculate_entropy(a)? + Self::calculate_entropy(b)?;
370 let individual_entropy = Self::calculate_joint_entropy(a, b)?;
371
372 Ok(joint_entropy - individual_entropy)
373 }
374
375 fn ncd(a: &[f32], b: &[f32]) -> Result<f32> {
376 let ca = Self::estimate_compression_size(a);
379 let cb = Self::estimate_compression_size(b);
380 let cab = Self::estimate_joint_compression_size(a, b);
381
382 let min_c = ca.min(cb);
383 let max_c = ca.max(cb);
384
385 if max_c == 0.0 {
386 return Ok(0.0);
387 }
388
389 Ok((cab - min_c) / max_c)
390 }
391
392 fn mahalanobis_distance(a: &[f32], b: &[f32]) -> Result<f32> {
395 Self::euclidean_distance(a, b)
398 }
399
400 fn bray_curtis_distance(a: &[f32], b: &[f32]) -> Result<f32> {
401 let mut numerator = 0.0;
402 let mut denominator = 0.0;
403
404 for (x, y) in a.iter().zip(b) {
405 numerator += (x - y).abs();
406 denominator += x + y;
407 }
408
409 if denominator == 0.0 {
410 return Ok(0.0);
411 }
412
413 Ok(numerator / denominator)
414 }
415
416 fn rank_vector(v: &[f32]) -> Vec<f32> {
419 let mut indexed: Vec<(usize, f32)> = v.iter().enumerate().map(|(i, &x)| (i, x)).collect();
420 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
421
422 let mut ranks = vec![0.0; v.len()];
423 for (rank, (original_index, _)) in indexed.iter().enumerate() {
424 ranks[*original_index] = rank as f32;
425 }
426
427 ranks
428 }
429
430 fn calculate_entropy(v: &[f32]) -> Result<f32> {
431 let epsilon = 1e-10;
432 let mut entropy = 0.0;
433
434 for &x in v {
435 if x > epsilon {
436 entropy -= x * x.ln();
437 }
438 }
439
440 Ok(entropy)
441 }
442
443 fn calculate_joint_entropy(a: &[f32], b: &[f32]) -> Result<f32> {
444 let epsilon = 1e-10;
445 let mut entropy = 0.0;
446
447 for (x, y) in a.iter().zip(b) {
448 let joint = x * y;
449 if joint > epsilon {
450 entropy -= joint * joint.ln();
451 }
452 }
453
454 Ok(entropy)
455 }
456
457 fn estimate_compression_size(v: &[f32]) -> f32 {
458 let mut sorted = v.to_vec();
461 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
462
463 let mut unique_count = 1;
464 for i in 1..sorted.len() {
465 if (sorted[i] - sorted[i - 1]).abs() > 1e-6 {
466 unique_count += 1;
467 }
468 }
469
470 unique_count as f32
471 }
472
473 fn estimate_joint_compression_size(a: &[f32], b: &[f32]) -> f32 {
474 let mut combined = Vec::with_capacity(a.len() + b.len());
475 combined.extend_from_slice(a);
476 combined.extend_from_slice(b);
477 Self::estimate_compression_size(&combined)
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_cosine_distance() {
487 let a = Vector::new(vec![1.0, 0.0, 0.0]);
488 let b = Vector::new(vec![1.0, 0.0, 0.0]);
489
490 let distance = ExtendedDistanceMetric::Cosine.distance(&a, &b).unwrap();
491 assert!(distance < 0.01); }
493
494 #[test]
495 fn test_euclidean_distance() {
496 let a = Vector::new(vec![0.0, 0.0]);
497 let b = Vector::new(vec![3.0, 4.0]);
498
499 let distance = ExtendedDistanceMetric::Euclidean.distance(&a, &b).unwrap();
500 assert!((distance - 5.0).abs() < 0.01); }
502
503 #[test]
504 fn test_hamming_distance() {
505 let a = Vector::new(vec![1.0, 1.0, 0.0, 0.0]);
506 let b = Vector::new(vec![1.0, 0.0, 1.0, 0.0]);
507
508 let distance = ExtendedDistanceMetric::Hamming.distance(&a, &b).unwrap();
509 assert_eq!(distance, 2.0); }
511
512 #[test]
513 fn test_jaccard_distance() {
514 let a = Vector::new(vec![1.0, 1.0, 0.0, 0.0]);
515 let b = Vector::new(vec![1.0, 0.0, 1.0, 0.0]);
516
517 let distance = ExtendedDistanceMetric::Jaccard.distance(&a, &b).unwrap();
518 assert!(distance > 0.0 && distance < 1.0);
519 }
520
521 #[test]
522 fn test_pearson_distance() {
523 let a = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
524 let b = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
525
526 let distance = ExtendedDistanceMetric::Pearson.distance(&a, &b).unwrap();
527 assert!(distance < 0.01); }
529
530 #[test]
531 fn test_manhattan_distance() {
532 let a = Vector::new(vec![1.0, 2.0, 3.0]);
533 let b = Vector::new(vec![4.0, 5.0, 6.0]);
534
535 let distance = ExtendedDistanceMetric::Manhattan.distance(&a, &b).unwrap();
536 assert_eq!(distance, 9.0); }
538}