use crate::ring_messages::{
K2KCentroidAggregation, K2KCentroidBroadcast, K2KCentroidBroadcastAck, K2KKMeansSync,
K2KKMeansSyncResponse, K2KPartialCentroid, KMeansAssignResponse, KMeansAssignRing,
KMeansQueryResponse, KMeansQueryRing, KMeansUpdateResponse, KMeansUpdateRing, from_fixed_point,
to_fixed_point, unpack_coordinates,
};
use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
use rand::prelude::*;
use ringkernel_core::RingContext;
use rustkernel_core::traits::RingKernelHandler;
use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
#[derive(Debug, Clone, Default)]
pub struct KMeansState {
pub centroids: Vec<f64>,
pub data: Option<DataMatrix>,
pub k: usize,
pub n_features: usize,
pub iteration: u32,
pub inertia: f64,
pub converged: bool,
pub labels: Vec<usize>,
}
#[derive(Debug)]
pub struct KMeans {
metadata: KernelMetadata,
state: std::sync::RwLock<KMeansState>,
}
impl Clone for KMeans {
fn clone(&self) -> Self {
Self {
metadata: self.metadata.clone(),
state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
}
}
}
impl Default for KMeans {
fn default() -> Self {
Self::new()
}
}
impl KMeans {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("ml/kmeans-cluster", Domain::StatisticalML)
.with_description("K-Means clustering with K-Means++ initialization")
.with_throughput(20_000)
.with_latency_us(50.0),
state: std::sync::RwLock::new(KMeansState::default()),
}
}
pub fn initialize(&self, data: DataMatrix, k: usize) {
let centroids = Self::kmeans_plus_plus_init(&data, k);
let n = data.n_samples;
let n_features = data.n_features;
let mut state = self.state.write().unwrap();
*state = KMeansState {
centroids,
data: Some(data),
k,
n_features,
iteration: 0,
inertia: 0.0,
converged: false,
labels: vec![0; n],
};
}
#[allow(clippy::needless_range_loop)]
pub fn assign_step(&self) -> f64 {
let mut state = self.state.write().unwrap();
let data = match state.data {
Some(ref d) => d.clone(),
None => return 0.0,
};
let n = data.n_samples;
let d_features = state.n_features;
let mut total_inertia = 0.0;
let centroids = state.centroids.clone();
let mut new_labels = vec![0usize; n];
for i in 0..n {
let point = data.row(i);
let mut min_dist = f64::MAX;
let mut min_cluster = 0;
for (c, centroid) in centroids.chunks(d_features).enumerate() {
let dist = Self::euclidean_distance(point, centroid);
if dist < min_dist {
min_dist = dist;
min_cluster = c;
}
}
new_labels[i] = min_cluster;
total_inertia += min_dist * min_dist;
}
state.labels = new_labels;
state.inertia = total_inertia;
total_inertia
}
pub fn update_step(&self) -> f64 {
let mut state = self.state.write().unwrap();
let Some(ref data) = state.data else {
return 0.0;
};
let n = data.n_samples;
let d = state.n_features;
let k = state.k;
let mut new_centroids = vec![0.0f64; k * d];
let mut counts = vec![0usize; k];
for i in 0..n {
let cluster = state.labels[i];
counts[cluster] += 1;
let point = data.row(i);
for j in 0..d {
new_centroids[cluster * d + j] += point[j];
}
}
for c in 0..k {
if counts[c] > 0 {
for j in 0..d {
new_centroids[c * d + j] /= counts[c] as f64;
}
}
}
let max_shift = state
.centroids
.chunks(d)
.zip(new_centroids.chunks(d))
.map(|(old, new)| Self::euclidean_distance(old, new))
.fold(0.0f64, f64::max);
state.centroids = new_centroids;
state.iteration += 1;
max_shift
}
pub fn query_point(&self, point: &[f64]) -> (usize, f64) {
let state = self.state.read().unwrap();
let d = state.n_features;
let mut min_dist = f64::MAX;
let mut min_cluster = 0;
for (c, centroid) in state.centroids.chunks(d).enumerate() {
let dist = Self::euclidean_distance(point, centroid);
if dist < min_dist {
min_dist = dist;
min_cluster = c;
}
}
(min_cluster, min_dist)
}
pub fn current_iteration(&self) -> u32 {
self.state.read().unwrap().iteration
}
pub fn current_inertia(&self) -> f64 {
self.state.read().unwrap().inertia
}
#[allow(clippy::needless_range_loop)]
pub fn compute(
data: &DataMatrix,
k: usize,
max_iterations: u32,
tolerance: f64,
) -> ClusteringResult {
let n = data.n_samples;
let d = data.n_features;
if n == 0 || k == 0 || k > n {
return ClusteringResult {
labels: Vec::new(),
n_clusters: 0,
centroids: Vec::new(),
inertia: 0.0,
iterations: 0,
converged: true,
};
}
let mut centroids = Self::kmeans_plus_plus_init(data, k);
let mut labels = vec![0usize; n];
let mut converged = false;
let mut iterations = 0u32;
for iter in 0..max_iterations {
iterations = iter + 1;
for i in 0..n {
let point = data.row(i);
let mut min_dist = f64::MAX;
let mut min_cluster = 0;
for (c, centroid) in centroids.chunks(d).enumerate() {
let dist = Self::euclidean_distance(point, centroid);
if dist < min_dist {
min_dist = dist;
min_cluster = c;
}
}
labels[i] = min_cluster;
}
let mut new_centroids = vec![0.0f64; k * d];
let mut counts = vec![0usize; k];
for i in 0..n {
let cluster = labels[i];
counts[cluster] += 1;
let point = data.row(i);
for j in 0..d {
new_centroids[cluster * d + j] += point[j];
}
}
for c in 0..k {
if counts[c] > 0 {
for j in 0..d {
new_centroids[c * d + j] /= counts[c] as f64;
}
}
}
let max_shift = centroids
.chunks(d)
.zip(new_centroids.chunks(d))
.map(|(old, new)| Self::euclidean_distance(old, new))
.fold(0.0f64, f64::max);
centroids = new_centroids;
if max_shift < tolerance {
converged = true;
break;
}
}
let inertia: f64 = (0..n)
.map(|i| {
let point = data.row(i);
let centroid_start = labels[i] * d;
let centroid = ¢roids[centroid_start..centroid_start + d];
let dist = Self::euclidean_distance(point, centroid);
dist * dist
})
.sum();
ClusteringResult {
labels,
n_clusters: k,
centroids,
inertia,
iterations,
converged,
}
}
#[allow(clippy::needless_range_loop)]
fn kmeans_plus_plus_init(data: &DataMatrix, k: usize) -> Vec<f64> {
let n = data.n_samples;
let d = data.n_features;
let mut rng = rand::rng();
let mut centroids = Vec::with_capacity(k * d);
let first_idx = rng.random_range(0..n);
centroids.extend_from_slice(data.row(first_idx));
let mut distances = vec![f64::MAX; n];
for _ in 1..k {
for i in 0..n {
let point = data.row(i);
let last_centroid = ¢roids[centroids.len() - d..];
let dist = Self::euclidean_distance(point, last_centroid);
distances[i] = distances[i].min(dist);
}
let total: f64 = distances.iter().map(|d| d * d).sum();
let threshold = rng.random::<f64>() * total;
let mut cumsum = 0.0;
let mut next_idx = 0;
for (i, &dist) in distances.iter().enumerate() {
cumsum += dist * dist;
if cumsum >= threshold {
next_idx = i;
break;
}
}
centroids.extend_from_slice(data.row(next_idx));
}
centroids
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
}
impl GpuKernel for KMeans {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[async_trait::async_trait]
impl RingKernelHandler<KMeansAssignRing, KMeansAssignResponse> for KMeans {
async fn handle(
&self,
_ctx: &mut RingContext,
msg: KMeansAssignRing,
) -> Result<KMeansAssignResponse> {
let inertia = self.assign_step();
let state = self.state.read().unwrap();
let points_assigned = state.labels.len() as u32;
Ok(KMeansAssignResponse {
request_id: msg.id.0,
iteration: msg.iteration,
inertia_fp: to_fixed_point(inertia),
points_assigned,
})
}
}
#[async_trait::async_trait]
impl RingKernelHandler<KMeansUpdateRing, KMeansUpdateResponse> for KMeans {
async fn handle(
&self,
_ctx: &mut RingContext,
msg: KMeansUpdateRing,
) -> Result<KMeansUpdateResponse> {
let max_shift = self.update_step();
let converged = max_shift < 1e-6;
if converged {
let mut state = self.state.write().unwrap();
state.converged = true;
}
Ok(KMeansUpdateResponse {
request_id: msg.id.0,
iteration: msg.iteration,
max_shift_fp: to_fixed_point(max_shift),
converged,
})
}
}
#[async_trait::async_trait]
impl RingKernelHandler<KMeansQueryRing, KMeansQueryResponse> for KMeans {
async fn handle(
&self,
_ctx: &mut RingContext,
msg: KMeansQueryRing,
) -> Result<KMeansQueryResponse> {
let point = unpack_coordinates(&msg.point, msg.n_dims as usize);
let (cluster, distance) = self.query_point(&point);
Ok(KMeansQueryResponse {
request_id: msg.id.0,
cluster: cluster as u32,
distance_fp: to_fixed_point(distance),
})
}
}
#[async_trait::async_trait]
impl RingKernelHandler<K2KPartialCentroid, K2KCentroidAggregation> for KMeans {
#[allow(clippy::needless_range_loop)]
async fn handle(
&self,
_ctx: &mut RingContext,
msg: K2KPartialCentroid,
) -> Result<K2KCentroidAggregation> {
let n_dims = msg.n_dims as usize;
let cluster_id = msg.cluster_id as usize;
let mut new_centroid = [0i64; 8];
if msg.point_count > 0 {
for i in 0..n_dims.min(8) {
new_centroid[i] = msg.coord_sum_fp[i] / msg.point_count as i64;
}
}
let shift = {
let state = self.state.read().unwrap();
let d = state.n_features;
if cluster_id < state.k && d > 0 {
let old_centroid = &state.centroids[cluster_id * d..(cluster_id + 1) * d];
let new_coords: Vec<f64> = new_centroid[..d.min(8)]
.iter()
.map(|&v| from_fixed_point(v))
.collect();
Self::euclidean_distance(old_centroid, &new_coords)
} else {
0.0
}
};
Ok(K2KCentroidAggregation {
request_id: msg.id.0,
cluster_id: msg.cluster_id,
iteration: msg.iteration,
new_centroid_fp: new_centroid,
total_points: msg.point_count,
shift_fp: to_fixed_point(shift),
})
}
}
#[async_trait::async_trait]
impl RingKernelHandler<K2KKMeansSync, K2KKMeansSyncResponse> for KMeans {
async fn handle(
&self,
_ctx: &mut RingContext,
msg: K2KKMeansSync,
) -> Result<K2KKMeansSyncResponse> {
let state = self.state.read().unwrap();
let current_iteration = state.iteration as u64;
let all_synced = msg.iteration <= current_iteration;
let global_shift = from_fixed_point(msg.max_shift_fp);
let converged = global_shift < 1e-6 || state.converged;
Ok(K2KKMeansSyncResponse {
request_id: msg.id.0,
iteration: msg.iteration,
all_synced,
global_inertia_fp: msg.local_inertia_fp,
global_max_shift_fp: msg.max_shift_fp,
converged,
})
}
}
#[async_trait::async_trait]
impl RingKernelHandler<K2KCentroidBroadcast, K2KCentroidBroadcastAck> for KMeans {
async fn handle(
&self,
_ctx: &mut RingContext,
msg: K2KCentroidBroadcast,
) -> Result<K2KCentroidBroadcastAck> {
Ok(K2KCentroidBroadcastAck {
request_id: msg.id.0,
worker_id: 0, iteration: msg.iteration,
applied: true,
})
}
}
#[derive(Debug, Clone)]
pub struct DBSCAN {
metadata: KernelMetadata,
}
impl Default for DBSCAN {
fn default() -> Self {
Self::new()
}
}
impl DBSCAN {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("ml/dbscan-cluster", Domain::StatisticalML)
.with_description("Density-based clustering with GPU union-find")
.with_throughput(1_000)
.with_latency_us(10_000.0),
}
}
#[allow(clippy::needless_range_loop)]
pub fn compute(
data: &DataMatrix,
eps: f64,
min_samples: usize,
metric: DistanceMetric,
) -> ClusteringResult {
let n = data.n_samples;
if n == 0 {
return ClusteringResult {
labels: Vec::new(),
n_clusters: 0,
centroids: Vec::new(),
inertia: 0.0,
iterations: 1,
converged: true,
};
}
let mut labels = vec![-1i64; n];
let mut current_cluster = 0i64;
let neighborhoods: Vec<Vec<usize>> = (0..n)
.map(|i| Self::get_neighbors(data, i, eps, metric))
.collect();
for i in 0..n {
if labels[i] != -1 {
continue; }
let neighbors = &neighborhoods[i];
if neighbors.len() < min_samples {
labels[i] = -2; continue;
}
labels[i] = current_cluster;
let mut seed_set: Vec<usize> = neighbors.clone();
let mut j = 0;
while j < seed_set.len() {
let q = seed_set[j];
j += 1;
if labels[q] == -2 {
labels[q] = current_cluster; }
if labels[q] != -1 {
continue; }
labels[q] = current_cluster;
let q_neighbors = &neighborhoods[q];
if q_neighbors.len() >= min_samples {
for &neighbor in q_neighbors {
if !seed_set.contains(&neighbor) {
seed_set.push(neighbor);
}
}
}
}
current_cluster += 1;
}
let n_clusters = current_cluster as usize;
let labels: Vec<usize> = labels
.iter()
.map(|&l| if l < 0 { usize::MAX } else { l as usize })
.collect();
let d = data.n_features;
let mut centroids = vec![0.0f64; n_clusters * d];
let mut counts = vec![0usize; n_clusters];
for i in 0..n {
if labels[i] < n_clusters {
let cluster = labels[i];
counts[cluster] += 1;
for j in 0..d {
centroids[cluster * d + j] += data.row(i)[j];
}
}
}
for c in 0..n_clusters {
if counts[c] > 0 {
for j in 0..d {
centroids[c * d + j] /= counts[c] as f64;
}
}
}
ClusteringResult {
labels,
n_clusters,
centroids,
inertia: 0.0,
iterations: 1,
converged: true,
}
}
fn get_neighbors(
data: &DataMatrix,
point_idx: usize,
eps: f64,
metric: DistanceMetric,
) -> Vec<usize> {
let n = data.n_samples;
let point = data.row(point_idx);
(0..n)
.filter(|&i| {
let other = data.row(i);
let dist = metric.compute(point, other);
dist <= eps
})
.collect()
}
}
impl GpuKernel for DBSCAN {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LinkageMethod {
Single,
Complete,
Average,
Ward,
}
#[derive(Debug, Clone)]
pub struct HierarchicalClustering {
metadata: KernelMetadata,
}
impl Default for HierarchicalClustering {
fn default() -> Self {
Self::new()
}
}
impl HierarchicalClustering {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("ml/hierarchical-cluster", Domain::StatisticalML)
.with_description("Agglomerative hierarchical clustering")
.with_throughput(500)
.with_latency_us(50_000.0),
}
}
#[allow(clippy::needless_range_loop)]
pub fn compute(
data: &DataMatrix,
n_clusters: usize,
linkage: LinkageMethod,
metric: DistanceMetric,
) -> ClusteringResult {
let n = data.n_samples;
if n == 0 || n_clusters == 0 {
return ClusteringResult {
labels: Vec::new(),
n_clusters: 0,
centroids: Vec::new(),
inertia: 0.0,
iterations: 0,
converged: true,
};
}
let mut labels: Vec<usize> = (0..n).collect();
let mut active_clusters: Vec<bool> = vec![true; n];
let mut cluster_sizes: Vec<usize> = vec![1; n];
let mut distances = Self::compute_distance_matrix(data, metric);
let mut current_n_clusters = n;
while current_n_clusters > n_clusters {
let (c1, c2) = Self::find_closest_clusters(&distances, &active_clusters, n);
if c1 == c2 {
break;
}
for label in &mut labels {
if *label == c2 {
*label = c1;
}
}
Self::update_distances(
&mut distances,
c1,
c2,
n,
linkage,
&cluster_sizes,
&active_clusters,
);
cluster_sizes[c1] += cluster_sizes[c2];
active_clusters[c2] = false;
current_n_clusters -= 1;
}
let mut label_map = std::collections::HashMap::new();
let mut next_label = 0usize;
for label in &mut labels {
let new_label = *label_map.entry(*label).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
*label = new_label;
}
let d = data.n_features;
let final_n_clusters = next_label;
let mut centroids = vec![0.0f64; final_n_clusters * d];
let mut counts = vec![0usize; final_n_clusters];
for i in 0..n {
let cluster = labels[i];
counts[cluster] += 1;
for j in 0..d {
centroids[cluster * d + j] += data.row(i)[j];
}
}
for c in 0..final_n_clusters {
if counts[c] > 0 {
for j in 0..d {
centroids[c * d + j] /= counts[c] as f64;
}
}
}
ClusteringResult {
labels,
n_clusters: final_n_clusters,
centroids,
inertia: 0.0,
iterations: (n - n_clusters) as u32,
converged: true,
}
}
fn compute_distance_matrix(data: &DataMatrix, metric: DistanceMetric) -> Vec<f64> {
let n = data.n_samples;
let mut distances = vec![f64::MAX; n * n];
for i in 0..n {
for j in 0..n {
if i != j {
distances[i * n + j] = metric.compute(data.row(i), data.row(j));
}
}
}
distances
}
fn find_closest_clusters(distances: &[f64], active: &[bool], n: usize) -> (usize, usize) {
let mut min_dist = f64::MAX;
let mut min_i = 0;
let mut min_j = 0;
for i in 0..n {
if !active[i] {
continue;
}
for j in (i + 1)..n {
if !active[j] {
continue;
}
let dist = distances[i * n + j];
if dist < min_dist {
min_dist = dist;
min_i = i;
min_j = j;
}
}
}
(min_i, min_j)
}
fn update_distances(
distances: &mut [f64],
c1: usize,
c2: usize,
n: usize,
linkage: LinkageMethod,
cluster_sizes: &[usize],
active: &[bool],
) {
for k in 0..n {
if !active[k] || k == c1 || k == c2 {
continue;
}
let d1 = distances[c1 * n + k];
let d2 = distances[c2 * n + k];
let new_dist = match linkage {
LinkageMethod::Single => d1.min(d2),
LinkageMethod::Complete => d1.max(d2),
LinkageMethod::Average => {
let n1 = cluster_sizes[c1] as f64;
let n2 = cluster_sizes[c2] as f64;
(n1 * d1 + n2 * d2) / (n1 + n2)
}
LinkageMethod::Ward => {
let n1 = cluster_sizes[c1] as f64;
let n2 = cluster_sizes[c2] as f64;
let nk = cluster_sizes[k] as f64;
let total = n1 + n2 + nk;
((n1 + nk) * d1 * d1 + (n2 + nk) * d2 * d2
- nk * distances[c1 * n + c2].powi(2))
/ total
}
};
distances[c1 * n + k] = new_dist;
distances[k * n + c1] = new_dist;
}
}
}
impl GpuKernel for HierarchicalClustering {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
use crate::messages::{
DBSCANInput, DBSCANOutput, HierarchicalInput, HierarchicalOutput, KMeansInput, KMeansOutput,
Linkage,
};
use async_trait::async_trait;
use rustkernel_core::error::Result;
use rustkernel_core::traits::BatchKernel;
use std::time::Instant;
impl KMeans {
pub async fn cluster_batch(&self, input: KMeansInput) -> Result<KMeansOutput> {
let start = Instant::now();
let result = Self::compute(&input.data, input.k, input.max_iterations, input.tolerance);
let compute_time_us = start.elapsed().as_micros() as u64;
Ok(KMeansOutput {
result,
compute_time_us,
})
}
}
#[async_trait]
impl BatchKernel<KMeansInput, KMeansOutput> for KMeans {
async fn execute(&self, input: KMeansInput) -> Result<KMeansOutput> {
self.cluster_batch(input).await
}
}
#[async_trait]
impl BatchKernel<DBSCANInput, DBSCANOutput> for DBSCAN {
async fn execute(&self, input: DBSCANInput) -> Result<DBSCANOutput> {
let start = Instant::now();
let result = Self::compute(&input.data, input.eps, input.min_samples, input.metric);
let compute_time_us = start.elapsed().as_micros() as u64;
Ok(DBSCANOutput {
result,
compute_time_us,
})
}
}
#[async_trait]
impl BatchKernel<HierarchicalInput, HierarchicalOutput> for HierarchicalClustering {
async fn execute(&self, input: HierarchicalInput) -> Result<HierarchicalOutput> {
let start = Instant::now();
let linkage_method = match input.linkage {
Linkage::Single => LinkageMethod::Single,
Linkage::Complete => LinkageMethod::Complete,
Linkage::Average => LinkageMethod::Average,
Linkage::Ward => LinkageMethod::Ward,
};
let result = Self::compute(&input.data, input.n_clusters, linkage_method, input.metric);
let compute_time_us = start.elapsed().as_micros() as u64;
Ok(HierarchicalOutput {
result,
compute_time_us,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_two_clusters() -> DataMatrix {
DataMatrix::from_rows(&[
&[0.0, 0.0],
&[0.1, 0.1],
&[0.2, 0.0],
&[10.0, 10.0],
&[10.1, 10.1],
&[10.2, 10.0],
])
}
#[test]
fn test_kmeans_metadata() {
let kernel = KMeans::new();
assert_eq!(kernel.metadata().id, "ml/kmeans-cluster");
assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
}
#[test]
fn test_kmeans_two_clusters() {
let data = create_two_clusters();
let result = KMeans::compute(&data, 2, 100, 1e-6);
assert_eq!(result.n_clusters, 2);
assert!(result.converged);
assert_eq!(result.labels[0], result.labels[1]);
assert_eq!(result.labels[1], result.labels[2]);
assert_eq!(result.labels[3], result.labels[4]);
assert_eq!(result.labels[4], result.labels[5]);
assert_ne!(result.labels[0], result.labels[3]);
}
#[test]
fn test_dbscan_two_clusters() {
let data = create_two_clusters();
let result = DBSCAN::compute(&data, 1.0, 2, DistanceMetric::Euclidean);
assert_eq!(result.n_clusters, 2);
assert_eq!(result.labels[0], result.labels[1]);
assert_eq!(result.labels[3], result.labels[4]);
assert_ne!(result.labels[0], result.labels[3]);
}
#[test]
fn test_hierarchical_two_clusters() {
let data = create_two_clusters();
let result = HierarchicalClustering::compute(
&data,
2,
LinkageMethod::Complete,
DistanceMetric::Euclidean,
);
assert_eq!(result.n_clusters, 2);
assert_eq!(result.labels[0], result.labels[1]);
assert_eq!(result.labels[1], result.labels[2]);
assert_eq!(result.labels[3], result.labels[4]);
assert_ne!(result.labels[0], result.labels[3]);
}
}