1use crate::ring_messages::{
9 K2KCentroidAggregation, K2KCentroidBroadcast, K2KCentroidBroadcastAck, K2KKMeansSync,
10 K2KKMeansSyncResponse, K2KPartialCentroid, KMeansAssignResponse, KMeansAssignRing,
11 KMeansQueryResponse, KMeansQueryRing, KMeansUpdateResponse, KMeansUpdateRing, from_fixed_point,
12 to_fixed_point, unpack_coordinates,
13};
14use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
15use rand::prelude::*;
16use ringkernel_core::RingContext;
17use rustkernel_core::traits::RingKernelHandler;
18use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
19
20#[derive(Debug, Clone, Default)]
26pub struct KMeansState {
27 pub centroids: Vec<f64>,
29 pub data: Option<DataMatrix>,
31 pub k: usize,
33 pub n_features: usize,
35 pub iteration: u32,
37 pub inertia: f64,
39 pub converged: bool,
41 pub labels: Vec<usize>,
43}
44
45#[derive(Debug)]
49pub struct KMeans {
50 metadata: KernelMetadata,
51 state: std::sync::RwLock<KMeansState>,
53}
54
55impl Clone for KMeans {
56 fn clone(&self) -> Self {
57 Self {
58 metadata: self.metadata.clone(),
59 state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
60 }
61 }
62}
63
64impl Default for KMeans {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl KMeans {
71 #[must_use]
73 pub fn new() -> Self {
74 Self {
75 metadata: KernelMetadata::batch("ml/kmeans-cluster", Domain::StatisticalML)
76 .with_description("K-Means clustering with K-Means++ initialization")
77 .with_throughput(20_000)
78 .with_latency_us(50.0),
79 state: std::sync::RwLock::new(KMeansState::default()),
80 }
81 }
82
83 pub fn initialize(&self, data: DataMatrix, k: usize) {
85 let centroids = Self::kmeans_plus_plus_init(&data, k);
86 let n = data.n_samples;
87 let n_features = data.n_features;
88
89 let mut state = self.state.write().unwrap();
90 *state = KMeansState {
91 centroids,
92 data: Some(data),
93 k,
94 n_features,
95 iteration: 0,
96 inertia: 0.0,
97 converged: false,
98 labels: vec![0; n],
99 };
100 }
101
102 #[allow(clippy::needless_range_loop)]
105 pub fn assign_step(&self) -> f64 {
106 let mut state = self.state.write().unwrap();
107
108 let data = match state.data {
110 Some(ref d) => d.clone(),
111 None => return 0.0,
112 };
113
114 let n = data.n_samples;
115 let d_features = state.n_features;
116 let mut total_inertia = 0.0;
117
118 let centroids = state.centroids.clone();
120
121 let mut new_labels = vec![0usize; n];
123 for i in 0..n {
124 let point = data.row(i);
125 let mut min_dist = f64::MAX;
126 let mut min_cluster = 0;
127
128 for (c, centroid) in centroids.chunks(d_features).enumerate() {
129 let dist = Self::euclidean_distance(point, centroid);
130 if dist < min_dist {
131 min_dist = dist;
132 min_cluster = c;
133 }
134 }
135 new_labels[i] = min_cluster;
136 total_inertia += min_dist * min_dist;
137 }
138
139 state.labels = new_labels;
141 state.inertia = total_inertia;
142 total_inertia
143 }
144
145 pub fn update_step(&self) -> f64 {
148 let mut state = self.state.write().unwrap();
149 let Some(ref data) = state.data else {
150 return 0.0;
151 };
152
153 let n = data.n_samples;
154 let d = state.n_features;
155 let k = state.k;
156
157 let mut new_centroids = vec![0.0f64; k * d];
158 let mut counts = vec![0usize; k];
159
160 for i in 0..n {
161 let cluster = state.labels[i];
162 counts[cluster] += 1;
163 let point = data.row(i);
164 for j in 0..d {
165 new_centroids[cluster * d + j] += point[j];
166 }
167 }
168
169 for c in 0..k {
171 if counts[c] > 0 {
172 for j in 0..d {
173 new_centroids[c * d + j] /= counts[c] as f64;
174 }
175 }
176 }
177
178 let max_shift = state
180 .centroids
181 .chunks(d)
182 .zip(new_centroids.chunks(d))
183 .map(|(old, new)| Self::euclidean_distance(old, new))
184 .fold(0.0f64, f64::max);
185
186 state.centroids = new_centroids;
187 state.iteration += 1;
188 max_shift
189 }
190
191 pub fn query_point(&self, point: &[f64]) -> (usize, f64) {
193 let state = self.state.read().unwrap();
194 let d = state.n_features;
195
196 let mut min_dist = f64::MAX;
197 let mut min_cluster = 0;
198
199 for (c, centroid) in state.centroids.chunks(d).enumerate() {
200 let dist = Self::euclidean_distance(point, centroid);
201 if dist < min_dist {
202 min_dist = dist;
203 min_cluster = c;
204 }
205 }
206
207 (min_cluster, min_dist)
208 }
209
210 pub fn current_iteration(&self) -> u32 {
212 self.state.read().unwrap().iteration
213 }
214
215 pub fn current_inertia(&self) -> f64 {
217 self.state.read().unwrap().inertia
218 }
219
220 #[allow(clippy::needless_range_loop)]
228 pub fn compute(
229 data: &DataMatrix,
230 k: usize,
231 max_iterations: u32,
232 tolerance: f64,
233 ) -> ClusteringResult {
234 let n = data.n_samples;
235 let d = data.n_features;
236
237 if n == 0 || k == 0 || k > n {
238 return ClusteringResult {
239 labels: Vec::new(),
240 n_clusters: 0,
241 centroids: Vec::new(),
242 inertia: 0.0,
243 iterations: 0,
244 converged: true,
245 };
246 }
247
248 let mut centroids = Self::kmeans_plus_plus_init(data, k);
250 let mut labels = vec![0usize; n];
251 let mut converged = false;
252 let mut iterations = 0u32;
253
254 for iter in 0..max_iterations {
255 iterations = iter + 1;
256
257 for i in 0..n {
259 let point = data.row(i);
260 let mut min_dist = f64::MAX;
261 let mut min_cluster = 0;
262
263 for (c, centroid) in centroids.chunks(d).enumerate() {
264 let dist = Self::euclidean_distance(point, centroid);
265 if dist < min_dist {
266 min_dist = dist;
267 min_cluster = c;
268 }
269 }
270 labels[i] = min_cluster;
271 }
272
273 let mut new_centroids = vec![0.0f64; k * d];
275 let mut counts = vec![0usize; k];
276
277 for i in 0..n {
278 let cluster = labels[i];
279 counts[cluster] += 1;
280 let point = data.row(i);
281 for j in 0..d {
282 new_centroids[cluster * d + j] += point[j];
283 }
284 }
285
286 for c in 0..k {
288 if counts[c] > 0 {
289 for j in 0..d {
290 new_centroids[c * d + j] /= counts[c] as f64;
291 }
292 }
293 }
294
295 let max_shift = centroids
297 .chunks(d)
298 .zip(new_centroids.chunks(d))
299 .map(|(old, new)| Self::euclidean_distance(old, new))
300 .fold(0.0f64, f64::max);
301
302 centroids = new_centroids;
303
304 if max_shift < tolerance {
305 converged = true;
306 break;
307 }
308 }
309
310 let inertia: f64 = (0..n)
312 .map(|i| {
313 let point = data.row(i);
314 let centroid_start = labels[i] * d;
315 let centroid = ¢roids[centroid_start..centroid_start + d];
316 let dist = Self::euclidean_distance(point, centroid);
317 dist * dist
318 })
319 .sum();
320
321 ClusteringResult {
322 labels,
323 n_clusters: k,
324 centroids,
325 inertia,
326 iterations,
327 converged,
328 }
329 }
330
331 #[allow(clippy::needless_range_loop)]
333 fn kmeans_plus_plus_init(data: &DataMatrix, k: usize) -> Vec<f64> {
334 let n = data.n_samples;
335 let d = data.n_features;
336 let mut rng = rand::rng();
337 let mut centroids = Vec::with_capacity(k * d);
338
339 let first_idx = rng.random_range(0..n);
341 centroids.extend_from_slice(data.row(first_idx));
342
343 let mut distances = vec![f64::MAX; n];
344
345 for _ in 1..k {
347 for i in 0..n {
349 let point = data.row(i);
350 let last_centroid = ¢roids[centroids.len() - d..];
351 let dist = Self::euclidean_distance(point, last_centroid);
352 distances[i] = distances[i].min(dist);
353 }
354
355 let total: f64 = distances.iter().map(|d| d * d).sum();
357 let threshold = rng.random::<f64>() * total;
358
359 let mut cumsum = 0.0;
360 let mut next_idx = 0;
361 for (i, &dist) in distances.iter().enumerate() {
362 cumsum += dist * dist;
363 if cumsum >= threshold {
364 next_idx = i;
365 break;
366 }
367 }
368
369 centroids.extend_from_slice(data.row(next_idx));
370 }
371
372 centroids
373 }
374
375 fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
377 a.iter()
378 .zip(b.iter())
379 .map(|(x, y)| (x - y).powi(2))
380 .sum::<f64>()
381 .sqrt()
382 }
383}
384
385impl GpuKernel for KMeans {
386 fn metadata(&self) -> &KernelMetadata {
387 &self.metadata
388 }
389}
390
391#[async_trait::async_trait]
397impl RingKernelHandler<KMeansAssignRing, KMeansAssignResponse> for KMeans {
398 async fn handle(
399 &self,
400 _ctx: &mut RingContext,
401 msg: KMeansAssignRing,
402 ) -> Result<KMeansAssignResponse> {
403 let inertia = self.assign_step();
405
406 let state = self.state.read().unwrap();
407 let points_assigned = state.labels.len() as u32;
408
409 Ok(KMeansAssignResponse {
410 request_id: msg.id.0,
411 iteration: msg.iteration,
412 inertia_fp: to_fixed_point(inertia),
413 points_assigned,
414 })
415 }
416}
417
418#[async_trait::async_trait]
420impl RingKernelHandler<KMeansUpdateRing, KMeansUpdateResponse> for KMeans {
421 async fn handle(
422 &self,
423 _ctx: &mut RingContext,
424 msg: KMeansUpdateRing,
425 ) -> Result<KMeansUpdateResponse> {
426 let max_shift = self.update_step();
428 let converged = max_shift < 1e-6;
429
430 if converged {
432 let mut state = self.state.write().unwrap();
433 state.converged = true;
434 }
435
436 Ok(KMeansUpdateResponse {
437 request_id: msg.id.0,
438 iteration: msg.iteration,
439 max_shift_fp: to_fixed_point(max_shift),
440 converged,
441 })
442 }
443}
444
445#[async_trait::async_trait]
447impl RingKernelHandler<KMeansQueryRing, KMeansQueryResponse> for KMeans {
448 async fn handle(
449 &self,
450 _ctx: &mut RingContext,
451 msg: KMeansQueryRing,
452 ) -> Result<KMeansQueryResponse> {
453 let point = unpack_coordinates(&msg.point, msg.n_dims as usize);
455
456 let (cluster, distance) = self.query_point(&point);
458
459 Ok(KMeansQueryResponse {
460 request_id: msg.id.0,
461 cluster: cluster as u32,
462 distance_fp: to_fixed_point(distance),
463 })
464 }
465}
466
467#[async_trait::async_trait]
471impl RingKernelHandler<K2KPartialCentroid, K2KCentroidAggregation> for KMeans {
472 #[allow(clippy::needless_range_loop)]
473 async fn handle(
474 &self,
475 _ctx: &mut RingContext,
476 msg: K2KPartialCentroid,
477 ) -> Result<K2KCentroidAggregation> {
478 let n_dims = msg.n_dims as usize;
479 let cluster_id = msg.cluster_id as usize;
480 let mut new_centroid = [0i64; 8];
481
482 if msg.point_count > 0 {
484 for i in 0..n_dims.min(8) {
485 new_centroid[i] = msg.coord_sum_fp[i] / msg.point_count as i64;
486 }
487 }
488
489 let shift = {
491 let state = self.state.read().unwrap();
492 let d = state.n_features;
493 if cluster_id < state.k && d > 0 {
494 let old_centroid = &state.centroids[cluster_id * d..(cluster_id + 1) * d];
495 let new_coords: Vec<f64> = new_centroid[..d.min(8)]
496 .iter()
497 .map(|&v| from_fixed_point(v))
498 .collect();
499 Self::euclidean_distance(old_centroid, &new_coords)
500 } else {
501 0.0
502 }
503 };
504
505 Ok(K2KCentroidAggregation {
506 request_id: msg.id.0,
507 cluster_id: msg.cluster_id,
508 iteration: msg.iteration,
509 new_centroid_fp: new_centroid,
510 total_points: msg.point_count,
511 shift_fp: to_fixed_point(shift),
512 })
513 }
514}
515
516#[async_trait::async_trait]
521impl RingKernelHandler<K2KKMeansSync, K2KKMeansSyncResponse> for KMeans {
522 async fn handle(
523 &self,
524 _ctx: &mut RingContext,
525 msg: K2KKMeansSync,
526 ) -> Result<K2KKMeansSyncResponse> {
527 let state = self.state.read().unwrap();
528
529 let current_iteration = state.iteration as u64;
531 let all_synced = msg.iteration <= current_iteration;
532
533 let global_shift = from_fixed_point(msg.max_shift_fp);
536 let converged = global_shift < 1e-6 || state.converged;
537
538 Ok(K2KKMeansSyncResponse {
539 request_id: msg.id.0,
540 iteration: msg.iteration,
541 all_synced,
542 global_inertia_fp: msg.local_inertia_fp,
543 global_max_shift_fp: msg.max_shift_fp,
544 converged,
545 })
546 }
547}
548
549#[async_trait::async_trait]
553impl RingKernelHandler<K2KCentroidBroadcast, K2KCentroidBroadcastAck> for KMeans {
554 async fn handle(
555 &self,
556 _ctx: &mut RingContext,
557 msg: K2KCentroidBroadcast,
558 ) -> Result<K2KCentroidBroadcastAck> {
559 Ok(K2KCentroidBroadcastAck {
561 request_id: msg.id.0,
562 worker_id: 0, iteration: msg.iteration,
564 applied: true,
565 })
566 }
567}
568
569#[derive(Debug, Clone)]
577pub struct DBSCAN {
578 metadata: KernelMetadata,
579}
580
581impl Default for DBSCAN {
582 fn default() -> Self {
583 Self::new()
584 }
585}
586
587impl DBSCAN {
588 #[must_use]
590 pub fn new() -> Self {
591 Self {
592 metadata: KernelMetadata::batch("ml/dbscan-cluster", Domain::StatisticalML)
593 .with_description("Density-based clustering with GPU union-find")
594 .with_throughput(1_000)
595 .with_latency_us(10_000.0),
596 }
597 }
598
599 #[allow(clippy::needless_range_loop)]
607 pub fn compute(
608 data: &DataMatrix,
609 eps: f64,
610 min_samples: usize,
611 metric: DistanceMetric,
612 ) -> ClusteringResult {
613 let n = data.n_samples;
614
615 if n == 0 {
616 return ClusteringResult {
617 labels: Vec::new(),
618 n_clusters: 0,
619 centroids: Vec::new(),
620 inertia: 0.0,
621 iterations: 1,
622 converged: true,
623 };
624 }
625
626 let mut labels = vec![-1i64; n];
628 let mut current_cluster = 0i64;
629
630 let neighborhoods: Vec<Vec<usize>> = (0..n)
632 .map(|i| Self::get_neighbors(data, i, eps, metric))
633 .collect();
634
635 for i in 0..n {
636 if labels[i] != -1 {
637 continue; }
639
640 let neighbors = &neighborhoods[i];
641
642 if neighbors.len() < min_samples {
643 labels[i] = -2; continue;
645 }
646
647 labels[i] = current_cluster;
649 let mut seed_set: Vec<usize> = neighbors.clone();
650 let mut j = 0;
651
652 while j < seed_set.len() {
653 let q = seed_set[j];
654 j += 1;
655
656 if labels[q] == -2 {
657 labels[q] = current_cluster; }
659
660 if labels[q] != -1 {
661 continue; }
663
664 labels[q] = current_cluster;
665
666 let q_neighbors = &neighborhoods[q];
667 if q_neighbors.len() >= min_samples {
668 for &neighbor in q_neighbors {
670 if !seed_set.contains(&neighbor) {
671 seed_set.push(neighbor);
672 }
673 }
674 }
675 }
676
677 current_cluster += 1;
678 }
679
680 let n_clusters = current_cluster as usize;
682 let labels: Vec<usize> = labels
683 .iter()
684 .map(|&l| if l < 0 { usize::MAX } else { l as usize })
685 .collect();
686
687 let d = data.n_features;
689 let mut centroids = vec![0.0f64; n_clusters * d];
690 let mut counts = vec![0usize; n_clusters];
691
692 for i in 0..n {
693 if labels[i] < n_clusters {
694 let cluster = labels[i];
695 counts[cluster] += 1;
696 for j in 0..d {
697 centroids[cluster * d + j] += data.row(i)[j];
698 }
699 }
700 }
701
702 for c in 0..n_clusters {
703 if counts[c] > 0 {
704 for j in 0..d {
705 centroids[c * d + j] /= counts[c] as f64;
706 }
707 }
708 }
709
710 ClusteringResult {
711 labels,
712 n_clusters,
713 centroids,
714 inertia: 0.0,
715 iterations: 1,
716 converged: true,
717 }
718 }
719
720 fn get_neighbors(
722 data: &DataMatrix,
723 point_idx: usize,
724 eps: f64,
725 metric: DistanceMetric,
726 ) -> Vec<usize> {
727 let n = data.n_samples;
728 let point = data.row(point_idx);
729
730 (0..n)
731 .filter(|&i| {
732 let other = data.row(i);
733 let dist = metric.compute(point, other);
734 dist <= eps
735 })
736 .collect()
737 }
738}
739
740impl GpuKernel for DBSCAN {
741 fn metadata(&self) -> &KernelMetadata {
742 &self.metadata
743 }
744}
745
746#[derive(Debug, Clone, Copy, PartialEq)]
752pub enum LinkageMethod {
753 Single,
755 Complete,
757 Average,
759 Ward,
761}
762
763#[derive(Debug, Clone)]
767pub struct HierarchicalClustering {
768 metadata: KernelMetadata,
769}
770
771impl Default for HierarchicalClustering {
772 fn default() -> Self {
773 Self::new()
774 }
775}
776
777impl HierarchicalClustering {
778 #[must_use]
780 pub fn new() -> Self {
781 Self {
782 metadata: KernelMetadata::batch("ml/hierarchical-cluster", Domain::StatisticalML)
783 .with_description("Agglomerative hierarchical clustering")
784 .with_throughput(500)
785 .with_latency_us(50_000.0),
786 }
787 }
788
789 #[allow(clippy::needless_range_loop)]
797 pub fn compute(
798 data: &DataMatrix,
799 n_clusters: usize,
800 linkage: LinkageMethod,
801 metric: DistanceMetric,
802 ) -> ClusteringResult {
803 let n = data.n_samples;
804
805 if n == 0 || n_clusters == 0 {
806 return ClusteringResult {
807 labels: Vec::new(),
808 n_clusters: 0,
809 centroids: Vec::new(),
810 inertia: 0.0,
811 iterations: 0,
812 converged: true,
813 };
814 }
815
816 let mut labels: Vec<usize> = (0..n).collect();
818 let mut active_clusters: Vec<bool> = vec![true; n];
819 let mut cluster_sizes: Vec<usize> = vec![1; n];
820
821 let mut distances = Self::compute_distance_matrix(data, metric);
823
824 let mut current_n_clusters = n;
826
827 while current_n_clusters > n_clusters {
828 let (c1, c2) = Self::find_closest_clusters(&distances, &active_clusters, n);
830
831 if c1 == c2 {
832 break;
833 }
834
835 for label in &mut labels {
837 if *label == c2 {
838 *label = c1;
839 }
840 }
841
842 Self::update_distances(
844 &mut distances,
845 c1,
846 c2,
847 n,
848 linkage,
849 &cluster_sizes,
850 &active_clusters,
851 );
852
853 cluster_sizes[c1] += cluster_sizes[c2];
854 active_clusters[c2] = false;
855 current_n_clusters -= 1;
856 }
857
858 let mut label_map = std::collections::HashMap::new();
860 let mut next_label = 0usize;
861
862 for label in &mut labels {
863 let new_label = *label_map.entry(*label).or_insert_with(|| {
864 let l = next_label;
865 next_label += 1;
866 l
867 });
868 *label = new_label;
869 }
870
871 let d = data.n_features;
873 let final_n_clusters = next_label;
874 let mut centroids = vec![0.0f64; final_n_clusters * d];
875 let mut counts = vec![0usize; final_n_clusters];
876
877 for i in 0..n {
878 let cluster = labels[i];
879 counts[cluster] += 1;
880 for j in 0..d {
881 centroids[cluster * d + j] += data.row(i)[j];
882 }
883 }
884
885 for c in 0..final_n_clusters {
886 if counts[c] > 0 {
887 for j in 0..d {
888 centroids[c * d + j] /= counts[c] as f64;
889 }
890 }
891 }
892
893 ClusteringResult {
894 labels,
895 n_clusters: final_n_clusters,
896 centroids,
897 inertia: 0.0,
898 iterations: (n - n_clusters) as u32,
899 converged: true,
900 }
901 }
902
903 fn compute_distance_matrix(data: &DataMatrix, metric: DistanceMetric) -> Vec<f64> {
904 let n = data.n_samples;
905 let mut distances = vec![f64::MAX; n * n];
906
907 for i in 0..n {
908 for j in 0..n {
909 if i != j {
910 distances[i * n + j] = metric.compute(data.row(i), data.row(j));
911 }
912 }
913 }
914
915 distances
916 }
917
918 fn find_closest_clusters(distances: &[f64], active: &[bool], n: usize) -> (usize, usize) {
919 let mut min_dist = f64::MAX;
920 let mut min_i = 0;
921 let mut min_j = 0;
922
923 for i in 0..n {
924 if !active[i] {
925 continue;
926 }
927 for j in (i + 1)..n {
928 if !active[j] {
929 continue;
930 }
931 let dist = distances[i * n + j];
932 if dist < min_dist {
933 min_dist = dist;
934 min_i = i;
935 min_j = j;
936 }
937 }
938 }
939
940 (min_i, min_j)
941 }
942
943 fn update_distances(
944 distances: &mut [f64],
945 c1: usize,
946 c2: usize,
947 n: usize,
948 linkage: LinkageMethod,
949 cluster_sizes: &[usize],
950 active: &[bool],
951 ) {
952 for k in 0..n {
953 if !active[k] || k == c1 || k == c2 {
954 continue;
955 }
956
957 let d1 = distances[c1 * n + k];
958 let d2 = distances[c2 * n + k];
959
960 let new_dist = match linkage {
961 LinkageMethod::Single => d1.min(d2),
962 LinkageMethod::Complete => d1.max(d2),
963 LinkageMethod::Average => {
964 let n1 = cluster_sizes[c1] as f64;
965 let n2 = cluster_sizes[c2] as f64;
966 (n1 * d1 + n2 * d2) / (n1 + n2)
967 }
968 LinkageMethod::Ward => {
969 let n1 = cluster_sizes[c1] as f64;
970 let n2 = cluster_sizes[c2] as f64;
971 let nk = cluster_sizes[k] as f64;
972 let total = n1 + n2 + nk;
973 ((n1 + nk) * d1 * d1 + (n2 + nk) * d2 * d2
974 - nk * distances[c1 * n + c2].powi(2))
975 / total
976 }
977 };
978
979 distances[c1 * n + k] = new_dist;
980 distances[k * n + c1] = new_dist;
981 }
982 }
983}
984
985impl GpuKernel for HierarchicalClustering {
986 fn metadata(&self) -> &KernelMetadata {
987 &self.metadata
988 }
989}
990
991use crate::messages::{
996 DBSCANInput, DBSCANOutput, HierarchicalInput, HierarchicalOutput, KMeansInput, KMeansOutput,
997 Linkage,
998};
999use async_trait::async_trait;
1000use rustkernel_core::error::Result;
1001use rustkernel_core::traits::BatchKernel;
1002use std::time::Instant;
1003
1004impl KMeans {
1006 pub async fn cluster_batch(&self, input: KMeansInput) -> Result<KMeansOutput> {
1010 let start = Instant::now();
1011 let result = Self::compute(&input.data, input.k, input.max_iterations, input.tolerance);
1012 let compute_time_us = start.elapsed().as_micros() as u64;
1013
1014 Ok(KMeansOutput {
1015 result,
1016 compute_time_us,
1017 })
1018 }
1019}
1020
1021#[async_trait]
1022impl BatchKernel<KMeansInput, KMeansOutput> for KMeans {
1023 async fn execute(&self, input: KMeansInput) -> Result<KMeansOutput> {
1024 self.cluster_batch(input).await
1025 }
1026}
1027
1028#[async_trait]
1030impl BatchKernel<DBSCANInput, DBSCANOutput> for DBSCAN {
1031 async fn execute(&self, input: DBSCANInput) -> Result<DBSCANOutput> {
1032 let start = Instant::now();
1033 let result = Self::compute(&input.data, input.eps, input.min_samples, input.metric);
1034 let compute_time_us = start.elapsed().as_micros() as u64;
1035
1036 Ok(DBSCANOutput {
1037 result,
1038 compute_time_us,
1039 })
1040 }
1041}
1042
1043#[async_trait]
1045impl BatchKernel<HierarchicalInput, HierarchicalOutput> for HierarchicalClustering {
1046 async fn execute(&self, input: HierarchicalInput) -> Result<HierarchicalOutput> {
1047 let start = Instant::now();
1048 let linkage_method = match input.linkage {
1049 Linkage::Single => LinkageMethod::Single,
1050 Linkage::Complete => LinkageMethod::Complete,
1051 Linkage::Average => LinkageMethod::Average,
1052 Linkage::Ward => LinkageMethod::Ward,
1053 };
1054 let result = Self::compute(&input.data, input.n_clusters, linkage_method, input.metric);
1055 let compute_time_us = start.elapsed().as_micros() as u64;
1056
1057 Ok(HierarchicalOutput {
1058 result,
1059 compute_time_us,
1060 })
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067
1068 fn create_two_clusters() -> DataMatrix {
1069 DataMatrix::from_rows(&[
1071 &[0.0, 0.0],
1072 &[0.1, 0.1],
1073 &[0.2, 0.0],
1074 &[10.0, 10.0],
1075 &[10.1, 10.1],
1076 &[10.2, 10.0],
1077 ])
1078 }
1079
1080 #[test]
1081 fn test_kmeans_metadata() {
1082 let kernel = KMeans::new();
1083 assert_eq!(kernel.metadata().id, "ml/kmeans-cluster");
1084 assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
1085 }
1086
1087 #[test]
1088 fn test_kmeans_two_clusters() {
1089 let data = create_two_clusters();
1090 let result = KMeans::compute(&data, 2, 100, 1e-6);
1091
1092 assert_eq!(result.n_clusters, 2);
1093 assert!(result.converged);
1094
1095 assert_eq!(result.labels[0], result.labels[1]);
1097 assert_eq!(result.labels[1], result.labels[2]);
1098 assert_eq!(result.labels[3], result.labels[4]);
1099 assert_eq!(result.labels[4], result.labels[5]);
1100 assert_ne!(result.labels[0], result.labels[3]);
1101 }
1102
1103 #[test]
1104 fn test_dbscan_two_clusters() {
1105 let data = create_two_clusters();
1106 let result = DBSCAN::compute(&data, 1.0, 2, DistanceMetric::Euclidean);
1107
1108 assert_eq!(result.n_clusters, 2);
1109
1110 assert_eq!(result.labels[0], result.labels[1]);
1112 assert_eq!(result.labels[3], result.labels[4]);
1113 assert_ne!(result.labels[0], result.labels[3]);
1114 }
1115
1116 #[test]
1117 fn test_hierarchical_two_clusters() {
1118 let data = create_two_clusters();
1119 let result = HierarchicalClustering::compute(
1120 &data,
1121 2,
1122 LinkageMethod::Complete,
1123 DistanceMetric::Euclidean,
1124 );
1125
1126 assert_eq!(result.n_clusters, 2);
1127
1128 assert_eq!(result.labels[0], result.labels[1]);
1130 assert_eq!(result.labels[1], result.labels[2]);
1131 assert_eq!(result.labels[3], result.labels[4]);
1132 assert_ne!(result.labels[0], result.labels[3]);
1133 }
1134}