Skip to main content

oxiphysics_core/
metric_spaces.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Metric space abstractions and concrete metric implementations.
6//!
7//! This module provides the [`MetricSpace`] trait and several concrete metrics:
8//! Euclidean (L2), Manhattan (L1), Chebyshev (L∞), Hamming, and edit distance.
9//! Also includes a [`MetricBallTree`] for efficient nearest-neighbor queries
10//! and [`FrechetDistance`] for comparing polygonal curves.
11
12#![allow(dead_code)]
13
14use std::cmp::Ordering;
15use std::collections::BinaryHeap;
16
17// ---------------------------------------------------------------------------
18// Core trait
19// ---------------------------------------------------------------------------
20
21/// A trait for types that form a metric space.
22///
23/// A metric space is a set `X` together with a function `d: X × X → ℝ≥0`
24/// satisfying:
25/// 1. `d(x, y) = 0` iff `x == y`  (identity of indiscernibles)
26/// 2. `d(x, y) = d(y, x)`          (symmetry)
27/// 3. `d(x, z) ≤ d(x, y) + d(y, z)` (triangle inequality)
28pub trait MetricSpace {
29    /// The type of points in the metric space.
30    type Point;
31
32    /// Compute the distance between two points.
33    fn distance(&self, a: &Self::Point, b: &Self::Point) -> f64;
34
35    /// Return all points from `candidates` within distance `radius` of `center`.
36    fn ball<'a>(
37        &self,
38        center: &Self::Point,
39        radius: f64,
40        candidates: &'a [Self::Point],
41    ) -> Vec<&'a Self::Point> {
42        candidates
43            .iter()
44            .filter(|p| self.distance(center, p) <= radius)
45            .collect()
46    }
47
48    /// Determine whether the diameter of a finite point set is bounded.
49    ///
50    /// Returns `true` when every pair of points has finite distance and the
51    /// diameter (supremum of pairwise distances) is finite.
52    fn is_bounded(&self, points: &[Self::Point]) -> bool {
53        if points.is_empty() {
54            return true;
55        }
56        let d = self.diameter(points);
57        d.is_finite()
58    }
59
60    /// Compute the diameter of a finite point set, i.e. `sup_{x,y} d(x,y)`.
61    ///
62    /// Returns `0.0` for an empty or single-point set.
63    fn diameter(&self, points: &[Self::Point]) -> f64 {
64        if points.len() < 2 {
65            return 0.0;
66        }
67        let mut max_d = 0.0_f64;
68        for i in 0..points.len() {
69            for j in (i + 1)..points.len() {
70                let d = self.distance(&points[i], &points[j]);
71                if d > max_d {
72                    max_d = d;
73                }
74            }
75        }
76        max_d
77    }
78}
79
80// ---------------------------------------------------------------------------
81// Euclidean (L2) metric
82// ---------------------------------------------------------------------------
83
84/// Euclidean (L2) metric on ℝⁿ.
85///
86/// `d(x, y) = sqrt(Σ (xᵢ − yᵢ)²)`
87///
88/// The Cauchy–Schwarz inequality states `|⟨x,y⟩| ≤ ‖x‖₂ · ‖y‖₂`, which
89/// underlies many properties of this metric.
90#[derive(Debug, Clone, Copy, Default)]
91pub struct EuclideanMetric;
92
93impl MetricSpace for EuclideanMetric {
94    type Point = Vec<f64>;
95
96    fn distance(&self, a: &Vec<f64>, b: &Vec<f64>) -> f64 {
97        assert_eq!(a.len(), b.len(), "EuclideanMetric: dimension mismatch");
98        a.iter()
99            .zip(b.iter())
100            .map(|(x, y)| (x - y).powi(2))
101            .sum::<f64>()
102            .sqrt()
103    }
104}
105
106impl EuclideanMetric {
107    /// Compute the L2 norm (Euclidean length) of a vector.
108    ///
109    /// `‖v‖₂ = sqrt(Σ vᵢ²)`
110    pub fn norm(v: &[f64]) -> f64 {
111        v.iter().map(|x| x.powi(2)).sum::<f64>().sqrt()
112    }
113
114    /// Dot product of two equal-length vectors.
115    pub fn dot(a: &[f64], b: &[f64]) -> f64 {
116        assert_eq!(a.len(), b.len(), "dot: dimension mismatch");
117        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
118    }
119
120    /// Verify the Cauchy–Schwarz inequality: `|⟨a,b⟩| ≤ ‖a‖₂ · ‖b‖₂`.
121    ///
122    /// Returns `true` when the inequality holds (within floating-point tolerance).
123    pub fn cauchy_schwarz_holds(a: &[f64], b: &[f64]) -> bool {
124        let dot = Self::dot(a, b).abs();
125        let product = Self::norm(a) * Self::norm(b);
126        dot <= product + 1e-9
127    }
128
129    /// Project vector `v` onto unit vector `u`.
130    ///
131    /// Returns the scalar projection `⟨v, u⟩`.
132    pub fn scalar_projection(v: &[f64], u: &[f64]) -> f64 {
133        Self::dot(v, u) / Self::norm(u)
134    }
135
136    /// Normalize a vector to unit length. Returns `None` for the zero vector.
137    pub fn normalize(v: &[f64]) -> Option<Vec<f64>> {
138        let n = Self::norm(v);
139        if n < 1e-15 {
140            return None;
141        }
142        Some(v.iter().map(|x| x / n).collect())
143    }
144}
145
146// ---------------------------------------------------------------------------
147// Manhattan (L1) metric
148// ---------------------------------------------------------------------------
149
150/// Manhattan (taxicab / L1) metric on ℝⁿ.
151///
152/// `d(x, y) = Σ |xᵢ − yᵢ|`
153#[derive(Debug, Clone, Copy, Default)]
154pub struct ManhattanMetric;
155
156impl MetricSpace for ManhattanMetric {
157    type Point = Vec<f64>;
158
159    fn distance(&self, a: &Vec<f64>, b: &Vec<f64>) -> f64 {
160        assert_eq!(a.len(), b.len(), "ManhattanMetric: dimension mismatch");
161        a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
162    }
163}
164
165impl ManhattanMetric {
166    /// Compute the L1 norm of a vector.
167    pub fn norm(v: &[f64]) -> f64 {
168        v.iter().map(|x| x.abs()).sum()
169    }
170
171    /// The L1 unit ball is the cross-polytope; check membership.
172    pub fn in_unit_ball(v: &[f64]) -> bool {
173        Self::norm(v) <= 1.0 + 1e-12
174    }
175}
176
177// ---------------------------------------------------------------------------
178// Chebyshev (L∞) metric
179// ---------------------------------------------------------------------------
180
181/// Chebyshev (chessboard / L∞) metric on ℝⁿ.
182///
183/// `d(x, y) = max_i |xᵢ − yᵢ|`
184#[derive(Debug, Clone, Copy, Default)]
185pub struct ChebyshevMetric;
186
187impl MetricSpace for ChebyshevMetric {
188    type Point = Vec<f64>;
189
190    fn distance(&self, a: &Vec<f64>, b: &Vec<f64>) -> f64 {
191        assert_eq!(a.len(), b.len(), "ChebyshevMetric: dimension mismatch");
192        a.iter()
193            .zip(b.iter())
194            .map(|(x, y)| (x - y).abs())
195            .fold(0.0_f64, f64::max)
196    }
197}
198
199impl ChebyshevMetric {
200    /// Compute the L∞ norm of a vector.
201    pub fn norm(v: &[f64]) -> f64 {
202        v.iter().map(|x| x.abs()).fold(0.0_f64, f64::max)
203    }
204}
205
206// ---------------------------------------------------------------------------
207// Hamming metric
208// ---------------------------------------------------------------------------
209
210/// Hamming metric on equal-length sequences.
211///
212/// `d(x, y) = |{ i : xᵢ ≠ yᵢ }|`
213///
214/// Works with any element type that implements `PartialEq`.
215#[derive(Debug, Clone, Copy, Default)]
216pub struct HammingMetric;
217
218/// A fixed-length sequence for the Hamming metric.
219#[derive(Debug, Clone, PartialEq)]
220pub struct HammingPoint(pub Vec<u8>);
221
222impl MetricSpace for HammingMetric {
223    type Point = HammingPoint;
224
225    fn distance(&self, a: &HammingPoint, b: &HammingPoint) -> f64 {
226        assert_eq!(
227            a.0.len(),
228            b.0.len(),
229            "HammingMetric: length mismatch ({} vs {})",
230            a.0.len(),
231            b.0.len()
232        );
233        a.0.iter().zip(b.0.iter()).filter(|(x, y)| x != y).count() as f64
234    }
235}
236
237impl HammingMetric {
238    /// Compute the Hamming distance between two byte slices of equal length.
239    pub fn hamming_distance(a: &[u8], b: &[u8]) -> usize {
240        assert_eq!(a.len(), b.len(), "hamming_distance: length mismatch");
241        a.iter().zip(b.iter()).filter(|(x, y)| x != y).count()
242    }
243
244    /// Compute the Hamming weight (number of 1 bits) of a byte slice.
245    pub fn hamming_weight(v: &[u8]) -> usize {
246        v.iter().map(|b| b.count_ones() as usize).sum()
247    }
248}
249
250// ---------------------------------------------------------------------------
251// Edit (Levenshtein) distance
252// ---------------------------------------------------------------------------
253
254/// Edit (Levenshtein) distance metric on sequences.
255///
256/// The edit distance counts the minimum number of single-element insertions,
257/// deletions, and substitutions required to transform one sequence into another.
258#[derive(Debug, Clone, Copy, Default)]
259pub struct EditDistance;
260
261/// A sequence point for edit-distance computation.
262#[derive(Debug, Clone, PartialEq)]
263pub struct Sequence(pub Vec<char>);
264
265impl MetricSpace for EditDistance {
266    type Point = Sequence;
267
268    fn distance(&self, a: &Sequence, b: &Sequence) -> f64 {
269        levenshtein(&a.0, &b.0) as f64
270    }
271}
272
273/// Compute the Levenshtein distance between two character slices.
274///
275/// Uses the classic dynamic-programming algorithm with O(min(m,n)) space.
276pub fn levenshtein(a: &[char], b: &[char]) -> usize {
277    let m = a.len();
278    let n = b.len();
279    if m == 0 {
280        return n;
281    }
282    if n == 0 {
283        return m;
284    }
285
286    let mut prev: Vec<usize> = (0..=n).collect();
287    let mut curr = vec![0usize; n + 1];
288
289    for i in 1..=m {
290        curr[0] = i;
291        for j in 1..=n {
292            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
293            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
294        }
295        std::mem::swap(&mut prev, &mut curr);
296    }
297    prev[n]
298}
299
300/// Compute the Levenshtein distance between two strings.
301pub fn levenshtein_str(a: &str, b: &str) -> usize {
302    let ac: Vec<char> = a.chars().collect();
303    let bc: Vec<char> = b.chars().collect();
304    levenshtein(&ac, &bc)
305}
306
307impl EditDistance {
308    /// Compute the longest common subsequence length of two sequences.
309    pub fn lcs_length(a: &[char], b: &[char]) -> usize {
310        let m = a.len();
311        let n = b.len();
312        let mut dp = vec![vec![0usize; n + 1]; m + 1];
313        for i in 1..=m {
314            for j in 1..=n {
315                dp[i][j] = if a[i - 1] == b[j - 1] {
316                    dp[i - 1][j - 1] + 1
317                } else {
318                    dp[i - 1][j].max(dp[i][j - 1])
319                };
320            }
321        }
322        dp[m][n]
323    }
324
325    /// Compute the edit distance with custom costs for insert, delete, and substitute.
326    #[allow(clippy::too_many_arguments)]
327    pub fn weighted_edit(
328        a: &[char],
329        b: &[char],
330        insert_cost: f64,
331        delete_cost: f64,
332        subst_cost: f64,
333    ) -> f64 {
334        let m = a.len();
335        let n = b.len();
336        if m == 0 {
337            return n as f64 * insert_cost;
338        }
339        if n == 0 {
340            return m as f64 * delete_cost;
341        }
342        let mut prev: Vec<f64> = (0..=n).map(|j| j as f64 * insert_cost).collect();
343        let mut curr = vec![0.0_f64; n + 1];
344        for i in 1..=m {
345            curr[0] = i as f64 * delete_cost;
346            for j in 1..=n {
347                let cost = if a[i - 1] == b[j - 1] {
348                    0.0
349                } else {
350                    subst_cost
351                };
352                curr[j] = (prev[j] + delete_cost)
353                    .min(curr[j - 1] + insert_cost)
354                    .min(prev[j - 1] + cost);
355            }
356            std::mem::swap(&mut prev, &mut curr);
357        }
358        prev[n]
359    }
360}
361
362// ---------------------------------------------------------------------------
363// Minkowski metric (generalization)
364// ---------------------------------------------------------------------------
365
366/// Minkowski metric of order `p` on ℝⁿ.
367///
368/// `d_p(x, y) = (Σ |xᵢ − yᵢ|^p)^(1/p)`
369///
370/// Special cases: `p = 1` → Manhattan, `p = 2` → Euclidean.
371/// For `p → ∞` this converges to the Chebyshev metric.
372#[derive(Debug, Clone, Copy)]
373pub struct MinkowskiMetric {
374    /// The order parameter `p ≥ 1`.
375    pub p: f64,
376}
377
378impl MinkowskiMetric {
379    /// Create a Minkowski metric of order `p`.
380    ///
381    /// # Panics
382    /// Panics if `p < 1.0`.
383    pub fn new(p: f64) -> Self {
384        assert!(p >= 1.0, "MinkowskiMetric: p must be >= 1, got {p}");
385        Self { p }
386    }
387}
388
389impl MetricSpace for MinkowskiMetric {
390    type Point = Vec<f64>;
391
392    fn distance(&self, a: &Vec<f64>, b: &Vec<f64>) -> f64 {
393        assert_eq!(a.len(), b.len(), "MinkowskiMetric: dimension mismatch");
394        a.iter()
395            .zip(b.iter())
396            .map(|(x, y)| (x - y).abs().powf(self.p))
397            .sum::<f64>()
398            .powf(1.0 / self.p)
399    }
400}
401
402// ---------------------------------------------------------------------------
403// Cosine distance (not a true metric but common)
404// ---------------------------------------------------------------------------
405
406/// Cosine similarity-based pseudo-distance on ℝⁿ.
407///
408/// `d_cos(x, y) = 1 − cos(θ) = 1 − (⟨x,y⟩ / (‖x‖₂ · ‖y‖₂))`
409///
410/// Note: this is not a true metric (triangle inequality may fail), but it is
411/// widely used in information retrieval and machine-learning contexts.
412#[derive(Debug, Clone, Copy, Default)]
413pub struct CosineDistance;
414
415impl CosineDistance {
416    /// Compute cosine similarity between two vectors.
417    pub fn similarity(a: &[f64], b: &[f64]) -> f64 {
418        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
419        let na = EuclideanMetric::norm(a);
420        let nb = EuclideanMetric::norm(b);
421        if na < 1e-15 || nb < 1e-15 {
422            return 0.0;
423        }
424        (dot / (na * nb)).clamp(-1.0, 1.0)
425    }
426
427    /// Compute cosine distance: `1 − similarity(a, b)`.
428    pub fn distance(a: &[f64], b: &[f64]) -> f64 {
429        1.0 - Self::similarity(a, b)
430    }
431}
432
433// ---------------------------------------------------------------------------
434// Ball Tree for nearest-neighbor queries
435// ---------------------------------------------------------------------------
436
437/// A node in a [`MetricBallTree`].
438#[derive(Debug, Clone)]
439struct BallTreeNode {
440    /// Index of the pivot point in the original dataset.
441    pivot_idx: usize,
442    /// Radius of the ball enclosing all points in this subtree.
443    radius: f64,
444    /// Left child index (or `usize::MAX` for leaf).
445    left: usize,
446    /// Right child index (or `usize::MAX` for leaf).
447    right: usize,
448    /// Indices of points stored at this leaf (non-empty only for leaves).
449    leaf_points: Vec<usize>,
450}
451
452impl BallTreeNode {
453    fn is_leaf(&self) -> bool {
454        self.left == usize::MAX && self.right == usize::MAX
455    }
456}
457
458/// Ball tree (metric tree) for efficient nearest-neighbor and range queries.
459///
460/// Supports any [`MetricSpace`] whose `Point` type is `Clone`. The tree is
461/// built once from a dataset and then supports O(log n) average-case queries.
462pub struct MetricBallTree {
463    points: Vec<Vec<f64>>,
464    nodes: Vec<BallTreeNode>,
465    root: usize,
466    leaf_size: usize,
467}
468
469/// A nearest-neighbor search result.
470#[derive(Debug, Clone)]
471pub struct NearestNeighbor {
472    /// Index of the neighbor in the original dataset.
473    pub index: usize,
474    /// Distance from the query point to this neighbor.
475    pub distance: f64,
476}
477
478impl PartialEq for NearestNeighbor {
479    fn eq(&self, other: &Self) -> bool {
480        self.distance == other.distance
481    }
482}
483
484impl Eq for NearestNeighbor {}
485
486impl PartialOrd for NearestNeighbor {
487    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
488        Some(self.cmp(other))
489    }
490}
491
492impl Ord for NearestNeighbor {
493    fn cmp(&self, other: &Self) -> Ordering {
494        // Max-heap by distance (so smallest distances bubble up as we invert)
495        other
496            .distance
497            .partial_cmp(&self.distance)
498            .unwrap_or(Ordering::Equal)
499    }
500}
501
502impl MetricBallTree {
503    /// Build a ball tree from a set of points using the Euclidean metric.
504    ///
505    /// `leaf_size` controls the maximum number of points stored at a leaf node.
506    pub fn build(points: Vec<Vec<f64>>, leaf_size: usize) -> Self {
507        let n = points.len();
508        assert!(!points.is_empty(), "MetricBallTree: empty dataset");
509        let leaf_size = leaf_size.max(1);
510
511        let mut nodes: Vec<BallTreeNode> = Vec::with_capacity(2 * n / leaf_size + 4);
512        let indices: Vec<usize> = (0..n).collect();
513        let root = build_node(&points, &indices, &mut nodes, leaf_size);
514
515        Self {
516            points,
517            nodes,
518            root,
519            leaf_size,
520        }
521    }
522
523    /// Find the `k` nearest neighbors of `query` using the Euclidean metric.
524    pub fn knn(&self, query: &[f64], k: usize) -> Vec<NearestNeighbor> {
525        let mut heap: BinaryHeap<NearestNeighbor> = BinaryHeap::new();
526        let metric = EuclideanMetric;
527        knn_search(
528            &self.points,
529            &self.nodes,
530            self.root,
531            query,
532            k,
533            &mut heap,
534            &metric,
535        );
536        let mut result: Vec<NearestNeighbor> = heap.into_sorted_vec();
537        result.sort_by(|a, b| {
538            a.distance
539                .partial_cmp(&b.distance)
540                .unwrap_or(Ordering::Equal)
541        });
542        result
543    }
544
545    /// Find all points within `radius` of `query`.
546    pub fn range_query(&self, query: &[f64], radius: f64) -> Vec<NearestNeighbor> {
547        let metric = EuclideanMetric;
548        let mut result = Vec::new();
549        range_search(
550            &self.points,
551            &self.nodes,
552            self.root,
553            query,
554            radius,
555            &mut result,
556            &metric,
557        );
558        result.sort_by(|a, b| {
559            a.distance
560                .partial_cmp(&b.distance)
561                .unwrap_or(Ordering::Equal)
562        });
563        result
564    }
565
566    /// Return the number of points in the tree.
567    pub fn len(&self) -> usize {
568        self.points.len()
569    }
570
571    /// Return `true` if the tree has no points.
572    pub fn is_empty(&self) -> bool {
573        self.points.is_empty()
574    }
575
576    /// Return the configured leaf size.
577    pub fn leaf_size(&self) -> usize {
578        self.leaf_size
579    }
580}
581
582// --- Ball tree construction helpers ---
583
584fn euclidean_dist(a: &[f64], b: &[f64]) -> f64 {
585    a.iter()
586        .zip(b.iter())
587        .map(|(x, y)| (x - y).powi(2))
588        .sum::<f64>()
589        .sqrt()
590}
591
592fn centroid(points: &[Vec<f64>], indices: &[usize]) -> Vec<f64> {
593    let dim = points[indices[0]].len();
594    let mut c = vec![0.0; dim];
595    for &idx in indices {
596        for (d, &v) in c.iter_mut().zip(points[idx].iter()) {
597            *d += v;
598        }
599    }
600    let n = indices.len() as f64;
601    c.iter_mut().for_each(|v| *v /= n);
602    c
603}
604
605fn build_node(
606    points: &[Vec<f64>],
607    indices: &[usize],
608    nodes: &mut Vec<BallTreeNode>,
609    leaf_size: usize,
610) -> usize {
611    let c = centroid(points, indices);
612    // Choose pivot as point farthest from centroid
613    let pivot_idx = indices
614        .iter()
615        .copied()
616        .max_by(|&a, &b| {
617            euclidean_dist(&points[a], &c)
618                .partial_cmp(&euclidean_dist(&points[b], &c))
619                .unwrap_or(Ordering::Equal)
620        })
621        .expect("indices is non-empty");
622
623    let radius = indices
624        .iter()
625        .map(|&i| euclidean_dist(&points[i], &points[pivot_idx]))
626        .fold(0.0_f64, f64::max);
627
628    let node_idx = nodes.len();
629
630    if indices.len() <= leaf_size {
631        nodes.push(BallTreeNode {
632            pivot_idx,
633            radius,
634            left: usize::MAX,
635            right: usize::MAX,
636            leaf_points: indices.to_vec(),
637        });
638        return node_idx;
639    }
640
641    // Split by distance to pivot
642    let mut left_idx: Vec<usize> = Vec::new();
643    let mut right_idx: Vec<usize> = Vec::new();
644    for &i in indices {
645        if i == pivot_idx {
646            left_idx.push(i);
647            continue;
648        }
649        if left_idx.len() <= right_idx.len() {
650            left_idx.push(i);
651        } else {
652            right_idx.push(i);
653        }
654    }
655    if left_idx.is_empty() {
656        left_idx.push(pivot_idx);
657    }
658    if right_idx.is_empty() {
659        right_idx = left_idx.split_off(left_idx.len() / 2 + 1);
660    }
661
662    // Push a placeholder so we can fix up children indices later
663    nodes.push(BallTreeNode {
664        pivot_idx,
665        radius,
666        left: usize::MAX,
667        right: usize::MAX,
668        leaf_points: vec![],
669    });
670
671    let left_child = build_node(points, &left_idx, nodes, leaf_size);
672    let right_child = build_node(points, &right_idx, nodes, leaf_size);
673    nodes[node_idx].left = left_child;
674    nodes[node_idx].right = right_child;
675
676    node_idx
677}
678
679fn knn_search(
680    points: &[Vec<f64>],
681    nodes: &[BallTreeNode],
682    node_idx: usize,
683    query: &[f64],
684    k: usize,
685    heap: &mut BinaryHeap<NearestNeighbor>,
686    _metric: &EuclideanMetric,
687) {
688    let node = &nodes[node_idx];
689    let pivot_dist = euclidean_dist(query, &points[node.pivot_idx]);
690
691    // Pruning: if the closest possible point in this ball is farther than
692    // the current k-th best, skip.
693    if heap.len() >= k {
694        let worst = heap.peek().map(|n| n.distance).unwrap_or(f64::MAX);
695        if pivot_dist - node.radius > worst {
696            return;
697        }
698    }
699
700    if node.is_leaf() {
701        for &idx in &node.leaf_points {
702            let d = euclidean_dist(query, &points[idx]);
703            if heap.len() < k {
704                heap.push(NearestNeighbor {
705                    index: idx,
706                    distance: d,
707                });
708            } else if let Some(worst) = heap.peek()
709                && d < worst.distance
710            {
711                heap.pop();
712                heap.push(NearestNeighbor {
713                    index: idx,
714                    distance: d,
715                });
716            }
717        }
718        return;
719    }
720
721    // Recurse into closer child first
722    let left_dist = if node.left != usize::MAX {
723        euclidean_dist(query, &points[nodes[node.left].pivot_idx])
724    } else {
725        f64::MAX
726    };
727    let right_dist = if node.right != usize::MAX {
728        euclidean_dist(query, &points[nodes[node.right].pivot_idx])
729    } else {
730        f64::MAX
731    };
732
733    if left_dist <= right_dist {
734        if node.left != usize::MAX {
735            knn_search(points, nodes, node.left, query, k, heap, _metric);
736        }
737        if node.right != usize::MAX {
738            knn_search(points, nodes, node.right, query, k, heap, _metric);
739        }
740    } else {
741        if node.right != usize::MAX {
742            knn_search(points, nodes, node.right, query, k, heap, _metric);
743        }
744        if node.left != usize::MAX {
745            knn_search(points, nodes, node.left, query, k, heap, _metric);
746        }
747    }
748}
749
750fn range_search(
751    points: &[Vec<f64>],
752    nodes: &[BallTreeNode],
753    node_idx: usize,
754    query: &[f64],
755    radius: f64,
756    result: &mut Vec<NearestNeighbor>,
757    _metric: &EuclideanMetric,
758) {
759    let node = &nodes[node_idx];
760    let pivot_dist = euclidean_dist(query, &points[node.pivot_idx]);
761
762    // Pruning: if the closest possible point in the ball is outside `radius`, skip.
763    if pivot_dist - node.radius > radius {
764        return;
765    }
766
767    if node.is_leaf() {
768        for &idx in &node.leaf_points {
769            let d = euclidean_dist(query, &points[idx]);
770            if d <= radius {
771                result.push(NearestNeighbor {
772                    index: idx,
773                    distance: d,
774                });
775            }
776        }
777        return;
778    }
779
780    if node.left != usize::MAX {
781        range_search(points, nodes, node.left, query, radius, result, _metric);
782    }
783    if node.right != usize::MAX {
784        range_search(points, nodes, node.right, query, radius, result, _metric);
785    }
786}
787
788// ---------------------------------------------------------------------------
789// Discrete Fréchet distance
790// ---------------------------------------------------------------------------
791
792/// Discrete Fréchet distance between two polygonal curves.
793///
794/// The Fréchet distance is informally described as the minimum leash length
795/// required for a person walking a dog, where person and dog each traverse
796/// their respective curve from start to finish without backtracking.
797///
798/// This implementation computes the *discrete* variant using dynamic programming
799/// in O(mn) time and space.
800pub struct FrechetDistance;
801
802impl FrechetDistance {
803    /// Compute the discrete Fréchet distance between curves `p` and `q`.
804    ///
805    /// Each curve is given as a slice of points in ℝⁿ (represented as `Vec`f64`).
806    /// Returns the minimum coupling distance.
807    pub fn compute(p: &[Vec<f64>], q: &[Vec<f64>]) -> f64 {
808        let m = p.len();
809        let n = q.len();
810        if m == 0 || n == 0 {
811            return 0.0;
812        }
813
814        let mut dp = vec![vec![f64::MAX; n]; m];
815
816        for i in 0..m {
817            for j in 0..n {
818                let d = euclidean_dist(&p[i], &q[j]);
819                dp[i][j] = if i == 0 && j == 0 {
820                    d
821                } else if i == 0 {
822                    dp[i][j - 1].max(d)
823                } else if j == 0 {
824                    dp[i - 1][j].max(d)
825                } else {
826                    dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1]).max(d)
827                };
828            }
829        }
830
831        dp[m - 1][n - 1]
832    }
833
834    /// Compute the discrete Fréchet distance using a custom metric.
835    pub fn compute_with_metric<M: MetricSpace<Point = Vec<f64>>>(
836        p: &[Vec<f64>],
837        q: &[Vec<f64>],
838        metric: &M,
839    ) -> f64 {
840        let m = p.len();
841        let n = q.len();
842        if m == 0 || n == 0 {
843            return 0.0;
844        }
845
846        let mut dp = vec![vec![f64::MAX; n]; m];
847
848        for i in 0..m {
849            for j in 0..n {
850                let d = metric.distance(&p[i], &q[j]);
851                dp[i][j] = if i == 0 && j == 0 {
852                    d
853                } else if i == 0 {
854                    dp[i][j - 1].max(d)
855                } else if j == 0 {
856                    dp[i - 1][j].max(d)
857                } else {
858                    dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1]).max(d)
859                };
860            }
861        }
862
863        dp[m - 1][n - 1]
864    }
865
866    /// Compute the Hausdorff distance between two finite point sets.
867    ///
868    /// `d_H(A, B) = max(sup_{a∈A} inf_{b∈B} d(a,b), sup_{b∈B} inf_{a∈A} d(a,b))`
869    pub fn hausdorff(p: &[Vec<f64>], q: &[Vec<f64>]) -> f64 {
870        let one_sided = |a: &[Vec<f64>], b: &[Vec<f64>]| {
871            a.iter()
872                .map(|ai| {
873                    b.iter()
874                        .map(|bj| euclidean_dist(ai, bj))
875                        .fold(f64::MAX, f64::min)
876                })
877                .fold(0.0_f64, f64::max)
878        };
879        one_sided(p, q).max(one_sided(q, p))
880    }
881}
882
883// ---------------------------------------------------------------------------
884// Geodesic distance on a graph
885// ---------------------------------------------------------------------------
886
887/// Graph-based geodesic metric using Dijkstra's shortest path.
888///
889/// Points are vertices (indexed by `usize`), and edges carry non-negative weights.
890pub struct GeodesicMetric {
891    /// Adjacency list: `adj\[i\]` = list of `(neighbor, weight)` pairs.
892    pub adj: Vec<Vec<(usize, f64)>>,
893}
894
895impl GeodesicMetric {
896    /// Create a new geodesic metric for `n` vertices.
897    pub fn new(n: usize) -> Self {
898        Self {
899            adj: vec![Vec::new(); n],
900        }
901    }
902
903    /// Add an undirected edge with the given weight.
904    pub fn add_edge(&mut self, u: usize, v: usize, weight: f64) {
905        self.adj[u].push((v, weight));
906        self.adj[v].push((u, weight));
907    }
908
909    /// Compute shortest distances from source `s` to all vertices (Dijkstra).
910    pub fn dijkstra(&self, s: usize) -> Vec<f64> {
911        let n = self.adj.len();
912        let mut dist = vec![f64::INFINITY; n];
913        dist[s] = 0.0;
914
915        // (distance, vertex) — use a min-heap via Reverse
916        let mut heap: BinaryHeap<(ordered_float::OrderedF64, usize)> = BinaryHeap::new();
917        heap.push((ordered_float::OrderedF64(0.0), s));
918
919        while let Some((ordered_float::OrderedF64(d), u)) = heap.pop() {
920            if d > dist[u] {
921                continue;
922            }
923            for &(v, w) in &self.adj[u] {
924                let nd = dist[u] + w;
925                if nd < dist[v] {
926                    dist[v] = nd;
927                    heap.push((ordered_float::OrderedF64(nd), v));
928                }
929            }
930        }
931        dist
932    }
933}
934
935/// Wrapper for `f64` that implements `Ord` for use in `BinaryHeap` with negated distances.
936mod ordered_float {
937    #[derive(Debug, Clone, Copy, PartialEq)]
938    pub struct OrderedF64(pub f64);
939
940    impl Eq for OrderedF64 {}
941
942    impl PartialOrd for OrderedF64 {
943        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
944            Some(self.cmp(other))
945        }
946    }
947
948    impl Ord for OrderedF64 {
949        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
950            // Min-heap: smaller f64 → larger in reversed comparison
951            other
952                .0
953                .partial_cmp(&self.0)
954                .unwrap_or(std::cmp::Ordering::Equal)
955        }
956    }
957}
958
959impl MetricSpace for GeodesicMetric {
960    type Point = usize; // vertices
961
962    fn distance(&self, a: &usize, b: &usize) -> f64 {
963        let dists = self.dijkstra(*a);
964        dists[*b]
965    }
966}
967
968// ---------------------------------------------------------------------------
969// Metric space utilities
970// ---------------------------------------------------------------------------
971
972/// Check whether a set of pairwise distances satisfies the triangle inequality.
973///
974/// `matrix\[i\]\[j\]` should equal `d(i, j)`. Returns `Ok(())` on success or
975/// `Err((i, j, k))` with the first violated triple.
976pub fn check_triangle_inequality(matrix: &[Vec<f64>]) -> Result<(), (usize, usize, usize)> {
977    let n = matrix.len();
978    for i in 0..n {
979        for j in 0..n {
980            for k in 0..n {
981                if matrix[i][k] > matrix[i][j] + matrix[j][k] + 1e-9 {
982                    return Err((i, j, k));
983                }
984            }
985        }
986    }
987    Ok(())
988}
989
990/// Compute the diameter of a point set under a given metric.
991///
992/// This is a convenience wrapper around [`MetricSpace::diameter`].
993pub fn set_diameter<M: MetricSpace>(metric: &M, points: &[M::Point]) -> f64 {
994    metric.diameter(points)
995}
996
997/// Compute the pairwise distance matrix for a set of points.
998pub fn distance_matrix<M: MetricSpace>(metric: &M, points: &[M::Point]) -> Vec<Vec<f64>> {
999    let n = points.len();
1000    let mut mat = vec![vec![0.0; n]; n];
1001    for i in 0..n {
1002        for j in 0..n {
1003            mat[i][j] = metric.distance(&points[i], &points[j]);
1004        }
1005    }
1006    mat
1007}
1008
1009// ---------------------------------------------------------------------------
1010// Tests
1011// ---------------------------------------------------------------------------
1012
1013#[cfg(test)]
1014mod tests {
1015    use super::*;
1016
1017    // --- EuclideanMetric ---
1018
1019    #[test]
1020    fn test_euclidean_zero_distance() {
1021        let m = EuclideanMetric;
1022        let p = vec![1.0, 2.0, 3.0];
1023        assert!((m.distance(&p, &p) - 0.0).abs() < 1e-12);
1024    }
1025
1026    #[test]
1027    fn test_euclidean_known_distance() {
1028        let m = EuclideanMetric;
1029        let a = vec![0.0, 0.0];
1030        let b = vec![3.0, 4.0];
1031        assert!((m.distance(&a, &b) - 5.0).abs() < 1e-12);
1032    }
1033
1034    #[test]
1035    fn test_euclidean_symmetry() {
1036        let m = EuclideanMetric;
1037        let a = vec![1.0, 2.0, 3.0];
1038        let b = vec![4.0, -1.0, 0.0];
1039        let diff = (m.distance(&a, &b) - m.distance(&b, &a)).abs();
1040        assert!(diff < 1e-12);
1041    }
1042
1043    #[test]
1044    fn test_euclidean_triangle_inequality() {
1045        let m = EuclideanMetric;
1046        let a = vec![0.0, 0.0];
1047        let b = vec![1.0, 0.0];
1048        let c = vec![0.5, 1.0];
1049        let dab = m.distance(&a, &b);
1050        let dbc = m.distance(&b, &c);
1051        let dac = m.distance(&a, &c);
1052        assert!(dac <= dab + dbc + 1e-10);
1053    }
1054
1055    #[test]
1056    fn test_euclidean_norm() {
1057        let v = vec![3.0, 4.0];
1058        assert!((EuclideanMetric::norm(&v) - 5.0).abs() < 1e-12);
1059    }
1060
1061    #[test]
1062    fn test_cauchy_schwarz() {
1063        let a = vec![1.0, 2.0, 3.0];
1064        let b = vec![4.0, -1.0, 2.0];
1065        assert!(EuclideanMetric::cauchy_schwarz_holds(&a, &b));
1066    }
1067
1068    #[test]
1069    fn test_euclidean_normalize() {
1070        let v = vec![3.0, 4.0];
1071        let u = EuclideanMetric::normalize(&v).unwrap();
1072        assert!((EuclideanMetric::norm(&u) - 1.0).abs() < 1e-12);
1073    }
1074
1075    #[test]
1076    fn test_euclidean_diameter() {
1077        let m = EuclideanMetric;
1078        let pts = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
1079        let d = m.diameter(&pts);
1080        assert!((d - 2.0_f64.sqrt()).abs() < 1e-10);
1081    }
1082
1083    #[test]
1084    fn test_euclidean_ball() {
1085        let m = EuclideanMetric;
1086        let center = vec![0.0, 0.0];
1087        let pts = vec![vec![0.5, 0.0], vec![2.0, 0.0], vec![0.0, 0.8]];
1088        let ball = m.ball(&center, 1.0, &pts);
1089        assert_eq!(ball.len(), 2);
1090    }
1091
1092    #[test]
1093    fn test_euclidean_is_bounded() {
1094        let m = EuclideanMetric;
1095        let pts = vec![vec![0.0, 0.0], vec![100.0, 100.0]];
1096        assert!(m.is_bounded(&pts));
1097    }
1098
1099    // --- ManhattanMetric ---
1100
1101    #[test]
1102    fn test_manhattan_distance() {
1103        let m = ManhattanMetric;
1104        let a = vec![0.0, 0.0];
1105        let b = vec![3.0, 4.0];
1106        assert!((m.distance(&a, &b) - 7.0).abs() < 1e-12);
1107    }
1108
1109    #[test]
1110    fn test_manhattan_symmetry() {
1111        let m = ManhattanMetric;
1112        let a = vec![1.0, -2.0, 3.0];
1113        let b = vec![-1.0, 4.0, 0.0];
1114        assert!((m.distance(&a, &b) - m.distance(&b, &a)).abs() < 1e-12);
1115    }
1116
1117    #[test]
1118    fn test_manhattan_norm() {
1119        let v = vec![-3.0, 4.0];
1120        assert!((ManhattanMetric::norm(&v) - 7.0).abs() < 1e-12);
1121    }
1122
1123    #[test]
1124    fn test_manhattan_diameter() {
1125        let m = ManhattanMetric;
1126        let pts = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![-1.0, -1.0]];
1127        let d = m.diameter(&pts);
1128        assert!((d - 4.0).abs() < 1e-12);
1129    }
1130
1131    #[test]
1132    fn test_manhattan_ball() {
1133        let m = ManhattanMetric;
1134        let center = vec![0.0, 0.0];
1135        let pts = vec![vec![1.0, 0.0], vec![1.0, 1.0], vec![0.5, 0.5]];
1136        let b = m.ball(&center, 1.0, &pts);
1137        // d([1,0]) = 1 ≤ 1, d([1,1]) = 2 > 1, d([0.5,0.5]) = 1 ≤ 1
1138        assert_eq!(b.len(), 2);
1139    }
1140
1141    // --- ChebyshevMetric ---
1142
1143    #[test]
1144    fn test_chebyshev_distance() {
1145        let m = ChebyshevMetric;
1146        let a = vec![1.0, 3.0, 5.0];
1147        let b = vec![4.0, 2.0, 1.0];
1148        // diffs: 3, 1, 4 → max = 4
1149        assert!((m.distance(&a, &b) - 4.0).abs() < 1e-12);
1150    }
1151
1152    #[test]
1153    fn test_chebyshev_zero() {
1154        let m = ChebyshevMetric;
1155        let p = vec![1.0, 2.0];
1156        assert!((m.distance(&p, &p)).abs() < 1e-12);
1157    }
1158
1159    #[test]
1160    fn test_chebyshev_norm() {
1161        let v = vec![-5.0, 3.0, 2.0];
1162        assert!((ChebyshevMetric::norm(&v) - 5.0).abs() < 1e-12);
1163    }
1164
1165    #[test]
1166    fn test_chebyshev_diameter() {
1167        let m = ChebyshevMetric;
1168        let pts = vec![vec![0.0, 0.0], vec![2.0, 1.0], vec![1.0, 3.0]];
1169        let d = m.diameter(&pts);
1170        // pairs: d([0,0],[2,1])=2, d([0,0],[1,3])=3, d([2,1],[1,3])=2 → max=3
1171        assert!((d - 3.0).abs() < 1e-12);
1172    }
1173
1174    // --- HammingMetric ---
1175
1176    #[test]
1177    fn test_hamming_equal() {
1178        let m = HammingMetric;
1179        let a = HammingPoint(vec![1, 0, 1, 1]);
1180        assert!((m.distance(&a, &a)).abs() < 1e-12);
1181    }
1182
1183    #[test]
1184    fn test_hamming_known() {
1185        let m = HammingMetric;
1186        let a = HammingPoint(vec![0, 0, 0, 1]);
1187        let b = HammingPoint(vec![1, 1, 0, 0]);
1188        assert!((m.distance(&a, &b) - 3.0).abs() < 1e-12);
1189    }
1190
1191    #[test]
1192    fn test_hamming_symmetry() {
1193        let m = HammingMetric;
1194        let a = HammingPoint(vec![1, 0, 1]);
1195        let b = HammingPoint(vec![0, 1, 1]);
1196        assert!((m.distance(&a, &b) - m.distance(&b, &a)).abs() < 1e-12);
1197    }
1198
1199    #[test]
1200    fn test_hamming_weight() {
1201        let v = vec![0b1010_1010u8, 0b1111_0000u8];
1202        // 4 ones + 4 ones = 8
1203        assert_eq!(HammingMetric::hamming_weight(&v), 8);
1204    }
1205
1206    // --- EditDistance ---
1207
1208    #[test]
1209    fn test_edit_identical() {
1210        let m = EditDistance;
1211        let s: Vec<char> = "hello".chars().collect();
1212        let a = Sequence(s.clone());
1213        let b = Sequence(s);
1214        assert!((m.distance(&a, &b)).abs() < 1e-12);
1215    }
1216
1217    #[test]
1218    fn test_edit_known() {
1219        assert_eq!(levenshtein_str("kitten", "sitting"), 3);
1220    }
1221
1222    #[test]
1223    fn test_edit_empty() {
1224        assert_eq!(levenshtein_str("", "abc"), 3);
1225        assert_eq!(levenshtein_str("abc", ""), 3);
1226    }
1227
1228    #[test]
1229    fn test_edit_symmetry() {
1230        let a: Vec<char> = "sunday".chars().collect();
1231        let b: Vec<char> = "saturday".chars().collect();
1232        assert_eq!(levenshtein(&a, &b), levenshtein(&b, &a));
1233    }
1234
1235    #[test]
1236    fn test_lcs_length() {
1237        let a: Vec<char> = "ABCBDAB".chars().collect();
1238        let b: Vec<char> = "BDCAB".chars().collect();
1239        assert_eq!(EditDistance::lcs_length(&a, &b), 4);
1240    }
1241
1242    #[test]
1243    fn test_weighted_edit() {
1244        let a: Vec<char> = "abc".chars().collect();
1245        let b: Vec<char> = "axc".chars().collect();
1246        // One substitution, cost 2.0
1247        let d = EditDistance::weighted_edit(&a, &b, 1.0, 1.0, 2.0);
1248        assert!((d - 2.0).abs() < 1e-12);
1249    }
1250
1251    // --- MinkowskiMetric ---
1252
1253    #[test]
1254    fn test_minkowski_p1_equals_manhattan() {
1255        let m1 = MinkowskiMetric::new(1.0);
1256        let m2 = ManhattanMetric;
1257        let a = vec![1.0, 2.0, 3.0];
1258        let b = vec![4.0, 0.0, -1.0];
1259        assert!((m1.distance(&a, &b) - m2.distance(&a, &b)).abs() < 1e-10);
1260    }
1261
1262    #[test]
1263    fn test_minkowski_p2_equals_euclidean() {
1264        let m1 = MinkowskiMetric::new(2.0);
1265        let m2 = EuclideanMetric;
1266        let a = vec![0.0, 0.0];
1267        let b = vec![3.0, 4.0];
1268        assert!((m1.distance(&a, &b) - m2.distance(&a, &b)).abs() < 1e-10);
1269    }
1270
1271    // --- CosineDistance ---
1272
1273    #[test]
1274    fn test_cosine_identical() {
1275        let v = vec![1.0, 2.0, 3.0];
1276        assert!((CosineDistance::similarity(&v, &v) - 1.0).abs() < 1e-12);
1277    }
1278
1279    #[test]
1280    fn test_cosine_orthogonal() {
1281        let a = vec![1.0, 0.0];
1282        let b = vec![0.0, 1.0];
1283        assert!((CosineDistance::similarity(&a, &b)).abs() < 1e-12);
1284    }
1285
1286    #[test]
1287    fn test_cosine_distance_range() {
1288        let a = vec![1.0, 0.0];
1289        let b = vec![-1.0, 0.0];
1290        let d = CosineDistance::distance(&a, &b);
1291        assert!((0.0..=2.0 + 1e-12).contains(&d));
1292    }
1293
1294    // --- MetricBallTree ---
1295
1296    #[test]
1297    fn test_ball_tree_knn_1() {
1298        let points: Vec<Vec<f64>> = vec![
1299            vec![0.0, 0.0],
1300            vec![1.0, 0.0],
1301            vec![0.0, 1.0],
1302            vec![5.0, 5.0],
1303        ];
1304        let tree = MetricBallTree::build(points, 2);
1305        let result = tree.knn(&[0.1, 0.1], 1);
1306        assert_eq!(result.len(), 1);
1307        assert_eq!(result[0].index, 0);
1308    }
1309
1310    #[test]
1311    fn test_ball_tree_knn_k() {
1312        let points: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64, 0.0]).collect();
1313        let tree = MetricBallTree::build(points, 4);
1314        let result = tree.knn(&[9.5, 0.0], 3);
1315        assert_eq!(result.len(), 3);
1316        // Nearest should be 9 or 10
1317        let indices: Vec<usize> = result.iter().map(|n| n.index).collect();
1318        assert!(indices.contains(&9) || indices.contains(&10));
1319    }
1320
1321    #[test]
1322    fn test_ball_tree_range_query() {
1323        let points: Vec<Vec<f64>> = vec![
1324            vec![0.0, 0.0],
1325            vec![0.5, 0.0],
1326            vec![1.5, 0.0],
1327            vec![10.0, 10.0],
1328        ];
1329        let tree = MetricBallTree::build(points, 2);
1330        let result = tree.range_query(&[0.0, 0.0], 1.0);
1331        let indices: Vec<usize> = result.iter().map(|n| n.index).collect();
1332        assert!(indices.contains(&0));
1333        assert!(indices.contains(&1));
1334        assert!(!indices.contains(&3));
1335    }
1336
1337    #[test]
1338    fn test_ball_tree_len() {
1339        let points: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64]).collect();
1340        let tree = MetricBallTree::build(points, 3);
1341        assert_eq!(tree.len(), 10);
1342        assert!(!tree.is_empty());
1343    }
1344
1345    // --- FrechetDistance ---
1346
1347    #[test]
1348    fn test_frechet_identical_curves() {
1349        let c: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]];
1350        let d = FrechetDistance::compute(&c, &c);
1351        assert!(d < 1e-12, "identical curves → Fréchet = 0, got {d}");
1352    }
1353
1354    #[test]
1355    fn test_frechet_offset_curves() {
1356        // Curve q is curve p shifted up by 1
1357        let p: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
1358        let q: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 1.0]];
1359        let d = FrechetDistance::compute(&p, &q);
1360        assert!((d - 1.0).abs() < 1e-10, "expected 1.0, got {d}");
1361    }
1362
1363    #[test]
1364    fn test_frechet_single_point() {
1365        let p = vec![vec![0.0, 0.0]];
1366        let q = vec![vec![3.0, 4.0]];
1367        let d = FrechetDistance::compute(&p, &q);
1368        assert!((d - 5.0).abs() < 1e-10, "expected 5.0, got {d}");
1369    }
1370
1371    #[test]
1372    fn test_frechet_with_manhattan_metric() {
1373        let p: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
1374        let q: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 1.0]];
1375        let d = FrechetDistance::compute_with_metric(&p, &q, &ManhattanMetric);
1376        assert!((d - 1.0).abs() < 1e-10);
1377    }
1378
1379    #[test]
1380    fn test_hausdorff() {
1381        let p: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![2.0, 0.0]];
1382        let q: Vec<Vec<f64>> = vec![vec![0.0, 1.0]];
1383        let h = FrechetDistance::hausdorff(&p, &q);
1384        // d([2,0],[0,1]) = sqrt(5) ≈ 2.236
1385        assert!(h > 2.0);
1386    }
1387
1388    // --- Distance matrix and utilities ---
1389
1390    #[test]
1391    fn test_distance_matrix_diagonal_zero() {
1392        let m = EuclideanMetric;
1393        let pts = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
1394        let mat = distance_matrix(&m, &pts);
1395        for i in 0..3 {
1396            assert!(mat[i][i].abs() < 1e-12);
1397        }
1398    }
1399
1400    #[test]
1401    fn test_distance_matrix_symmetry() {
1402        let m = ManhattanMetric;
1403        let pts = vec![vec![0.0, 0.0], vec![1.0, 2.0], vec![3.0, -1.0]];
1404        let mat = distance_matrix(&m, &pts);
1405        for i in 0..3 {
1406            for j in 0..3 {
1407                assert!((mat[i][j] - mat[j][i]).abs() < 1e-12);
1408            }
1409        }
1410    }
1411
1412    #[test]
1413    fn test_triangle_inequality_check_pass() {
1414        let m = EuclideanMetric;
1415        let pts = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
1416        let mat = distance_matrix(&m, &pts);
1417        assert!(check_triangle_inequality(&mat).is_ok());
1418    }
1419
1420    #[test]
1421    fn test_set_diameter_empty() {
1422        let m = EuclideanMetric;
1423        let pts: Vec<Vec<f64>> = vec![];
1424        assert!((set_diameter(&m, &pts)).abs() < 1e-12);
1425    }
1426
1427    // --- GeodesicMetric ---
1428
1429    #[test]
1430    fn test_geodesic_simple() {
1431        let mut g = GeodesicMetric::new(4);
1432        g.add_edge(0, 1, 1.0);
1433        g.add_edge(1, 2, 2.0);
1434        g.add_edge(2, 3, 1.0);
1435        assert!((g.distance(&0, &3) - 4.0).abs() < 1e-12);
1436    }
1437
1438    #[test]
1439    fn test_geodesic_shortest_path() {
1440        let mut g = GeodesicMetric::new(3);
1441        g.add_edge(0, 1, 10.0);
1442        g.add_edge(0, 2, 1.0);
1443        g.add_edge(2, 1, 2.0);
1444        assert!((g.distance(&0, &1) - 3.0).abs() < 1e-12);
1445    }
1446
1447    #[test]
1448    fn test_geodesic_self_distance() {
1449        let g = GeodesicMetric::new(3);
1450        assert!((g.distance(&0, &0)).abs() < 1e-12);
1451    }
1452}