1use crate::dbscan::{DbscanParams, DbscanValidParams};
2use linfa_nn::{
3 distance::{Distance, L2Dist},
4 CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex,
5};
6use ndarray::{Array1, ArrayBase, Data, Ix2};
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11use linfa::Float;
12use linfa::{traits::Transformer, DatasetBase};
13
14#[derive(Clone, Debug, PartialEq, Eq)]
15#[cfg_attr(
16 feature = "serde",
17 derive(Serialize, Deserialize),
18 serde(crate = "serde_crate")
19)]
20pub struct Dbscan;
82
83impl Dbscan {
84 pub fn params<F: Float>(min_points: usize) -> DbscanParams<F, L2Dist, CommonNearestNeighbour> {
91 Self::params_with(min_points, L2Dist, CommonNearestNeighbour::KdTree)
92 }
93
94 pub fn params_with<F: Float, D: Distance<F>, N: NearestNeighbour>(
97 min_points: usize,
98 dist_fn: D,
99 nn_algo: N,
100 ) -> DbscanParams<F, D, N> {
101 DbscanParams::new(min_points, dist_fn, nn_algo)
102 }
103}
104
105impl<F: Float, D: Data<Elem = F>, DF: Distance<F>, N: NearestNeighbour>
106 Transformer<&ArrayBase<D, Ix2>, Array1<Option<usize>>> for DbscanValidParams<F, DF, N>
107{
108 fn transform(&self, observations: &ArrayBase<D, Ix2>) -> Array1<Option<usize>> {
109 let mut cluster_memberships = Array1::from_elem(observations.nrows(), None);
110 let mut current_cluster_id = 0;
111 let mut search_found = vec![false; observations.nrows()];
113 let mut search_queue = VecDeque::with_capacity(observations.nrows());
114
115 let nn = match self.nn_algo.from_batch(observations, self.dist_fn.clone()) {
117 Ok(nn) => nn,
118 Err(linfa_nn::BuildError::ZeroDimension) => {
119 return Array1::from_elem(observations.nrows(), None)
120 }
121 Err(e) => panic!("Unexpected nearest neighbour error: {}", e),
122 };
123
124 for i in 0..observations.nrows() {
125 if cluster_memberships[i].is_some() {
126 continue;
127 }
128 let (neighbor_count, neighbors) =
129 self.find_neighbors(&*nn, i, observations, self.tolerance, &cluster_memberships);
130 if neighbor_count < self.min_points {
131 continue;
132 }
133 neighbors.iter().for_each(|&n| search_found[n] = true);
134 search_queue.extend(neighbors.into_iter());
135
136 cluster_memberships[i] = Some(current_cluster_id);
138
139 while let Some(candidate_idx) = search_queue.pop_front() {
140 search_found[candidate_idx] = false;
141
142 let (neighbor_count, neighbors) = self.find_neighbors(
143 &*nn,
144 candidate_idx,
145 observations,
146 self.tolerance,
147 &cluster_memberships,
148 );
149 cluster_memberships[candidate_idx] = Some(current_cluster_id);
151 if neighbor_count >= self.min_points {
152 for n in neighbors.into_iter() {
153 if !search_found[n] {
154 search_queue.push_back(n);
155 search_found[n] = true;
156 }
157 }
158 }
159 }
160 current_cluster_id += 1;
161 }
162 cluster_memberships
163 }
164}
165
166impl<F: Float, D: Data<Elem = F>, DF: Distance<F>, N: NearestNeighbour, T>
167 Transformer<
168 DatasetBase<ArrayBase<D, Ix2>, T>,
169 DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>>,
170 > for DbscanValidParams<F, DF, N>
171{
172 fn transform(
173 &self,
174 dataset: DatasetBase<ArrayBase<D, Ix2>, T>,
175 ) -> DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
176 let predicted = self.transform(dataset.records());
177 dataset.with_targets(predicted)
178 }
179}
180
181impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanValidParams<F, D, N> {
182 fn find_neighbors(
183 &self,
184 nn: &dyn NearestNeighbourIndex<F>,
185 idx: usize,
186 observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
187 eps: F,
188 clusters: &Array1<Option<usize>>,
189 ) -> (usize, Vec<usize>) {
190 let candidate = observations.row(idx);
191 let mut res = Vec::with_capacity(self.min_points);
192 let mut count = 0;
193
194 for (_, i) in nn.within_range(candidate.view(), eps).unwrap().into_iter() {
197 count += 1;
198 if clusters[i].is_none() && i != idx {
199 res.push(i);
200 }
201 }
202 (count, res)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use linfa::ParamGuard;
210 use linfa_nn::{distance::L1Dist, BallTree};
211 use ndarray::{arr1, arr2, s, Array2};
212
213 #[test]
214 fn nested_clusters() {
215 let mut data: Array2<f64> = Array2::zeros((50, 2));
218 let rising = Array1::linspace(0.0, 8.0, 10);
219 data.column_mut(0).slice_mut(s![0..10]).assign(&rising);
220 data.column_mut(0).slice_mut(s![10..20]).assign(&rising);
221 data.column_mut(1).slice_mut(s![20..30]).assign(&rising);
222 data.column_mut(1).slice_mut(s![30..40]).assign(&rising);
223
224 data.column_mut(1).slice_mut(s![0..10]).fill(0.0);
225 data.column_mut(1).slice_mut(s![10..20]).fill(8.0);
226 data.column_mut(0).slice_mut(s![20..30]).fill(0.0);
227 data.column_mut(0).slice_mut(s![30..40]).fill(8.0);
228
229 data.column_mut(0).slice_mut(s![40..]).fill(5.0);
230 data.column_mut(1).slice_mut(s![40..]).fill(5.0);
231
232 let labels = Dbscan::params(2)
233 .tolerance(1.0)
234 .check()
235 .unwrap()
236 .transform(&data);
237
238 assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0)));
239 assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1)));
240 }
241
242 #[test]
243 fn non_cluster_points() {
244 let mut data: Array2<f64> = Array2::zeros((5, 2));
245 data.row_mut(0).assign(&arr1(&[10.0, 10.0]));
246
247 let labels = Dbscan::params(4).check().unwrap().transform(&data);
248
249 let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]);
250 assert_eq!(labels, expected);
251 }
252
253 #[test]
254 fn border_points() {
255 let data: Array2<f64> = arr2(&[
256 [0.0, 2.0],
258 [0.0, 0.0],
260 [0.0, 1.0],
262 [0.0, -1.0],
263 [-1.0, 0.0],
264 [1.0, 0.0],
265 ]);
266
267 let labels = Dbscan::params(5)
269 .tolerance(1.1)
270 .check()
271 .unwrap()
272 .transform(&data);
273
274 assert_eq!(labels[0], None);
275 for id in labels.slice(s![1..]).iter() {
276 assert_eq!(id, &Some(0));
277 }
278 }
279
280 #[test]
281 fn l1_dist() {
282 let data: Array2<f64> = arr2(&[
283 [0.0, 6.0],
285 [0.0, 0.0],
287 [2.0, 3.0],
289 [1.0, -3.0],
290 [-4.0, 1.0],
291 [1.0, 1.0],
292 ]);
293
294 let labels = Dbscan::params_with(5, L1Dist, BallTree)
296 .tolerance(5.01)
297 .check()
298 .unwrap()
299 .transform(&data);
300
301 assert_eq!(labels[0], None);
302 for id in labels.slice(s![1..]).iter() {
303 assert_eq!(id, &Some(0));
304 }
305 }
306
307 #[test]
308 fn dataset_too_small() {
309 let data: Array2<f64> = Array2::zeros((3, 2));
310
311 let labels = Dbscan::params(4).check().unwrap().transform(&data);
312 assert!(labels.iter().all(|x| x.is_none()));
313 }
314}