linfa_clustering/dbscan/
algorithm.rs

1use crate::dbscan::{DbscanParams, DbscanValidParams};
2use linfa_nn::{
3    distance::{Distance, L2Dist},
4    CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex,
5};
6use ndarray::{Array1, ArrayBase, Data, Ix2};
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11use linfa::Float;
12use linfa::{traits::Transformer, DatasetBase};
13
14#[derive(Clone, Debug, PartialEq, Eq)]
15#[cfg_attr(
16    feature = "serde",
17    derive(Serialize, Deserialize),
18    serde(crate = "serde_crate")
19)]
20/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
21/// clusters together points which are close together with enough neighbors
22/// labelled points which are sparsely neighbored as noise. As points may be
23/// part of a cluster or noise the predict method returns
24/// `Array1<Option<usize>>`
25///
26/// As it groups together points in dense regions the number of clusters is
27/// determined by the dataset and distance tolerance not the user.
28///
29/// We provide an implemention of the standard O(N^2) query-based algorithm
30/// of which more details can be found in the next section or
31/// [here](https://en.wikipedia.org/wiki/DBSCAN).
32///
33/// The standard DBSCAN algorithm isn't iterative and therefore there's
34/// no fit method provided only predict.
35///
36/// ## The algorithm
37///
38/// The algorithm iterates over each point in the dataset and for every point
39/// not yet assigned to a cluster:
40/// - Find all points within the neighborhood of size `tolerance`
41/// - If the number of points in the neighborhood is below a minimum size label as noise
42/// - Otherwise label the point with the cluster ID and repeat with each of the neighbours
43///
44/// ## Tutorial
45///
46/// Let's do a walkthrough of an example running DBSCAN on some data.
47///
48/// ```rust
49/// use linfa::traits::*;
50/// use linfa_clustering::{DbscanParams, Dbscan};
51/// use linfa_datasets::generate;
52/// use ndarray::{Axis, array, s};
53/// use ndarray_rand::rand::SeedableRng;
54/// use rand_xoshiro::Xoshiro256Plus;
55/// use approx::assert_abs_diff_eq;
56///
57/// // Our random number generator, seeded for reproducibility
58/// let seed = 42;
59/// let mut rng = Xoshiro256Plus::seed_from_u64(seed);
60///
61/// // `expected_centroids` has shape `(n_centroids, n_features)`
62/// // i.e. three points in the 2-dimensional plane
63/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
64/// // Let's generate a synthetic dataset: three blobs of observations
65/// // (100 points each) centered around our `expected_centroids`
66/// let observations = generate::blobs(100, &expected_centroids, &mut rng);
67///
68/// // Let's configure and run our DBSCAN algorithm
69/// // We use the builder pattern to specify the hyperparameters
70/// // `min_points` is the only mandatory parameter.
71/// // If you don't specify the others (e.g. `tolerance`)
72/// // default values will be used.
73/// let min_points = 3;
74/// let clusters = Dbscan::params(min_points)
75///     .tolerance(1e-2)
76///     .transform(&observations)
77///     .unwrap();
78/// // Points are `None` if noise `Some(id)` if belonging to a cluster.
79/// ```
80///
81pub struct Dbscan;
82
83impl Dbscan {
84    /// Configures the hyperparameters with the minimum number of points required to form a cluster
85    ///
86    /// Defaults are provided if the optional parameters are not specified:
87    /// * `tolerance = 1e-4`
88    /// * `dist_fn = L2Dist` (Euclidean distance)
89    /// * `nn_algo = KdTree`
90    pub fn params<F: Float>(min_points: usize) -> DbscanParams<F, L2Dist, CommonNearestNeighbour> {
91        Self::params_with(min_points, L2Dist, CommonNearestNeighbour::KdTree)
92    }
93
94    /// Configures the hyperparameters with the minimum number of points, a custom distance metric,
95    /// and a custom nearest neighbour algorithm
96    pub fn params_with<F: Float, D: Distance<F>, N: NearestNeighbour>(
97        min_points: usize,
98        dist_fn: D,
99        nn_algo: N,
100    ) -> DbscanParams<F, D, N> {
101        DbscanParams::new(min_points, dist_fn, nn_algo)
102    }
103}
104
105impl<F: Float, D: Data<Elem = F>, DF: Distance<F>, N: NearestNeighbour>
106    Transformer<&ArrayBase<D, Ix2>, Array1<Option<usize>>> for DbscanValidParams<F, DF, N>
107{
108    fn transform(&self, observations: &ArrayBase<D, Ix2>) -> Array1<Option<usize>> {
109        let mut cluster_memberships = Array1::from_elem(observations.nrows(), None);
110        let mut current_cluster_id = 0;
111        // Tracks whether a value is in the search queue to prevent duplicates
112        let mut search_found = vec![false; observations.nrows()];
113        let mut search_queue = VecDeque::with_capacity(observations.nrows());
114
115        // Construct NN index
116        let nn = match self.nn_algo.from_batch(observations, self.dist_fn.clone()) {
117            Ok(nn) => nn,
118            Err(linfa_nn::BuildError::ZeroDimension) => {
119                return Array1::from_elem(observations.nrows(), None)
120            }
121            Err(e) => panic!("Unexpected nearest neighbour error: {}", e),
122        };
123
124        for i in 0..observations.nrows() {
125            if cluster_memberships[i].is_some() {
126                continue;
127            }
128            let (neighbor_count, neighbors) =
129                self.find_neighbors(&*nn, i, observations, self.tolerance, &cluster_memberships);
130            if neighbor_count < self.min_points {
131                continue;
132            }
133            neighbors.iter().for_each(|&n| search_found[n] = true);
134            search_queue.extend(neighbors.into_iter());
135
136            // Now go over the neighbours adding them to the cluster
137            cluster_memberships[i] = Some(current_cluster_id);
138
139            while let Some(candidate_idx) = search_queue.pop_front() {
140                search_found[candidate_idx] = false;
141
142                let (neighbor_count, neighbors) = self.find_neighbors(
143                    &*nn,
144                    candidate_idx,
145                    observations,
146                    self.tolerance,
147                    &cluster_memberships,
148                );
149                // Make the candidate a part of the cluster even if it's not a core point
150                cluster_memberships[candidate_idx] = Some(current_cluster_id);
151                if neighbor_count >= self.min_points {
152                    for n in neighbors.into_iter() {
153                        if !search_found[n] {
154                            search_queue.push_back(n);
155                            search_found[n] = true;
156                        }
157                    }
158                }
159            }
160            current_cluster_id += 1;
161        }
162        cluster_memberships
163    }
164}
165
166impl<F: Float, D: Data<Elem = F>, DF: Distance<F>, N: NearestNeighbour, T>
167    Transformer<
168        DatasetBase<ArrayBase<D, Ix2>, T>,
169        DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>>,
170    > for DbscanValidParams<F, DF, N>
171{
172    fn transform(
173        &self,
174        dataset: DatasetBase<ArrayBase<D, Ix2>, T>,
175    ) -> DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
176        let predicted = self.transform(dataset.records());
177        dataset.with_targets(predicted)
178    }
179}
180
181impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanValidParams<F, D, N> {
182    fn find_neighbors(
183        &self,
184        nn: &dyn NearestNeighbourIndex<F>,
185        idx: usize,
186        observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
187        eps: F,
188        clusters: &Array1<Option<usize>>,
189    ) -> (usize, Vec<usize>) {
190        let candidate = observations.row(idx);
191        let mut res = Vec::with_capacity(self.min_points);
192        let mut count = 0;
193
194        // Unwrap here is fine because we don't expect any dimension mismatch when calling
195        // within_range with points from the observations
196        for (_, i) in nn.within_range(candidate.view(), eps).unwrap().into_iter() {
197            count += 1;
198            if clusters[i].is_none() && i != idx {
199                res.push(i);
200            }
201        }
202        (count, res)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use linfa::ParamGuard;
210    use linfa_nn::{distance::L1Dist, BallTree};
211    use ndarray::{arr1, arr2, s, Array2};
212
213    #[test]
214    fn nested_clusters() {
215        // Create a circuit of points and then a cluster in the centre
216        // and ensure they are identified as two separate clusters
217        let mut data: Array2<f64> = Array2::zeros((50, 2));
218        let rising = Array1::linspace(0.0, 8.0, 10);
219        data.column_mut(0).slice_mut(s![0..10]).assign(&rising);
220        data.column_mut(0).slice_mut(s![10..20]).assign(&rising);
221        data.column_mut(1).slice_mut(s![20..30]).assign(&rising);
222        data.column_mut(1).slice_mut(s![30..40]).assign(&rising);
223
224        data.column_mut(1).slice_mut(s![0..10]).fill(0.0);
225        data.column_mut(1).slice_mut(s![10..20]).fill(8.0);
226        data.column_mut(0).slice_mut(s![20..30]).fill(0.0);
227        data.column_mut(0).slice_mut(s![30..40]).fill(8.0);
228
229        data.column_mut(0).slice_mut(s![40..]).fill(5.0);
230        data.column_mut(1).slice_mut(s![40..]).fill(5.0);
231
232        let labels = Dbscan::params(2)
233            .tolerance(1.0)
234            .check()
235            .unwrap()
236            .transform(&data);
237
238        assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0)));
239        assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1)));
240    }
241
242    #[test]
243    fn non_cluster_points() {
244        let mut data: Array2<f64> = Array2::zeros((5, 2));
245        data.row_mut(0).assign(&arr1(&[10.0, 10.0]));
246
247        let labels = Dbscan::params(4).check().unwrap().transform(&data);
248
249        let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]);
250        assert_eq!(labels, expected);
251    }
252
253    #[test]
254    fn border_points() {
255        let data: Array2<f64> = arr2(&[
256            // Outlier
257            [0.0, 2.0],
258            // Core point
259            [0.0, 0.0],
260            // Border points
261            [0.0, 1.0],
262            [0.0, -1.0],
263            [-1.0, 0.0],
264            [1.0, 0.0],
265        ]);
266
267        // Run the dbscan with tolerance of 1.1, 5 min points for density
268        let labels = Dbscan::params(5)
269            .tolerance(1.1)
270            .check()
271            .unwrap()
272            .transform(&data);
273
274        assert_eq!(labels[0], None);
275        for id in labels.slice(s![1..]).iter() {
276            assert_eq!(id, &Some(0));
277        }
278    }
279
280    #[test]
281    fn l1_dist() {
282        let data: Array2<f64> = arr2(&[
283            // Outlier
284            [0.0, 6.0],
285            // Core point
286            [0.0, 0.0],
287            // Border points
288            [2.0, 3.0],
289            [1.0, -3.0],
290            [-4.0, 1.0],
291            [1.0, 1.0],
292        ]);
293
294        // Run the L1-dist dbscan with tolerance of 5.01, 5 min points for density
295        let labels = Dbscan::params_with(5, L1Dist, BallTree)
296            .tolerance(5.01)
297            .check()
298            .unwrap()
299            .transform(&data);
300
301        assert_eq!(labels[0], None);
302        for id in labels.slice(s![1..]).iter() {
303            assert_eq!(id, &Some(0));
304        }
305    }
306
307    #[test]
308    fn dataset_too_small() {
309        let data: Array2<f64> = Array2::zeros((3, 2));
310
311        let labels = Dbscan::params(4).check().unwrap().transform(&data);
312        assert!(labels.iter().all(|x| x.is_none()));
313    }
314}