rustframe 0.0.1-a.20250805

A simple dataframe and math toolkit
Documentation
//! Simple k-means clustering working on [`Matrix`] data.
//!
//! ```
//! use rustframe::compute::models::k_means::KMeans;
//! use rustframe::matrix::Matrix;
//!
//! let data = Matrix::from_vec(vec![1.0, 1.0, 5.0, 5.0], 2, 2);
//! let (model, labels) = KMeans::fit(&data, 2, 10, 1e-4);
//! assert_eq!(model.centroids.rows(), 2);
//! assert_eq!(labels.len(), 2);
//! ```
use crate::compute::stats::mean_vertical;
use crate::matrix::Matrix;
use crate::random::prelude::*;

pub struct KMeans {
    pub centroids: Matrix<f64>, // (k, n_features)
}

impl KMeans {
    /// Fit with k clusters.
    pub fn fit(x: &Matrix<f64>, k: usize, max_iter: usize, tol: f64) -> (Self, Vec<usize>) {
        let m = x.rows();
        let n = x.cols();
        assert!(k <= m, "k must be ≤ number of samples");

        // ----- initialise centroids -----
        let mut centroids = Matrix::zeros(k, n);
        if k > 0 && m > 0 {
            // case for empty data
            if k == 1 {
                let mean = mean_vertical(x);
                centroids.row_copy_from_slice(0, &mean.data()); // ideally, data.row(0), but thats the same
            } else {
                // For k > 1, pick k distinct rows at random
                let mut rng = rng();
                let mut indices: Vec<usize> = (0..m).collect();
                indices.shuffle(&mut rng);
                for c in 0..k {
                    centroids.row_copy_from_slice(c, &x.row(indices[c]));
                }
            }
        }

        let mut labels = vec![0usize; m];
        let mut distances = vec![0.0f64; m];

        for _iter in 0..max_iter {
            let mut changed = false;
            // ----- assignment step -----
            for i in 0..m {
                let sample_row = x.row(i);
                let mut best = 0usize;
                let mut best_dist_sq = f64::MAX;

                for c in 0..k {
                    let centroid_row = centroids.row(c);

                    let dist_sq: f64 = sample_row
                        .iter()
                        .zip(centroid_row.iter())
                        .map(|(a, b)| (a - b).powi(2))
                        .sum();

                    if dist_sq < best_dist_sq {
                        best_dist_sq = dist_sq;
                        best = c;
                    }
                }

                distances[i] = best_dist_sq;

                if labels[i] != best {
                    labels[i] = best;
                    changed = true;
                }
            }

            // ----- update step -----
            let mut new_centroids = Matrix::zeros(k, n);
            let mut counts = vec![0usize; k];
            for i in 0..m {
                let c = labels[i];
                counts[c] += 1;
                for j in 0..n {
                    new_centroids[(c, j)] += x[(i, j)];
                }
            }

            for c in 0..k {
                if counts[c] == 0 {
                    // This cluster is empty. Re-initialize its centroid to the point
                    // furthest from its assigned centroid to prevent the cluster from dying.
                    let mut furthest_point_idx = 0;
                    let mut max_dist_sq = 0.0;
                    for (i, &dist) in distances.iter().enumerate() {
                        if dist > max_dist_sq {
                            max_dist_sq = dist;
                            furthest_point_idx = i;
                        }
                    }

                    for j in 0..n {
                        new_centroids[(c, j)] = x[(furthest_point_idx, j)];
                    }
                    // Ensure this point isn't chosen again for another empty cluster in the same iteration.
                    if m > 0 {
                        distances[furthest_point_idx] = 0.0;
                    }
                } else {
                    // Normalize the centroid by the number of points in it.
                    for j in 0..n {
                        new_centroids[(c, j)] /= counts[c] as f64;
                    }
                }
            }

            // ----- convergence test -----
            if !changed {
                centroids = new_centroids; //  update before breaking
                break; // assignments stable
            }

            let diff = &new_centroids - &centroids;
            centroids = new_centroids; // Update for the next iteration

            if tol > 0.0 {
                let sq_diff = &diff * &diff;
                let shift = sq_diff.data().iter().sum::<f64>().sqrt();
                if shift < tol {
                    break;
                }
            }
        }
        (Self { centroids }, labels)
    }

    /// Predict nearest centroid for each sample.
    pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
        let m = x.rows();
        let k = self.centroids.rows();

        if m == 0 {
            return Vec::new();
        }

