steiner_tree/minimum_spanning_tree/
mod.rs

1// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5//! Compute minimum spanning trees (MST) of points in a multi-dimensional space
6//! with a defined distance metric.
7
8mod locality_sensitive_hash;
9mod nearest_neighbor;
10
11use locality_sensitive_hash::*;
12use nearest_neighbor::*;
13
14use num_traits::Zero;
15use std::cmp::Ordering;
16use std::collections::{BinaryHeap, HashMap, HashSet};
17use std::hash::Hash;
18
19pub fn minimum_spanning_tree_hamming_space<P>(points: &[P], exact: bool) -> HashMap<P, P>
20where
21    P: HammingPoint<Coord = bool> + Clone + Hash + Eq,
22{
23    if points.is_empty() {
24        return Default::default();
25    }
26
27    // Build locality sensitive hash.
28    let lsh = {
29        let bin_size = 8;
30        let num_points = points.len() as f64;
31        let num_bins = num_points / (bin_size as f64);
32        let num_sub_samples = num_bins.log2().ceil() as usize;
33        let lsh: HammingLSH<_, usize> = HammingLSH::new(num_sub_samples, &points[1..]);
34        lsh
35    };
36
37    // Build nearest-neighbor search.
38    let mut search = LSHNearestNeighbourSearch::new(lsh);
39
40    for p in &points[1..] {
41        search.insert(p.clone());
42    }
43
44    let mut front = BinaryHeap::new();
45    front.push(PQElement {
46        distance: 0u32,
47        point: points[0].clone(),
48        prev: points[0].clone(),
49    });
50
51    // Edges of the minimum spanning tree.
52    let mut tree_edges = HashMap::new();
53
54    // Store potential shortest edges.
55    // If a point `p` is added to the tree, then all other nodes who had `n` as a nearest neighbour
56    // need to get another nearest-neighbour candidate.
57    let mut potential_prev_nodes: HashMap<P, Vec<P>> = HashMap::new();
58
59    while let Some(e) = front.pop() {
60        if tree_edges.contains_key(&e.point) {
61            continue;
62        }
63
64        println!(
65            "progress: {:.1} %",
66            (100. * (tree_edges.len() as f64) / (points.len() as f64))
67        );
68
69        search.remove(&e.point);
70
71        // All nodes who had `e.point` as nearest neighbour need to be assigned with new nearest neighbours.
72        {
73            let users_of_p = potential_prev_nodes.remove(&e.point).unwrap_or(Vec::new());
74            // `e.point` was a nearest neighbour of itself. Hence it needs to be assigned with a yes unused nearest neighbour.
75            let users_of_p = std::iter::once(e.point.clone()).chain(users_of_p); // Skip already used nodes.
76
77            // Find yet unused nearest neighbours for all points who had `e.point` as nearest neighbour.
78            for p in users_of_p {
79                let maybe_nearest_neighbour = if exact {
80                    search.exact_nearest_neighbour(&p)
81                } else {
82                    search
83                        .approx_nearest_neighbour(&p)
84                        .or_else(|| search.exact_nearest_neighbour(&p))
85                };
86
87                if let Some(n) = maybe_nearest_neighbour {
88                    let dist = p.distance(n);
89
90                    front.push(PQElement {
91                        distance: dist,
92                        point: n.clone(),
93                        prev: p.clone(),
94                    });
95
96                    potential_prev_nodes
97                        .entry(n.clone())
98                        .or_insert(Vec::new())
99                        .push(p);
100                }
101            }
102        }
103
104        println!("dist = {}", e.point.distance(&e.prev));
105        // Create edge.
106        if e.point != e.prev {
107            // Skip the edge from the first point to itself.
108            tree_edges.insert(e.point, e.prev);
109        }
110    }
111
112    tree_edges
113}
114
115/// Element in the priority queue.
116/// Used to sort edges by ascending distance.
117struct PQElement<D, V> {
118    distance: D,
119    point: V,
120    prev: V,
121}
122
123impl<D: Ord + Eq, V> Ord for PQElement<D, V> {
124    fn cmp(&self, other: &Self) -> Ordering {
125        self.distance.cmp(&other.distance).reverse()
126    }
127}
128
129impl<D: PartialOrd, V> PartialOrd for PQElement<D, V> {
130    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
131        self.distance
132            .partial_cmp(&other.distance)
133            .map(|o| o.reverse())
134    }
135}
136
137impl<D: PartialEq, V> Eq for PQElement<D, V> {}
138
139impl<D: PartialEq, V> PartialEq for PQElement<D, V> {
140    fn eq(&self, other: &Self) -> bool {
141        self.distance.eq(&other.distance)
142    }
143}
144
145#[test]
146fn test_minimum_spanning_tree_with_hamming_distance_of_integers() {
147    // Under Hamming distance metric the minimum spanning tree of natural numbers (with zero)
148    // until `n` is equal to a 'Gray code'. E.g. a sequence of integers where neigbors differ only
149    // by one bit.
150    let num_points = 100;
151    let points: Vec<_> = (0..num_points).collect();
152
153    let tree_edges = minimum_spanning_tree_hamming_space(&points, true);
154
155    dbg!(&tree_edges);
156
157    for (a, b) in tree_edges.iter() {
158        let dist = a.distance(b);
159        assert_eq!(dist, 1);
160    }
161
162    assert_eq!(tree_edges.len(), num_points - 1);
163}