radiate_core/domain/math/
centroid.rs1use crate::diversity::Distance;
2use crate::math::knn::KNN;
3use std::sync::Arc;
4
5pub struct CentroidClusterer<P> {
9 centroids: Vec<P>,
10 metric: Arc<dyn Distance<P>>,
11}
12
13impl<P> CentroidClusterer<P> {
14 pub fn new(metric: Arc<dyn Distance<P>>) -> Self {
15 CentroidClusterer {
16 centroids: Vec::new(),
17 metric,
18 }
19 }
20
21 pub fn with_centroids(mut self, centroids: Vec<P>) -> Self {
22 self.centroids = centroids;
23 self
24 }
25
26 pub fn centroids(&self) -> &[P] {
28 &self.centroids
29 }
30
31 pub fn len(&self) -> usize {
33 self.centroids.len()
34 }
35
36 pub fn is_empty(&self) -> bool {
37 self.centroids.is_empty()
38 }
39}
40
41impl<P: Clone> CentroidClusterer<P> {
42 pub fn assign_or_create(&mut self, point: &P, threshold: Option<f32>) -> (usize, f32) {
47 if self.centroids.is_empty() {
48 self.centroids.push(point.clone());
49 return (0, 0.0);
50 }
51
52 let mut knn = KNN::new(&self.centroids, Arc::clone(&self.metric));
53 let result = knn.query_point(point, 1);
54
55 if let Some(&(idx, dist)) = result.cluster.first() {
56 if let Some(threshold) = threshold {
57 if dist <= threshold {
58 (idx, dist)
59 } else {
60 let new_idx = self.centroids.len();
61 self.centroids.push(point.clone());
62 (new_idx, dist)
63 }
64 } else {
65 (idx, dist)
66 }
67 } else {
68 let new_idx = self.centroids.len();
69 self.centroids.push(point.clone());
70 (new_idx, 0.0)
71 }
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::EuclideanDistance;
79 use std::sync::Arc;
80
81 #[test]
82 fn test_centroid_clusterer() {
83 let metric = EuclideanDistance;
84 let mut clusterer = CentroidClusterer::new(Arc::new(metric));
85
86 let points = vec![
87 vec![1.0, 2.0],
88 vec![1.5, 1.8],
89 vec![5.0, 8.0],
90 vec![8.0, 8.0],
91 vec![1.0, 0.6],
92 vec![9.0, 11.0],
93 ];
94
95 let threshold = 3.0;
96 let mut assignments = Vec::new();
97 for point in &points {
98 let (idx, dist) = clusterer.assign_or_create(point, Some(threshold));
99 assignments.push((idx, dist));
100 }
101
102 assert_eq!(clusterer.len(), 3);
103 assert_eq!(assignments[0].0, assignments[1].0); assert_eq!(assignments[0].0, assignments[4].0); assert_ne!(assignments[0].0, assignments[2].0); assert_eq!(assignments[2].0, assignments[3].0); }
108}