1#![doc = include_str!("../README.md")]
2
3use std::{collections::VecDeque, num::NonZeroU32};
4
5pub use augurs_core::DistanceMatrix;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum DbscanCluster {
12 Noise,
14 Cluster(NonZeroU32),
21}
22
23impl DbscanCluster {
24 pub fn is_noise(&self) -> bool {
26 matches!(self, Self::Noise)
27 }
28
29 pub fn is_cluster(&self) -> bool {
31 matches!(self, Self::Cluster(_))
32 }
33
34 pub fn as_i32(&self) -> i32 {
36 match self {
37 Self::Noise => -1,
38 Self::Cluster(id) => id.get() as i32,
39 }
40 }
41
42 fn increment(&mut self) {
43 match self {
44 Self::Noise => unreachable!(),
45 Self::Cluster(id) => *id = id.checked_add(1).expect("cluster ID overflow"),
46 }
47 }
48}
49
50#[cfg(test)]
52impl PartialEq<i32> for DbscanCluster {
53 fn eq(&self, other: &i32) -> bool {
54 if self.is_noise() {
55 *other == -1
56 } else {
57 self.as_i32() == *other
58 }
59 }
60}
61
62#[derive(Debug)]
64pub struct DbscanClusterer {
65 epsilon: f64,
66 min_cluster_size: usize,
67}
68
69impl DbscanClusterer {
70 pub fn new(epsilon: f64, min_cluster_size: usize) -> Self {
78 Self {
79 epsilon,
80 min_cluster_size,
81 }
82 }
83
84 pub fn epsilon(&self) -> f64 {
87 self.epsilon
88 }
89
90 pub fn min_cluster_size(&self) -> usize {
93 self.min_cluster_size
94 }
95
96 pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec<DbscanCluster> {
100 let n = distance_matrix.shape().0;
101 let mut clusters = vec![DbscanCluster::Noise; n];
102 let mut cluster = DbscanCluster::Cluster(NonZeroU32::new(1).unwrap());
103 let mut visited = vec![false; n];
104 let mut to_visit = VecDeque::with_capacity(n);
105
106 let mut neighbours = Vec::with_capacity(n);
108
109 for (i, d) in distance_matrix.iter().enumerate() {
110 if clusters[i].is_cluster() {
112 continue;
113 }
114 self.find_neighbours(i, d, &mut neighbours);
115 if neighbours.len() < self.min_cluster_size - 1 {
116 continue;
118 }
119 visited[i] = true;
122 clusters[i] = cluster;
123 for neighbour in neighbours.drain(..) {
125 if clusters[neighbour].is_noise() {
126 visited[neighbour] = true;
127 to_visit.push_back(neighbour);
128 }
129 }
130
131 while let Some(candidate) = to_visit.pop_front() {
133 clusters[candidate] = cluster;
134 self.find_neighbours(candidate, &distance_matrix[candidate], &mut neighbours);
135 if neighbours.len() >= self.min_cluster_size - 1 {
136 for neighbour in neighbours.drain(..) {
138 if !visited[neighbour] {
139 visited[neighbour] = true;
140 to_visit.push_back(neighbour);
141 }
142 }
143 }
144 }
145 cluster.increment();
146 }
147 clusters
148 }
149
150 #[inline]
151 fn find_neighbours(&self, i: usize, dists: &[f64], n: &mut Vec<usize>) {
152 n.clear();
153 n.extend(
154 dists
155 .iter()
156 .enumerate()
157 .filter(|(j, &x)| i != *j && x <= self.epsilon)
158 .map(|(j, _)| j),
159 );
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use super::*;
166
167 #[test]
168 fn dbscan() {
169 let distance_matrix = vec![
170 vec![0.0, 1.0, 2.0, 3.0],
171 vec![1.0, 0.0, 3.0, 3.0],
172 vec![2.0, 3.0, 0.0, 4.0],
173 vec![3.0, 3.0, 4.0, 0.0],
174 ];
175 let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
176
177 let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix);
178 assert_eq!(clusters, vec![-1, -1, -1, -1]);
179
180 let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
181 assert_eq!(clusters, vec![1, 1, -1, -1]);
182
183 let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
184 assert_eq!(clusters, vec![-1, -1, -1, -1]);
185
186 let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
187 assert_eq!(clusters, vec![1, 1, 1, -1]);
188
189 let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
190 assert_eq!(clusters, vec![1, 1, 1, -1]);
191
192 let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
193 assert_eq!(clusters, vec![1, 1, 1, 1]);
194 }
195
196 #[test]
197 fn dbscan_real() {
198 let distance_matrix = include_str!("../data/dist.csv")
199 .lines()
200 .map(|l| {
201 l.split(',')
202 .map(|s| s.parse::<f64>().unwrap())
203 .collect::<Vec<f64>>()
204 })
205 .collect::<Vec<Vec<f64>>>();
206 let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
207 let clusters = DbscanClusterer::new(10.0, 3).fit(&distance_matrix);
208 let expected = vec![
209 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
210 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
211 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
212 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
213 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
214 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
215 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
216 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
217 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
218 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
219 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
220 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
221 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
222 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
223 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
224 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
225 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
226 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
227 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
228 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
229 1, 1, 1, 1, 1, 3, -1, 3, -1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
230 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
231 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
232 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
233 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
234 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
235 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
236 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
237 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
238 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
239 ];
240 assert_eq!(clusters, expected);
241 }
242}