use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use crate::agglomerative::{AgglomerativeClustering, Linkage};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AgglomerativeLinkage {
Ward,
Complete,
Average,
Single,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolingFunc {
Mean,
Max,
}
#[derive(Debug, Clone)]
pub struct FeatureAgglomeration<F> {
n_clusters: usize,
linkage: AgglomerativeLinkage,
pooling_func: PoolingFunc,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> FeatureAgglomeration<F> {
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
linkage: AgglomerativeLinkage::Ward,
pooling_func: PoolingFunc::Mean,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_linkage(mut self, linkage: AgglomerativeLinkage) -> Self {
self.linkage = linkage;
self
}
#[must_use]
pub fn with_pooling_func(mut self, pooling: PoolingFunc) -> Self {
self.pooling_func = pooling;
self
}
#[must_use]
pub fn n_clusters(&self) -> usize {
self.n_clusters
}
#[must_use]
pub fn linkage(&self) -> AgglomerativeLinkage {
self.linkage
}
#[must_use]
pub fn pooling_func(&self) -> PoolingFunc {
self.pooling_func
}
}
#[derive(Debug, Clone)]
pub struct FittedFeatureAgglomeration<F> {
feature_labels_: Array1<usize>,
n_clusters_: usize,
n_features_: usize,
pooling_func_: PoolingFunc,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> FittedFeatureAgglomeration<F> {
#[must_use]
pub fn feature_labels(&self) -> &Array1<usize> {
&self.feature_labels_
}
#[must_use]
pub fn n_clusters(&self) -> usize {
self.n_clusters_
}
#[must_use]
pub fn n_features(&self) -> usize {
self.n_features_
}
}
fn map_linkage(l: AgglomerativeLinkage) -> Linkage {
match l {
AgglomerativeLinkage::Ward => Linkage::Ward,
AgglomerativeLinkage::Complete => Linkage::Complete,
AgglomerativeLinkage::Average => Linkage::Average,
AgglomerativeLinkage::Single => Linkage::Single,
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FeatureAgglomeration<F> {
type Fitted = FittedFeatureAgglomeration<F>;
type Error = FerroError;
fn fit(
&self,
x: &Array2<F>,
_y: &(),
) -> Result<FittedFeatureAgglomeration<F>, FerroError> {
let n_features = x.ncols();
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if n_features < self.n_clusters {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: format!(
"n_clusters ({}) exceeds n_features ({})",
self.n_clusters, n_features
),
});
}
if x.nrows() == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "FeatureAgglomeration::fit requires at least 1 sample".into(),
});
}
let x_t = x.t().as_standard_layout().into_owned();
let agg = AgglomerativeClustering::<F>::new(self.n_clusters)
.with_linkage(map_linkage(self.linkage));
let fitted_agg = agg.fit(&x_t, &())?;
Ok(FittedFeatureAgglomeration {
feature_labels_: fitted_agg.labels_,
n_clusters_: self.n_clusters,
n_features_: n_features,
pooling_func_: self.pooling_func,
_marker: std::marker::PhantomData,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFeatureAgglomeration<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features_ {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), self.n_features_],
actual: vec![x.nrows(), x.ncols()],
context: "FittedFeatureAgglomeration::transform".into(),
});
}
let n_samples = x.nrows();
let mut result = Array2::<F>::zeros((n_samples, self.n_clusters_));
match self.pooling_func_ {
PoolingFunc::Mean => {
let mut counts = vec![0usize; self.n_clusters_];
for &label in self.feature_labels_.iter() {
counts[label] += 1;
}
for i in 0..n_samples {
for (j, &label) in self.feature_labels_.iter().enumerate() {
result[[i, label]] = result[[i, label]] + x[[i, j]];
}
}
for i in 0..n_samples {
for c in 0..self.n_clusters_ {
if counts[c] > 0 {
result[[i, c]] =
result[[i, c]] / F::from(counts[c]).unwrap();
}
}
}
}
PoolingFunc::Max => {
result.fill(F::neg_infinity());
for i in 0..n_samples {
for (j, &label) in self.feature_labels_.iter().enumerate() {
if x[[i, j]] > result[[i, label]] {
result[[i, label]] = x[[i, j]];
}
}
}
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn make_correlated_features() -> Array2<f64> {
Array2::from_shape_vec(
(5, 6),
vec![
1.0, 1.1, 5.0, 5.1, 9.0, 9.1, 2.0, 2.1, 6.0, 6.1, 8.0, 8.1, 3.0, 3.1, 7.0,
7.1, 7.0, 7.1, 4.0, 4.1, 8.0, 8.1, 6.0, 6.1, 5.0, 5.1, 9.0, 9.1, 5.0, 5.1,
],
)
.unwrap()
}
#[test]
fn test_feature_agglom_basic() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.dim(), (5, 3));
}
#[test]
fn test_feature_agglom_output_shape() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(2);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 2);
assert_eq!(reduced.nrows(), 5);
}
#[test]
fn test_feature_agglom_labels_valid_range() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3);
let fitted = fa.fit(&x, &()).unwrap();
for &l in fitted.feature_labels().iter() {
assert!(l < 3, "label {l} out of range");
}
}
#[test]
fn test_feature_agglom_correlated_grouped() {
let x = Array2::from_shape_vec(
(5, 4),
vec![
1.0, 1.001, 100.0, 100.001,
2.0, 2.001, 90.0, 90.001,
3.0, 3.001, 80.0, 80.001,
4.0, 4.001, 70.0, 70.001,
5.0, 5.001, 60.0, 60.001,
],
)
.unwrap();
let fa = FeatureAgglomeration::<f64>::new(2)
.with_linkage(AgglomerativeLinkage::Single);
let fitted = fa.fit(&x, &()).unwrap();
let labels = fitted.feature_labels();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_feature_agglom_mean_pooling() {
let x = Array2::from_shape_vec(
(3, 2),
vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0],
)
.unwrap();
let fa = FeatureAgglomeration::<f64>::new(1);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 1);
assert_abs_diff_eq!(reduced[[0, 0]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(reduced[[1, 0]], 7.0, epsilon = 1e-10);
assert_abs_diff_eq!(reduced[[2, 0]], 11.0, epsilon = 1e-10);
}
#[test]
fn test_feature_agglom_max_pooling() {
let x = Array2::from_shape_vec(
(3, 2),
vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0],
)
.unwrap();
let fa = FeatureAgglomeration::<f64>::new(1).with_pooling_func(PoolingFunc::Max);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 1);
assert_abs_diff_eq!(reduced[[0, 0]], 4.0, epsilon = 1e-10);
assert_abs_diff_eq!(reduced[[1, 0]], 8.0, epsilon = 1e-10);
assert_abs_diff_eq!(reduced[[2, 0]], 12.0, epsilon = 1e-10);
}
#[test]
fn test_feature_agglom_complete_linkage() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3)
.with_linkage(AgglomerativeLinkage::Complete);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 3);
}
#[test]
fn test_feature_agglom_average_linkage() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3)
.with_linkage(AgglomerativeLinkage::Average);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 3);
}
#[test]
fn test_feature_agglom_single_linkage() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3)
.with_linkage(AgglomerativeLinkage::Single);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 3);
}
#[test]
fn test_feature_agglom_n_clusters_equals_n_features() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(6);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 6);
}
#[test]
fn test_feature_agglom_zero_clusters_error() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(0);
assert!(fa.fit(&x, &()).is_err());
}
#[test]
fn test_feature_agglom_too_many_clusters_error() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(10); assert!(fa.fit(&x, &()).is_err());
}
#[test]
fn test_feature_agglom_empty_data_error() {
let x = Array2::<f64>::zeros((0, 4));
let fa = FeatureAgglomeration::<f64>::new(2);
assert!(fa.fit(&x, &()).is_err());
}
#[test]
fn test_feature_agglom_transform_shape_mismatch() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3);
let fitted = fa.fit(&x, &()).unwrap();
let x_bad = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.unwrap();
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_feature_agglom_f32() {
let x = Array2::<f32>::from_shape_vec(
(4, 4),
vec![
1.0, 1.1, 5.0, 5.1, 2.0, 2.1, 6.0, 6.1, 3.0, 3.1, 7.0, 7.1, 4.0, 4.1, 8.0,
8.1,
],
)
.unwrap();
let fa = FeatureAgglomeration::<f32>::new(2);
let fitted = fa.fit(&x, &()).unwrap();
let reduced = fitted.transform(&x).unwrap();
assert_eq!(reduced.ncols(), 2);
}
#[test]
fn test_feature_agglom_getters() {
let fa = FeatureAgglomeration::<f64>::new(3)
.with_linkage(AgglomerativeLinkage::Complete)
.with_pooling_func(PoolingFunc::Max);
assert_eq!(fa.n_clusters(), 3);
assert_eq!(fa.linkage(), AgglomerativeLinkage::Complete);
assert_eq!(fa.pooling_func(), PoolingFunc::Max);
}
#[test]
fn test_feature_agglom_n_features_getter() {
let x = make_correlated_features();
let fa = FeatureAgglomeration::<f64>::new(3);
let fitted = fa.fit(&x, &()).unwrap();
assert_eq!(fitted.n_features(), 6);
assert_eq!(fitted.n_clusters(), 3);
}
}