Skip to main content

radiate_core/domain/math/
centroid.rs

1use crate::diversity::Distance;
2use crate::math::knn::KNN;
3use std::sync::Arc;
4
5/// Simple online clustering over a set of centroids.
6///
7/// - `P`: point type (e.g. Vec<f32>, Genotype<C>, etc.)
8pub struct CentroidClusterer<P> {
9    centroids: Vec<P>,
10    metric: Arc<dyn Distance<P>>,
11}
12
13impl<P> CentroidClusterer<P> {
14    pub fn new(metric: Arc<dyn Distance<P>>) -> Self {
15        CentroidClusterer {
16            centroids: Vec::new(),
17            metric,
18        }
19    }
20
21    pub fn with_centroids(mut self, centroids: Vec<P>) -> Self {
22        self.centroids = centroids;
23        self
24    }
25
26    /// Returns a slice of current centroids.
27    pub fn centroids(&self) -> &[P] {
28        &self.centroids
29    }
30
31    /// Returns the number of centroids.
32    pub fn len(&self) -> usize {
33        self.centroids.len()
34    }
35
36    pub fn is_empty(&self) -> bool {
37        self.centroids.is_empty()
38    }
39}
40
41impl<P: Clone> CentroidClusterer<P> {
42    /// Assign a point to the nearest centroid if within `threshold`.
43    /// Otherwise, create a new centroid with this point as its center.
44    ///
45    /// Returns the index of the assigned/created centroid and the distance.
46    pub fn assign_or_create(&mut self, point: &P, threshold: Option<f32>) -> (usize, f32) {
47        if self.centroids.is_empty() {
48            self.centroids.push(point.clone());
49            return (0, 0.0);
50        }
51
52        let mut knn = KNN::new(&self.centroids, Arc::clone(&self.metric));
53        let result = knn.query_point(point, 1);
54
55        if let Some(&(idx, dist)) = result.cluster.first() {
56            if let Some(threshold) = threshold {
57                if dist <= threshold {
58                    (idx, dist)
59                } else {
60                    let new_idx = self.centroids.len();
61                    self.centroids.push(point.clone());
62                    (new_idx, dist)
63                }
64            } else {
65                (idx, dist)
66            }
67        } else {
68            let new_idx = self.centroids.len();
69            self.centroids.push(point.clone());
70            (new_idx, 0.0)
71        }
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::EuclideanDistance;
79    use std::sync::Arc;
80
81    #[test]
82    fn test_centroid_clusterer() {
83        let metric = EuclideanDistance;
84        let mut clusterer = CentroidClusterer::new(Arc::new(metric));
85
86        let points = vec![
87            vec![1.0, 2.0],
88            vec![1.5, 1.8],
89            vec![5.0, 8.0],
90            vec![8.0, 8.0],
91            vec![1.0, 0.6],
92            vec![9.0, 11.0],
93        ];
94
95        let threshold = 3.0;
96        let mut assignments = Vec::new();
97        for point in &points {
98            let (idx, dist) = clusterer.assign_or_create(point, Some(threshold));
99            assignments.push((idx, dist));
100        }
101
102        assert_eq!(clusterer.len(), 3);
103        assert_eq!(assignments[0].0, assignments[1].0); // first two points in same cluster
104        assert_eq!(assignments[0].0, assignments[4].0); // first and fifth points in same cluster
105        assert_ne!(assignments[0].0, assignments[2].0); // third point in different cluster
106        assert_eq!(assignments[2].0, assignments[3].0); // fourth point in different cluster
107    }
108}