aprender/cluster/
agglomerative.rs1use crate::error::Result;
7use crate::primitives::Matrix;
8use crate::traits::UnsupervisedEstimator;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum Linkage {
14 Single,
16 Complete,
18 Average,
20 Ward,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Merge {
27 pub clusters: (usize, usize),
29 pub distance: f32,
31 pub size: usize,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct AgglomerativeClustering {
72 n_clusters: usize,
74 linkage: Linkage,
76 labels: Option<Vec<usize>>,
78 dendrogram: Option<Vec<Merge>>,
80}
81
82impl AgglomerativeClustering {
83 #[must_use]
85 pub fn new(n_clusters: usize, linkage: Linkage) -> Self {
86 Self {
87 n_clusters,
88 linkage,
89 labels: None,
90 dendrogram: None,
91 }
92 }
93
94 #[must_use]
96 pub fn n_clusters(&self) -> usize {
97 self.n_clusters
98 }
99
100 #[must_use]
102 pub fn linkage(&self) -> Linkage {
103 self.linkage
104 }
105
106 #[must_use]
108 pub fn is_fitted(&self) -> bool {
109 self.labels.is_some()
110 }
111
112 #[must_use]
114 pub fn labels(&self) -> &Vec<usize> {
115 self.labels
116 .as_ref()
117 .expect("Model not fitted. Call fit() first.")
118 }
119
120 #[must_use]
122 pub fn dendrogram(&self) -> &Vec<Merge> {
123 self.dendrogram
124 .as_ref()
125 .expect("Model not fitted. Call fit() first.")
126 }
127
128 #[allow(clippy::unused_self)]
130 fn euclidean_distance(&self, x: &Matrix<f32>, i: usize, j: usize) -> f32 {
131 let n_features = x.shape().1;
132 let row_i: Vec<f32> = (0..n_features).map(|k| x.get(i, k)).collect();
133 let row_j: Vec<f32> = (0..n_features).map(|k| x.get(j, k)).collect();
134 crate::nn::functional::euclidean_distance(&row_i, &row_j)
135 }
136
137 #[allow(clippy::needless_range_loop)]
139 fn pairwise_distances(&self, x: &Matrix<f32>) -> Vec<Vec<f32>> {
140 let n_samples = x.shape().0;
141 let mut distances = vec![vec![0.0; n_samples]; n_samples];
142 for i in 0..n_samples {
143 for j in (i + 1)..n_samples {
144 let dist = self.euclidean_distance(x, i, j);
145 distances[i][j] = dist;
146 distances[j][i] = dist;
147 }
148 }
149 distances
150 }
151
152 #[allow(clippy::unused_self)]
154 fn find_closest_clusters(
155 &self,
156 distances: &[Vec<f32>],
157 active: &[bool],
158 ) -> (usize, usize, f32) {
159 let n = distances.len();
160 let mut min_dist = f32::INFINITY;
161 let mut min_i = 0;
162 let mut min_j = 1;
163
164 for i in 0..n {
165 if !active[i] {
166 continue;
167 }
168 for j in (i + 1)..n {
169 if !active[j] {
170 continue;
171 }
172 if distances[i][j] < min_dist {
173 min_dist = distances[i][j];
174 min_i = i;
175 min_j = j;
176 }
177 }
178 }
179
180 (min_i, min_j, min_dist)
181 }
182
183 fn pairwise_cluster_distances(
185 &self,
186 x: &Matrix<f32>,
187 cluster_a: &[usize],
188 cluster_b: &[usize],
189 ) -> Vec<f32> {
190 let mut dists = Vec::with_capacity(cluster_a.len() * cluster_b.len());
191 for &i in cluster_a {
192 for &j in cluster_b {
193 dists.push(self.euclidean_distance(x, i, j));
194 }
195 }
196 dists
197 }
198
199 fn update_distances(
201 &self,
202 x: &Matrix<f32>,
203 distances: &mut [Vec<f32>],
204 clusters: &[Vec<usize>],
205 merged_idx: usize,
206 other_idx: usize,
207 ) {
208 let merged_cluster = &clusters[merged_idx];
209 let other_cluster = &clusters[other_idx];
210
211 let dist = match self.linkage {
212 Linkage::Single => {
213 let dists = self.pairwise_cluster_distances(x, merged_cluster, other_cluster);
214 dists.into_iter().fold(f32::INFINITY, f32::min)
215 }
216 Linkage::Complete => {
217 let dists = self.pairwise_cluster_distances(x, merged_cluster, other_cluster);
218 dists.into_iter().fold(0.0_f32, f32::max)
219 }
220 Linkage::Average => {
221 let dists = self.pairwise_cluster_distances(x, merged_cluster, other_cluster);
222 if dists.is_empty() {
223 0.0
224 } else {
225 dists.iter().sum::<f32>() / dists.len() as f32
226 }
227 }
228 Linkage::Ward => {
229 let merged_centroid = self.compute_centroid(x, merged_cluster);
230 let other_centroid = self.compute_centroid(x, other_cluster);
231 let centroid_dist =
232 crate::nn::functional::euclidean_distance(&merged_centroid, &other_centroid);
233 let n1 = merged_cluster.len() as f32;
234 let n2 = other_cluster.len() as f32;
235 ((n1 * n2) / (n1 + n2)) * centroid_dist
236 }
237 };
238
239 distances[merged_idx][other_idx] = dist;
240 distances[other_idx][merged_idx] = dist;
241 }
242
243 #[allow(clippy::needless_range_loop)]
245 #[allow(clippy::unused_self)]
246 fn compute_centroid(&self, x: &Matrix<f32>, cluster: &[usize]) -> Vec<f32> {
247 let n_features = x.shape().1;
248 let mut centroid = vec![0.0; n_features];
249
250 for &idx in cluster {
251 for k in 0..n_features {
252 centroid[k] += x.get(idx, k);
253 }
254 }
255
256 let size = cluster.len() as f32;
257 for val in &mut centroid {
258 *val /= size;
259 }
260
261 centroid
262 }
263}
264
265impl UnsupervisedEstimator for AgglomerativeClustering {
266 type Labels = Vec<usize>;
267
268 fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
269 let n_samples = x.shape().0;
270
271 let mut clusters: Vec<Vec<usize>> = (0..n_samples).map(|i| vec![i]).collect();
273 let mut active = vec![true; n_samples];
274 let mut cluster_labels = vec![0; n_samples];
275 let mut dendrogram = Vec::new();
276
277 let mut distances = self.pairwise_distances(x);
279
280 while clusters.iter().filter(|c| !c.is_empty()).count() > self.n_clusters {
282 let (i, j, dist) = self.find_closest_clusters(&distances, &active);
284
285 let merged_cluster = clusters[j].clone();
287 clusters[i].extend(&merged_cluster);
288 clusters[j].clear();
289 active[j] = false;
290
291 dendrogram.push(Merge {
293 clusters: (i, j),
294 distance: dist,
295 size: clusters[i].len(),
296 });
297
298 #[allow(clippy::needless_range_loop)]
300 for k in 0..n_samples {
301 if k == i || !active[k] {
302 continue;
303 }
304 self.update_distances(x, &mut distances, &clusters, i, k);
305 }
306 }
307
308 let mut cluster_id = 0;
310 for cluster in &clusters {
311 if !cluster.is_empty() {
312 for &point_idx in cluster {
313 cluster_labels[point_idx] = cluster_id;
314 }
315 cluster_id += 1;
316 }
317 }
318
319 self.labels = Some(cluster_labels);
320 self.dendrogram = Some(dendrogram);
321 Ok(())
322 }
323
324 fn predict(&self, _x: &Matrix<f32>) -> Self::Labels {
325 self.labels().clone()
328 }
329}