use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::Fit;
use ndarray::{Array1, Array2};
use num_traits::Float;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Linkage {
Ward,
Complete,
Average,
Single,
}
#[derive(Debug, Clone)]
pub struct AgglomerativeClustering<F> {
pub n_clusters: usize,
pub linkage: Linkage,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float> AgglomerativeClustering<F> {
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
linkage: Linkage::Ward,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_linkage(mut self, linkage: Linkage) -> Self {
self.linkage = linkage;
self
}
}
#[derive(Debug, Clone)]
pub struct FittedAgglomerativeClustering<F> {
pub labels_: Array1<usize>,
pub n_clusters_: usize,
pub children_: Vec<(usize, usize)>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float> FittedAgglomerativeClustering<F> {
#[must_use]
pub fn labels(&self) -> &Array1<usize> {
&self.labels_
}
#[must_use]
pub fn n_clusters(&self) -> usize {
self.n_clusters_
}
#[must_use]
pub fn children(&self) -> &[(usize, usize)] {
&self.children_
}
}
#[inline]
fn sq_euclidean<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b.iter())
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
}
fn pairwise_sq_dists<F: Float>(x: &Array2<F>) -> Vec<F> {
let n = x.nrows();
let mut d = vec![F::zero(); n * n];
for i in 0..n {
let ri = x.row(i);
let si = ri.as_slice().unwrap_or(&[]);
for j in (i + 1)..n {
let rj = x.row(j);
let sj = rj.as_slice().unwrap_or(&[]);
let dist = sq_euclidean(si, sj);
d[i * n + j] = dist;
d[j * n + i] = dist;
}
}
d
}
fn find_min_pair(dist_mat: &[f64], active: &[usize]) -> (usize, usize) {
let mut best_i = active[0];
let mut best_j = active[1];
let n = (dist_mat.len() as f64).sqrt() as usize;
let mut best_val = f64::INFINITY;
for (ai, &i) in active.iter().enumerate() {
for &j in active.iter().skip(ai + 1) {
let v = dist_mat[i * n + j];
if v < best_val {
best_val = v;
best_i = i;
best_j = j;
}
}
}
(best_i, best_j)
}
type AgglomerateResult = Result<(Array1<usize>, Vec<(usize, usize)>), FerroError>;
fn agglomerate<F: Float>(
x: &Array2<F>,
n_clusters_target: usize,
linkage: Linkage,
) -> AgglomerateResult {
let n_samples = x.nrows();
let x_f64: Array2<f64> = x.mapv(|v| v.to_f64().unwrap_or(0.0));
let mut sq_dists = pairwise_sq_dists(&x_f64);
let n = n_samples;
let mut sizes: Vec<f64> = vec![1.0; n];
let mut active: Vec<usize> = (0..n).collect();
let mut children: Vec<(usize, usize)> = Vec::with_capacity(n - n_clusters_target);
let mut assignment: Vec<usize> = (0..n).collect();
while active.len() > n_clusters_target {
let (ci, cj) = find_min_pair(&sq_dists, &active);
active.retain(|&id| id != cj);
children.push((ci, cj));
let ni = sizes[ci];
let nj = sizes[cj];
let new_size = ni + nj;
for &ck in &active {
if ck == ci {
continue;
}
let nk = sizes[ck];
let d_ik = sq_dists[ci * n + ck];
let d_jk = sq_dists[cj * n + ck];
let new_dist = match linkage {
Linkage::Single => {
if d_ik < d_jk {
d_ik
} else {
d_jk
}
}
Linkage::Complete => {
if d_ik > d_jk {
d_ik
} else {
d_jk
}
}
Linkage::Average => (ni * d_ik + nj * d_jk) / (ni + nj),
Linkage::Ward => {
let d_ij = sq_dists[ci * n + cj];
let denom = ni + nj + nk;
((ni + nk) / denom) * d_ik + ((nj + nk) / denom) * d_jk - (nk / denom) * d_ij
}
};
sq_dists[ci * n + ck] = new_dist;
sq_dists[ck * n + ci] = new_dist;
}
sizes[ci] = new_size;
for s in assignment.iter_mut() {
if *s == cj {
*s = ci;
}
}
}
let mut id_map: std::collections::HashMap<usize, usize> = std::collections::HashMap::new();
for (new_id, &cluster_id) in active.iter().enumerate() {
id_map.insert(cluster_id, new_id);
}
let labels: Array1<usize> = assignment
.iter()
.map(|id| *id_map.get(id).unwrap_or(&0))
.collect();
Ok((labels, children))
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for AgglomerativeClustering<F> {
type Fitted = FittedAgglomerativeClustering<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedAgglomerativeClustering<F>, FerroError> {
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
let n_samples = x.nrows();
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "AgglomerativeClustering requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "AgglomerativeClustering requires at least n_clusters samples".into(),
});
}
let (labels, children) = agglomerate(x, self.n_clusters, self.linkage)?;
Ok(FittedAgglomerativeClustering {
labels_: labels,
n_clusters_: self.n_clusters,
children_: children,
_marker: std::marker::PhantomData,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_two_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(8, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 0.05, 0.05, 10.0, 10.0, 10.1, 10.0, 10.0, 10.1,
10.05, 10.05,
],
)
.unwrap()
}
fn make_three_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1, 0.0, 10.0, 0.1,
10.1, -0.1, 9.9,
],
)
.unwrap()
}
#[test]
fn test_new_defaults() {
let model = AgglomerativeClustering::<f64>::new(3);
assert_eq!(model.n_clusters, 3);
assert_eq!(model.linkage, Linkage::Ward);
}
#[test]
fn test_with_linkage() {
let model = AgglomerativeClustering::<f64>::new(2).with_linkage(Linkage::Complete);
assert_eq!(model.linkage, Linkage::Complete);
}
#[test]
fn test_zero_clusters_error() {
let x = make_two_blobs();
let result = AgglomerativeClustering::<f64>::new(0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
let result = AgglomerativeClustering::<f64>::new(2).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_more_clusters_than_samples_error() {
let x = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
let result = AgglomerativeClustering::<f64>::new(5).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_ward_two_blobs() {
let x = make_two_blobs();
let fitted = AgglomerativeClustering::<f64>::new(2)
.with_linkage(Linkage::Ward)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[0], labels[3]);
assert_eq!(labels[4], labels[5]);
assert_eq!(labels[4], labels[6]);
assert_eq!(labels[4], labels[7]);
assert_ne!(labels[0], labels[4]);
}
#[test]
fn test_ward_three_blobs() {
let x = make_three_blobs();
let fitted = AgglomerativeClustering::<f64>::new(3)
.with_linkage(Linkage::Ward)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[6], labels[7]);
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
assert_ne!(labels[3], labels[6]);
}
#[test]
fn test_complete_two_blobs() {
let x = make_two_blobs();
let fitted = AgglomerativeClustering::<f64>::new(2)
.with_linkage(Linkage::Complete)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[4]);
}
#[test]
fn test_complete_three_blobs() {
let x = make_three_blobs();
let fitted = AgglomerativeClustering::<f64>::new(3)
.with_linkage(Linkage::Complete)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
}
#[test]
fn test_average_two_blobs() {
let x = make_two_blobs();
let fitted = AgglomerativeClustering::<f64>::new(2)
.with_linkage(Linkage::Average)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[4]);
}
#[test]
fn test_average_three_blobs() {
let x = make_three_blobs();
let fitted = AgglomerativeClustering::<f64>::new(3)
.with_linkage(Linkage::Average)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
}
#[test]
fn test_single_two_blobs() {
let x = make_two_blobs();
let fitted = AgglomerativeClustering::<f64>::new(2)
.with_linkage(Linkage::Single)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[4]);
}
#[test]
fn test_single_three_blobs() {
let x = make_three_blobs();
let fitted = AgglomerativeClustering::<f64>::new(3)
.with_linkage(Linkage::Single)
.fit(&x, &())
.unwrap();
let labels = fitted.labels();
assert_ne!(labels[0], labels[3]);
assert_ne!(labels[0], labels[6]);
}
#[test]
fn test_label_count_equals_n_samples() {
let x = make_two_blobs();
let fitted = AgglomerativeClustering::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), x.nrows());
}
#[test]
fn test_labels_in_valid_range() {
let x = make_three_blobs();
let fitted = AgglomerativeClustering::<f64>::new(3).fit(&x, &()).unwrap();
for &l in fitted.labels().iter() {
assert!(l < 3, "label {l} out of range");
}
}
#[test]
fn test_n_clusters_matches_config() {
let x = make_three_blobs();
let fitted = AgglomerativeClustering::<f64>::new(3).fit(&x, &()).unwrap();
assert_eq!(fitted.n_clusters(), 3);
}
#[test]
fn test_children_length() {
let x = make_two_blobs(); let fitted = AgglomerativeClustering::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.children().len(), x.nrows() - 2);
}
#[test]
fn test_children_empty_when_n_clusters_equals_n_samples() {
let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0]).unwrap();
let fitted = AgglomerativeClustering::<f64>::new(3).fit(&x, &()).unwrap();
assert!(fitted.children().is_empty());
}
#[test]
fn test_single_cluster() {
let x = make_two_blobs();
let fitted = AgglomerativeClustering::<f64>::new(1).fit(&x, &()).unwrap();
for &l in fitted.labels().iter() {
assert_eq!(l, 0);
}
}
#[test]
fn test_n_clusters_equals_n_samples() {
let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0]).unwrap();
let fitted = AgglomerativeClustering::<f64>::new(3).fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_ne!(labels[0], labels[1]);
assert_ne!(labels[0], labels[2]);
assert_ne!(labels[1], labels[2]);
}
#[test]
fn test_single_sample_single_cluster() {
let x = Array2::from_shape_vec((1, 2), vec![3.0, 4.0]).unwrap();
let fitted = AgglomerativeClustering::<f64>::new(1).fit(&x, &()).unwrap();
assert_eq!(fitted.labels()[0], 0);
assert_eq!(fitted.n_clusters(), 1);
assert!(fitted.children().is_empty());
}
#[test]
fn test_1d_data() {
let x = Array2::from_shape_vec((6, 1), vec![0.0, 0.1, -0.1, 100.0, 100.1, 99.9]).unwrap();
let fitted = AgglomerativeClustering::<f64>::new(2).fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_f32_support() {
let x = Array2::<f32>::from_shape_vec(
(6, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 10.0, 10.0, 10.1, 10.0, 10.0, 10.1,
],
)
.unwrap();
let fitted = AgglomerativeClustering::<f32>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 6);
let labels = fitted.labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[3], labels[4]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_identical_points() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).unwrap();
let fitted = AgglomerativeClustering::<f64>::new(1).fit(&x, &()).unwrap();
for &l in fitted.labels().iter() {
assert_eq!(l, 0);
}
}
}