linfa_clustering/optics/
algorithm.rs

1use crate::optics::hyperparams::{OpticsParams, OpticsValidParams};
2use linfa::traits::Transformer;
3use linfa::Float;
4use linfa_nn::distance::{Distance, L2Dist};
5use linfa_nn::{CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex};
6use ndarray::{ArrayView, Ix1, Ix2};
7use noisy_float::{checkers::NumChecker, NoisyFloat};
8#[cfg(feature = "serde")]
9use serde_crate::{Deserialize, Serialize};
10use std::cmp::Ordering;
11use std::collections::BTreeSet;
12use std::ops::Index;
13use std::slice::SliceIndex;
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16#[cfg_attr(
17    feature = "serde",
18    derive(Serialize, Deserialize),
19    serde(crate = "serde_crate")
20)]
21/// OPTICS (Ordering Points To Identify Clustering Structure) is a clustering algorithm that
22/// doesn't explicitly cluster the data but instead creates an "augmented ordering" of the dataset
23/// representing it's density-based clustering structure. This ordering contains information which
24/// is equivalent to the density-based clusterings and can then be used for automatic and
25/// interactive cluster analysis.
26///
27/// OPTICS cluster analysis can be used to derive clusters equivalent to the output of other
28/// clustering algorithms such as DBSCAN. However, due to it's more complicated neighborhood
29/// queries it typically has a higher computational cost than other more specific algorithms.
30///
31/// More details on the OPTICS algorithm can be found
32/// [here](https://www.wikipedia.org/wiki/OPTICS_algorithm)
33pub struct Optics;
34
35/// This struct represents a data point in the dataset with it's associated distances obtained from
36/// the OPTICS analysis
37#[derive(Debug, Clone)]
38#[cfg_attr(
39    feature = "serde",
40    derive(Serialize, Deserialize),
41    serde(crate = "serde_crate")
42)]
43pub struct Sample<F> {
44    /// Index of the observation in the dataset
45    index: usize,
46    /// The core distance
47    core_distance: Option<F>,
48    /// The reachability distance
49    reachability_distance: Option<F>,
50}
51
52impl<F: Float> Sample<F> {
53    /// Create a new neighbor
54    fn new(index: usize) -> Self {
55        Self {
56            index,
57            core_distance: None,
58            reachability_distance: None,
59        }
60    }
61
62    /// Index of the sample in the dataset.
63    pub fn index(&self) -> usize {
64        self.index
65    }
66
67    /// The reachability distance of a sample is the distance between the point and it's cluster
68    /// core or another point whichever is larger.
69    pub fn reachability_distance(&self) -> &Option<F> {
70        &self.reachability_distance
71    }
72
73    /// The distance to the nth closest point where n is the minimum points to form a cluster.
74    pub fn core_distance(&self) -> &Option<F> {
75        &self.core_distance
76    }
77}
78
79impl<F: Float> Eq for Sample<F> {}
80
81impl<F: Float> PartialEq for Sample<F> {
82    fn eq(&self, other: &Self) -> bool {
83        self.reachability_distance == other.reachability_distance
84    }
85}
86
87#[allow(clippy::non_canonical_partial_ord_impl)]
88impl<F: Float> PartialOrd for Sample<F> {
89    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90        self.reachability_distance
91            .partial_cmp(&other.reachability_distance)
92    }
93}
94
95impl<F: Float> Ord for Sample<F> {
96    fn cmp(&self, other: &Self) -> Ordering {
97        self.reachability_distance
98            .map(NoisyFloat::<_, NumChecker>::new)
99            .cmp(
100                &other
101                    .reachability_distance
102                    .map(NoisyFloat::<_, NumChecker>::new),
103            )
104    }
105}
106
107/// The analysis from running OPTICS on a dataset, this allows you iterate over the data points and
108/// access their core and reachability distances. The ordering of the points also doesn't match
109/// that of the dataset instead ordering based on the clustering structure worked out during
110/// analysis.
111#[derive(Clone, Debug, PartialEq)]
112#[cfg_attr(
113    feature = "serde",
114    derive(Serialize, Deserialize),
115    serde(crate = "serde_crate")
116)]
117pub struct OpticsAnalysis<F: Float> {
118    /// A list of the samples in the dataset sorted and with their reachability and core distances
119    /// computed.
120    orderings: Vec<Sample<F>>,
121}
122
123impl<F: Float> OpticsAnalysis<F> {
124    /// Extracts a slice containing all samples in the dataset
125    pub fn as_slice(&self) -> &[Sample<F>] {
126        self.orderings.as_slice()
127    }
128
129    /// Returns an iterator over the samples in the dataset
130    pub fn iter(&self) -> std::slice::Iter<'_, Sample<F>> {
131        self.orderings.iter()
132    }
133}
134
135impl<I, F: Float> Index<I> for OpticsAnalysis<F>
136where
137    I: SliceIndex<[Sample<F>]>,
138{
139    type Output = I::Output;
140
141    fn index(&self, index: I) -> &Self::Output {
142        self.orderings.index(index)
143    }
144}
145
146impl Optics {
147    /// Configures the hyperparameters with the minimum number of points required to form a cluster
148    ///
149    /// Defaults are provided if the optional parameters are not specified:
150    /// * `tolerance = f64::MAX`
151    /// * `dist_fn = L2Dist` (Euclidean distance)
152    /// * `nn_algo = KdTree`
153    pub fn params<F: Float>(min_points: usize) -> OpticsParams<F, L2Dist, CommonNearestNeighbour> {
154        OpticsParams::new(min_points, L2Dist, CommonNearestNeighbour::KdTree)
155    }
156
157    /// Configures the hyperparameters with the minimum number of points, a custom distance metric,
158    /// and a custom nearest neighbour algorithm
159    pub fn params_with<F: Float, D: Distance<F>, N: NearestNeighbour>(
160        min_points: usize,
161        dist_fn: D,
162        nn_algo: N,
163    ) -> OpticsParams<F, D, N> {
164        OpticsParams::new(min_points, dist_fn, nn_algo)
165    }
166}
167
168impl<F: Float, D: Distance<F>, N: NearestNeighbour>
169    Transformer<ArrayView<'_, F, Ix2>, OpticsAnalysis<F>> for OpticsValidParams<F, D, N>
170{
171    fn transform(&self, observations: ArrayView<F, Ix2>) -> OpticsAnalysis<F> {
172        let mut result = OpticsAnalysis { orderings: vec![] };
173
174        let mut points = (0..observations.nrows())
175            .map(Sample::new)
176            .collect::<Vec<_>>();
177
178        let nn = match self
179            .nn_algo()
180            .from_batch(&observations, self.dist_fn().clone())
181        {
182            Ok(nn) => nn,
183            Err(linfa_nn::BuildError::ZeroDimension) => {
184                return OpticsAnalysis { orderings: points }
185            }
186            Err(e) => panic!("Unexpected nearest neighbour error: {}", e),
187        };
188
189        // The BTreeSet is used so that the indexes are ordered to make it easy to find next
190        // index
191        let mut processed = BTreeSet::new();
192        let mut index = 0;
193        let mut seeds = Vec::new();
194        loop {
195            if index == points.len() {
196                break;
197            } else if processed.contains(&index) {
198                index += 1;
199                continue;
200            }
201            let mut expected = if processed.is_empty() { 0 } else { index };
202            let mut points_index = index;
203            // Look for next point to process starting from lowest possible unprocessed index
204            for index in processed.range(index..) {
205                if expected != *index {
206                    points_index = expected;
207                    break;
208                }
209                expected += 1;
210            }
211            index += 1;
212            let neighbors = self.find_neighbors(&*nn, observations.row(points_index));
213            let n = &mut points[points_index];
214            self.set_core_distance(n, &neighbors, observations);
215            if n.core_distance.is_some() {
216                seeds.clear();
217                // Here we get a list of "density reachable" samples that haven't been processed
218                // and sort them by reachability so we can process the closest ones first.
219                self.get_seeds(
220                    observations,
221                    n.clone(),
222                    &neighbors,
223                    &mut points,
224                    &processed,
225                    &mut seeds,
226                );
227                while !seeds.is_empty() {
228                    seeds.sort_unstable_by(|a, b| b.cmp(a));
229                    let (i, min_point) = seeds
230                        .iter()
231                        .enumerate()
232                        .min_by(|(_, a), (_, b)| points[**a].cmp(&points[**b]))
233                        .unwrap();
234                    let n = &mut points[*min_point];
235                    seeds.remove(i);
236                    processed.insert(n.index);
237                    let neighbors = self.find_neighbors(&*nn, observations.row(n.index));
238
239                    self.set_core_distance(n, &neighbors, observations);
240                    result.orderings.push(n.clone());
241                    if n.core_distance.is_some() {
242                        self.get_seeds(
243                            observations,
244                            n.clone(),
245                            &neighbors,
246                            &mut points,
247                            &processed,
248                            &mut seeds,
249                        );
250                    }
251                }
252            } else {
253                // Ensure whole dataset is included so we can see the points with undefined core or
254                // reachability distance
255                result.orderings.push(n.clone());
256                processed.insert(n.index);
257            }
258        }
259        result
260    }
261}
262
263impl<F: Float, D: Distance<F>, N: NearestNeighbour> OpticsValidParams<F, D, N> {
264    /// Given a candidate point, a list of observations, epsilon and list of already
265    /// assigned cluster IDs return a list of observations that neighbor the candidate. This function
266    /// uses euclidean distance and the neighbours are returned in sorted order.
267    fn find_neighbors(
268        &self,
269        nn: &dyn NearestNeighbourIndex<F>,
270        candidate: ArrayView<F, Ix1>,
271    ) -> Vec<Sample<F>> {
272        // Unwrap here is fine because we don't expect any dimension mismatch when calling
273        // within_range with points from the observations
274        nn.within_range(candidate, self.tolerance())
275            .unwrap()
276            .into_iter()
277            .map(|(pt, index)| Sample {
278                index,
279                reachability_distance: Some(self.dist_fn().distance(pt, candidate)),
280                core_distance: None,
281            })
282            .collect()
283    }
284
285    /// Set the core distance given the minimum points in a cluster and the points neighbors
286    fn set_core_distance(
287        &self,
288        point: &mut Sample<F>,
289        neighbors: &[Sample<F>],
290        dataset: ArrayView<F, Ix2>,
291    ) {
292        let observation = dataset.row(point.index);
293        point.core_distance = neighbors
294            .get(self.minimum_points() - 1)
295            .map(|x| dataset.row(x.index))
296            .map(|x| self.dist_fn().distance(observation, x));
297    }
298
299    /// For a sample find the points which are directly density reachable which have not
300    /// yet been processed
301    fn get_seeds(
302        &self,
303        observations: ArrayView<F, Ix2>,
304        sample: Sample<F>,
305        neighbors: &[Sample<F>],
306        points: &mut [Sample<F>],
307        processed: &BTreeSet<usize>,
308        seeds: &mut Vec<usize>,
309    ) {
310        for n in neighbors.iter().filter(|x| !processed.contains(&x.index)) {
311            let dist = self
312                .dist_fn()
313                .distance(observations.row(n.index), observations.row(sample.index));
314            let r_dist = F::max(sample.core_distance.unwrap(), dist);
315            match points[n.index].reachability_distance {
316                None => {
317                    points[n.index].reachability_distance = Some(r_dist);
318                    seeds.push(n.index);
319                }
320                Some(s) if r_dist < s => points[n.index].reachability_distance = Some(r_dist),
321                _ => {}
322            }
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::OpticsError;
331    use linfa::ParamGuard;
332    use linfa_nn::KdTree;
333    use ndarray::Array2;
334    use std::collections::BTreeSet;
335
336    #[test]
337    fn autotraits() {
338        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
339        has_autotraits::<OpticsAnalysis<f64>>();
340        has_autotraits::<Optics>();
341        has_autotraits::<Sample<f64>>();
342        has_autotraits::<OpticsError>();
343        has_autotraits::<OpticsParams<f64, L2Dist, KdTree>>();
344        has_autotraits::<OpticsValidParams<f64, L2Dist, KdTree>>();
345    }
346
347    #[test]
348    fn optics_consistency() {
349        let params = Optics::params(3);
350        let data = vec![1.0, 2.0, 3.0, 8.0, 8.0, 7.0, 2.0, 5.0, 6.0, 7.0, 8.0, 3.0];
351        let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
352
353        let samples = params.transform(data.view()).unwrap();
354
355        // Make sure whole dataset is present:
356        let indexes = samples
357            .orderings
358            .iter()
359            .map(|x| x.index)
360            .collect::<BTreeSet<_>>();
361        assert!((0..data.len()).all(|x| indexes.contains(&x)));
362
363        // As we haven't set a tolerance every point should have a core distance
364        assert!(samples.orderings.iter().all(|x| x.core_distance.is_some()));
365    }
366
367    #[test]
368    fn simple_dataset() {
369        let params = Optics::params(3).tolerance(4.0);
370        //               0    1   2    3     4     5     6     7     8    9     10    11     12
371        let data = vec![
372            1.0, 2.0, 3.0, 10.0, 18.0, 18.0, 15.0, 2.0, 15.0, 18.0, 3.0, 100.0, 101.0,
373        ];
374        let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
375
376        // indexes of groupings of points in the dataset. These will end up with an outlier value
377        // in between them to help separate things
378        let first_grouping = [0, 1, 2, 7, 10].iter().collect::<BTreeSet<_>>();
379        let second_grouping = [4, 5, 6, 8, 9].iter().collect::<BTreeSet<_>>();
380
381        let samples = params.transform(data.view()).unwrap();
382
383        let indexes = samples
384            .orderings
385            .iter()
386            .map(|x| x.index)
387            .collect::<BTreeSet<_>>();
388        assert!((0..data.len()).all(|x| indexes.contains(&x)));
389
390        assert!(samples
391            .orderings
392            .iter()
393            .take(first_grouping.len())
394            .all(|x| first_grouping.contains(&x.index)));
395        let skip_len = first_grouping.len() + 1;
396        assert!(samples
397            .orderings
398            .iter()
399            .skip(skip_len)
400            .take(first_grouping.len())
401            .all(|x| second_grouping.contains(&x.index)));
402
403        let anomaly = samples.orderings.iter().find(|x| x.index == 3).unwrap();
404        assert!(anomaly.core_distance.is_none());
405        assert!(anomaly.reachability_distance.is_none());
406
407        let anomaly = samples.orderings.iter().find(|x| x.index == 11).unwrap();
408        assert!(anomaly.core_distance.is_none());
409        assert!(anomaly.reachability_distance.is_none());
410
411        let anomaly = samples.orderings.iter().find(|x| x.index == 12).unwrap();
412        assert!(anomaly.core_distance.is_none());
413        assert!(anomaly.reachability_distance.is_none());
414    }
415
416    #[test]
417    fn dataset_too_small() {
418        let params = Optics::params(4);
419        let data = vec![1.0, 2.0, 3.0];
420        let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
421
422        let samples = params.transform(data.view()).unwrap();
423
424        assert!(samples
425            .orderings
426            .iter()
427            .all(|x| x.core_distance.is_none() && x.reachability_distance.is_none()));
428    }
429
430    #[test]
431    fn invalid_params() {
432        let params = Optics::params(1);
433        let data = vec![1.0, 2.0, 3.0];
434        let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
435        assert!(params.transform(data.view()).is_err());
436
437        let params = Optics::params(2);
438        assert!(params.transform(data.view()).is_ok());
439
440        let params = params.tolerance(0.0);
441        assert!(params.transform(data.view()).is_err());
442    }
443
444    #[test]
445    fn find_neighbors_test() {
446        let data = vec![1.0, 2.0, 10.0, 15.0, 13.0];
447        let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
448
449        let param = Optics::params(3).tolerance(6.0).check_unwrap();
450        let nn = CommonNearestNeighbour::KdTree
451            .from_batch(&data, L2Dist)
452            .unwrap();
453
454        let neighbors = param.find_neighbors(&*nn, data.row(0));
455        assert_eq!(neighbors.len(), 2);
456        assert_eq!(
457            vec![0, 1],
458            neighbors
459                .iter()
460                .map(|x| x.reachability_distance.unwrap() as u32)
461                .collect::<Vec<u32>>()
462        );
463        assert!(neighbors.iter().all(|x| x.core_distance.is_none()));
464
465        let neighbors = param.find_neighbors(&*nn, data.row(4));
466        assert_eq!(neighbors.len(), 3);
467        assert!(neighbors.iter().all(|x| x.core_distance.is_none()));
468        assert_eq!(
469            vec![0, 2, 3],
470            neighbors
471                .iter()
472                .map(|x| x.reachability_distance.unwrap() as u32)
473                .collect::<Vec<u32>>()
474        );
475    }
476
477    #[test]
478    fn get_seeds_test() {
479        let data = vec![1.0, 2.0, 10.0, 15.0, 13.0];
480        let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
481
482        let param = Optics::params(3).tolerance(6.0).check_unwrap();
483        let nn = CommonNearestNeighbour::KdTree
484            .from_batch(&data, L2Dist)
485            .unwrap();
486
487        let mut points = (0..data.nrows()).map(Sample::new).collect::<Vec<_>>();
488
489        let neighbors = param.find_neighbors(&*nn, data.row(0));
490        // set core distance and make sure it's set correctly given number of neghobrs restriction
491
492        param.set_core_distance(&mut points[0], &neighbors, data.view());
493        assert!(points[0].core_distance.is_none());
494
495        let neighbors = param.find_neighbors(&*nn, data.row(4));
496        param.set_core_distance(&mut points[4], &neighbors, data.view());
497        dbg!(&points);
498        assert!(points[4].core_distance.is_some());
499
500        let mut seeds = vec![];
501        let mut processed = BTreeSet::new();
502        // With a valid core distance make sure neighbours to point are returned in order if
503        // unprocessed
504
505        param.get_seeds(
506            data.view(),
507            points[4].clone(),
508            &neighbors,
509            &mut points,
510            &processed,
511            &mut seeds,
512        );
513
514        assert_eq!(seeds, vec![4, 3, 2]);
515
516        let mut points = (0..data.nrows()).map(Sample::new).collect::<Vec<_>>();
517
518        // if one of the neighbours has been processed make sure it's not in the seed list
519
520        param.set_core_distance(&mut points[4], &neighbors, data.view());
521        processed.insert(3);
522        seeds.clear();
523
524        param.get_seeds(
525            data.view(),
526            points[4].clone(),
527            &neighbors,
528            &mut points,
529            &processed,
530            &mut seeds,
531        );
532
533        assert_eq!(seeds, vec![4, 2]);
534
535        let mut points = (0..data.nrows()).map(Sample::new).collect::<Vec<_>>();
536
537        // If one of the neighbours has a smaller R distance than it has to the core point make
538        // sure it's not added to the seed list
539
540        processed.clear();
541        param.set_core_distance(&mut points[4], &neighbors, data.view());
542        points[2].reachability_distance = Some(0.001);
543        seeds.clear();
544
545        param.get_seeds(
546            data.view(),
547            points[4].clone(),
548            &neighbors,
549            &mut points,
550            &processed,
551            &mut seeds,
552        );
553
554        assert_eq!(seeds, vec![4, 3]);
555    }
556}