rustkernel_ml/
clustering.rs

1//! Clustering kernels.
2//!
3//! This module provides machine learning clustering algorithms:
4//! - K-Means (Lloyd's algorithm with K-Means++ initialization)
5//! - DBSCAN (density-based clustering)
6//! - Hierarchical clustering (agglomerative)
7
8use crate::ring_messages::{
9    K2KCentroidAggregation, K2KCentroidBroadcast, K2KCentroidBroadcastAck, K2KKMeansSync,
10    K2KKMeansSyncResponse, K2KPartialCentroid, KMeansAssignResponse, KMeansAssignRing,
11    KMeansQueryResponse, KMeansQueryRing, KMeansUpdateResponse, KMeansUpdateRing, from_fixed_point,
12    to_fixed_point, unpack_coordinates,
13};
14use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
15use rand::prelude::*;
16use ringkernel_core::RingContext;
17use rustkernel_core::traits::RingKernelHandler;
18use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
19
20// ============================================================================
21// K-Means Clustering Kernel
22// ============================================================================
23
24/// K-Means clustering state for Ring mode operations.
25#[derive(Debug, Clone, Default)]
26pub struct KMeansState {
27    /// Current centroids (k * n_features).
28    pub centroids: Vec<f64>,
29    /// Input data reference (stored for query operations).
30    pub data: Option<DataMatrix>,
31    /// Number of clusters.
32    pub k: usize,
33    /// Number of features per point.
34    pub n_features: usize,
35    /// Current iteration.
36    pub iteration: u32,
37    /// Current inertia (sum of squared distances).
38    pub inertia: f64,
39    /// Whether converged.
40    pub converged: bool,
41    /// Current cluster assignments.
42    pub labels: Vec<usize>,
43}
44
45/// K-Means clustering kernel.
46///
47/// Implements Lloyd's algorithm with K-Means++ initialization.
48#[derive(Debug)]
49pub struct KMeans {
50    metadata: KernelMetadata,
51    /// Internal state for Ring mode operations.
52    state: std::sync::RwLock<KMeansState>,
53}
54
55impl Clone for KMeans {
56    fn clone(&self) -> Self {
57        Self {
58            metadata: self.metadata.clone(),
59            state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
60        }
61    }
62}
63
64impl Default for KMeans {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl KMeans {
71    /// Create a new K-Means kernel.
72    #[must_use]
73    pub fn new() -> Self {
74        Self {
75            metadata: KernelMetadata::batch("ml/kmeans-cluster", Domain::StatisticalML)
76                .with_description("K-Means clustering with K-Means++ initialization")
77                .with_throughput(20_000)
78                .with_latency_us(50.0),
79            state: std::sync::RwLock::new(KMeansState::default()),
80        }
81    }
82
83    /// Initialize the kernel with data and k for Ring mode operations.
84    pub fn initialize(&self, data: DataMatrix, k: usize) {
85        let centroids = Self::kmeans_plus_plus_init(&data, k);
86        let n = data.n_samples;
87        let n_features = data.n_features;
88
89        let mut state = self.state.write().unwrap();
90        *state = KMeansState {
91            centroids,
92            data: Some(data),
93            k,
94            n_features,
95            iteration: 0,
96            inertia: 0.0,
97            converged: false,
98            labels: vec![0; n],
99        };
100    }
101
102    /// Perform one E-step (assignment) on internal state.
103    /// Returns the total inertia (sum of squared distances).
104    #[allow(clippy::needless_range_loop)]
105    pub fn assign_step(&self) -> f64 {
106        let mut state = self.state.write().unwrap();
107
108        // Check if data exists
109        let data = match state.data {
110            Some(ref d) => d.clone(),
111            None => return 0.0,
112        };
113
114        let n = data.n_samples;
115        let d_features = state.n_features;
116        let mut total_inertia = 0.0;
117
118        // Clone centroids to avoid borrow conflict
119        let centroids = state.centroids.clone();
120
121        // Compute assignments
122        let mut new_labels = vec![0usize; n];
123        for i in 0..n {
124            let point = data.row(i);
125            let mut min_dist = f64::MAX;
126            let mut min_cluster = 0;
127
128            for (c, centroid) in centroids.chunks(d_features).enumerate() {
129                let dist = Self::euclidean_distance(point, centroid);
130                if dist < min_dist {
131                    min_dist = dist;
132                    min_cluster = c;
133                }
134            }
135            new_labels[i] = min_cluster;
136            total_inertia += min_dist * min_dist;
137        }
138
139        // Update state
140        state.labels = new_labels;
141        state.inertia = total_inertia;
142        total_inertia
143    }
144
145    /// Perform one M-step (centroid update) on internal state.
146    /// Returns the maximum centroid shift.
147    pub fn update_step(&self) -> f64 {
148        let mut state = self.state.write().unwrap();
149        let Some(ref data) = state.data else {
150            return 0.0;
151        };
152
153        let n = data.n_samples;
154        let d = state.n_features;
155        let k = state.k;
156
157        let mut new_centroids = vec![0.0f64; k * d];
158        let mut counts = vec![0usize; k];
159
160        for i in 0..n {
161            let cluster = state.labels[i];
162            counts[cluster] += 1;
163            let point = data.row(i);
164            for j in 0..d {
165                new_centroids[cluster * d + j] += point[j];
166            }
167        }
168
169        // Normalize centroids
170        for c in 0..k {
171            if counts[c] > 0 {
172                for j in 0..d {
173                    new_centroids[c * d + j] /= counts[c] as f64;
174                }
175            }
176        }
177
178        // Calculate maximum shift
179        let max_shift = state
180            .centroids
181            .chunks(d)
182            .zip(new_centroids.chunks(d))
183            .map(|(old, new)| Self::euclidean_distance(old, new))
184            .fold(0.0f64, f64::max);
185
186        state.centroids = new_centroids;
187        state.iteration += 1;
188        max_shift
189    }
190
191    /// Query the nearest cluster for a point.
192    pub fn query_point(&self, point: &[f64]) -> (usize, f64) {
193        let state = self.state.read().unwrap();
194        let d = state.n_features;
195
196        let mut min_dist = f64::MAX;
197        let mut min_cluster = 0;
198
199        for (c, centroid) in state.centroids.chunks(d).enumerate() {
200            let dist = Self::euclidean_distance(point, centroid);
201            if dist < min_dist {
202                min_dist = dist;
203                min_cluster = c;
204            }
205        }
206
207        (min_cluster, min_dist)
208    }
209
210    /// Get current iteration count.
211    pub fn current_iteration(&self) -> u32 {
212        self.state.read().unwrap().iteration
213    }
214
215    /// Get current inertia.
216    pub fn current_inertia(&self) -> f64 {
217        self.state.read().unwrap().inertia
218    }
219
220    /// Run K-Means clustering.
221    ///
222    /// # Arguments
223    /// * `data` - Input data matrix (n_samples x n_features)
224    /// * `k` - Number of clusters
225    /// * `max_iterations` - Maximum number of iterations
226    /// * `tolerance` - Convergence threshold for centroid movement
227    #[allow(clippy::needless_range_loop)]
228    pub fn compute(
229        data: &DataMatrix,
230        k: usize,
231        max_iterations: u32,
232        tolerance: f64,
233    ) -> ClusteringResult {
234        let n = data.n_samples;
235        let d = data.n_features;
236
237        if n == 0 || k == 0 || k > n {
238            return ClusteringResult {
239                labels: Vec::new(),
240                n_clusters: 0,
241                centroids: Vec::new(),
242                inertia: 0.0,
243                iterations: 0,
244                converged: true,
245            };
246        }
247
248        // K-Means++ initialization
249        let mut centroids = Self::kmeans_plus_plus_init(data, k);
250        let mut labels = vec![0usize; n];
251        let mut converged = false;
252        let mut iterations = 0u32;
253
254        for iter in 0..max_iterations {
255            iterations = iter + 1;
256
257            // Assignment step: assign each point to nearest centroid
258            for i in 0..n {
259                let point = data.row(i);
260                let mut min_dist = f64::MAX;
261                let mut min_cluster = 0;
262
263                for (c, centroid) in centroids.chunks(d).enumerate() {
264                    let dist = Self::euclidean_distance(point, centroid);
265                    if dist < min_dist {
266                        min_dist = dist;
267                        min_cluster = c;
268                    }
269                }
270                labels[i] = min_cluster;
271            }
272
273            // Update step: recalculate centroids
274            let mut new_centroids = vec![0.0f64; k * d];
275            let mut counts = vec![0usize; k];
276
277            for i in 0..n {
278                let cluster = labels[i];
279                counts[cluster] += 1;
280                let point = data.row(i);
281                for j in 0..d {
282                    new_centroids[cluster * d + j] += point[j];
283                }
284            }
285
286            // Normalize centroids
287            for c in 0..k {
288                if counts[c] > 0 {
289                    for j in 0..d {
290                        new_centroids[c * d + j] /= counts[c] as f64;
291                    }
292                }
293            }
294
295            // Check convergence
296            let max_shift = centroids
297                .chunks(d)
298                .zip(new_centroids.chunks(d))
299                .map(|(old, new)| Self::euclidean_distance(old, new))
300                .fold(0.0f64, f64::max);
301
302            centroids = new_centroids;
303
304            if max_shift < tolerance {
305                converged = true;
306                break;
307            }
308        }
309
310        // Calculate inertia (sum of squared distances to centroids)
311        let inertia: f64 = (0..n)
312            .map(|i| {
313                let point = data.row(i);
314                let centroid_start = labels[i] * d;
315                let centroid = &centroids[centroid_start..centroid_start + d];
316                let dist = Self::euclidean_distance(point, centroid);
317                dist * dist
318            })
319            .sum();
320
321        ClusteringResult {
322            labels,
323            n_clusters: k,
324            centroids,
325            inertia,
326            iterations,
327            converged,
328        }
329    }
330
331    /// K-Means++ initialization.
332    #[allow(clippy::needless_range_loop)]
333    fn kmeans_plus_plus_init(data: &DataMatrix, k: usize) -> Vec<f64> {
334        let n = data.n_samples;
335        let d = data.n_features;
336        let mut rng = rand::rng();
337        let mut centroids = Vec::with_capacity(k * d);
338
339        // Choose first centroid randomly
340        let first_idx = rng.random_range(0..n);
341        centroids.extend_from_slice(data.row(first_idx));
342
343        let mut distances = vec![f64::MAX; n];
344
345        // Choose remaining centroids
346        for _ in 1..k {
347            // Update distances to nearest centroid
348            for i in 0..n {
349                let point = data.row(i);
350                let last_centroid = &centroids[centroids.len() - d..];
351                let dist = Self::euclidean_distance(point, last_centroid);
352                distances[i] = distances[i].min(dist);
353            }
354
355            // Choose next centroid with probability proportional to D^2
356            let total: f64 = distances.iter().map(|d| d * d).sum();
357            let threshold = rng.random::<f64>() * total;
358
359            let mut cumsum = 0.0;
360            let mut next_idx = 0;
361            for (i, &dist) in distances.iter().enumerate() {
362                cumsum += dist * dist;
363                if cumsum >= threshold {
364                    next_idx = i;
365                    break;
366                }
367            }
368
369            centroids.extend_from_slice(data.row(next_idx));
370        }
371
372        centroids
373    }
374
375    /// Euclidean distance between two vectors.
376    fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
377        a.iter()
378            .zip(b.iter())
379            .map(|(x, y)| (x - y).powi(2))
380            .sum::<f64>()
381            .sqrt()
382    }
383}
384
385impl GpuKernel for KMeans {
386    fn metadata(&self) -> &KernelMetadata {
387        &self.metadata
388    }
389}
390
391// ============================================================================
392// KMeans RingKernelHandler Implementations
393// ============================================================================
394
395/// RingKernelHandler for KMeans assignment step (E-step).
396#[async_trait::async_trait]
397impl RingKernelHandler<KMeansAssignRing, KMeansAssignResponse> for KMeans {
398    async fn handle(
399        &self,
400        _ctx: &mut RingContext,
401        msg: KMeansAssignRing,
402    ) -> Result<KMeansAssignResponse> {
403        // Perform assignment step on internal state
404        let inertia = self.assign_step();
405
406        let state = self.state.read().unwrap();
407        let points_assigned = state.labels.len() as u32;
408
409        Ok(KMeansAssignResponse {
410            request_id: msg.id.0,
411            iteration: msg.iteration,
412            inertia_fp: to_fixed_point(inertia),
413            points_assigned,
414        })
415    }
416}
417
418/// RingKernelHandler for KMeans update step (M-step).
419#[async_trait::async_trait]
420impl RingKernelHandler<KMeansUpdateRing, KMeansUpdateResponse> for KMeans {
421    async fn handle(
422        &self,
423        _ctx: &mut RingContext,
424        msg: KMeansUpdateRing,
425    ) -> Result<KMeansUpdateResponse> {
426        // Perform update step on internal state
427        let max_shift = self.update_step();
428        let converged = max_shift < 1e-6;
429
430        // Update convergence status in state
431        if converged {
432            let mut state = self.state.write().unwrap();
433            state.converged = true;
434        }
435
436        Ok(KMeansUpdateResponse {
437            request_id: msg.id.0,
438            iteration: msg.iteration,
439            max_shift_fp: to_fixed_point(max_shift),
440            converged,
441        })
442    }
443}
444
445/// RingKernelHandler for point queries.
446#[async_trait::async_trait]
447impl RingKernelHandler<KMeansQueryRing, KMeansQueryResponse> for KMeans {
448    async fn handle(
449        &self,
450        _ctx: &mut RingContext,
451        msg: KMeansQueryRing,
452    ) -> Result<KMeansQueryResponse> {
453        // Unpack the query point coordinates
454        let point = unpack_coordinates(&msg.point, msg.n_dims as usize);
455
456        // Query the nearest cluster using internal state
457        let (cluster, distance) = self.query_point(&point);
458
459        Ok(KMeansQueryResponse {
460            request_id: msg.id.0,
461            cluster: cluster as u32,
462            distance_fp: to_fixed_point(distance),
463        })
464    }
465}
466
467/// RingKernelHandler for K2K partial centroid updates.
468///
469/// Aggregates partial centroid contributions from distributed workers.
470#[async_trait::async_trait]
471impl RingKernelHandler<K2KPartialCentroid, K2KCentroidAggregation> for KMeans {
472    #[allow(clippy::needless_range_loop)]
473    async fn handle(
474        &self,
475        _ctx: &mut RingContext,
476        msg: K2KPartialCentroid,
477    ) -> Result<K2KCentroidAggregation> {
478        let n_dims = msg.n_dims as usize;
479        let cluster_id = msg.cluster_id as usize;
480        let mut new_centroid = [0i64; 8];
481
482        // Compute new centroid from partial sums
483        if msg.point_count > 0 {
484            for i in 0..n_dims.min(8) {
485                new_centroid[i] = msg.coord_sum_fp[i] / msg.point_count as i64;
486            }
487        }
488
489        // Calculate shift from old centroid in internal state
490        let shift = {
491            let state = self.state.read().unwrap();
492            let d = state.n_features;
493            if cluster_id < state.k && d > 0 {
494                let old_centroid = &state.centroids[cluster_id * d..(cluster_id + 1) * d];
495                let new_coords: Vec<f64> = new_centroid[..d.min(8)]
496                    .iter()
497                    .map(|&v| from_fixed_point(v))
498                    .collect();
499                Self::euclidean_distance(old_centroid, &new_coords)
500            } else {
501                0.0
502            }
503        };
504
505        Ok(K2KCentroidAggregation {
506            request_id: msg.id.0,
507            cluster_id: msg.cluster_id,
508            iteration: msg.iteration,
509            new_centroid_fp: new_centroid,
510            total_points: msg.point_count,
511            shift_fp: to_fixed_point(shift),
512        })
513    }
514}
515
516/// RingKernelHandler for K2K iteration sync.
517///
518/// Synchronizes distributed KMeans workers after each iteration.
519/// In a single-instance setting, validates iteration state and returns convergence status.
520#[async_trait::async_trait]
521impl RingKernelHandler<K2KKMeansSync, K2KKMeansSyncResponse> for KMeans {
522    async fn handle(
523        &self,
524        _ctx: &mut RingContext,
525        msg: K2KKMeansSync,
526    ) -> Result<K2KKMeansSyncResponse> {
527        let state = self.state.read().unwrap();
528
529        // Verify iteration matches internal state
530        let current_iteration = state.iteration as u64;
531        let all_synced = msg.iteration <= current_iteration;
532
533        // Use reported values for single-worker case
534        // In distributed setting, would aggregate across workers
535        let global_shift = from_fixed_point(msg.max_shift_fp);
536        let converged = global_shift < 1e-6 || state.converged;
537
538        Ok(K2KKMeansSyncResponse {
539            request_id: msg.id.0,
540            iteration: msg.iteration,
541            all_synced,
542            global_inertia_fp: msg.local_inertia_fp,
543            global_max_shift_fp: msg.max_shift_fp,
544            converged,
545        })
546    }
547}
548
549/// RingKernelHandler for K2K centroid broadcast.
550///
551/// Receives new centroids broadcast from coordinator.
552#[async_trait::async_trait]
553impl RingKernelHandler<K2KCentroidBroadcast, K2KCentroidBroadcastAck> for KMeans {
554    async fn handle(
555        &self,
556        _ctx: &mut RingContext,
557        msg: K2KCentroidBroadcast,
558    ) -> Result<K2KCentroidBroadcastAck> {
559        // In a distributed setting, this would update local centroids
560        Ok(K2KCentroidBroadcastAck {
561            request_id: msg.id.0,
562            worker_id: 0, // Would be actual worker ID
563            iteration: msg.iteration,
564            applied: true,
565        })
566    }
567}
568
569// ============================================================================
570// DBSCAN Clustering Kernel
571// ============================================================================
572
573/// DBSCAN clustering kernel.
574///
575/// Density-based spatial clustering of applications with noise.
576#[derive(Debug, Clone)]
577pub struct DBSCAN {
578    metadata: KernelMetadata,
579}
580
581impl Default for DBSCAN {
582    fn default() -> Self {
583        Self::new()
584    }
585}
586
587impl DBSCAN {
588    /// Create a new DBSCAN kernel.
589    #[must_use]
590    pub fn new() -> Self {
591        Self {
592            metadata: KernelMetadata::batch("ml/dbscan-cluster", Domain::StatisticalML)
593                .with_description("Density-based clustering with GPU union-find")
594                .with_throughput(1_000)
595                .with_latency_us(10_000.0),
596        }
597    }
598
599    /// Run DBSCAN clustering.
600    ///
601    /// # Arguments
602    /// * `data` - Input data matrix
603    /// * `eps` - Maximum distance for neighborhood
604    /// * `min_samples` - Minimum points to form a dense region
605    /// * `metric` - Distance metric to use
606    #[allow(clippy::needless_range_loop)]
607    pub fn compute(
608        data: &DataMatrix,
609        eps: f64,
610        min_samples: usize,
611        metric: DistanceMetric,
612    ) -> ClusteringResult {
613        let n = data.n_samples;
614
615        if n == 0 {
616            return ClusteringResult {
617                labels: Vec::new(),
618                n_clusters: 0,
619                centroids: Vec::new(),
620                inertia: 0.0,
621                iterations: 1,
622                converged: true,
623            };
624        }
625
626        // -1 = unvisited, -2 = noise, >= 0 = cluster label
627        let mut labels = vec![-1i64; n];
628        let mut current_cluster = 0i64;
629
630        // Precompute neighborhoods (for efficiency)
631        let neighborhoods: Vec<Vec<usize>> = (0..n)
632            .map(|i| Self::get_neighbors(data, i, eps, metric))
633            .collect();
634
635        for i in 0..n {
636            if labels[i] != -1 {
637                continue; // Already processed
638            }
639
640            let neighbors = &neighborhoods[i];
641
642            if neighbors.len() < min_samples {
643                labels[i] = -2; // Mark as noise
644                continue;
645            }
646
647            // Start new cluster
648            labels[i] = current_cluster;
649            let mut seed_set: Vec<usize> = neighbors.clone();
650            let mut j = 0;
651
652            while j < seed_set.len() {
653                let q = seed_set[j];
654                j += 1;
655
656                if labels[q] == -2 {
657                    labels[q] = current_cluster; // Change noise to border
658                }
659
660                if labels[q] != -1 {
661                    continue; // Already processed
662                }
663
664                labels[q] = current_cluster;
665
666                let q_neighbors = &neighborhoods[q];
667                if q_neighbors.len() >= min_samples {
668                    // Add new neighbors to seed set
669                    for &neighbor in q_neighbors {
670                        if !seed_set.contains(&neighbor) {
671                            seed_set.push(neighbor);
672                        }
673                    }
674                }
675            }
676
677            current_cluster += 1;
678        }
679
680        // Convert labels to usize (noise stays as max value)
681        let n_clusters = current_cluster as usize;
682        let labels: Vec<usize> = labels
683            .iter()
684            .map(|&l| if l < 0 { usize::MAX } else { l as usize })
685            .collect();
686
687        // Calculate centroids for each cluster
688        let d = data.n_features;
689        let mut centroids = vec![0.0f64; n_clusters * d];
690        let mut counts = vec![0usize; n_clusters];
691
692        for i in 0..n {
693            if labels[i] < n_clusters {
694                let cluster = labels[i];
695                counts[cluster] += 1;
696                for j in 0..d {
697                    centroids[cluster * d + j] += data.row(i)[j];
698                }
699            }
700        }
701
702        for c in 0..n_clusters {
703            if counts[c] > 0 {
704                for j in 0..d {
705                    centroids[c * d + j] /= counts[c] as f64;
706                }
707            }
708        }
709
710        ClusteringResult {
711            labels,
712            n_clusters,
713            centroids,
714            inertia: 0.0,
715            iterations: 1,
716            converged: true,
717        }
718    }
719
720    /// Get neighbors within eps distance.
721    fn get_neighbors(
722        data: &DataMatrix,
723        point_idx: usize,
724        eps: f64,
725        metric: DistanceMetric,
726    ) -> Vec<usize> {
727        let n = data.n_samples;
728        let point = data.row(point_idx);
729
730        (0..n)
731            .filter(|&i| {
732                let other = data.row(i);
733                let dist = metric.compute(point, other);
734                dist <= eps
735            })
736            .collect()
737    }
738}
739
740impl GpuKernel for DBSCAN {
741    fn metadata(&self) -> &KernelMetadata {
742        &self.metadata
743    }
744}
745
746// ============================================================================
747// Hierarchical Clustering Kernel
748// ============================================================================
749
750/// Linkage method for hierarchical clustering.
751#[derive(Debug, Clone, Copy, PartialEq)]
752pub enum LinkageMethod {
753    /// Single linkage (minimum distance)
754    Single,
755    /// Complete linkage (maximum distance)
756    Complete,
757    /// Average linkage (UPGMA)
758    Average,
759    /// Ward's method (minimize variance)
760    Ward,
761}
762
763/// Hierarchical clustering kernel.
764///
765/// Agglomerative hierarchical clustering with various linkage methods.
766#[derive(Debug, Clone)]
767pub struct HierarchicalClustering {
768    metadata: KernelMetadata,
769}
770
771impl Default for HierarchicalClustering {
772    fn default() -> Self {
773        Self::new()
774    }
775}
776
777impl HierarchicalClustering {
778    /// Create a new hierarchical clustering kernel.
779    #[must_use]
780    pub fn new() -> Self {
781        Self {
782            metadata: KernelMetadata::batch("ml/hierarchical-cluster", Domain::StatisticalML)
783                .with_description("Agglomerative hierarchical clustering")
784                .with_throughput(500)
785                .with_latency_us(50_000.0),
786        }
787    }
788
789    /// Run hierarchical clustering.
790    ///
791    /// # Arguments
792    /// * `data` - Input data matrix
793    /// * `n_clusters` - Number of clusters to form
794    /// * `linkage` - Linkage method
795    /// * `metric` - Distance metric
796    #[allow(clippy::needless_range_loop)]
797    pub fn compute(
798        data: &DataMatrix,
799        n_clusters: usize,
800        linkage: LinkageMethod,
801        metric: DistanceMetric,
802    ) -> ClusteringResult {
803        let n = data.n_samples;
804
805        if n == 0 || n_clusters == 0 {
806            return ClusteringResult {
807                labels: Vec::new(),
808                n_clusters: 0,
809                centroids: Vec::new(),
810                inertia: 0.0,
811                iterations: 0,
812                converged: true,
813            };
814        }
815
816        // Initialize each point as its own cluster
817        let mut labels: Vec<usize> = (0..n).collect();
818        let mut active_clusters: Vec<bool> = vec![true; n];
819        let mut cluster_sizes: Vec<usize> = vec![1; n];
820
821        // Compute initial distance matrix
822        let mut distances = Self::compute_distance_matrix(data, metric);
823
824        // Merge clusters until we have n_clusters
825        let mut current_n_clusters = n;
826
827        while current_n_clusters > n_clusters {
828            // Find closest pair of clusters
829            let (c1, c2) = Self::find_closest_clusters(&distances, &active_clusters, n);
830
831            if c1 == c2 {
832                break;
833            }
834
835            // Merge c2 into c1
836            for label in &mut labels {
837                if *label == c2 {
838                    *label = c1;
839                }
840            }
841
842            // Update distances based on linkage
843            Self::update_distances(
844                &mut distances,
845                c1,
846                c2,
847                n,
848                linkage,
849                &cluster_sizes,
850                &active_clusters,
851            );
852
853            cluster_sizes[c1] += cluster_sizes[c2];
854            active_clusters[c2] = false;
855            current_n_clusters -= 1;
856        }
857
858        // Renumber labels to be contiguous
859        let mut label_map = std::collections::HashMap::new();
860        let mut next_label = 0usize;
861
862        for label in &mut labels {
863            let new_label = *label_map.entry(*label).or_insert_with(|| {
864                let l = next_label;
865                next_label += 1;
866                l
867            });
868            *label = new_label;
869        }
870
871        // Calculate centroids
872        let d = data.n_features;
873        let final_n_clusters = next_label;
874        let mut centroids = vec![0.0f64; final_n_clusters * d];
875        let mut counts = vec![0usize; final_n_clusters];
876
877        for i in 0..n {
878            let cluster = labels[i];
879            counts[cluster] += 1;
880            for j in 0..d {
881                centroids[cluster * d + j] += data.row(i)[j];
882            }
883        }
884
885        for c in 0..final_n_clusters {
886            if counts[c] > 0 {
887                for j in 0..d {
888                    centroids[c * d + j] /= counts[c] as f64;
889                }
890            }
891        }
892
893        ClusteringResult {
894            labels,
895            n_clusters: final_n_clusters,
896            centroids,
897            inertia: 0.0,
898            iterations: (n - n_clusters) as u32,
899            converged: true,
900        }
901    }
902
903    fn compute_distance_matrix(data: &DataMatrix, metric: DistanceMetric) -> Vec<f64> {
904        let n = data.n_samples;
905        let mut distances = vec![f64::MAX; n * n];
906
907        for i in 0..n {
908            for j in 0..n {
909                if i != j {
910                    distances[i * n + j] = metric.compute(data.row(i), data.row(j));
911                }
912            }
913        }
914
915        distances
916    }
917
918    fn find_closest_clusters(distances: &[f64], active: &[bool], n: usize) -> (usize, usize) {
919        let mut min_dist = f64::MAX;
920        let mut min_i = 0;
921        let mut min_j = 0;
922
923        for i in 0..n {
924            if !active[i] {
925                continue;
926            }
927            for j in (i + 1)..n {
928                if !active[j] {
929                    continue;
930                }
931                let dist = distances[i * n + j];
932                if dist < min_dist {
933                    min_dist = dist;
934                    min_i = i;
935                    min_j = j;
936                }
937            }
938        }
939
940        (min_i, min_j)
941    }
942
943    fn update_distances(
944        distances: &mut [f64],
945        c1: usize,
946        c2: usize,
947        n: usize,
948        linkage: LinkageMethod,
949        cluster_sizes: &[usize],
950        active: &[bool],
951    ) {
952        for k in 0..n {
953            if !active[k] || k == c1 || k == c2 {
954                continue;
955            }
956
957            let d1 = distances[c1 * n + k];
958            let d2 = distances[c2 * n + k];
959
960            let new_dist = match linkage {
961                LinkageMethod::Single => d1.min(d2),
962                LinkageMethod::Complete => d1.max(d2),
963                LinkageMethod::Average => {
964                    let n1 = cluster_sizes[c1] as f64;
965                    let n2 = cluster_sizes[c2] as f64;
966                    (n1 * d1 + n2 * d2) / (n1 + n2)
967                }
968                LinkageMethod::Ward => {
969                    let n1 = cluster_sizes[c1] as f64;
970                    let n2 = cluster_sizes[c2] as f64;
971                    let nk = cluster_sizes[k] as f64;
972                    let total = n1 + n2 + nk;
973                    ((n1 + nk) * d1 * d1 + (n2 + nk) * d2 * d2
974                        - nk * distances[c1 * n + c2].powi(2))
975                        / total
976                }
977            };
978
979            distances[c1 * n + k] = new_dist;
980            distances[k * n + c1] = new_dist;
981        }
982    }
983}
984
985impl GpuKernel for HierarchicalClustering {
986    fn metadata(&self) -> &KernelMetadata {
987        &self.metadata
988    }
989}
990
991// ============================================================================
992// BatchKernel Implementations
993// ============================================================================
994
995use crate::messages::{
996    DBSCANInput, DBSCANOutput, HierarchicalInput, HierarchicalOutput, KMeansInput, KMeansOutput,
997    Linkage,
998};
999use async_trait::async_trait;
1000use rustkernel_core::error::Result;
1001use rustkernel_core::traits::BatchKernel;
1002use std::time::Instant;
1003
1004/// K-Means batch kernel implementation.
1005impl KMeans {
1006    /// Execute K-Means clustering as a batch operation.
1007    ///
1008    /// Convenience method for batch clustering.
1009    pub async fn cluster_batch(&self, input: KMeansInput) -> Result<KMeansOutput> {
1010        let start = Instant::now();
1011        let result = Self::compute(&input.data, input.k, input.max_iterations, input.tolerance);
1012        let compute_time_us = start.elapsed().as_micros() as u64;
1013
1014        Ok(KMeansOutput {
1015            result,
1016            compute_time_us,
1017        })
1018    }
1019}
1020
1021#[async_trait]
1022impl BatchKernel<KMeansInput, KMeansOutput> for KMeans {
1023    async fn execute(&self, input: KMeansInput) -> Result<KMeansOutput> {
1024        self.cluster_batch(input).await
1025    }
1026}
1027
1028/// DBSCAN batch kernel implementation.
1029#[async_trait]
1030impl BatchKernel<DBSCANInput, DBSCANOutput> for DBSCAN {
1031    async fn execute(&self, input: DBSCANInput) -> Result<DBSCANOutput> {
1032        let start = Instant::now();
1033        let result = Self::compute(&input.data, input.eps, input.min_samples, input.metric);
1034        let compute_time_us = start.elapsed().as_micros() as u64;
1035
1036        Ok(DBSCANOutput {
1037            result,
1038            compute_time_us,
1039        })
1040    }
1041}
1042
1043/// Hierarchical clustering batch kernel implementation.
1044#[async_trait]
1045impl BatchKernel<HierarchicalInput, HierarchicalOutput> for HierarchicalClustering {
1046    async fn execute(&self, input: HierarchicalInput) -> Result<HierarchicalOutput> {
1047        let start = Instant::now();
1048        let linkage_method = match input.linkage {
1049            Linkage::Single => LinkageMethod::Single,
1050            Linkage::Complete => LinkageMethod::Complete,
1051            Linkage::Average => LinkageMethod::Average,
1052            Linkage::Ward => LinkageMethod::Ward,
1053        };
1054        let result = Self::compute(&input.data, input.n_clusters, linkage_method, input.metric);
1055        let compute_time_us = start.elapsed().as_micros() as u64;
1056
1057        Ok(HierarchicalOutput {
1058            result,
1059            compute_time_us,
1060        })
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067
1068    fn create_two_clusters() -> DataMatrix {
1069        // Two clear clusters
1070        DataMatrix::from_rows(&[
1071            &[0.0, 0.0],
1072            &[0.1, 0.1],
1073            &[0.2, 0.0],
1074            &[10.0, 10.0],
1075            &[10.1, 10.1],
1076            &[10.2, 10.0],
1077        ])
1078    }
1079
1080    #[test]
1081    fn test_kmeans_metadata() {
1082        let kernel = KMeans::new();
1083        assert_eq!(kernel.metadata().id, "ml/kmeans-cluster");
1084        assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
1085    }
1086
1087    #[test]
1088    fn test_kmeans_two_clusters() {
1089        let data = create_two_clusters();
1090        let result = KMeans::compute(&data, 2, 100, 1e-6);
1091
1092        assert_eq!(result.n_clusters, 2);
1093        assert!(result.converged);
1094
1095        // Points 0,1,2 should be in one cluster, 3,4,5 in another
1096        assert_eq!(result.labels[0], result.labels[1]);
1097        assert_eq!(result.labels[1], result.labels[2]);
1098        assert_eq!(result.labels[3], result.labels[4]);
1099        assert_eq!(result.labels[4], result.labels[5]);
1100        assert_ne!(result.labels[0], result.labels[3]);
1101    }
1102
1103    #[test]
1104    fn test_dbscan_two_clusters() {
1105        let data = create_two_clusters();
1106        let result = DBSCAN::compute(&data, 1.0, 2, DistanceMetric::Euclidean);
1107
1108        assert_eq!(result.n_clusters, 2);
1109
1110        // Points should be grouped correctly
1111        assert_eq!(result.labels[0], result.labels[1]);
1112        assert_eq!(result.labels[3], result.labels[4]);
1113        assert_ne!(result.labels[0], result.labels[3]);
1114    }
1115
1116    #[test]
1117    fn test_hierarchical_two_clusters() {
1118        let data = create_two_clusters();
1119        let result = HierarchicalClustering::compute(
1120            &data,
1121            2,
1122            LinkageMethod::Complete,
1123            DistanceMetric::Euclidean,
1124        );
1125
1126        assert_eq!(result.n_clusters, 2);
1127
1128        // Points should be grouped correctly
1129        assert_eq!(result.labels[0], result.labels[1]);
1130        assert_eq!(result.labels[1], result.labels[2]);
1131        assert_eq!(result.labels[3], result.labels[4]);
1132        assert_ne!(result.labels[0], result.labels[3]);
1133    }
1134}