Skip to main content

augurs_clustering/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{collections::VecDeque, num::NonZeroU32};
4
5pub use augurs_core::DistanceMatrix;
6
7/// A cluster identified by the DBSCAN algorithm.
8///
9/// This is either a noise cluster, or a cluster with a specific ID.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum DbscanCluster {
12    /// A noise cluster.
13    Noise,
14    /// A cluster with the given ID.
15    ///
16    /// The ID is not guaranteed to remain the same between runs of the algorithm.
17    ///
18    /// We use a `NonZeroU32` here to ensure that the ID is never zero. This is mostly
19    /// just a size optimization.
20    Cluster(NonZeroU32),
21}
22
23impl DbscanCluster {
24    /// Returns true if this cluster is a noise cluster.
25    pub fn is_noise(&self) -> bool {
26        matches!(self, Self::Noise)
27    }
28
29    /// Returns true if this cluster is a cluster with the given ID.
30    pub fn is_cluster(&self) -> bool {
31        matches!(self, Self::Cluster(_))
32    }
33
34    /// Returns the ID of the cluster, if it is a cluster, or `-1` if it is a noise cluster.
35    pub fn as_i32(&self) -> i32 {
36        match self {
37            Self::Noise => -1,
38            Self::Cluster(id) => id.get() as i32,
39        }
40    }
41
42    fn increment(&mut self) {
43        match self {
44            Self::Noise => unreachable!(),
45            Self::Cluster(id) => *id = id.checked_add(1).expect("cluster ID overflow"),
46        }
47    }
48}
49
50// Simplify tests by allowing comparisons with i32.
51#[cfg(test)]
52impl PartialEq<i32> for DbscanCluster {
53    fn eq(&self, other: &i32) -> bool {
54        if self.is_noise() {
55            *other == -1
56        } else {
57            self.as_i32() == *other
58        }
59    }
60}
61
62/// DBSCAN clustering algorithm.
63#[derive(Debug)]
64pub struct DbscanClusterer {
65    epsilon: f64,
66    min_cluster_size: usize,
67}
68
69impl DbscanClusterer {
70    /// Create a new DBSCAN instance clustering instance.
71    ///
72    /// # Arguments
73    /// * `epsilon` - The maximum distance between two samples for one to be considered as in the
74    ///   neighborhood of the other.
75    /// * `min_cluster_size` - The number of samples in a neighborhood for a point to be considered as a core
76    ///   point.
77    pub fn new(epsilon: f64, min_cluster_size: usize) -> Self {
78        Self {
79            epsilon,
80            min_cluster_size,
81        }
82    }
83
84    /// Return epsilon, the maximum distance between two samples for one to be considered as in the
85    /// neighborhood of the other.
86    pub fn epsilon(&self) -> f64 {
87        self.epsilon
88    }
89
90    /// Return the minimum number of points in a neighborhood for a point to be considered as a core
91    /// point.
92    pub fn min_cluster_size(&self) -> usize {
93        self.min_cluster_size
94    }
95
96    /// Run the DBSCAN clustering algorithm.
97    ///
98    /// The return value is a vector of cluster assignments, with `DbscanCluster::Noise` indicating noise.
99    pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec<DbscanCluster> {
100        let n = distance_matrix.shape().0;
101        let mut clusters = vec![DbscanCluster::Noise; n];
102        let mut cluster = DbscanCluster::Cluster(NonZeroU32::new(1).unwrap());
103        let mut visited = vec![false; n];
104        let mut to_visit = VecDeque::with_capacity(n);
105
106        // We'll reuse this vector to avoid reallocations.
107        let mut neighbours = Vec::with_capacity(n);
108
109        for (i, d) in distance_matrix.iter().enumerate() {
110            // Skip if already assigned to a cluster.
111            if clusters[i].is_cluster() {
112                continue;
113            }
114            self.find_neighbours(i, d, &mut neighbours);
115            if neighbours.len() < self.min_cluster_size - 1 {
116                // Not a core point: leave marked as noise.
117                continue;
118            }
119            // We're in a cluster: expand it to all core neighbours.
120            // Mark this point as visited so we can skip checking it later.
121            visited[i] = true;
122            clusters[i] = cluster;
123            // Mark all noise neighbours as visited and add them to the queue.
124            for neighbour in neighbours.drain(..) {
125                if clusters[neighbour].is_noise() {
126                    visited[neighbour] = true;
127                    to_visit.push_back(neighbour);
128                }
129            }
130
131            // Expand the cluster.
132            while let Some(candidate) = to_visit.pop_front() {
133                clusters[candidate] = cluster;
134                self.find_neighbours(candidate, &distance_matrix[candidate], &mut neighbours);
135                if neighbours.len() >= self.min_cluster_size - 1 {
136                    // Add unvisited extended neighbours to the queue.
137                    for neighbour in neighbours.drain(..) {
138                        if !visited[neighbour] {
139                            visited[neighbour] = true;
140                            to_visit.push_back(neighbour);
141                        }
142                    }
143                }
144            }
145            cluster.increment();
146        }
147        clusters
148    }
149
150    #[inline]
151    fn find_neighbours(&self, i: usize, dists: &[f64], n: &mut Vec<usize>) {
152        n.clear();
153        n.extend(
154            dists
155                .iter()
156                .enumerate()
157                .filter(|(j, &x)| i != *j && x <= self.epsilon)
158                .map(|(j, _)| j),
159        );
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use super::*;
166
167    #[test]
168    fn dbscan() {
169        let distance_matrix = vec![
170            vec![0.0, 1.0, 2.0, 3.0],
171            vec![1.0, 0.0, 3.0, 3.0],
172            vec![2.0, 3.0, 0.0, 4.0],
173            vec![3.0, 3.0, 4.0, 0.0],
174        ];
175        let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
176
177        let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix);
178        assert_eq!(clusters, vec![-1, -1, -1, -1]);
179
180        let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
181        assert_eq!(clusters, vec![1, 1, -1, -1]);
182
183        let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
184        assert_eq!(clusters, vec![-1, -1, -1, -1]);
185
186        let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
187        assert_eq!(clusters, vec![1, 1, 1, -1]);
188
189        let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
190        assert_eq!(clusters, vec![1, 1, 1, -1]);
191
192        let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
193        assert_eq!(clusters, vec![1, 1, 1, 1]);
194    }
195
196    #[test]
197    fn dbscan_real() {
198        let distance_matrix = include_str!("../data/dist.csv")
199            .lines()
200            .map(|l| {
201                l.split(',')
202                    .map(|s| s.parse::<f64>().unwrap())
203                    .collect::<Vec<f64>>()
204            })
205            .collect::<Vec<Vec<f64>>>();
206        let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
207        let clusters = DbscanClusterer::new(10.0, 3).fit(&distance_matrix);
208        let expected = vec![
209            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
210            1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
211            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
212            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
213            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
214            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
215            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
216            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
217            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
218            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
219            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
220            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
221            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
222            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
223            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
224            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
225            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
226            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
227            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
228            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
229            1, 1, 1, 1, 1, 3, -1, 3, -1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
230            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
231            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
232            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
233            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
234            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
235            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
236            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
237            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
238            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
239        ];
240        assert_eq!(clusters, expected);
241    }
242}