        let mut labels = vec![0usize; m];
        for i in 0..m {
            let sample_row = x.row(i);
            let mut best = 0usize;
            let mut best_dist_sq = f64::MAX;

            for c in 0..k {
                let centroid_row = self.centroids.row(c);

                let dist_sq: f64 = sample_row
                    .iter()
                    .zip(centroid_row.iter())
                    .map(|(a, b)| (a - b).powi(2))
                    .sum();

                if dist_sq < best_dist_sq {
                    best_dist_sq = dist_sq;
                    best = c;
                }
            }
            labels[i] = best;
        }
        labels
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_k_means_empty_cluster_reinit_centroid() {
        // Try multiple times to increase the chance of hitting the empty cluster case
        for _ in 0..20 {
            let data = vec![0.0, 0.0, 0.0, 0.0, 10.0, 10.0];
            let x = FloatMatrix::from_rows_vec(data, 3, 2);
            let k = 2;
            let max_iter = 10;
            let tol = 1e-6;

            let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);

            // Check if any cluster is empty
            let mut counts = vec![0; k];
            for &label in &labels {
                counts[label] += 1;
            }
            if counts.iter().any(|&c| c == 0) {
                // Only check the property for clusters that are empty
                let centroids = kmeans_model.centroids;
                for c in 0..k {
                    if counts[c] == 0 {
                        let mut matches_data_point = false;
                        for i in 0..3 {
                            let dx = centroids[(c, 0)] - x[(i, 0)];
                            let dy = centroids[(c, 1)] - x[(i, 1)];
                            if dx.abs() < 1e-9 && dy.abs() < 1e-9 {
                                matches_data_point = true;
                                break;
                            }
                        }
                        // "Centroid {} (empty cluster) does not match any data point",c
                        assert!(matches_data_point);
                    }
                }
                break;
            }
        }
        // If we never saw an empty cluster, that's fine; the test passes as long as no panic occurred
    }
    use super::*;
    use crate::matrix::FloatMatrix;

    fn create_test_data() -> (FloatMatrix, usize) {
        // Simple 2D data for testing K-Means
        // Cluster 1: (1,1), (1.5,1.5)
        // Cluster 2: (5,8), (8,8), (6,7)
        let data = vec![
            1.0, 1.0, // Sample 0
            1.5, 1.5, // Sample 1
            5.0, 8.0, // Sample 2
            8.0, 8.0, // Sample 3
            6.0, 7.0, // Sample 4
        ];
        let x = FloatMatrix::from_rows_vec(data, 5, 2);
        let k = 2;
        (x, k)
    }

    // Helper for single cluster test with exact mean
    fn create_simple_integer_data() -> FloatMatrix {
        // Data points: (1,1), (2,2), (3,3)
        FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2)
    }

    #[test]
    fn test_k_means_fit_predict_basic() {
        let (x, k) = create_test_data();
        let max_iter = 100;
        let tol = 1e-6;

        let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);

        // Assertions for fit
        assert_eq!(kmeans_model.centroids.rows(), k);
        assert_eq!(kmeans_model.centroids.cols(), x.cols());
        assert_eq!(labels.len(), x.rows());

        // Check if labels are within expected range (0 to k-1)
        for &label in &labels {
            assert!(label < k);
        }

        // Predict with the same data
        let predicted_labels = kmeans_model.predict(&x);

        // The exact labels might vary due to random initialization,
        // but the clustering should be consistent.
        // We expect two clusters. Let's check if samples 0,1 are in one cluster
        // and samples 2,3,4 are in another.
        let cluster_0_members = vec![labels[0], labels[1]];
        let cluster_1_members = vec![labels[2], labels[3], labels[4]];

        // All members of cluster 0 should have the same label
        assert_eq!(cluster_0_members[0], cluster_0_members[1]);
        // All members of cluster 1 should have the same label
        assert_eq!(cluster_1_members[0], cluster_1_members[1]);
        assert_eq!(cluster_1_members[0], cluster_1_members[2]);
        // The two clusters should have different labels
        assert_ne!(cluster_0_members[0], cluster_1_members[0]);

        // Check predicted labels are consistent with fitted labels
        assert_eq!(labels, predicted_labels);

        // Test with a new sample
        let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0
        let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2);
        let new_sample_label = kmeans_model.predict(&new_sample)[0];
        assert_eq!(new_sample_label, cluster_0_members[0]);

        let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1
        let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2);
        let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0];
        assert_eq!(new_sample_label_2, cluster_1_members[0]);
    }

    #[test]
    fn test_k_means_fit_k_equals_m() {
        // Test case where k (number of clusters) equals m (number of samples)
        let (x, _) = create_test_data(); // 5 samples
        let k = 5; // 5 clusters
        let max_iter = 10;
        let tol = 1e-6;

        let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);

        assert_eq!(kmeans_model.centroids.rows(), k);
        assert_eq!(labels.len(), x.rows());

        // Each sample should be its own cluster. Due to random init, labels
        // might not be [0,1,2,3,4] but will be a permutation of it.
        let mut sorted_labels = labels.clone();
        sorted_labels.sort_unstable();
        sorted_labels.dedup();
        // Labels should all be unique when k==m
        assert_eq!(sorted_labels.len(), k);
    }

    #[test]
    #[should_panic(expected = "k must be ≤ number of samples")]
    fn test_k_means_fit_k_greater_than_m() {
        let (x, _) = create_test_data(); // 5 samples
        let k = 6; // k > m
        let max_iter = 10;
        let tol = 1e-6;

        let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
    }

    #[test]
    fn test_k_means_fit_single_cluster() {
        // Test with k=1
        let x = create_simple_integer_data(); // Use integer data
        let k = 1;
        let max_iter = 100;
        let tol = 1e-6;

        let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);

        assert_eq!(kmeans_model.centroids.rows(), 1);
        assert_eq!(labels.len(), x.rows());

        // All labels should be 0
        assert!(labels.iter().all(|&l| l == 0));

        // Centroid should be the mean of all data points
        let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
        let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;

        assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
        assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
    }

    #[test]
    fn test_k_means_predict_empty_matrix() {
        let (x, k) = create_test_data();
        let max_iter = 10;
        let tol = 1e-6;
        let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);

        // The `Matrix` type not support 0xN or Nx0 matrices.
        // test with a 0x0 matrix is a valid edge case.
        let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
        let predicted_labels = kmeans_model.predict(&empty_x);
        assert!(predicted_labels.is_empty());
    }

    #[test]
    fn test_k_means_predict_single_sample() {
        let (x, k) = create_test_data();
        let max_iter = 10;
        let tol = 1e-6;
        let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);

        let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2);
        let predicted_label = kmeans_model.predict(&single_sample);
        assert_eq!(predicted_label.len(), 1);
        assert!(predicted_label[0] < k);
    }
}