oxirs_vec/
distance_metrics.rs

1//! Extended distance metrics for vector similarity
2//!
3//! This module provides a comprehensive collection of distance metrics
4//! including specialized metrics for different data types and use cases.
5
6use crate::Vector;
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9
10/// Extended distance metric types
11#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum ExtendedDistanceMetric {
13    // Standard metrics
14    Cosine,
15    Euclidean,
16    Manhattan,
17    Chebyshev,
18    Minkowski { p: f32 },
19
20    // Specialized metrics
21    Hamming,
22    Jaccard,
23    Dice,
24    Pearson,
25    Spearman,
26    Kendall,
27
28    // Statistical metrics
29    KLDivergence,
30    JensenShannon,
31    Bhattacharyya,
32    Hellinger,
33
34    // Edit distance metrics
35    Levenshtein,
36    DamerauLevenshtein,
37
38    // Information-theoretic metrics
39    MutualInformation,
40    NormalizedCompressionDistance,
41
42    // Specialized for embeddings
43    Mahalanobis,
44    BrayCurtis,
45
46    // Custom metric (user-defined)
47    Custom(u32), // ID for custom metric lookup
48}
49
50impl ExtendedDistanceMetric {
51    /// Calculate distance between two vectors
52    pub fn distance(&self, a: &Vector, b: &Vector) -> Result<f32> {
53        let a_f32 = a.as_f32();
54        let b_f32 = b.as_f32();
55
56        if a_f32.len() != b_f32.len() {
57            return Err(anyhow::anyhow!(
58                "Vector dimensions must match: {} != {}",
59                a_f32.len(),
60                b_f32.len()
61            ));
62        }
63
64        match self {
65            ExtendedDistanceMetric::Cosine => Self::cosine_distance(&a_f32, &b_f32),
66            ExtendedDistanceMetric::Euclidean => Self::euclidean_distance(&a_f32, &b_f32),
67            ExtendedDistanceMetric::Manhattan => Self::manhattan_distance(&a_f32, &b_f32),
68            ExtendedDistanceMetric::Chebyshev => Self::chebyshev_distance(&a_f32, &b_f32),
69            ExtendedDistanceMetric::Minkowski { p } => Self::minkowski_distance(&a_f32, &b_f32, *p),
70            ExtendedDistanceMetric::Hamming => Self::hamming_distance(&a_f32, &b_f32),
71            ExtendedDistanceMetric::Jaccard => Self::jaccard_distance(&a_f32, &b_f32),
72            ExtendedDistanceMetric::Dice => Self::dice_distance(&a_f32, &b_f32),
73            ExtendedDistanceMetric::Pearson => Self::pearson_distance(&a_f32, &b_f32),
74            ExtendedDistanceMetric::Spearman => Self::spearman_distance(&a_f32, &b_f32),
75            ExtendedDistanceMetric::Kendall => Self::kendall_distance(&a_f32, &b_f32),
76            ExtendedDistanceMetric::KLDivergence => Self::kl_divergence(&a_f32, &b_f32),
77            ExtendedDistanceMetric::JensenShannon => Self::jensen_shannon(&a_f32, &b_f32),
78            ExtendedDistanceMetric::Bhattacharyya => Self::bhattacharyya(&a_f32, &b_f32),
79            ExtendedDistanceMetric::Hellinger => Self::hellinger(&a_f32, &b_f32),
80            ExtendedDistanceMetric::Levenshtein => Self::levenshtein_distance(&a_f32, &b_f32),
81            ExtendedDistanceMetric::DamerauLevenshtein => {
82                Self::damerau_levenshtein_distance(&a_f32, &b_f32)
83            }
84            ExtendedDistanceMetric::MutualInformation => Self::mutual_information(&a_f32, &b_f32),
85            ExtendedDistanceMetric::NormalizedCompressionDistance => Self::ncd(&a_f32, &b_f32),
86            ExtendedDistanceMetric::Mahalanobis => Self::mahalanobis_distance(&a_f32, &b_f32),
87            ExtendedDistanceMetric::BrayCurtis => Self::bray_curtis_distance(&a_f32, &b_f32),
88            ExtendedDistanceMetric::Custom(_id) => {
89                // Custom metrics would be looked up from a registry
90                Err(anyhow::anyhow!("Custom metrics not implemented"))
91            }
92        }
93    }
94
95    // Standard distance metrics
96
97    fn cosine_distance(a: &[f32], b: &[f32]) -> Result<f32> {
98        let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
99        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
100        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
101
102        if norm_a == 0.0 || norm_b == 0.0 {
103            return Ok(1.0);
104        }
105
106        Ok(1.0 - (dot / (norm_a * norm_b)))
107    }
108
109    fn euclidean_distance(a: &[f32], b: &[f32]) -> Result<f32> {
110        let dist: f32 = a
111            .iter()
112            .zip(b)
113            .map(|(x, y)| (x - y).powi(2))
114            .sum::<f32>()
115            .sqrt();
116        Ok(dist)
117    }
118
119    fn manhattan_distance(a: &[f32], b: &[f32]) -> Result<f32> {
120        let dist: f32 = a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum();
121        Ok(dist)
122    }
123
124    fn chebyshev_distance(a: &[f32], b: &[f32]) -> Result<f32> {
125        let dist = a
126            .iter()
127            .zip(b)
128            .map(|(x, y)| (x - y).abs())
129            .fold(0.0f32, |max, val| max.max(val));
130        Ok(dist)
131    }
132
133    fn minkowski_distance(a: &[f32], b: &[f32], p: f32) -> Result<f32> {
134        if p <= 0.0 {
135            return Err(anyhow::anyhow!("p must be positive for Minkowski distance"));
136        }
137
138        if p == f32::INFINITY {
139            return Self::chebyshev_distance(a, b);
140        }
141
142        let dist = a
143            .iter()
144            .zip(b)
145            .map(|(x, y)| (x - y).abs().powf(p))
146            .sum::<f32>()
147            .powf(1.0 / p);
148        Ok(dist)
149    }
150
151    // Specialized distance metrics
152
153    fn hamming_distance(a: &[f32], b: &[f32]) -> Result<f32> {
154        let threshold = 0.5; // Threshold for binary conversion
155        let dist = a
156            .iter()
157            .zip(b)
158            .filter(|(x, y)| {
159                let x_bin = **x > threshold;
160                let y_bin = **y > threshold;
161                x_bin != y_bin
162            })
163            .count();
164        Ok(dist as f32)
165    }
166
167    fn jaccard_distance(a: &[f32], b: &[f32]) -> Result<f32> {
168        let threshold = 0.5;
169        let mut intersection = 0;
170        let mut union = 0;
171
172        for (x, y) in a.iter().zip(b) {
173            let x_bin = *x > threshold;
174            let y_bin = *y > threshold;
175
176            if x_bin || y_bin {
177                union += 1;
178                if x_bin && y_bin {
179                    intersection += 1;
180                }
181            }
182        }
183
184        if union == 0 {
185            return Ok(0.0);
186        }
187
188        Ok(1.0 - (intersection as f32 / union as f32))
189    }
190
191    fn dice_distance(a: &[f32], b: &[f32]) -> Result<f32> {
192        let threshold = 0.5;
193        let mut intersection = 0;
194        let mut a_count = 0;
195        let mut b_count = 0;
196
197        for (x, y) in a.iter().zip(b) {
198            let x_bin = *x > threshold;
199            let y_bin = *y > threshold;
200
201            if x_bin {
202                a_count += 1;
203            }
204            if y_bin {
205                b_count += 1;
206            }
207            if x_bin && y_bin {
208                intersection += 1;
209            }
210        }
211
212        let sum = a_count + b_count;
213        if sum == 0 {
214            return Ok(0.0);
215        }
216
217        Ok(1.0 - (2.0 * intersection as f32 / sum as f32))
218    }
219
220    fn pearson_distance(a: &[f32], b: &[f32]) -> Result<f32> {
221        let n = a.len() as f32;
222        let mean_a: f32 = a.iter().sum::<f32>() / n;
223        let mean_b: f32 = b.iter().sum::<f32>() / n;
224
225        let mut numerator = 0.0;
226        let mut sum_sq_a = 0.0;
227        let mut sum_sq_b = 0.0;
228
229        for (x, y) in a.iter().zip(b) {
230            let da = x - mean_a;
231            let db = y - mean_b;
232            numerator += da * db;
233            sum_sq_a += da * da;
234            sum_sq_b += db * db;
235        }
236
237        if sum_sq_a == 0.0 || sum_sq_b == 0.0 {
238            return Ok(1.0);
239        }
240
241        let correlation = numerator / (sum_sq_a.sqrt() * sum_sq_b.sqrt());
242        Ok(1.0 - correlation)
243    }
244
245    fn spearman_distance(a: &[f32], b: &[f32]) -> Result<f32> {
246        // Convert to ranks
247        let rank_a = Self::rank_vector(a);
248        let rank_b = Self::rank_vector(b);
249
250        // Calculate Pearson on ranks
251        Self::pearson_distance(&rank_a, &rank_b)
252    }
253
254    fn kendall_distance(a: &[f32], b: &[f32]) -> Result<f32> {
255        let n = a.len();
256        let mut concordant = 0;
257        let mut discordant = 0;
258
259        for i in 0..n {
260            for j in (i + 1)..n {
261                let sign_a = (a[j] - a[i]).signum();
262                let sign_b = (b[j] - b[i]).signum();
263
264                if sign_a * sign_b > 0.0 {
265                    concordant += 1;
266                } else if sign_a * sign_b < 0.0 {
267                    discordant += 1;
268                }
269            }
270        }
271
272        let total_pairs = (n * (n - 1)) / 2;
273        if total_pairs == 0 {
274            return Ok(0.0);
275        }
276
277        let tau = (concordant - discordant) as f32 / total_pairs as f32;
278        Ok(1.0 - tau)
279    }
280
281    // Statistical distance metrics
282
283    fn kl_divergence(p: &[f32], q: &[f32]) -> Result<f32> {
284        let epsilon = 1e-10;
285        let mut divergence = 0.0;
286
287        for (pi, qi) in p.iter().zip(q) {
288            let pi_safe = pi.max(epsilon);
289            let qi_safe = qi.max(epsilon);
290            divergence += pi_safe * (pi_safe / qi_safe).ln();
291        }
292
293        Ok(divergence)
294    }
295
296    fn jensen_shannon(p: &[f32], q: &[f32]) -> Result<f32> {
297        let m: Vec<f32> = p.iter().zip(q).map(|(pi, qi)| (pi + qi) / 2.0).collect();
298
299        let kl_pm = Self::kl_divergence(p, &m)?;
300        let kl_qm = Self::kl_divergence(q, &m)?;
301
302        Ok((kl_pm + kl_qm) / 2.0)
303    }
304
305    fn bhattacharyya(p: &[f32], q: &[f32]) -> Result<f32> {
306        let bc: f32 = p.iter().zip(q).map(|(pi, qi)| (pi * qi).sqrt()).sum();
307        Ok(-bc.ln())
308    }
309
310    fn hellinger(p: &[f32], q: &[f32]) -> Result<f32> {
311        let sum: f32 = p
312            .iter()
313            .zip(q)
314            .map(|(pi, qi)| (pi.sqrt() - qi.sqrt()).powi(2))
315            .sum();
316        Ok((sum / 2.0).sqrt())
317    }
318
319    // Edit distance metrics
320
321    #[allow(clippy::needless_range_loop)]
322    fn levenshtein_distance(a: &[f32], b: &[f32]) -> Result<f32> {
323        let threshold = 0.5;
324        let a_bin: Vec<bool> = a.iter().map(|x| *x > threshold).collect();
325        let b_bin: Vec<bool> = b.iter().map(|x| *x > threshold).collect();
326
327        let m = a_bin.len();
328        let n = b_bin.len();
329
330        if m == 0 {
331            return Ok(n as f32);
332        }
333        if n == 0 {
334            return Ok(m as f32);
335        }
336
337        let mut dp = vec![vec![0; n + 1]; m + 1];
338
339        for i in 0..=m {
340            dp[i][0] = i;
341        }
342        for j in 0..=n {
343            dp[0][j] = j;
344        }
345
346        for i in 1..=m {
347            for j in 1..=n {
348                let cost = if a_bin[i - 1] == b_bin[j - 1] { 0 } else { 1 };
349                dp[i][j] = (dp[i - 1][j] + 1)
350                    .min(dp[i][j - 1] + 1)
351                    .min(dp[i - 1][j - 1] + cost);
352            }
353        }
354
355        Ok(dp[m][n] as f32)
356    }
357
358    fn damerau_levenshtein_distance(a: &[f32], b: &[f32]) -> Result<f32> {
359        // Simplified Damerau-Levenshtein (allows transpositions)
360        // Full implementation is complex, this is an approximation
361        Self::levenshtein_distance(a, b)
362    }
363
364    // Information-theoretic metrics
365
366    fn mutual_information(a: &[f32], b: &[f32]) -> Result<f32> {
367        // Simplified mutual information calculation
368        // Full implementation would require histogram binning
369        let joint_entropy = Self::calculate_entropy(a)? + Self::calculate_entropy(b)?;
370        let individual_entropy = Self::calculate_joint_entropy(a, b)?;
371
372        Ok(joint_entropy - individual_entropy)
373    }
374
375    fn ncd(a: &[f32], b: &[f32]) -> Result<f32> {
376        // Normalized Compression Distance
377        // Approximation using simple compression ratios
378        let ca = Self::estimate_compression_size(a);
379        let cb = Self::estimate_compression_size(b);
380        let cab = Self::estimate_joint_compression_size(a, b);
381
382        let min_c = ca.min(cb);
383        let max_c = ca.max(cb);
384
385        if max_c == 0.0 {
386            return Ok(0.0);
387        }
388
389        Ok((cab - min_c) / max_c)
390    }
391
392    // Advanced distance metrics
393
394    fn mahalanobis_distance(a: &[f32], b: &[f32]) -> Result<f32> {
395        // Simplified Mahalanobis distance (assuming identity covariance)
396        // Full implementation would require covariance matrix
397        Self::euclidean_distance(a, b)
398    }
399
400    fn bray_curtis_distance(a: &[f32], b: &[f32]) -> Result<f32> {
401        let mut numerator = 0.0;
402        let mut denominator = 0.0;
403
404        for (x, y) in a.iter().zip(b) {
405            numerator += (x - y).abs();
406            denominator += x + y;
407        }
408
409        if denominator == 0.0 {
410            return Ok(0.0);
411        }
412
413        Ok(numerator / denominator)
414    }
415
416    // Helper functions
417
418    fn rank_vector(v: &[f32]) -> Vec<f32> {
419        let mut indexed: Vec<(usize, f32)> = v.iter().enumerate().map(|(i, &x)| (i, x)).collect();
420        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
421
422        let mut ranks = vec![0.0; v.len()];
423        for (rank, (original_index, _)) in indexed.iter().enumerate() {
424            ranks[*original_index] = rank as f32;
425        }
426
427        ranks
428    }
429
430    fn calculate_entropy(v: &[f32]) -> Result<f32> {
431        let epsilon = 1e-10;
432        let mut entropy = 0.0;
433
434        for &x in v {
435            if x > epsilon {
436                entropy -= x * x.ln();
437            }
438        }
439
440        Ok(entropy)
441    }
442
443    fn calculate_joint_entropy(a: &[f32], b: &[f32]) -> Result<f32> {
444        let epsilon = 1e-10;
445        let mut entropy = 0.0;
446
447        for (x, y) in a.iter().zip(b) {
448            let joint = x * y;
449            if joint > epsilon {
450                entropy -= joint * joint.ln();
451            }
452        }
453
454        Ok(entropy)
455    }
456
457    fn estimate_compression_size(v: &[f32]) -> f32 {
458        // Rough estimate based on unique values and entropy
459        // Since f32 doesn't implement Eq/Hash, we'll use a different approach
460        let mut sorted = v.to_vec();
461        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
462
463        let mut unique_count = 1;
464        for i in 1..sorted.len() {
465            if (sorted[i] - sorted[i - 1]).abs() > 1e-6 {
466                unique_count += 1;
467            }
468        }
469
470        unique_count as f32
471    }
472
473    fn estimate_joint_compression_size(a: &[f32], b: &[f32]) -> f32 {
474        let mut combined = Vec::with_capacity(a.len() + b.len());
475        combined.extend_from_slice(a);
476        combined.extend_from_slice(b);
477        Self::estimate_compression_size(&combined)
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn test_cosine_distance() {
487        let a = Vector::new(vec![1.0, 0.0, 0.0]);
488        let b = Vector::new(vec![1.0, 0.0, 0.0]);
489
490        let distance = ExtendedDistanceMetric::Cosine.distance(&a, &b).unwrap();
491        assert!(distance < 0.01); // Should be close to 0
492    }
493
494    #[test]
495    fn test_euclidean_distance() {
496        let a = Vector::new(vec![0.0, 0.0]);
497        let b = Vector::new(vec![3.0, 4.0]);
498
499        let distance = ExtendedDistanceMetric::Euclidean.distance(&a, &b).unwrap();
500        assert!((distance - 5.0).abs() < 0.01); // Should be 5.0
501    }
502
503    #[test]
504    fn test_hamming_distance() {
505        let a = Vector::new(vec![1.0, 1.0, 0.0, 0.0]);
506        let b = Vector::new(vec![1.0, 0.0, 1.0, 0.0]);
507
508        let distance = ExtendedDistanceMetric::Hamming.distance(&a, &b).unwrap();
509        assert_eq!(distance, 2.0); // 2 positions differ
510    }
511
512    #[test]
513    fn test_jaccard_distance() {
514        let a = Vector::new(vec![1.0, 1.0, 0.0, 0.0]);
515        let b = Vector::new(vec![1.0, 0.0, 1.0, 0.0]);
516
517        let distance = ExtendedDistanceMetric::Jaccard.distance(&a, &b).unwrap();
518        assert!(distance > 0.0 && distance < 1.0);
519    }
520
521    #[test]
522    fn test_pearson_distance() {
523        let a = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
524        let b = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
525
526        let distance = ExtendedDistanceMetric::Pearson.distance(&a, &b).unwrap();
527        assert!(distance < 0.01); // Perfect correlation
528    }
529
530    #[test]
531    fn test_manhattan_distance() {
532        let a = Vector::new(vec![1.0, 2.0, 3.0]);
533        let b = Vector::new(vec![4.0, 5.0, 6.0]);
534
535        let distance = ExtendedDistanceMetric::Manhattan.distance(&a, &b).unwrap();
536        assert_eq!(distance, 9.0); // |1-4| + |2-5| + |3-6| = 9
537    }
538}