Skip to main content

scry_learn/cluster/
agglomerative.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Agglomerative (hierarchical) clustering.
3//!
4//! Bottom-up clustering that starts with each sample as its own cluster
5//! and merges the closest pair until `n_clusters` remain.
6//!
7//! # Example
8//!
9//! ```
10//! use scry_learn::cluster::AgglomerativeClustering;
11//! use scry_learn::dataset::Dataset;
12//!
13//! let data = Dataset::new(
14//!     vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
15//!     vec![0.0; 4],
16//!     vec!["x".into(), "y".into()],
17//!     "label",
18//! );
19//!
20//! let mut model = AgglomerativeClustering::new(2);
21//! model.fit(&data).unwrap();
22//! assert_eq!(model.labels().len(), 4);
23//! ```
24
25use crate::dataset::Dataset;
26use crate::distance::euclidean_sq;
27use crate::error::{Result, ScryLearnError};
28use std::cmp::Ordering;
29use std::collections::BinaryHeap;
30
31/// Inter-cluster distance measure for agglomerative clustering.
32///
33/// Controls how the distance between two clusters is computed when
34/// deciding which pair to merge next.
35///
36/// # Example
37///
38/// ```
39/// use scry_learn::cluster::{AgglomerativeClustering, Linkage};
40///
41/// let model = AgglomerativeClustering::new(3).linkage(Linkage::Ward);
42/// ```
43#[derive(Clone, Copy, Debug, Default)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[non_exhaustive]
46pub enum Linkage {
47    /// Minimum distance between any pair of points across two clusters.
48    Single,
49    /// Maximum distance between any pair of points across two clusters.
50    Complete,
51    /// Mean distance between all pairs of points across two clusters.
52    Average,
53    /// Minimize the total within-cluster variance after merging (most common).
54    #[default]
55    Ward,
56}
57
58/// A single merge event in the dendrogram.
59///
60/// Records which two clusters were merged and at what distance.
61#[derive(Clone, Debug)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63#[non_exhaustive]
64pub struct MergeStep {
65    /// Index of the first cluster merged.
66    pub cluster_a: usize,
67    /// Index of the second cluster merged.
68    pub cluster_b: usize,
69    /// Distance at which the merge occurred.
70    pub distance: f64,
71    /// Number of samples in the merged cluster.
72    pub size: usize,
73}
74
75/// Agglomerative (hierarchical) clustering.
76///
77/// Starts with each sample as its own cluster and iteratively merges the
78/// closest pair until `n_clusters` remain. Supports four linkage criteria.
79///
80/// # Example
81///
82/// ```
83/// use scry_learn::cluster::AgglomerativeClustering;
84/// use scry_learn::dataset::Dataset;
85///
86/// let data = Dataset::new(
87///     vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
88///     vec![0.0; 4],
89///     vec!["x".into(), "y".into()],
90///     "label",
91/// );
92///
93/// let mut model = AgglomerativeClustering::new(2);
94/// model.fit(&data).unwrap();
95/// assert_eq!(model.labels().len(), 4);
96/// ```
97#[derive(Clone, Debug)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99#[non_exhaustive]
100pub struct AgglomerativeClustering {
101    n_clusters: usize,
102    linkage: Linkage,
103    labels: Vec<usize>,
104    children: Vec<MergeStep>,
105    fitted: bool,
106    #[cfg_attr(feature = "serde", serde(default))]
107    _schema_version: u32,
108}
109
110impl AgglomerativeClustering {
111    /// Create a new agglomerative clustering model.
112    ///
113    /// # Arguments
114    ///
115    /// * `n_clusters` — target number of clusters.
116    pub fn new(n_clusters: usize) -> Self {
117        Self {
118            n_clusters,
119            linkage: Linkage::Ward,
120            labels: Vec::new(),
121            children: Vec::new(),
122            fitted: false,
123            _schema_version: crate::version::SCHEMA_VERSION,
124        }
125    }
126
127    /// Set the linkage criterion.
128    pub fn linkage(mut self, l: Linkage) -> Self {
129        self.linkage = l;
130        self
131    }
132
133    /// Fit the model on a dataset.
134    ///
135    /// Uses the features and ignores the target column. Computes the
136    /// full O(n²) pairwise distance matrix, then greedily merges the
137    /// closest cluster pair using a priority queue.
138    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
139        data.validate_finite()?;
140        let n = data.n_samples();
141        if n == 0 {
142            return Err(ScryLearnError::EmptyDataset);
143        }
144        if self.n_clusters == 0 || self.n_clusters > n {
145            return Err(ScryLearnError::InvalidParameter(format!(
146                "n_clusters must be between 1 and n_samples ({}), got {}",
147                n, self.n_clusters
148            )));
149        }
150
151        let rows = data.feature_matrix();
152        let n_features = data.n_features();
153
154        // Compute pairwise squared-Euclidean distance matrix (upper triangle).
155        let mut dist = vec![vec![0.0_f64; n]; n];
156        for i in 0..n {
157            for j in (i + 1)..n {
158                let d = euclidean_sq(&rows[i], &rows[j]);
159                dist[i][j] = d;
160                dist[j][i] = d;
161            }
162        }
163
164        // Track which cluster each original sample belongs to.
165        // cluster_id[i] = current cluster for original sample i.
166        let mut cluster_of = (0..n).collect::<Vec<usize>>();
167
168        // Members of each cluster (indexed by cluster id).
169        let mut members: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
170
171        // Centroids for Ward linkage.
172        let mut centroids: Vec<Vec<f64>> = rows.clone();
173
174        // Priority queue: (neg_distance, cluster_a, cluster_b).
175        // We use a max-heap with negated distances to get a min-heap.
176        let mut heap: BinaryHeap<MergeCandidate> = BinaryHeap::new();
177
178        // Populate initial distances.
179        for i in 0..n {
180            for j in (i + 1)..n {
181                let d = self.linkage_distance(i, j, &dist, &members, &centroids, n_features);
182                heap.push(MergeCandidate {
183                    neg_dist: -d,
184                    a: i,
185                    b: j,
186                });
187            }
188        }
189
190        let mut active: Vec<bool> = vec![true; n];
191        let mut n_active = n;
192        let mut next_cluster_id = n; // new clusters get IDs >= n
193        let mut children = Vec::new();
194
195        while n_active > self.n_clusters {
196            // Pop the closest pair (skip stale entries).
197            let merge = loop {
198                let Some(candidate) = heap.pop() else {
199                    break None;
200                };
201                if active[candidate.a] && active[candidate.b] {
202                    break Some(candidate);
203                }
204            };
205
206            let Some(merge) = merge else { break };
207
208            let ca = merge.a;
209            let cb = merge.b;
210            let merge_dist = -merge.neg_dist;
211
212            // Create a new merged cluster.
213            let new_id = next_cluster_id;
214            next_cluster_id += 1;
215
216            // Merge members.
217            let mut new_members = std::mem::take(&mut members[ca]);
218            new_members.extend(std::mem::take(&mut members[cb]));
219            let new_size = new_members.len();
220
221            children.push(MergeStep {
222                cluster_a: ca,
223                cluster_b: cb,
224                distance: merge_dist.sqrt(),
225                size: new_size,
226            });
227
228            // Compute new centroid (for Ward).
229            let new_centroid = if matches!(self.linkage, Linkage::Ward) {
230                let mut c = vec![0.0; n_features];
231                for &idx in &new_members {
232                    for (j, &v) in rows[idx].iter().enumerate() {
233                        c[j] += v;
234                    }
235                }
236                for v in &mut c {
237                    *v /= new_size as f64;
238                }
239                c
240            } else {
241                Vec::new()
242            };
243
244            // Deactivate old clusters.
245            active[ca] = false;
246            active[cb] = false;
247
248            // Expand storage for the new cluster.
249            while active.len() <= new_id {
250                active.push(false);
251                members.push(Vec::new());
252                centroids.push(Vec::new());
253                // Expand dist matrix
254                for row in &mut dist {
255                    row.push(f64::INFINITY);
256                }
257                dist.push(vec![f64::INFINITY; dist[0].len()]);
258            }
259
260            active[new_id] = true;
261            members[new_id] = new_members;
262            centroids[new_id] = new_centroid;
263
264            // Compute distances from new cluster to all remaining active clusters.
265            for other in 0..active.len() {
266                if !active[other] || other == new_id {
267                    continue;
268                }
269                let d = self.compute_merged_distance(
270                    ca, cb, other, &dist, &members, &centroids, n_features, &rows,
271                );
272                dist[new_id][other] = d;
273                dist[other][new_id] = d;
274                heap.push(MergeCandidate {
275                    neg_dist: -d,
276                    a: new_id.min(other),
277                    b: new_id.max(other),
278                });
279            }
280
281            // Update cluster_of for merged members.
282            for &idx in &members[new_id] {
283                cluster_of[idx] = new_id;
284            }
285
286            n_active -= 1;
287        }
288
289        // Assign final labels 0..n_clusters-1.
290        let active_ids: Vec<usize> = active
291            .iter()
292            .enumerate()
293            .filter(|(_, &a)| a)
294            .map(|(i, _)| i)
295            .collect();
296
297        let mut labels = vec![0usize; n];
298        for (label, &cid) in active_ids.iter().enumerate() {
299            for &sample in &members[cid] {
300                labels[sample] = label;
301            }
302        }
303
304        self.labels = labels;
305        self.children = children;
306        self.fitted = true;
307        Ok(())
308    }
309
310    /// Compute linkage distance between two clusters.
311    fn linkage_distance(
312        &self,
313        a: usize,
314        b: usize,
315        dist: &[Vec<f64>],
316        members: &[Vec<usize>],
317        centroids: &[Vec<f64>],
318        _n_features: usize,
319    ) -> f64 {
320        match self.linkage {
321            Linkage::Single => {
322                let mut min_d = f64::INFINITY;
323                for &i in &members[a] {
324                    for &j in &members[b] {
325                        let d = dist[i][j];
326                        if d < min_d {
327                            min_d = d;
328                        }
329                    }
330                }
331                min_d
332            }
333            Linkage::Complete => {
334                let mut max_d = 0.0_f64;
335                for &i in &members[a] {
336                    for &j in &members[b] {
337                        let d = dist[i][j];
338                        if d > max_d {
339                            max_d = d;
340                        }
341                    }
342                }
343                max_d
344            }
345            Linkage::Average => {
346                let mut sum = 0.0;
347                let count = members[a].len() * members[b].len();
348                for &i in &members[a] {
349                    for &j in &members[b] {
350                        sum += dist[i][j];
351                    }
352                }
353                if count > 0 {
354                    sum / count as f64
355                } else {
356                    0.0
357                }
358            }
359            Linkage::Ward => {
360                // Ward distance = size_a * size_b / (size_a + size_b) * ||c_a - c_b||²
361                let sa = members[a].len() as f64;
362                let sb = members[b].len() as f64;
363                let d: f64 = centroids[a]
364                    .iter()
365                    .zip(centroids[b].iter())
366                    .map(|(ca, cb)| (ca - cb).powi(2))
367                    .sum();
368                sa * sb / (sa + sb) * d
369            }
370        }
371    }
372
373    /// Compute distance from a newly merged cluster (ca+cb) to another cluster.
374    #[allow(clippy::too_many_arguments)]
375    fn compute_merged_distance(
376        &self,
377        ca: usize,
378        cb: usize,
379        other: usize,
380        dist: &[Vec<f64>],
381        members: &[Vec<usize>],
382        _centroids: &[Vec<f64>],
383        _n_features: usize,
384        _rows: &[Vec<f64>],
385    ) -> f64 {
386        match self.linkage {
387            Linkage::Single => dist[ca][other].min(dist[cb][other]),
388            Linkage::Complete => dist[ca][other].max(dist[cb][other]),
389            Linkage::Average => {
390                let na = members[ca].len() as f64;
391                let nb = members[cb].len() as f64;
392                (na * dist[ca][other] + nb * dist[cb][other]) / (na + nb)
393            }
394            Linkage::Ward => {
395                // Lance-Williams formula for Ward's method
396                let na = members[ca].len() as f64;
397                let nb = members[cb].len() as f64;
398                let nc = members[other].len() as f64;
399                let total = na + nb + nc;
400                ((na + nc) * dist[ca][other] + (nb + nc) * dist[cb][other] - nc * dist[ca][cb])
401                    / total
402            }
403        }
404    }
405
406    /// Get cluster labels for training data.
407    pub fn labels(&self) -> &[usize] {
408        &self.labels
409    }
410
411    /// Number of clusters.
412    pub fn n_clusters(&self) -> usize {
413        self.n_clusters
414    }
415
416    /// Merge history (dendrogram data).
417    ///
418    /// Each entry records which two clusters were merged and at what distance.
419    pub fn children(&self) -> &[MergeStep] {
420        &self.children
421    }
422}
423
424/// Priority queue entry for cluster merging.
425#[derive(Clone, Copy)]
426struct MergeCandidate {
427    neg_dist: f64, // negated so BinaryHeap (max-heap) gives us the minimum
428    a: usize,
429    b: usize,
430}
431
432impl PartialEq for MergeCandidate {
433    fn eq(&self, other: &Self) -> bool {
434        self.neg_dist == other.neg_dist
435    }
436}
437
438impl Eq for MergeCandidate {}
439
440impl PartialOrd for MergeCandidate {
441    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
442        Some(self.cmp(other))
443    }
444}
445
446impl Ord for MergeCandidate {
447    fn cmp(&self, other: &Self) -> Ordering {
448        self.neg_dist
449            .partial_cmp(&other.neg_dist)
450            .unwrap_or(Ordering::Equal)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_agglomerative_three_clusters() {
460        // Three well-separated clusters.
461        let mut rng = crate::rng::FastRng::new(0);
462        let mut f1 = Vec::new();
463        let mut f2 = Vec::new();
464        for _ in 0..10 {
465            f1.push(rng.f64() * 2.0);
466            f2.push(rng.f64() * 2.0);
467        }
468        for _ in 0..10 {
469            f1.push(50.0 + rng.f64() * 2.0);
470            f2.push(50.0 + rng.f64() * 2.0);
471        }
472        for _ in 0..10 {
473            f1.push(100.0 + rng.f64() * 2.0);
474            f2.push(100.0 + rng.f64() * 2.0);
475        }
476
477        let data = Dataset::new(
478            vec![f1, f2],
479            vec![0.0; 30],
480            vec!["x".into(), "y".into()],
481            "label",
482        );
483
484        let mut model = AgglomerativeClustering::new(3);
485        model.fit(&data).unwrap();
486
487        let labels = model.labels();
488        assert_eq!(labels.len(), 30);
489
490        // All points in the same group should have the same label.
491        let label_a = labels[0];
492        assert!(
493            labels[..10].iter().all(|&l| l == label_a),
494            "Cluster A inconsistent"
495        );
496
497        let label_b = labels[10];
498        assert!(
499            labels[10..20].iter().all(|&l| l == label_b),
500            "Cluster B inconsistent"
501        );
502
503        let label_c = labels[20];
504        assert!(
505            labels[20..].iter().all(|&l| l == label_c),
506            "Cluster C inconsistent"
507        );
508
509        // All three labels should be distinct.
510        assert_ne!(label_a, label_b);
511        assert_ne!(label_a, label_c);
512        assert_ne!(label_b, label_c);
513    }
514
515    #[test]
516    fn test_agglomerative_linkage_variants() {
517        let data = Dataset::new(
518            vec![vec![0.0, 1.0, 5.0, 6.0], vec![0.0, 0.0, 0.0, 0.0]],
519            vec![0.0; 4],
520            vec!["x".into(), "y".into()],
521            "label",
522        );
523
524        for linkage in [
525            Linkage::Single,
526            Linkage::Complete,
527            Linkage::Average,
528            Linkage::Ward,
529        ] {
530            let mut model = AgglomerativeClustering::new(2).linkage(linkage);
531            model.fit(&data).unwrap();
532            assert_eq!(model.labels().len(), 4, "Failed for {linkage:?}");
533        }
534    }
535
536    #[test]
537    fn test_agglomerative_ward_vs_single() {
538        // Ward and Single should produce different merge histories
539        // on data where they disagree.
540        let data = Dataset::new(
541            vec![
542                vec![0.0, 1.0, 3.0, 10.0, 11.0, 13.0],
543                vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
544            ],
545            vec![0.0; 6],
546            vec!["x".into(), "y".into()],
547            "label",
548        );
549
550        let mut ward = AgglomerativeClustering::new(2).linkage(Linkage::Ward);
551        ward.fit(&data).unwrap();
552
553        let mut single = AgglomerativeClustering::new(2).linkage(Linkage::Single);
554        single.fit(&data).unwrap();
555
556        // Both should output valid labels of length 6.
557        assert_eq!(ward.labels().len(), 6);
558        assert_eq!(single.labels().len(), 6);
559
560        // The merge histories should have the right number of steps.
561        assert_eq!(ward.children().len(), 4); // 6 - 2 merges
562        assert_eq!(single.children().len(), 4);
563    }
564
565    #[test]
566    fn test_agglomerative_single_cluster() {
567        let data = Dataset::new(
568            vec![vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 2.0]],
569            vec![0.0; 3],
570            vec!["x".into(), "y".into()],
571            "label",
572        );
573
574        let mut model = AgglomerativeClustering::new(1);
575        model.fit(&data).unwrap();
576        assert!(
577            model.labels().iter().all(|&l| l == 0),
578            "All should be cluster 0"
579        );
580    }
581}