hdbscan/
centers.rs

1use num_traits::Float;
2use std::cmp::Ordering;
3use std::collections::HashSet;
4
5/// Possible methodologies for calculating the center of clusters
6#[derive(Debug, PartialEq)]
7pub enum Center {
8    /// The elementwise mean of all data points in a cluster.
9    /// The output is not guaranteed to be an observed data point.
10    Centroid,
11    /// Calculates the geographical centroid for lat/lon coordinates.
12    /// Assumes input coordinates are in degrees (latitude, longitude).
13    /// Output coordinates are also in degrees.
14    GeoCentroid,
15    /// The point in a cluster with the minimum distance to all other points. Computationally more
16    /// expensive than centroids as requires calculation of pairwise distances (using the selected
17    /// distance metric). The output will be an observed data point in the cluster.
18    Medoid,
19}
20
21impl Center {
22    pub(crate) fn calc_centers<T: Float, F: Fn(&[T], &[T]) -> T>(
23        &self,
24        data: &[Vec<T>],
25        labels: &[i32],
26        dist_func: F,
27    ) -> Vec<Vec<T>> {
28        match self {
29            Center::Centroid => self.calc_centroids(data, labels),
30            Center::GeoCentroid => self.calc_geo_centroids(data, labels),
31            Center::Medoid => self.calc_medoids(data, labels, dist_func),
32        }
33    }
34
35    fn calc_centroids<T: Float>(&self, data: &[Vec<T>], labels: &[i32]) -> Vec<Vec<T>> {
36        // All points weighted equally for now
37        let weights = vec![T::one(); data.len()];
38        Center::calc_weighted_centroids(data, labels, &weights)
39    }
40
41    fn calc_weighted_centroids<T: Float>(
42        data: &[Vec<T>],
43        labels: &[i32],
44        weights: &[T],
45    ) -> Vec<Vec<T>> {
46        let n_dims = data[0].len();
47        let n_clusters = labels
48            .iter()
49            .filter(|&&label| label != -1)
50            .collect::<HashSet<_>>()
51            .len();
52
53        let mut centroids = Vec::with_capacity(n_clusters);
54        for cluster_id in 0..n_clusters as i32 {
55            let mut count = T::zero();
56            let mut element_wise_mean = vec![T::zero(); n_dims];
57            for n in 0..data.len() {
58                if cluster_id == labels[n] {
59                    count = count + T::one();
60                    element_wise_mean = data[n]
61                        .iter()
62                        .zip(element_wise_mean.iter())
63                        .map(|(&element, &sum)| (element * weights[n]) + sum)
64                        .collect();
65                }
66            }
67            for element in element_wise_mean.iter_mut() {
68                *element = *element / count;
69            }
70            centroids.push(element_wise_mean);
71        }
72        centroids
73    }
74
75    /// Calculates the geographical centeroid for each cluster.
76    ///
77    /// This method is specifically designed for geographical data where each point
78    /// is represented by latitude and longitude coordinates.
79    ///
80    /// # Arguments
81    ///
82    /// * `data` - A slice of vectors, where each vector contains [latitude, longitude] in degrees.
83    /// * `labels` - A slice of cluster labels corresponding to each data point.
84    ///
85    /// # Returns
86    ///
87    /// A vector of cluster centers, where each center is a vector of [latitude, longitude] in degrees.
88    ///
89    /// # Notes
90    ///
91    /// - Assumes input coordinates are in degrees.
92    /// - Output coordinates are in degrees.
93    /// - Points with label -1 are considered noise and are ignored in calculations.
94    /// - Uses a spherical approximation of the Earth for calculations.
95    fn calc_geo_centroids<T: Float>(&self, data: &[Vec<T>], labels: &[i32]) -> Vec<Vec<T>> {
96        let n_clusters = labels
97            .iter()
98            .filter(|&&label| label != -1)
99            .collect::<HashSet<_>>()
100            .len();
101        let mut centers = vec![vec![T::zero(), T::zero(), T::zero()]; n_clusters];
102        let mut counts = vec![T::zero(); n_clusters];
103
104        for (point, &label) in data.iter().zip(labels.iter()) {
105            if label != -1 {
106                let cluster_index = label as usize;
107
108                let lat = point[0].to_radians();
109                let lon = point[1].to_radians();
110
111                let x = lon.cos() * lat.cos();
112                let y = lon.sin() * lat.cos();
113                let z = lat.sin();
114
115                centers[cluster_index][0] = centers[cluster_index][0] + x;
116                centers[cluster_index][1] = centers[cluster_index][1] + y;
117                centers[cluster_index][2] = centers[cluster_index][2] + z;
118                counts[cluster_index] = counts[cluster_index] + T::one();
119            }
120        }
121
122        for (center, &count) in centers.iter_mut().zip(counts.iter()) {
123            if count > T::zero() {
124                let x = center[0] / count;
125                let y = center[1] / count;
126                let z = center[2] / count;
127
128                let lon = y.atan2(x);
129                let hyp = (x * x + y * y).sqrt();
130                let lat = z.atan2(hyp);
131
132                // Convert back to degrees
133                center[0] = lat.to_degrees();
134                center[1] = lon.to_degrees();
135            }
136        }
137
138        centers.iter().map(|c| vec![c[0], c[1]]).collect()
139    }
140
141    fn calc_medoids<T: Float, F: Fn(&[T], &[T]) -> T>(
142        &self,
143        data: &[Vec<T>],
144        labels: &[i32],
145        dist_func: F,
146    ) -> Vec<Vec<T>> {
147        let n_clusters = labels
148            .iter()
149            .filter(|&&label| label != -1)
150            .collect::<HashSet<_>>()
151            .len();
152        let mut medoids = Vec::with_capacity(n_clusters);
153
154        for cluster_id in 0..n_clusters as i32 {
155            let cluster_data = data
156                .iter()
157                .zip(labels.iter())
158                .filter(|(_datapoint, &label)| label == cluster_id)
159                .map(|(datapoint, _label)| datapoint)
160                .collect::<Vec<&Vec<_>>>();
161
162            let n_samples = cluster_data.len();
163            let medoid_idx = (0..n_samples)
164                .map(|i| {
165                    (0..n_samples)
166                        .map(|j| dist_func(cluster_data[i], cluster_data[j]))
167                        .fold(T::zero(), std::ops::Add::add)
168                })
169                .enumerate()
170                .min_by(|(_idx_a, sum_a), (_idx_b, sum_b)| {
171                    sum_a.partial_cmp(sum_b).unwrap_or(Ordering::Equal)
172                })
173                .map(|(idx, _sum)| idx)
174                .unwrap_or(0);
175
176            medoids.push(cluster_data[medoid_idx].clone())
177        }
178
179        medoids
180    }
181}