use crate::clustering::{ClusteringError, ClusteringResult};
use linfa::prelude::*;
use linfa_clustering::KMeans as LinfaKMeans;
use ndarray::Array2;
#[derive(Debug, Clone)]
pub struct KMeansConfig {
pub n_clusters: usize,
pub max_iterations: usize,
pub tolerance: f64,
pub seed: Option<u64>,
}
impl Default for KMeansConfig {
fn default() -> Self {
Self {
n_clusters: 2,
max_iterations: 300,
tolerance: 1e-4,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct KMeansResult {
pub assignments: Vec<usize>,
pub centroids: Array2<f64>,
pub iterations: usize,
pub inertia: f64,
}
pub struct KMeans;
impl KMeans {
pub fn fit_from_rows(data_rows: Vec<Vec<f64>>, config: &KMeansConfig) -> ClusteringResult<KMeansResult> {
if data_rows.is_empty() {
return Err(ClusteringError::EmptyData);
}
let n_features = data_rows[0].len();
let n_samples = data_rows.len();
let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
let data = Array2::from_shape_vec((n_samples, n_features), flat)
.map_err(|e| ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e)))?;
Self::fit(&data, config)
}
pub fn fit(data: &Array2<f64>, config: &KMeansConfig) -> ClusteringResult<KMeansResult> {
if data.nrows() == 0 {
return Err(ClusteringError::EmptyData);
}
if data.nrows() < config.n_clusters {
return Err(ClusteringError::InsufficientData {
min: config.n_clusters,
actual: data.nrows(),
});
}
let dataset = DatasetBase::new(data.clone(), ());
let model = LinfaKMeans::params(config.n_clusters)
.max_n_iterations(config.max_iterations as u64)
.tolerance(config.tolerance)
.fit(&dataset)
.map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
let assignments: Vec<usize> = (0..data.nrows())
.map(|i| {
let point = data.row(i);
let mut min_dist = f64::INFINITY;
let mut best_cluster = 0;
for (j, centroid) in model.centroids().rows().into_iter().enumerate() {
let dist: f64 = point
.iter()
.zip(centroid.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
if dist < min_dist {
min_dist = dist;
best_cluster = j;
}
}
best_cluster
})
.collect();
let centroids = model.centroids().to_owned();
let inertia = Self::calculate_inertia(data, ¢roids, &assignments);
Ok(KMeansResult {
assignments,
centroids,
iterations: config.max_iterations, inertia,
})
}
fn calculate_inertia(
data: &Array2<f64>,
centroids: &Array2<f64>,
assignments: &[usize],
) -> f64 {
let mut inertia = 0.0;
for (i, assignment) in assignments.iter().enumerate() {
let point = data.row(i);
let centroid = centroids.row(*assignment);
let dist_sq: f64 = point
.iter()
.zip(centroid.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
inertia += dist_sq;
}
inertia
}
}