smartcore/algorithm/neighbour/
fastpair.rs

1///
2/// ### FastPair: Data-structure for the dynamic closest-pair problem.
3///
4/// Reference:
5///  Eppstein, David: Fast hierarchical clustering and other applications of
6///  dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
7///
8/// Example:
9/// ```
10/// use smartcore::metrics::distance::PairwiseDistance;
11/// use smartcore::linalg::basic::matrix::DenseMatrix;
12/// use smartcore::algorithm::neighbour::fastpair::FastPair;
13/// let x = DenseMatrix::<f64>::from_2d_array(&[
14///     &[5.1, 3.5, 1.4, 0.2],
15///     &[4.9, 3.0, 1.4, 0.2],
16///     &[4.7, 3.2, 1.3, 0.2],
17///     &[4.6, 3.1, 1.5, 0.2],
18///     &[5.0, 3.6, 1.4, 0.2],
19///     &[5.4, 3.9, 1.7, 0.4],
20/// ]).unwrap();
21/// let fastpair = FastPair::new(&x);
22/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
23/// ```
24/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
25/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
26use std::collections::HashMap;
27
28use num::Bounded;
29
30use crate::error::{Failed, FailedError};
31use crate::linalg::basic::arrays::{Array1, Array2};
32use crate::metrics::distance::euclidian::Euclidian;
33use crate::metrics::distance::PairwiseDistance;
34use crate::numbers::floatnum::FloatNumber;
35use crate::numbers::realnum::RealNumber;
36
37///
38/// Inspired by Python implementation:
39/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
40/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
41///
42/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
43///
44#[derive(Debug, Clone)]
45pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
46    /// initial matrix
47    samples: &'a M,
48    /// closest pair hashmap (connectivity matrix for closest pairs)
49    pub distances: HashMap<usize, PairwiseDistance<T>>,
50    /// conga line used to keep track of the closest pair
51    pub neighbours: Vec<usize>,
52}
53
54impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
55    /// Constructor
56    /// Instantiate and initialize the algorithm
57    pub fn new(m: &'a M) -> Result<Self, Failed> {
58        if m.shape().0 < 3 {
59            return Err(Failed::because(
60                FailedError::FindFailed,
61                "min number of rows should be 3",
62            ));
63        }
64
65        let mut init = Self {
66            samples: m,
67            // to be computed in init(..)
68            distances: HashMap::with_capacity(m.shape().0),
69            neighbours: Vec::with_capacity(m.shape().0 + 1),
70        };
71        init.init();
72        Ok(init)
73    }
74
75    /// Initialise `FastPair` by passing a `Array2`.
76    /// Build a FastPairs data-structure from a set of (new) points.
77    fn init(&mut self) {
78        // basic measures
79        let len = self.samples.shape().0;
80        let max_index = self.samples.shape().0 - 1;
81
82        // Store all closest neighbors
83        let _distances = Box::new(HashMap::with_capacity(len));
84        let _neighbours = Box::new(Vec::with_capacity(len));
85
86        let mut distances = *_distances;
87        let mut neighbours = *_neighbours;
88
89        // fill neighbours with -1 values
90        neighbours.extend(0..len);
91
92        // init closest neighbour pairwise data
93        for index_row_i in 0..(max_index) {
94            distances.insert(
95                index_row_i,
96                PairwiseDistance {
97                    node: index_row_i,
98                    neighbour: Option::None,
99                    distance: Some(<T as Bounded>::max_value()),
100                },
101            );
102        }
103
104        // loop through indeces and neighbours
105        for index_row_i in 0..(len) {
106            // start looking for the neighbour in the second element
107            let mut index_closest = index_row_i + 1; // closest neighbour index
108            let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
109            for index_row_j in (index_row_i + 1)..len {
110                distances.insert(
111                    index_row_j,
112                    PairwiseDistance {
113                        node: index_row_j,
114                        neighbour: Some(index_row_i),
115                        distance: nbd,
116                    },
117                );
118
119                let d = Euclidian::squared_distance(
120                    &Vec::from_iterator(
121                        self.samples.get_row(index_row_i).iterator(0).copied(),
122                        self.samples.shape().1,
123                    ),
124                    &Vec::from_iterator(
125                        self.samples.get_row(index_row_j).iterator(0).copied(),
126                        self.samples.shape().1,
127                    ),
128                );
129                if d < nbd.unwrap().to_f64().unwrap() {
130                    // set this j-value to be the closest neighbour
131                    index_closest = index_row_j;
132                    nbd = Some(T::from(d).unwrap());
133                }
134            }
135
136            // Add that edge
137            distances.entry(index_row_i).and_modify(|e| {
138                e.distance = nbd;
139                e.neighbour = Some(index_closest);
140            });
141        }
142        // No more neighbors, terminate conga line.
143        // Last person on the line has no neigbors
144        distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
145        distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
146
147        // compute sparse matrix (connectivity matrix)
148        let mut sparse_matrix = M::zeros(len, len);
149        for (_, p) in distances.iter() {
150            sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
151        }
152
153        self.distances = distances;
154        self.neighbours = neighbours;
155    }
156
157    /// Find closest pair by scanning list of nearest neighbors.
158    #[allow(dead_code)]
159    pub fn closest_pair(&self) -> PairwiseDistance<T> {
160        let mut a = self.neighbours[0]; // Start with first point
161        let mut d = self.distances[&a].distance;
162        for p in self.neighbours.iter() {
163            if self.distances[p].distance < d {
164                a = *p; // Update `a` and distance `d`
165                d = self.distances[p].distance;
166            }
167        }
168        let b = self.distances[&a].neighbour;
169        PairwiseDistance {
170            node: a,
171            neighbour: b,
172            distance: d,
173        }
174    }
175
176    ///
177    /// Return order dissimilarities from closest to furthest
178    ///
179    #[allow(dead_code)]
180    pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
181        // improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
182        // need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
183        let mut distances = self
184            .distances
185            .values()
186            .collect::<Vec<&PairwiseDistance<T>>>();
187        distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
188        distances.into_iter()
189    }
190
191    //
192    // Compute distances from input to all other points in data-structure.
193    // input is the row index of the sample matrix
194    //
195    #[allow(dead_code)]
196    fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
197        let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
198        for other in self.neighbours.iter() {
199            if index_row != *other {
200                distances.push(PairwiseDistance {
201                    node: index_row,
202                    neighbour: Some(*other),
203                    distance: Some(
204                        T::from(Euclidian::squared_distance(
205                            &Vec::from_iterator(
206                                self.samples.get_row(index_row).iterator(0).copied(),
207                                self.samples.shape().1,
208                            ),
209                            &Vec::from_iterator(
210                                self.samples.get_row(*other).iterator(0).copied(),
211                                self.samples.shape().1,
212                            ),
213                        ))
214                        .unwrap(),
215                    ),
216                })
217            }
218        }
219        distances
220    }
221}
222
223#[cfg(test)]
224mod tests_fastpair {
225
226    use super::*;
227    use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
228
229    /// Brute force algorithm, used only for comparison and testing
230    pub fn closest_pair_brute(
231        fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
232    ) -> PairwiseDistance<f64> {
233        use itertools::Itertools;
234        let m = fastpair.samples.shape().0;
235
236        let mut closest_pair = PairwiseDistance {
237            node: 0,
238            neighbour: Option::None,
239            distance: Some(f64::max_value()),
240        };
241        for pair in (0..m).combinations(2) {
242            let d = Euclidian::squared_distance(
243                &Vec::from_iterator(
244                    fastpair.samples.get_row(pair[0]).iterator(0).copied(),
245                    fastpair.samples.shape().1,
246                ),
247                &Vec::from_iterator(
248                    fastpair.samples.get_row(pair[1]).iterator(0).copied(),
249                    fastpair.samples.shape().1,
250                ),
251            );
252            if d < closest_pair.distance.unwrap() {
253                closest_pair.node = pair[0];
254                closest_pair.neighbour = Some(pair[1]);
255                closest_pair.distance = Some(d);
256            }
257        }
258        closest_pair
259    }
260
261    #[test]
262    fn fastpair_init() {
263        let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
264        let _fastpair = FastPair::new(&x);
265        assert!(_fastpair.is_ok());
266
267        let fastpair = _fastpair.unwrap();
268
269        let distances = fastpair.distances;
270        let neighbours = fastpair.neighbours;
271
272        assert!(!distances.is_empty());
273        assert!(!neighbours.is_empty());
274
275        assert_eq!(10, neighbours.len());
276        assert_eq!(10, distances.len());
277    }
278
279    #[test]
280    fn dataset_has_at_least_three_points() {
281        // Create a dataset which consists of only two points:
282        // A(0.0, 0.0) and B(1.0, 1.0).
283        let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap();
284
285        // We expect an error when we run `FastPair` on this dataset,
286        // becuase `FastPair` currently only works on a minimum of 3
287        // points.
288        let fastpair = FastPair::new(&dataset);
289        assert!(fastpair.is_err());
290
291        if let Err(e) = fastpair {
292            let expected_error =
293                Failed::because(FailedError::FindFailed, "min number of rows should be 3");
294            assert_eq!(e, expected_error)
295        }
296    }
297
298    #[test]
299    fn one_dimensional_dataset_minimal() {
300        let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap();
301
302        let result = FastPair::new(&dataset);
303        assert!(result.is_ok());
304
305        let fastpair = result.unwrap();
306        let closest_pair = fastpair.closest_pair();
307        let expected_closest_pair = PairwiseDistance {
308            node: 0,
309            neighbour: Some(1),
310            distance: Some(4.0),
311        };
312        assert_eq!(closest_pair, expected_closest_pair);
313
314        let closest_pair_brute = closest_pair_brute(&fastpair);
315        assert_eq!(closest_pair_brute, expected_closest_pair);
316    }
317
318    #[test]
319    fn one_dimensional_dataset_2() {
320        let dataset =
321            DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
322
323        let result = FastPair::new(&dataset);
324        assert!(result.is_ok());
325
326        let fastpair = result.unwrap();
327        let closest_pair = fastpair.closest_pair();
328        let expected_closest_pair = PairwiseDistance {
329            node: 1,
330            neighbour: Some(3),
331            distance: Some(4.0),
332        };
333        assert_eq!(closest_pair, closest_pair_brute(&fastpair));
334        assert_eq!(closest_pair, expected_closest_pair);
335    }
336
337    #[test]
338    fn fastpair_new() {
339        // compute
340        let x = DenseMatrix::<f64>::from_2d_array(&[
341            &[5.1, 3.5, 1.4, 0.2],
342            &[4.9, 3.0, 1.4, 0.2],
343            &[4.7, 3.2, 1.3, 0.2],
344            &[4.6, 3.1, 1.5, 0.2],
345            &[5.0, 3.6, 1.4, 0.2],
346            &[5.4, 3.9, 1.7, 0.4],
347            &[4.6, 3.4, 1.4, 0.3],
348            &[5.0, 3.4, 1.5, 0.2],
349            &[4.4, 2.9, 1.4, 0.2],
350            &[4.9, 3.1, 1.5, 0.1],
351            &[7.0, 3.2, 4.7, 1.4],
352            &[6.4, 3.2, 4.5, 1.5],
353            &[6.9, 3.1, 4.9, 1.5],
354            &[5.5, 2.3, 4.0, 1.3],
355            &[6.5, 2.8, 4.6, 1.5],
356        ])
357        .unwrap();
358        let fastpair = FastPair::new(&x);
359        assert!(fastpair.is_ok());
360
361        // unwrap results
362        let result = fastpair.unwrap();
363
364        // list of minimal pairwise dissimilarities
365        let dissimilarities = vec![
366            (
367                1,
368                PairwiseDistance {
369                    node: 1,
370                    neighbour: Some(9),
371                    distance: Some(0.030000000000000037),
372                },
373            ),
374            (
375                10,
376                PairwiseDistance {
377                    node: 10,
378                    neighbour: Some(12),
379                    distance: Some(0.07000000000000003),
380                },
381            ),
382            (
383                11,
384                PairwiseDistance {
385                    node: 11,
386                    neighbour: Some(14),
387                    distance: Some(0.18000000000000013),
388                },
389            ),
390            (
391                12,
392                PairwiseDistance {
393                    node: 12,
394                    neighbour: Some(14),
395                    distance: Some(0.34000000000000086),
396                },
397            ),
398            (
399                13,
400                PairwiseDistance {
401                    node: 13,
402                    neighbour: Some(14),
403                    distance: Some(1.6499999999999997),
404                },
405            ),
406            (
407                14,
408                PairwiseDistance {
409                    node: 14,
410                    neighbour: Some(14),
411                    distance: Some(f64::MAX),
412                },
413            ),
414            (
415                6,
416                PairwiseDistance {
417                    node: 6,
418                    neighbour: Some(7),
419                    distance: Some(0.18000000000000027),
420                },
421            ),
422            (
423                0,
424                PairwiseDistance {
425                    node: 0,
426                    neighbour: Some(4),
427                    distance: Some(0.01999999999999995),
428                },
429            ),
430            (
431                8,
432                PairwiseDistance {
433                    node: 8,
434                    neighbour: Some(9),
435                    distance: Some(0.3100000000000001),
436                },
437            ),
438            (
439                2,
440                PairwiseDistance {
441                    node: 2,
442                    neighbour: Some(3),
443                    distance: Some(0.0600000000000001),
444                },
445            ),
446            (
447                3,
448                PairwiseDistance {
449                    node: 3,
450                    neighbour: Some(8),
451                    distance: Some(0.08999999999999982),
452                },
453            ),
454            (
455                7,
456                PairwiseDistance {
457                    node: 7,
458                    neighbour: Some(9),
459                    distance: Some(0.10999999999999982),
460                },
461            ),
462            (
463                9,
464                PairwiseDistance {
465                    node: 9,
466                    neighbour: Some(13),
467                    distance: Some(8.69),
468                },
469            ),
470            (
471                4,
472                PairwiseDistance {
473                    node: 4,
474                    neighbour: Some(7),
475                    distance: Some(0.050000000000000086),
476                },
477            ),
478            (
479                5,
480                PairwiseDistance {
481                    node: 5,
482                    neighbour: Some(7),
483                    distance: Some(0.4900000000000002),
484                },
485            ),
486        ];
487
488        let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
489
490        for i in 0..(x.shape().0 - 1) {
491            let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
492            let distance = Euclidian::squared_distance(
493                &Vec::from_iterator(
494                    result.samples.get_row(i).iterator(0).copied(),
495                    result.samples.shape().1,
496                ),
497                &Vec::from_iterator(
498                    result.samples.get_row(input_neighbour).iterator(0).copied(),
499                    result.samples.shape().1,
500                ),
501            );
502
503            assert_eq!(i, expected.get(&i).unwrap().node);
504            assert_eq!(
505                input_neighbour,
506                expected.get(&i).unwrap().neighbour.unwrap()
507            );
508            assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
509        }
510    }
511
512    #[test]
513    fn fastpair_closest_pair() {
514        let x = DenseMatrix::<f64>::from_2d_array(&[
515            &[5.1, 3.5, 1.4, 0.2],
516            &[4.9, 3.0, 1.4, 0.2],
517            &[4.7, 3.2, 1.3, 0.2],
518            &[4.6, 3.1, 1.5, 0.2],
519            &[5.0, 3.6, 1.4, 0.2],
520            &[5.4, 3.9, 1.7, 0.4],
521            &[4.6, 3.4, 1.4, 0.3],
522            &[5.0, 3.4, 1.5, 0.2],
523            &[4.4, 2.9, 1.4, 0.2],
524            &[4.9, 3.1, 1.5, 0.1],
525            &[7.0, 3.2, 4.7, 1.4],
526            &[6.4, 3.2, 4.5, 1.5],
527            &[6.9, 3.1, 4.9, 1.5],
528            &[5.5, 2.3, 4.0, 1.3],
529            &[6.5, 2.8, 4.6, 1.5],
530        ])
531        .unwrap();
532        // compute
533        let fastpair = FastPair::new(&x);
534        assert!(fastpair.is_ok());
535
536        let dissimilarity = fastpair.unwrap().closest_pair();
537        let closest = PairwiseDistance {
538            node: 0,
539            neighbour: Some(4),
540            distance: Some(0.01999999999999995),
541        };
542
543        assert_eq!(closest, dissimilarity);
544    }
545
546    #[test]
547    fn fastpair_closest_pair_random_matrix() {
548        let x = DenseMatrix::<f64>::rand(200, 25);
549        // compute
550        let fastpair = FastPair::new(&x);
551        assert!(fastpair.is_ok());
552
553        let result = fastpair.unwrap();
554
555        let dissimilarity1 = result.closest_pair();
556        let dissimilarity2 = closest_pair_brute(&result);
557
558        assert_eq!(dissimilarity1, dissimilarity2);
559    }
560
561    #[test]
562    fn fastpair_distances() {
563        let x = DenseMatrix::<f64>::from_2d_array(&[
564            &[5.1, 3.5, 1.4, 0.2],
565            &[4.9, 3.0, 1.4, 0.2],
566            &[4.7, 3.2, 1.3, 0.2],
567            &[4.6, 3.1, 1.5, 0.2],
568            &[5.0, 3.6, 1.4, 0.2],
569            &[5.4, 3.9, 1.7, 0.4],
570            &[4.6, 3.4, 1.4, 0.3],
571            &[5.0, 3.4, 1.5, 0.2],
572            &[4.4, 2.9, 1.4, 0.2],
573            &[4.9, 3.1, 1.5, 0.1],
574            &[7.0, 3.2, 4.7, 1.4],
575            &[6.4, 3.2, 4.5, 1.5],
576            &[6.9, 3.1, 4.9, 1.5],
577            &[5.5, 2.3, 4.0, 1.3],
578            &[6.5, 2.8, 4.6, 1.5],
579        ])
580        .unwrap();
581        // compute
582        let fastpair = FastPair::new(&x);
583        assert!(fastpair.is_ok());
584
585        let dissimilarities = fastpair.unwrap().distances_from(0);
586
587        let mut min_dissimilarity = PairwiseDistance {
588            node: 0,
589            neighbour: Option::None,
590            distance: Some(f64::MAX),
591        };
592        for p in dissimilarities.iter() {
593            if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
594                min_dissimilarity = *p
595            }
596        }
597
598        let closest = PairwiseDistance {
599            node: 0,
600            neighbour: Some(4),
601            distance: Some(0.01999999999999995),
602        };
603
604        assert_eq!(closest, min_dissimilarity);
605    }
606
607    #[test]
608    fn fastpair_ordered_pairs() {
609        let x = DenseMatrix::<f64>::from_2d_array(&[
610            &[5.1, 3.5, 1.4, 0.2],
611            &[4.9, 3.0, 1.4, 0.2],
612            &[4.7, 3.2, 1.3, 0.2],
613            &[4.6, 3.1, 1.5, 0.2],
614            &[5.0, 3.6, 1.4, 0.2],
615            &[5.4, 3.9, 1.7, 0.4],
616            &[4.9, 3.1, 1.5, 0.1],
617            &[7.0, 3.2, 4.7, 1.4],
618            &[6.4, 3.2, 4.5, 1.5],
619            &[6.9, 3.1, 4.9, 1.5],
620            &[5.5, 2.3, 4.0, 1.3],
621            &[6.5, 2.8, 4.6, 1.5],
622            &[4.6, 3.4, 1.4, 0.3],
623            &[5.0, 3.4, 1.5, 0.2],
624            &[4.4, 2.9, 1.4, 0.2],
625        ])
626        .unwrap();
627        let fastpair = FastPair::new(&x).unwrap();
628
629        let ordered = fastpair.ordered_pairs();
630
631        let mut previous: f64 = -1.0;
632        for p in ordered {
633            if previous == -1.0 {
634                previous = p.distance.unwrap();
635            } else {
636                let current = p.distance.unwrap();
637                assert!(current >= previous);
638                previous = current;
639            }
640        }
641    }
642
643    #[test]
644    fn test_empty_set() {
645        let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
646        let result = FastPair::new(&empty_matrix);
647        assert!(result.is_err());
648        if let Err(e) = result {
649            assert_eq!(
650                e,
651                Failed::because(FailedError::FindFailed, "min number of rows should be 3")
652            );
653        }
654    }
655
656    #[test]
657    fn test_single_point() {
658        let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
659        let result = FastPair::new(&single_point);
660        assert!(result.is_err());
661        if let Err(e) = result {
662            assert_eq!(
663                e,
664                Failed::because(FailedError::FindFailed, "min number of rows should be 3")
665            );
666        }
667    }
668
669    #[test]
670    fn test_two_points() {
671        let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
672        let result = FastPair::new(&two_points);
673        assert!(result.is_err());
674        if let Err(e) = result {
675            assert_eq!(
676                e,
677                Failed::because(FailedError::FindFailed, "min number of rows should be 3")
678            );
679        }
680    }
681
682    #[test]
683    fn test_three_identical_points() {
684        let identical_points =
685            DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
686        let result = FastPair::new(&identical_points);
687        assert!(result.is_ok());
688        let fastpair = result.unwrap();
689        let closest_pair = fastpair.closest_pair();
690        assert_eq!(closest_pair.distance, Some(0.0));
691    }
692
693    #[test]
694    fn test_result_unwrapping() {
695        let valid_matrix =
696            DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
697                .unwrap();
698
699        let result = FastPair::new(&valid_matrix);
700        assert!(result.is_ok());
701
702        // This should not panic
703        let _fastpair = result.unwrap();
704    }
705}