1use crate::ring_messages::{
9 K2KCentroidAggregation, K2KCentroidBroadcast, K2KCentroidBroadcastAck, K2KKMeansSync,
10 K2KKMeansSyncResponse, K2KPartialCentroid, KMeansAssignResponse, KMeansAssignRing,
11 KMeansQueryResponse, KMeansQueryRing, KMeansUpdateResponse, KMeansUpdateRing, from_fixed_point,
12 to_fixed_point, unpack_coordinates,
13};
14use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
15use rand::prelude::*;
16use ringkernel_core::RingContext;
17use rustkernel_core::traits::RingKernelHandler;
18use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
19
20#[derive(Debug, Clone, Default)]
26pub struct KMeansState {
27 pub centroids: Vec<f64>,
29 pub data: Option<DataMatrix>,
31 pub k: usize,
33 pub n_features: usize,
35 pub iteration: u32,
37 pub inertia: f64,
39 pub converged: bool,
41 pub labels: Vec<usize>,
43}
44
45#[derive(Debug)]
49pub struct KMeans {
50 metadata: KernelMetadata,
51 state: std::sync::RwLock<KMeansState>,
53}
54
55impl Clone for KMeans {
56 fn clone(&self) -> Self {
57 Self {
58 metadata: self.metadata.clone(),
59 state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
60 }
61 }
62}
63
64impl Default for KMeans {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl KMeans {
71 #[must_use]
73 pub fn new() -> Self {
74 Self {
75 metadata: KernelMetadata::batch("ml/kmeans-cluster", Domain::StatisticalML)
76 .with_description("K-Means clustering with K-Means++ initialization")
77 .with_throughput(20_000)
78 .with_latency_us(50.0),
79 state: std::sync::RwLock::new(KMeansState::default()),
80 }
81 }
82
83 pub fn initialize(&self, data: DataMatrix, k: usize) {
85 let centroids = Self::kmeans_plus_plus_init(&data, k);
86 let n = data.n_samples;
87 let n_features = data.n_features;
88
89 let mut state = self.state.write().unwrap();
90 *state = KMeansState {
91 centroids,
92 data: Some(data),
93 k,
94 n_features,
95 iteration: 0,
96 inertia: 0.0,
97 converged: false,
98 labels: vec![0; n],
99 };
100 }
101
102 pub fn assign_step(&self) -> f64 {
105 let mut state = self.state.write().unwrap();
106
107 let data = match state.data {
109 Some(ref d) => d.clone(),
110 None => return 0.0,
111 };
112
113 let n = data.n_samples;
114 let d_features = state.n_features;
115 let mut total_inertia = 0.0;
116
117 let centroids = state.centroids.clone();
119
120 let mut new_labels = vec![0usize; n];
122 for i in 0..n {
123 let point = data.row(i);
124 let mut min_dist = f64::MAX;
125 let mut min_cluster = 0;
126
127 for (c, centroid) in centroids.chunks(d_features).enumerate() {
128 let dist = Self::euclidean_distance(point, centroid);
129 if dist < min_dist {
130 min_dist = dist;
131 min_cluster = c;
132 }
133 }
134 new_labels[i] = min_cluster;
135 total_inertia += min_dist * min_dist;
136 }
137
138 state.labels = new_labels;
140 state.inertia = total_inertia;
141 total_inertia
142 }
143
144 pub fn update_step(&self) -> f64 {
147 let mut state = self.state.write().unwrap();
148 let Some(ref data) = state.data else {
149 return 0.0;
150 };
151
152 let n = data.n_samples;
153 let d = state.n_features;
154 let k = state.k;
155
156 let mut new_centroids = vec![0.0f64; k * d];
157 let mut counts = vec![0usize; k];
158
159 for i in 0..n {
160 let cluster = state.labels[i];
161 counts[cluster] += 1;
162 let point = data.row(i);
163 for j in 0..d {
164 new_centroids[cluster * d + j] += point[j];
165 }
166 }
167
168 for c in 0..k {
170 if counts[c] > 0 {
171 for j in 0..d {
172 new_centroids[c * d + j] /= counts[c] as f64;
173 }
174 }
175 }
176
177 let max_shift = state
179 .centroids
180 .chunks(d)
181 .zip(new_centroids.chunks(d))
182 .map(|(old, new)| Self::euclidean_distance(old, new))
183 .fold(0.0f64, f64::max);
184
185 state.centroids = new_centroids;
186 state.iteration += 1;
187 max_shift
188 }
189
190 pub fn query_point(&self, point: &[f64]) -> (usize, f64) {
192 let state = self.state.read().unwrap();
193 let d = state.n_features;
194
195 let mut min_dist = f64::MAX;
196 let mut min_cluster = 0;
197
198 for (c, centroid) in state.centroids.chunks(d).enumerate() {
199 let dist = Self::euclidean_distance(point, centroid);
200 if dist < min_dist {
201 min_dist = dist;
202 min_cluster = c;
203 }
204 }
205
206 (min_cluster, min_dist)
207 }
208
209 pub fn current_iteration(&self) -> u32 {
211 self.state.read().unwrap().iteration
212 }
213
214 pub fn current_inertia(&self) -> f64 {
216 self.state.read().unwrap().inertia
217 }
218
219 pub fn compute(
227 data: &DataMatrix,
228 k: usize,
229 max_iterations: u32,
230 tolerance: f64,
231 ) -> ClusteringResult {
232 let n = data.n_samples;
233 let d = data.n_features;
234
235 if n == 0 || k == 0 || k > n {
236 return ClusteringResult {
237 labels: Vec::new(),
238 n_clusters: 0,
239 centroids: Vec::new(),
240 inertia: 0.0,
241 iterations: 0,
242 converged: true,
243 };
244 }
245
246 let mut centroids = Self::kmeans_plus_plus_init(data, k);
248 let mut labels = vec![0usize; n];
249 let mut converged = false;
250 let mut iterations = 0u32;
251
252 for iter in 0..max_iterations {
253 iterations = iter + 1;
254
255 for i in 0..n {
257 let point = data.row(i);
258 let mut min_dist = f64::MAX;
259 let mut min_cluster = 0;
260
261 for (c, centroid) in centroids.chunks(d).enumerate() {
262 let dist = Self::euclidean_distance(point, centroid);
263 if dist < min_dist {
264 min_dist = dist;
265 min_cluster = c;
266 }
267 }
268 labels[i] = min_cluster;
269 }
270
271 let mut new_centroids = vec![0.0f64; k * d];
273 let mut counts = vec![0usize; k];
274
275 for i in 0..n {
276 let cluster = labels[i];
277 counts[cluster] += 1;
278 let point = data.row(i);
279 for j in 0..d {
280 new_centroids[cluster * d + j] += point[j];
281 }
282 }
283
284 for c in 0..k {
286 if counts[c] > 0 {
287 for j in 0..d {
288 new_centroids[c * d + j] /= counts[c] as f64;
289 }
290 }
291 }
292
293 let max_shift = centroids
295 .chunks(d)
296 .zip(new_centroids.chunks(d))
297 .map(|(old, new)| Self::euclidean_distance(old, new))
298 .fold(0.0f64, f64::max);
299
300 centroids = new_centroids;
301
302 if max_shift < tolerance {
303 converged = true;
304 break;
305 }
306 }
307
308 let inertia: f64 = (0..n)
310 .map(|i| {
311 let point = data.row(i);
312 let centroid_start = labels[i] * d;
313 let centroid = ¢roids[centroid_start..centroid_start + d];
314 let dist = Self::euclidean_distance(point, centroid);
315 dist * dist
316 })
317 .sum();
318
319 ClusteringResult {
320 labels,
321 n_clusters: k,
322 centroids,
323 inertia,
324 iterations,
325 converged,
326 }
327 }
328
329 fn kmeans_plus_plus_init(data: &DataMatrix, k: usize) -> Vec<f64> {
331 let n = data.n_samples;
332 let d = data.n_features;
333 let mut rng = rand::rng();
334 let mut centroids = Vec::with_capacity(k * d);
335
336 let first_idx = rng.random_range(0..n);
338 centroids.extend_from_slice(data.row(first_idx));
339
340 let mut distances = vec![f64::MAX; n];
341
342 for _ in 1..k {
344 for i in 0..n {
346 let point = data.row(i);
347 let last_centroid = ¢roids[centroids.len() - d..];
348 let dist = Self::euclidean_distance(point, last_centroid);
349 distances[i] = distances[i].min(dist);
350 }
351
352 let total: f64 = distances.iter().map(|d| d * d).sum();
354 let threshold = rng.random::<f64>() * total;
355
356 let mut cumsum = 0.0;
357 let mut next_idx = 0;
358 for (i, &dist) in distances.iter().enumerate() {
359 cumsum += dist * dist;
360 if cumsum >= threshold {
361 next_idx = i;
362 break;
363 }
364 }
365
366 centroids.extend_from_slice(data.row(next_idx));
367 }
368
369 centroids
370 }
371
372 fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
374 a.iter()
375 .zip(b.iter())
376 .map(|(x, y)| (x - y).powi(2))
377 .sum::<f64>()
378 .sqrt()
379 }
380}
381
382impl GpuKernel for KMeans {
383 fn metadata(&self) -> &KernelMetadata {
384 &self.metadata
385 }
386}
387
388#[async_trait::async_trait]
394impl RingKernelHandler<KMeansAssignRing, KMeansAssignResponse> for KMeans {
395 async fn handle(
396 &self,
397 _ctx: &mut RingContext,
398 msg: KMeansAssignRing,
399 ) -> Result<KMeansAssignResponse> {
400 let inertia = self.assign_step();
402
403 let state = self.state.read().unwrap();
404 let points_assigned = state.labels.len() as u32;
405
406 Ok(KMeansAssignResponse {
407 request_id: msg.id.0,
408 iteration: msg.iteration,
409 inertia_fp: to_fixed_point(inertia),
410 points_assigned,
411 })
412 }
413}
414
415#[async_trait::async_trait]
417impl RingKernelHandler<KMeansUpdateRing, KMeansUpdateResponse> for KMeans {
418 async fn handle(
419 &self,
420 _ctx: &mut RingContext,
421 msg: KMeansUpdateRing,
422 ) -> Result<KMeansUpdateResponse> {
423 let max_shift = self.update_step();
425 let converged = max_shift < 1e-6;
426
427 if converged {
429 let mut state = self.state.write().unwrap();
430 state.converged = true;
431 }
432
433 Ok(KMeansUpdateResponse {
434 request_id: msg.id.0,
435 iteration: msg.iteration,
436 max_shift_fp: to_fixed_point(max_shift),
437 converged,
438 })
439 }
440}
441
442#[async_trait::async_trait]
444impl RingKernelHandler<KMeansQueryRing, KMeansQueryResponse> for KMeans {
445 async fn handle(
446 &self,
447 _ctx: &mut RingContext,
448 msg: KMeansQueryRing,
449 ) -> Result<KMeansQueryResponse> {
450 let point = unpack_coordinates(&msg.point, msg.n_dims as usize);
452
453 let (cluster, distance) = self.query_point(&point);
455
456 Ok(KMeansQueryResponse {
457 request_id: msg.id.0,
458 cluster: cluster as u32,
459 distance_fp: to_fixed_point(distance),
460 })
461 }
462}
463
464#[async_trait::async_trait]
468impl RingKernelHandler<K2KPartialCentroid, K2KCentroidAggregation> for KMeans {
469 async fn handle(
470 &self,
471 _ctx: &mut RingContext,
472 msg: K2KPartialCentroid,
473 ) -> Result<K2KCentroidAggregation> {
474 let n_dims = msg.n_dims as usize;
475 let cluster_id = msg.cluster_id as usize;
476 let mut new_centroid = [0i64; 8];
477
478 if msg.point_count > 0 {
480 for i in 0..n_dims.min(8) {
481 new_centroid[i] = msg.coord_sum_fp[i] / msg.point_count as i64;
482 }
483 }
484
485 let shift = {
487 let state = self.state.read().unwrap();
488 let d = state.n_features;
489 if cluster_id < state.k && d > 0 {
490 let old_centroid = &state.centroids[cluster_id * d..(cluster_id + 1) * d];
491 let new_coords: Vec<f64> = new_centroid[..d.min(8)]
492 .iter()
493 .map(|&v| from_fixed_point(v))
494 .collect();
495 Self::euclidean_distance(old_centroid, &new_coords)
496 } else {
497 0.0
498 }
499 };
500
501 Ok(K2KCentroidAggregation {
502 request_id: msg.id.0,
503 cluster_id: msg.cluster_id,
504 iteration: msg.iteration,
505 new_centroid_fp: new_centroid,
506 total_points: msg.point_count,
507 shift_fp: to_fixed_point(shift),
508 })
509 }
510}
511
512#[async_trait::async_trait]
517impl RingKernelHandler<K2KKMeansSync, K2KKMeansSyncResponse> for KMeans {
518 async fn handle(
519 &self,
520 _ctx: &mut RingContext,
521 msg: K2KKMeansSync,
522 ) -> Result<K2KKMeansSyncResponse> {
523 let state = self.state.read().unwrap();
524
525 let current_iteration = state.iteration as u64;
527 let all_synced = msg.iteration <= current_iteration;
528
529 let global_shift = from_fixed_point(msg.max_shift_fp);
532 let converged = global_shift < 1e-6 || state.converged;
533
534 Ok(K2KKMeansSyncResponse {
535 request_id: msg.id.0,
536 iteration: msg.iteration,
537 all_synced,
538 global_inertia_fp: msg.local_inertia_fp,
539 global_max_shift_fp: msg.max_shift_fp,
540 converged,
541 })
542 }
543}
544
545#[async_trait::async_trait]
549impl RingKernelHandler<K2KCentroidBroadcast, K2KCentroidBroadcastAck> for KMeans {
550 async fn handle(
551 &self,
552 _ctx: &mut RingContext,
553 msg: K2KCentroidBroadcast,
554 ) -> Result<K2KCentroidBroadcastAck> {
555 Ok(K2KCentroidBroadcastAck {
557 request_id: msg.id.0,
558 worker_id: 0, iteration: msg.iteration,
560 applied: true,
561 })
562 }
563}
564
565#[derive(Debug, Clone)]
573pub struct DBSCAN {
574 metadata: KernelMetadata,
575}
576
577impl Default for DBSCAN {
578 fn default() -> Self {
579 Self::new()
580 }
581}
582
583impl DBSCAN {
584 #[must_use]
586 pub fn new() -> Self {
587 Self {
588 metadata: KernelMetadata::batch("ml/dbscan-cluster", Domain::StatisticalML)
589 .with_description("Density-based clustering with GPU union-find")
590 .with_throughput(1_000)
591 .with_latency_us(10_000.0),
592 }
593 }
594
595 pub fn compute(
603 data: &DataMatrix,
604 eps: f64,
605 min_samples: usize,
606 metric: DistanceMetric,
607 ) -> ClusteringResult {
608 let n = data.n_samples;
609
610 if n == 0 {
611 return ClusteringResult {
612 labels: Vec::new(),
613 n_clusters: 0,
614 centroids: Vec::new(),
615 inertia: 0.0,
616 iterations: 1,
617 converged: true,
618 };
619 }
620
621 let mut labels = vec![-1i64; n];
623 let mut current_cluster = 0i64;
624
625 let neighborhoods: Vec<Vec<usize>> = (0..n)
627 .map(|i| Self::get_neighbors(data, i, eps, metric))
628 .collect();
629
630 for i in 0..n {
631 if labels[i] != -1 {
632 continue; }
634
635 let neighbors = &neighborhoods[i];
636
637 if neighbors.len() < min_samples {
638 labels[i] = -2; continue;
640 }
641
642 labels[i] = current_cluster;
644 let mut seed_set: Vec<usize> = neighbors.clone();
645 let mut j = 0;
646
647 while j < seed_set.len() {
648 let q = seed_set[j];
649 j += 1;
650
651 if labels[q] == -2 {
652 labels[q] = current_cluster; }
654
655 if labels[q] != -1 {
656 continue; }
658
659 labels[q] = current_cluster;
660
661 let q_neighbors = &neighborhoods[q];
662 if q_neighbors.len() >= min_samples {
663 for &neighbor in q_neighbors {
665 if !seed_set.contains(&neighbor) {
666 seed_set.push(neighbor);
667 }
668 }
669 }
670 }
671
672 current_cluster += 1;
673 }
674
675 let n_clusters = current_cluster as usize;
677 let labels: Vec<usize> = labels
678 .iter()
679 .map(|&l| if l < 0 { usize::MAX } else { l as usize })
680 .collect();
681
682 let d = data.n_features;
684 let mut centroids = vec![0.0f64; n_clusters * d];
685 let mut counts = vec![0usize; n_clusters];
686
687 for i in 0..n {
688 if labels[i] < n_clusters {
689 let cluster = labels[i];
690 counts[cluster] += 1;
691 for j in 0..d {
692 centroids[cluster * d + j] += data.row(i)[j];
693 }
694 }
695 }
696
697 for c in 0..n_clusters {
698 if counts[c] > 0 {
699 for j in 0..d {
700 centroids[c * d + j] /= counts[c] as f64;
701 }
702 }
703 }
704
705 ClusteringResult {
706 labels,
707 n_clusters,
708 centroids,
709 inertia: 0.0,
710 iterations: 1,
711 converged: true,
712 }
713 }
714
715 fn get_neighbors(
717 data: &DataMatrix,
718 point_idx: usize,
719 eps: f64,
720 metric: DistanceMetric,
721 ) -> Vec<usize> {
722 let n = data.n_samples;
723 let point = data.row(point_idx);
724
725 (0..n)
726 .filter(|&i| {
727 let other = data.row(i);
728 let dist = metric.compute(point, other);
729 dist <= eps
730 })
731 .collect()
732 }
733}
734
735impl GpuKernel for DBSCAN {
736 fn metadata(&self) -> &KernelMetadata {
737 &self.metadata
738 }
739}
740
741#[derive(Debug, Clone, Copy, PartialEq)]
747pub enum LinkageMethod {
748 Single,
750 Complete,
752 Average,
754 Ward,
756}
757
758#[derive(Debug, Clone)]
762pub struct HierarchicalClustering {
763 metadata: KernelMetadata,
764}
765
766impl Default for HierarchicalClustering {
767 fn default() -> Self {
768 Self::new()
769 }
770}
771
772impl HierarchicalClustering {
773 #[must_use]
775 pub fn new() -> Self {
776 Self {
777 metadata: KernelMetadata::batch("ml/hierarchical-cluster", Domain::StatisticalML)
778 .with_description("Agglomerative hierarchical clustering")
779 .with_throughput(500)
780 .with_latency_us(50_000.0),
781 }
782 }
783
784 pub fn compute(
792 data: &DataMatrix,
793 n_clusters: usize,
794 linkage: LinkageMethod,
795 metric: DistanceMetric,
796 ) -> ClusteringResult {
797 let n = data.n_samples;
798
799 if n == 0 || n_clusters == 0 {
800 return ClusteringResult {
801 labels: Vec::new(),
802 n_clusters: 0,
803 centroids: Vec::new(),
804 inertia: 0.0,
805 iterations: 0,
806 converged: true,
807 };
808 }
809
810 let mut labels: Vec<usize> = (0..n).collect();
812 let mut active_clusters: Vec<bool> = vec![true; n];
813 let mut cluster_sizes: Vec<usize> = vec![1; n];
814
815 let mut distances = Self::compute_distance_matrix(data, metric);
817
818 let mut current_n_clusters = n;
820
821 while current_n_clusters > n_clusters {
822 let (c1, c2) = Self::find_closest_clusters(&distances, &active_clusters, n);
824
825 if c1 == c2 {
826 break;
827 }
828
829 for label in &mut labels {
831 if *label == c2 {
832 *label = c1;
833 }
834 }
835
836 Self::update_distances(
838 &mut distances,
839 c1,
840 c2,
841 n,
842 linkage,
843 &cluster_sizes,
844 &active_clusters,
845 );
846
847 cluster_sizes[c1] += cluster_sizes[c2];
848 active_clusters[c2] = false;
849 current_n_clusters -= 1;
850 }
851
852 let mut label_map = std::collections::HashMap::new();
854 let mut next_label = 0usize;
855
856 for label in &mut labels {
857 let new_label = *label_map.entry(*label).or_insert_with(|| {
858 let l = next_label;
859 next_label += 1;
860 l
861 });
862 *label = new_label;
863 }
864
865 let d = data.n_features;
867 let final_n_clusters = next_label;
868 let mut centroids = vec![0.0f64; final_n_clusters * d];
869 let mut counts = vec![0usize; final_n_clusters];
870
871 for i in 0..n {
872 let cluster = labels[i];
873 counts[cluster] += 1;
874 for j in 0..d {
875 centroids[cluster * d + j] += data.row(i)[j];
876 }
877 }
878
879 for c in 0..final_n_clusters {
880 if counts[c] > 0 {
881 for j in 0..d {
882 centroids[c * d + j] /= counts[c] as f64;
883 }
884 }
885 }
886
887 ClusteringResult {
888 labels,
889 n_clusters: final_n_clusters,
890 centroids,
891 inertia: 0.0,
892 iterations: (n - n_clusters) as u32,
893 converged: true,
894 }
895 }
896
897 fn compute_distance_matrix(data: &DataMatrix, metric: DistanceMetric) -> Vec<f64> {
898 let n = data.n_samples;
899 let mut distances = vec![f64::MAX; n * n];
900
901 for i in 0..n {
902 for j in 0..n {
903 if i != j {
904 distances[i * n + j] = metric.compute(data.row(i), data.row(j));
905 }
906 }
907 }
908
909 distances
910 }
911
912 fn find_closest_clusters(distances: &[f64], active: &[bool], n: usize) -> (usize, usize) {
913 let mut min_dist = f64::MAX;
914 let mut min_i = 0;
915 let mut min_j = 0;
916
917 for i in 0..n {
918 if !active[i] {
919 continue;
920 }
921 for j in (i + 1)..n {
922 if !active[j] {
923 continue;
924 }
925 let dist = distances[i * n + j];
926 if dist < min_dist {
927 min_dist = dist;
928 min_i = i;
929 min_j = j;
930 }
931 }
932 }
933
934 (min_i, min_j)
935 }
936
937 fn update_distances(
938 distances: &mut [f64],
939 c1: usize,
940 c2: usize,
941 n: usize,
942 linkage: LinkageMethod,
943 cluster_sizes: &[usize],
944 active: &[bool],
945 ) {
946 for k in 0..n {
947 if !active[k] || k == c1 || k == c2 {
948 continue;
949 }
950
951 let d1 = distances[c1 * n + k];
952 let d2 = distances[c2 * n + k];
953
954 let new_dist = match linkage {
955 LinkageMethod::Single => d1.min(d2),
956 LinkageMethod::Complete => d1.max(d2),
957 LinkageMethod::Average => {
958 let n1 = cluster_sizes[c1] as f64;
959 let n2 = cluster_sizes[c2] as f64;
960 (n1 * d1 + n2 * d2) / (n1 + n2)
961 }
962 LinkageMethod::Ward => {
963 let n1 = cluster_sizes[c1] as f64;
964 let n2 = cluster_sizes[c2] as f64;
965 let nk = cluster_sizes[k] as f64;
966 let total = n1 + n2 + nk;
967 ((n1 + nk) * d1 * d1 + (n2 + nk) * d2 * d2
968 - nk * distances[c1 * n + c2].powi(2))
969 / total
970 }
971 };
972
973 distances[c1 * n + k] = new_dist;
974 distances[k * n + c1] = new_dist;
975 }
976 }
977}
978
979impl GpuKernel for HierarchicalClustering {
980 fn metadata(&self) -> &KernelMetadata {
981 &self.metadata
982 }
983}
984
985use crate::messages::{
990 DBSCANInput, DBSCANOutput, HierarchicalInput, HierarchicalOutput, KMeansInput, KMeansOutput,
991 Linkage,
992};
993use async_trait::async_trait;
994use rustkernel_core::error::Result;
995use rustkernel_core::traits::BatchKernel;
996use std::time::Instant;
997
998impl KMeans {
1000 pub async fn cluster_batch(&self, input: KMeansInput) -> Result<KMeansOutput> {
1004 let start = Instant::now();
1005 let result = Self::compute(&input.data, input.k, input.max_iterations, input.tolerance);
1006 let compute_time_us = start.elapsed().as_micros() as u64;
1007
1008 Ok(KMeansOutput {
1009 result,
1010 compute_time_us,
1011 })
1012 }
1013}
1014
1015#[async_trait]
1016impl BatchKernel<KMeansInput, KMeansOutput> for KMeans {
1017 async fn execute(&self, input: KMeansInput) -> Result<KMeansOutput> {
1018 self.cluster_batch(input).await
1019 }
1020}
1021
1022#[async_trait]
1024impl BatchKernel<DBSCANInput, DBSCANOutput> for DBSCAN {
1025 async fn execute(&self, input: DBSCANInput) -> Result<DBSCANOutput> {
1026 let start = Instant::now();
1027 let result = Self::compute(&input.data, input.eps, input.min_samples, input.metric);
1028 let compute_time_us = start.elapsed().as_micros() as u64;
1029
1030 Ok(DBSCANOutput {
1031 result,
1032 compute_time_us,
1033 })
1034 }
1035}
1036
1037#[async_trait]
1039impl BatchKernel<HierarchicalInput, HierarchicalOutput> for HierarchicalClustering {
1040 async fn execute(&self, input: HierarchicalInput) -> Result<HierarchicalOutput> {
1041 let start = Instant::now();
1042 let linkage_method = match input.linkage {
1043 Linkage::Single => LinkageMethod::Single,
1044 Linkage::Complete => LinkageMethod::Complete,
1045 Linkage::Average => LinkageMethod::Average,
1046 Linkage::Ward => LinkageMethod::Ward,
1047 };
1048 let result = Self::compute(&input.data, input.n_clusters, linkage_method, input.metric);
1049 let compute_time_us = start.elapsed().as_micros() as u64;
1050
1051 Ok(HierarchicalOutput {
1052 result,
1053 compute_time_us,
1054 })
1055 }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 use super::*;
1061
1062 fn create_two_clusters() -> DataMatrix {
1063 DataMatrix::from_rows(&[
1065 &[0.0, 0.0],
1066 &[0.1, 0.1],
1067 &[0.2, 0.0],
1068 &[10.0, 10.0],
1069 &[10.1, 10.1],
1070 &[10.2, 10.0],
1071 ])
1072 }
1073
1074 #[test]
1075 fn test_kmeans_metadata() {
1076 let kernel = KMeans::new();
1077 assert_eq!(kernel.metadata().id, "ml/kmeans-cluster");
1078 assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
1079 }
1080
1081 #[test]
1082 fn test_kmeans_two_clusters() {
1083 let data = create_two_clusters();
1084 let result = KMeans::compute(&data, 2, 100, 1e-6);
1085
1086 assert_eq!(result.n_clusters, 2);
1087 assert!(result.converged);
1088
1089 assert_eq!(result.labels[0], result.labels[1]);
1091 assert_eq!(result.labels[1], result.labels[2]);
1092 assert_eq!(result.labels[3], result.labels[4]);
1093 assert_eq!(result.labels[4], result.labels[5]);
1094 assert_ne!(result.labels[0], result.labels[3]);
1095 }
1096
1097 #[test]
1098 fn test_dbscan_two_clusters() {
1099 let data = create_two_clusters();
1100 let result = DBSCAN::compute(&data, 1.0, 2, DistanceMetric::Euclidean);
1101
1102 assert_eq!(result.n_clusters, 2);
1103
1104 assert_eq!(result.labels[0], result.labels[1]);
1106 assert_eq!(result.labels[3], result.labels[4]);
1107 assert_ne!(result.labels[0], result.labels[3]);
1108 }
1109
1110 #[test]
1111 fn test_hierarchical_two_clusters() {
1112 let data = create_two_clusters();
1113 let result = HierarchicalClustering::compute(
1114 &data,
1115 2,
1116 LinkageMethod::Complete,
1117 DistanceMetric::Euclidean,
1118 );
1119
1120 assert_eq!(result.n_clusters, 2);
1121
1122 assert_eq!(result.labels[0], result.labels[1]);
1124 assert_eq!(result.labels[1], result.labels[2]);
1125 assert_eq!(result.labels[3], result.labels[4]);
1126 assert_ne!(result.labels[0], result.labels[3]);
1127 }
1128}