1use num_traits::Float;
2use std::cmp::Ordering;
3use std::collections::HashSet;
4
5#[derive(Debug, PartialEq)]
7pub enum Center {
8 Centroid,
11 GeoCentroid,
15 Medoid,
19}
20
21impl Center {
22 pub(crate) fn calc_centers<T: Float, F: Fn(&[T], &[T]) -> T>(
23 &self,
24 data: &[Vec<T>],
25 labels: &[i32],
26 dist_func: F,
27 ) -> Vec<Vec<T>> {
28 match self {
29 Center::Centroid => self.calc_centroids(data, labels),
30 Center::GeoCentroid => self.calc_geo_centroids(data, labels),
31 Center::Medoid => self.calc_medoids(data, labels, dist_func),
32 }
33 }
34
35 fn calc_centroids<T: Float>(&self, data: &[Vec<T>], labels: &[i32]) -> Vec<Vec<T>> {
36 let weights = vec![T::one(); data.len()];
38 Center::calc_weighted_centroids(data, labels, &weights)
39 }
40
41 fn calc_weighted_centroids<T: Float>(
42 data: &[Vec<T>],
43 labels: &[i32],
44 weights: &[T],
45 ) -> Vec<Vec<T>> {
46 let n_dims = data[0].len();
47 let n_clusters = labels
48 .iter()
49 .filter(|&&label| label != -1)
50 .collect::<HashSet<_>>()
51 .len();
52
53 let mut centroids = Vec::with_capacity(n_clusters);
54 for cluster_id in 0..n_clusters as i32 {
55 let mut count = T::zero();
56 let mut element_wise_mean = vec![T::zero(); n_dims];
57 for n in 0..data.len() {
58 if cluster_id == labels[n] {
59 count = count + T::one();
60 element_wise_mean = data[n]
61 .iter()
62 .zip(element_wise_mean.iter())
63 .map(|(&element, &sum)| (element * weights[n]) + sum)
64 .collect();
65 }
66 }
67 for element in element_wise_mean.iter_mut() {
68 *element = *element / count;
69 }
70 centroids.push(element_wise_mean);
71 }
72 centroids
73 }
74
75 fn calc_geo_centroids<T: Float>(&self, data: &[Vec<T>], labels: &[i32]) -> Vec<Vec<T>> {
96 let n_clusters = labels
97 .iter()
98 .filter(|&&label| label != -1)
99 .collect::<HashSet<_>>()
100 .len();
101 let mut centers = vec![vec![T::zero(), T::zero(), T::zero()]; n_clusters];
102 let mut counts = vec![T::zero(); n_clusters];
103
104 for (point, &label) in data.iter().zip(labels.iter()) {
105 if label != -1 {
106 let cluster_index = label as usize;
107
108 let lat = point[0].to_radians();
109 let lon = point[1].to_radians();
110
111 let x = lon.cos() * lat.cos();
112 let y = lon.sin() * lat.cos();
113 let z = lat.sin();
114
115 centers[cluster_index][0] = centers[cluster_index][0] + x;
116 centers[cluster_index][1] = centers[cluster_index][1] + y;
117 centers[cluster_index][2] = centers[cluster_index][2] + z;
118 counts[cluster_index] = counts[cluster_index] + T::one();
119 }
120 }
121
122 for (center, &count) in centers.iter_mut().zip(counts.iter()) {
123 if count > T::zero() {
124 let x = center[0] / count;
125 let y = center[1] / count;
126 let z = center[2] / count;
127
128 let lon = y.atan2(x);
129 let hyp = (x * x + y * y).sqrt();
130 let lat = z.atan2(hyp);
131
132 center[0] = lat.to_degrees();
134 center[1] = lon.to_degrees();
135 }
136 }
137
138 centers.iter().map(|c| vec![c[0], c[1]]).collect()
139 }
140
141 fn calc_medoids<T: Float, F: Fn(&[T], &[T]) -> T>(
142 &self,
143 data: &[Vec<T>],
144 labels: &[i32],
145 dist_func: F,
146 ) -> Vec<Vec<T>> {
147 let n_clusters = labels
148 .iter()
149 .filter(|&&label| label != -1)
150 .collect::<HashSet<_>>()
151 .len();
152 let mut medoids = Vec::with_capacity(n_clusters);
153
154 for cluster_id in 0..n_clusters as i32 {
155 let cluster_data = data
156 .iter()
157 .zip(labels.iter())
158 .filter(|(_datapoint, &label)| label == cluster_id)
159 .map(|(datapoint, _label)| datapoint)
160 .collect::<Vec<&Vec<_>>>();
161
162 let n_samples = cluster_data.len();
163 let medoid_idx = (0..n_samples)
164 .map(|i| {
165 (0..n_samples)
166 .map(|j| dist_func(cluster_data[i], cluster_data[j]))
167 .fold(T::zero(), std::ops::Add::add)
168 })
169 .enumerate()
170 .min_by(|(_idx_a, sum_a), (_idx_b, sum_b)| {
171 sum_a.partial_cmp(sum_b).unwrap_or(Ordering::Equal)
172 })
173 .map(|(idx, _sum)| idx)
174 .unwrap_or(0);
175
176 medoids.push(cluster_data[medoid_idx].clone())
177 }
178
179 medoids
180 }
181}