use crate::distance::{compute_modes, CategoricalDistance, MatchingDistance, JaccardDistance, CentroidTracker};
use crate::error::{Error, Result};
use crate::initialization::{initialize_centroids, InitMethod};
use crate::utils::{
assign_points_to_centroids, assignments_equal, calculate_cost, get_cluster_indices,
validate_data, validate_parameters,
};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::prelude::*;
use rayon::prelude::*;
use std::hash::Hash;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum DistanceMetric {
Matching,
Hamming,
Jaccard,
}
impl Default for DistanceMetric {
fn default() -> Self {
Self::Matching
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct KModes {
pub n_clusters: usize,
pub init_method: InitMethod,
pub max_iter: usize,
pub tol: f64,
pub n_init: usize,
pub random_state: Option<u64>,
pub n_jobs: Option<usize>,
pub verbose: bool,
pub distance_metric: DistanceMetric,
pub use_incremental_updates: bool,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct KModesResult<T> {
pub labels: Array1<usize>,
pub centroids: Array2<T>,
pub n_iter: usize,
pub inertia: f64,
pub converged: bool,
}
impl Default for KModes {
fn default() -> Self {
Self {
n_clusters: 8,
init_method: InitMethod::Huang,
max_iter: 100,
tol: 1e-4,
n_init: 10,
random_state: None,
n_jobs: None,
verbose: false,
distance_metric: DistanceMetric::default(),
use_incremental_updates: true,
}
}
}
impl KModes {
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
..Default::default()
}
}
pub fn init_method(mut self, method: InitMethod) -> Self {
self.init_method = method;
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn tolerance(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
pub fn random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
pub fn n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
pub fn verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
self.distance_metric = metric;
self
}
pub fn use_incremental_updates(mut self, use_incremental: bool) -> Self {
self.use_incremental_updates = use_incremental;
self
}
pub fn fit<T>(&self, data: ArrayView2<T>) -> Result<KModesResult<T>>
where
T: Clone + Eq + Hash + Send + Sync,
{
self.validate_input(data)?;
let mut best_result: Option<KModesResult<T>> = None;
let mut best_inertia = f64::INFINITY;
let results: Vec<Result<KModesResult<T>>> = if self.should_use_parallel() {
(0..self.n_init)
.into_par_iter()
.map(|i| {
let seed = self.random_state.unwrap_or(0) + i as u64;
self.fit_single(data, seed)
})
.collect()
} else {
(0..self.n_init)
.map(|i| {
let seed = self.random_state.unwrap_or(0) + i as u64;
self.fit_single(data, seed)
})
.collect()
};
for result in results {
let result = result?;
if result.inertia < best_inertia {
best_inertia = result.inertia;
best_result = Some(result);
}
}
best_result.ok_or_else(|| Error::convergence_failure("No successful runs"))
}
fn fit_single<T>(&self, data: ArrayView2<T>, seed: u64) -> Result<KModesResult<T>>
where
T: Clone + Eq + Hash,
{
if self.use_incremental_updates {
self.fit_single_incremental(data, seed)
} else {
self.fit_single_classic(data, seed)
}
}
fn fit_single_classic<T>(&self, data: ArrayView2<T>, seed: u64) -> Result<KModesResult<T>>
where
T: Clone + Eq + Hash,
{
let mut rng = StdRng::seed_from_u64(seed);
let mut centroids = initialize_centroids(data, self.n_clusters, self.init_method, &mut rng)?;
let mut previous_labels: Option<Array1<usize>> = None;
let mut n_iter = 0;
let mut converged = false;
for iter in 0..self.max_iter {
n_iter = iter + 1;
let labels = assign_points_to_centroids(
data,
centroids.view(),
|a, b| self.compute_distance(a, b),
)?;
if let Some(ref prev_labels) = previous_labels {
if assignments_equal(labels.view(), prev_labels.view()) {
converged = true;
if self.verbose {
println!("K-modes converged after {} iterations", n_iter);
}
break;
}
}
let new_centroids = self.update_centroids(data, &labels)?;
if let Some(ref _prev_labels) = previous_labels {
let centroid_change = self.calculate_centroid_change(¢roids, &new_centroids)?;
if centroid_change < self.tol {
converged = true;
if self.verbose {
println!("K-modes converged (centroid change < tol) after {} iterations", n_iter);
}
break;
}
}
centroids = new_centroids;
previous_labels = Some(labels);
if self.verbose && (iter + 1) % 10 == 0 {
println!("K-modes iteration {}", iter + 1);
}
}
let final_labels = assign_points_to_centroids(
data,
centroids.view(),
|a, b| self.compute_distance(a, b),
)?;
let inertia = calculate_cost(
data,
centroids.view(),
final_labels.view(),
|a, b| self.compute_distance(a, b),
)?;
Ok(KModesResult {
labels: final_labels,
centroids,
n_iter,
inertia,
converged,
})
}
fn fit_single_incremental<T>(&self, data: ArrayView2<T>, seed: u64) -> Result<KModesResult<T>>
where
T: Clone + Eq + Hash,
{
let mut rng = StdRng::seed_from_u64(seed);
let mut centroids = initialize_centroids(data, self.n_clusters, self.init_method, &mut rng)?;
let mut centroid_trackers: Vec<CentroidTracker<T>> = (0..self.n_clusters)
.map(|_| CentroidTracker::new(data.ncols()))
.collect();
let mut previous_labels: Option<Array1<usize>> = None;
let mut n_iter = 0;
let mut converged = false;
let mut current_labels = assign_points_to_centroids(
data,
centroids.view(),
|a, b| self.compute_distance(a, b),
)?;
self.update_trackers_full(&mut centroid_trackers, data, ¤t_labels)?;
for iter in 0..self.max_iter {
n_iter = iter + 1;
let new_labels = assign_points_to_centroids(
data,
centroids.view(),
|a, b| self.compute_distance(a, b),
)?;
if let Some(ref prev_labels) = previous_labels {
if assignments_equal(new_labels.view(), prev_labels.view()) {
converged = true;
if self.verbose {
println!("K-modes converged after {} iterations", n_iter);
}
break;
}
}
self.update_trackers_incremental(&mut centroid_trackers, data, ¤t_labels, &new_labels)?;
let new_centroids = self.get_centroids_from_trackers(¢roid_trackers, data)?;
if let Some(ref _prev_labels) = previous_labels {
let centroid_change = self.calculate_centroid_change(¢roids, &new_centroids)?;
if centroid_change < self.tol {
converged = true;
if self.verbose {
println!("K-modes converged (centroid change < tol) after {} iterations", n_iter);
}
break;
}
}
centroids = new_centroids;
previous_labels = Some(current_labels);
current_labels = new_labels;
if self.verbose && (iter + 1) % 10 == 0 {
println!("K-modes iteration {}", iter + 1);
}
}
let inertia = calculate_cost(
data,
centroids.view(),
current_labels.view(),
|a, b| self.compute_distance(a, b),
)?;
Ok(KModesResult {
labels: current_labels,
centroids,
n_iter,
inertia,
converged,
})
}
fn update_centroids<T>(&self, data: ArrayView2<T>, labels: &Array1<usize>) -> Result<Array2<T>>
where
T: Clone + Eq + Hash,
{
let cluster_indices = get_cluster_indices(labels.view(), self.n_clusters);
let mut new_centroids = Array2::uninit((self.n_clusters, data.ncols()));
for (cluster_id, indices) in cluster_indices.iter().enumerate() {
if indices.is_empty() {
let mut rng = StdRng::seed_from_u64(self.random_state.unwrap_or(0) + cluster_id as u64);
let random_idx = rng.gen_range(0..data.nrows());
for feature_idx in 0..data.ncols() {
new_centroids[[cluster_id, feature_idx]].write(data[[random_idx, feature_idx]].clone());
}
} else {
let modes = compute_modes(data, indices)?;
for (feature_idx, mode) in modes.into_iter().enumerate() {
new_centroids[[cluster_id, feature_idx]].write(mode);
}
}
}
Ok(unsafe { new_centroids.assume_init() })
}
fn update_trackers_full<T>(
&self,
trackers: &mut [CentroidTracker<T>],
data: ArrayView2<T>,
labels: &Array1<usize>
) -> Result<()>
where
T: Clone + Eq + Hash,
{
for tracker in trackers.iter_mut() {
tracker.clear();
}
for (point_idx, &cluster_id) in labels.iter().enumerate() {
if cluster_id < trackers.len() {
let point_values: Vec<T> = (0..data.ncols())
.map(|col| data[[point_idx, col]].clone())
.collect();
trackers[cluster_id].add_point(point_idx, &point_values)?;
}
}
Ok(())
}
fn update_trackers_incremental<T>(
&self,
trackers: &mut [CentroidTracker<T>],
data: ArrayView2<T>,
old_labels: &Array1<usize>,
new_labels: &Array1<usize>
) -> Result<()>
where
T: Clone + Eq + Hash,
{
for (point_idx, (&old_cluster, &new_cluster)) in
old_labels.iter().zip(new_labels.iter()).enumerate()
{
if old_cluster != new_cluster {
let point_values: Vec<T> = (0..data.ncols())
.map(|col| data[[point_idx, col]].clone())
.collect();
if old_cluster < trackers.len() {
trackers[old_cluster].remove_point(point_idx)?;
}
if new_cluster < trackers.len() {
trackers[new_cluster].add_point(point_idx, &point_values)?;
}
}
}
Ok(())
}
fn get_centroids_from_trackers<T>(&self, trackers: &[CentroidTracker<T>], data: ArrayView2<T>) -> Result<Array2<T>>
where
T: Clone + Eq + Hash,
{
if trackers.is_empty() {
return Err(Error::computation_error("No trackers provided"));
}
let num_features = trackers.iter()
.find_map(|tracker| {
if !tracker.is_empty() {
tracker.get_centroid().ok().map(|centroid| centroid.len())
} else {
None
}
})
.unwrap_or(data.ncols());
let mut centroids = Array2::uninit((self.n_clusters, num_features));
for (cluster_id, tracker) in trackers.iter().enumerate() {
if tracker.is_empty() {
let mut rng = StdRng::seed_from_u64(self.random_state.unwrap_or(0) + cluster_id as u64);
let random_idx = rng.gen_range(0..data.nrows());
for feature_idx in 0..num_features {
centroids[[cluster_id, feature_idx]].write(data[[random_idx, feature_idx]].clone());
}
} else {
let centroid_values = tracker.get_centroid()?;
for (feature_idx, value) in centroid_values.into_iter().enumerate() {
centroids[[cluster_id, feature_idx]].write(value);
}
}
}
Ok(unsafe { centroids.assume_init() })
}
fn calculate_centroid_change<T>(&self, old: &Array2<T>, new: &Array2<T>) -> Result<f64>
where
T: Clone + PartialEq,
{
if old.dim() != new.dim() {
return Err(Error::computation_error("Centroid dimension mismatch"));
}
let mut total_changes = 0;
let total_elements = old.nrows() * old.ncols();
for (old_val, new_val) in old.iter().zip(new.iter()) {
if old_val != new_val {
total_changes += 1;
}
}
Ok(total_changes as f64 / total_elements as f64)
}
fn validate_input<T>(&self, data: ArrayView2<T>) -> Result<()> {
validate_parameters(self.n_clusters, self.max_iter, self.tol, self.n_init)?;
validate_data(data)?;
if self.n_clusters > data.nrows() {
return Err(Error::invalid_parameter(
"Number of clusters cannot exceed number of data points",
));
}
Ok(())
}
fn compute_distance<T>(&self, a: ArrayView1<T>, b: ArrayView1<T>) -> Result<f64>
where
T: Clone + Eq + Hash,
{
match self.distance_metric {
DistanceMetric::Matching => {
let metric = MatchingDistance;
metric.distance(a, b)
}
DistanceMetric::Hamming => {
let metric = crate::distance::HammingDistance;
metric.distance(a, b)
}
DistanceMetric::Jaccard => {
let metric = JaccardDistance;
metric.distance(a, b)
}
}
}
fn should_use_parallel(&self) -> bool {
match self.n_jobs {
Some(1) => false,
Some(_) => true,
None => self.n_init > 1, }
}
pub fn fit_predict<T>(&self, data: ArrayView2<T>) -> Result<Array1<usize>>
where
T: Clone + Eq + Hash + Send + Sync,
{
let result = self.fit(data)?;
Ok(result.labels)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_kmodes_creation() {
let kmodes = KModes::new(3);
assert_eq!(kmodes.n_clusters, 3);
assert_eq!(kmodes.init_method, InitMethod::Huang);
}
#[test]
fn test_kmodes_builder_pattern() {
let kmodes = KModes::new(5)
.init_method(InitMethod::Random)
.max_iter(50)
.tolerance(0.001)
.n_init(5)
.random_state(42)
.verbose(true);
assert_eq!(kmodes.n_clusters, 5);
assert_eq!(kmodes.init_method, InitMethod::Random);
assert_eq!(kmodes.max_iter, 50);
assert_eq!(kmodes.tol, 0.001);
assert_eq!(kmodes.n_init, 5);
assert_eq!(kmodes.random_state, Some(42));
assert!(kmodes.verbose);
}
#[test]
fn test_kmodes_simple_clustering() {
let data = Array2::from_shape_vec(
(6, 2),
vec!["A", "X", "A", "X", "B", "Y", "B", "Y", "A", "X", "B", "Y"],
)
.unwrap();
let kmodes = KModes::new(2)
.random_state(42)
.n_init(3)
.max_iter(10);
let result = kmodes.fit(data.view()).unwrap();
assert_eq!(result.labels.len(), 6);
assert_eq!(result.centroids.nrows(), 2);
assert_eq!(result.centroids.ncols(), 2);
assert!(result.n_iter <= 10);
}
#[test]
fn test_kmodes_convergence() {
let data = Array2::from_shape_vec(
(4, 1),
vec!["A", "A", "B", "B"],
).unwrap();
let kmodes = KModes::new(2)
.random_state(42)
.n_init(1)
.max_iter(100);
let result = kmodes.fit(data.view()).unwrap();
assert!(result.converged);
assert!(result.n_iter < 100);
}
#[test]
fn test_kmodes_fit_predict() {
let data = Array2::from_shape_vec(
(4, 2),
vec!["A", "X", "A", "X", "B", "Y", "B", "Y"],
).unwrap();
let kmodes = KModes::new(2).random_state(42);
let labels = kmodes.fit_predict(data.view()).unwrap();
assert_eq!(labels.len(), 4);
assert!(labels.iter().all(|&label| label < 2));
}
#[test]
fn test_invalid_parameters() {
let data = Array2::from_shape_vec((2, 1), vec!["A", "B"]).unwrap();
let kmodes = KModes::new(3);
assert!(kmodes.fit(data.view()).is_err());
let kmodes = KModes::new(0);
assert!(kmodes.fit(data.view()).is_err());
}
#[test]
fn test_empty_data() {
let data = Array2::from_shape_vec((0, 0), Vec::<&str>::new()).unwrap();
let kmodes = KModes::new(1);
assert!(kmodes.fit(data.view()).is_err());
}
#[test]
fn test_jaccard_distance_metric() {
let data = Array2::from_shape_vec(
(6, 2),
vec!["A", "X", "A", "X", "B", "Y", "B", "Y", "C", "Z", "C", "Z"],
).unwrap();
let kmodes = KModes::new(3)
.distance_metric(DistanceMetric::Jaccard)
.random_state(42)
.n_init(3)
.max_iter(10);
let result = kmodes.fit(data.view()).unwrap();
assert_eq!(result.labels.len(), 6);
assert_eq!(result.centroids.nrows(), 3);
assert_eq!(result.centroids.ncols(), 2);
assert!(result.n_iter <= 10);
}
#[test]
fn test_hamming_distance_metric() {
let data = Array2::from_shape_vec(
(4, 2),
vec!["A", "X", "A", "X", "B", "Y", "B", "Y"],
).unwrap();
let kmodes = KModes::new(2)
.distance_metric(DistanceMetric::Hamming)
.random_state(42)
.n_init(1)
.max_iter(50);
let result = kmodes.fit(data.view()).unwrap();
assert_eq!(result.labels.len(), 4);
assert_eq!(result.centroids.nrows(), 2);
assert_eq!(result.centroids.ncols(), 2);
}
#[test]
fn test_distance_metric_builder() {
let kmodes = KModes::new(5)
.distance_metric(DistanceMetric::Jaccard)
.init_method(InitMethod::Random);
assert_eq!(kmodes.distance_metric, DistanceMetric::Jaccard);
assert_eq!(kmodes.init_method, InitMethod::Random);
}
#[test]
fn test_default_distance_metric() {
let kmodes = KModes::new(3);
assert_eq!(kmodes.distance_metric, DistanceMetric::Matching);
}
}