1use crate::{
19 error::{SslError, SslResult},
20 handle::LcgRng,
21};
22
23#[derive(Debug, Clone)]
27pub struct DeepClusterConfig {
28 pub n_clusters: usize,
30 pub n_pca_components: usize,
32 pub kmeans_max_iter: usize,
34 pub kmeans_tol: f64,
36 pub reassign_empty: bool,
38 pub seed: u64,
40}
41
42impl Default for DeepClusterConfig {
43 fn default() -> Self {
44 Self {
45 n_clusters: 1000,
46 n_pca_components: 256,
47 kmeans_max_iter: 100,
48 kmeans_tol: 1e-4,
49 reassign_empty: true,
50 seed: 42,
51 }
52 }
53}
54
55impl DeepClusterConfig {
56 pub fn new(
61 n_clusters: usize,
62 n_pca_components: usize,
63 kmeans_max_iter: usize,
64 kmeans_tol: f64,
65 reassign_empty: bool,
66 seed: u64,
67 ) -> SslResult<Self> {
68 if n_clusters == 0 {
69 return Err(SslError::InvalidParameter {
70 name: "n_clusters".to_string(),
71 reason: "must be >= 1".to_string(),
72 });
73 }
74 if kmeans_max_iter == 0 {
75 return Err(SslError::InvalidParameter {
76 name: "kmeans_max_iter".to_string(),
77 reason: "must be >= 1".to_string(),
78 });
79 }
80 Ok(Self {
81 n_clusters,
82 n_pca_components,
83 kmeans_max_iter,
84 kmeans_tol,
85 reassign_empty,
86 seed,
87 })
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct DeepClusterResult {
94 pub labels: Vec<usize>,
96 pub centroids: Vec<f64>,
98 pub inertia: f64,
100 pub n_iter: usize,
102 pub converged: bool,
104 pub n_reassignments: usize,
106 pub empty_clusters: usize,
108}
109
110#[derive(Debug, Clone)]
114pub struct DeeperClusterConfig {
115 pub cluster_scales: Vec<usize>,
117 pub base_config: DeepClusterConfig,
119}
120
121impl Default for DeeperClusterConfig {
122 fn default() -> Self {
123 Self {
124 cluster_scales: vec![100, 1000],
125 base_config: DeepClusterConfig::default(),
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct DeeperClusterResult {
133 pub per_scale: Vec<DeepClusterResult>,
135 pub multi_labels: Vec<Vec<usize>>,
137}
138
139fn compute_covariance(x_centered: &[f64], n: usize, d: usize) -> Vec<f64> {
144 let mut cov = vec![0.0_f64; d * d];
145 let inv_n = 1.0 / (n as f64 - 1.0).max(1.0);
146 for row in 0..n {
147 let xi = &x_centered[row * d..(row + 1) * d];
148 for i in 0..d {
149 for j in i..d {
150 cov[i * d + j] += xi[i] * xi[j] * inv_n;
151 }
152 }
153 }
154 for i in 0..d {
156 for j in 0..i {
157 cov[i * d + j] = cov[j * d + i];
158 }
159 }
160 cov
161}
162
163#[inline]
165fn matvec(a: &[f64], v: &[f64], out: &mut [f64], d: usize) {
166 for i in 0..d {
167 let mut acc = 0.0_f64;
168 for j in 0..d {
169 acc += a[i * d + j] * v[j];
170 }
171 out[i] = acc;
172 }
173}
174
175fn l2_normalize_inplace(v: &mut [f64]) -> f64 {
177 let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
178 if norm > 1e-12 {
179 for x in v.iter_mut() {
180 *x /= norm;
181 }
182 }
183 norm
184}
185
186#[inline]
188fn l2_norm(v: &[f64]) -> f64 {
189 v.iter().map(|x| x * x).sum::<f64>().sqrt()
190}
191
192fn power_iteration(cov: &[f64], d: usize, init_vec: &[f64], n_iter: usize) -> (f64, Vec<f64>) {
195 let mut v = init_vec.to_vec();
196 l2_normalize_inplace(&mut v);
197 let mut av = vec![0.0_f64; d];
198 let mut eigenvalue = 0.0_f64;
199 for _ in 0..n_iter {
200 matvec(cov, &v, &mut av, d);
201 eigenvalue = av.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
202 let norm = l2_norm(&av);
203 if norm < 1e-14 {
204 break;
205 }
206 for i in 0..d {
207 v[i] = av[i] / norm;
208 }
209 }
210 (eigenvalue, v)
211}
212
213fn deflate(cov: &mut [f64], eigenvalue: f64, eigenvec: &[f64], d: usize) {
215 for i in 0..d {
216 for j in 0..d {
217 cov[i * d + j] -= eigenvalue * eigenvec[i] * eigenvec[j];
218 }
219 }
220}
221
222pub fn pca_whiten(
241 features: &[f64],
242 n_samples: usize,
243 feat_dim: usize,
244 n_components: usize,
245 eps: f64,
246) -> SslResult<Vec<f64>> {
247 if n_samples == 0 {
248 return Err(SslError::EmptyInput);
249 }
250 if feat_dim == 0 {
251 return Err(SslError::InvalidFeatureDim);
252 }
253 if n_components == 0 || n_components > feat_dim {
254 return Err(SslError::InvalidParameter {
255 name: "n_components".to_string(),
256 reason: format!("must be in [1, feat_dim={feat_dim}]"),
257 });
258 }
259 if features.len() != n_samples * feat_dim {
260 return Err(SslError::DimensionMismatch {
261 expected: n_samples * feat_dim,
262 got: features.len(),
263 });
264 }
265
266 let mut mean = vec![0.0_f64; feat_dim];
268 for i in 0..n_samples {
269 for j in 0..feat_dim {
270 mean[j] += features[i * feat_dim + j];
271 }
272 }
273 let inv_n = 1.0 / n_samples as f64;
274 for m in mean.iter_mut() {
275 *m *= inv_n;
276 }
277 let mut x_centered = features.to_vec();
278 for i in 0..n_samples {
279 for j in 0..feat_dim {
280 x_centered[i * feat_dim + j] -= mean[j];
281 }
282 }
283
284 let mut cov = compute_covariance(&x_centered, n_samples, feat_dim);
286
287 let power_iter_steps = 30_usize.max(n_components * 2);
289 let mut eigenvecs: Vec<Vec<f64>> = Vec::with_capacity(n_components);
290 let mut eigenvalues: Vec<f64> = Vec::with_capacity(n_components);
291
292 let mut init = vec![0.0_f64; feat_dim];
294 for (i, v) in init.iter_mut().enumerate() {
295 *v = ((i as f64 + 1.0) * 0.618_033_988).fract() * 2.0 - 1.0;
296 }
297
298 for k in 0..n_components {
299 let perturb = (k as f64 + 1.0) * 0.01;
301 let mut v_init: Vec<f64> = init
302 .iter()
303 .enumerate()
304 .map(|(i, &v)| v + perturb * ((i as f64 + k as f64 * 17.0).sin()))
305 .collect();
306 for ev in &eigenvecs {
308 let dot: f64 = v_init.iter().zip(ev.iter()).map(|(a, b)| a * b).sum();
309 for (vi, ei) in v_init.iter_mut().zip(ev.iter()) {
310 *vi -= dot * ei;
311 }
312 }
313 l2_normalize_inplace(&mut v_init);
314 let (lambda, eigvec) = power_iteration(&cov, feat_dim, &v_init, power_iter_steps);
315 let lambda_pos = lambda.max(0.0);
316 deflate(&mut cov, lambda, &eigvec, feat_dim);
317 eigenvecs.push(eigvec);
318 eigenvalues.push(lambda_pos);
319 }
320
321 let mut out = vec![0.0_f64; n_samples * n_components];
324 for i in 0..n_samples {
325 let xi = &x_centered[i * feat_dim..(i + 1) * feat_dim];
326 for k in 0..n_components {
327 let dot: f64 = xi.iter().zip(eigenvecs[k].iter()).map(|(a, b)| a * b).sum();
328 out[i * n_components + k] = dot / (eigenvalues[k] + eps).sqrt();
329 }
330 }
331 Ok(out)
332}
333
334fn kmeans_pp_init(
338 features: &[f64],
339 n_samples: usize,
340 d: usize,
341 k: usize,
342 rng: &mut LcgRng,
343) -> Vec<usize> {
344 let mut chosen = Vec::with_capacity(k);
345 chosen.push(rng.next_usize(n_samples));
347
348 let mut min_sq_dists = vec![f64::MAX; n_samples];
349
350 for c_idx in 1..k {
351 let last = chosen[c_idx - 1];
353 let c_row = &features[last * d..(last + 1) * d];
354 for i in 0..n_samples {
355 let xi = &features[i * d..(i + 1) * d];
356 let sq_dist = sq_dist_slices(xi, c_row);
357 if sq_dist < min_sq_dists[i] {
358 min_sq_dists[i] = sq_dist;
359 }
360 }
361 let total: f64 = min_sq_dists.iter().sum();
363 if total <= 0.0 {
364 chosen.push(rng.next_usize(n_samples));
366 continue;
367 }
368 let threshold = rng.next_f32() as f64 * total;
369 let mut cumsum = 0.0_f64;
370 let mut selected = n_samples - 1;
371 for (i, &dist) in min_sq_dists.iter().enumerate() {
372 cumsum += dist;
373 if cumsum >= threshold {
374 selected = i;
375 break;
376 }
377 }
378 chosen.push(selected);
379 }
380 chosen
381}
382
383#[inline]
387fn sq_dist_slices(a: &[f64], b: &[f64]) -> f64 {
388 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
389}
390
391fn assign_step(
394 features: &[f64],
395 centroids: &[f64],
396 labels: &[usize],
397 n_samples: usize,
398 d: usize,
399 k: usize,
400) -> (Vec<usize>, f64, usize) {
401 let mut new_labels = vec![0_usize; n_samples];
402 let mut inertia = 0.0_f64;
403 let mut n_changed = 0_usize;
404 for i in 0..n_samples {
405 let xi = &features[i * d..(i + 1) * d];
406 let mut best_dist = f64::MAX;
407 let mut best_c = 0_usize;
408 for c in 0..k {
409 let dist = sq_dist_slices(xi, ¢roids[c * d..(c + 1) * d]);
410 if dist < best_dist {
411 best_dist = dist;
412 best_c = c;
413 }
414 }
415 new_labels[i] = best_c;
416 inertia += best_dist;
417 if best_c != labels[i] {
418 n_changed += 1;
419 }
420 }
421 (new_labels, inertia, n_changed)
422}
423
424fn update_step(
427 features: &[f64],
428 labels: &[usize],
429 n_samples: usize,
430 d: usize,
431 k: usize,
432) -> (Vec<f64>, Vec<usize>) {
433 let mut centroids = vec![0.0_f64; k * d];
434 let mut counts = vec![0_usize; k];
435 for i in 0..n_samples {
436 let c = labels[i];
437 counts[c] += 1;
438 let xi = &features[i * d..(i + 1) * d];
439 for j in 0..d {
440 centroids[c * d + j] += xi[j];
441 }
442 }
443 for c in 0..k {
444 if counts[c] > 0 {
445 let inv = 1.0 / counts[c] as f64;
446 for j in 0..d {
447 centroids[c * d + j] *= inv;
448 }
449 }
450 }
451 (centroids, counts)
452}
453
454fn largest_cluster(counts: &[usize]) -> usize {
456 counts
457 .iter()
458 .enumerate()
459 .max_by_key(|&(_, &c)| c)
460 .map(|(i, _)| i)
461 .unwrap_or(0)
462}
463
464fn reassign_empty_clusters(
468 centroids: &mut [f64],
469 counts: &mut [usize],
470 features: &[f64],
471 labels: &mut [usize],
472 n_samples: usize,
473 d: usize,
474 k: usize,
475 rng: &mut LcgRng,
476) {
477 for c in 0..k {
478 if counts[c] == 0 {
479 let src = largest_cluster(counts);
480 let members: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == src).collect();
482 if members.is_empty() {
483 continue;
484 }
485 let rand_idx = members[rng.next_usize(members.len())];
486 let src_row = &features[rand_idx * d..(rand_idx + 1) * d];
488 for j in 0..d {
489 let perturb = 1e-6 * if j % 2 == 0 { 1.0 } else { -1.0 };
491 centroids[c * d + j] = src_row[j] + perturb;
492 }
493 for j in 0..d {
495 let perturb = 1e-6 * if j % 2 == 0 { -1.0 } else { 1.0 };
496 centroids[src * d + j] = features[rand_idx * d + j] + perturb;
497 }
498 counts[c] = 0; }
500 }
501}
502
503pub fn deep_cluster(
525 features: &[f64],
526 n_samples: usize,
527 feat_dim: usize,
528 config: &DeepClusterConfig,
529) -> SslResult<DeepClusterResult> {
530 if n_samples == 0 {
532 return Err(SslError::EmptyInput);
533 }
534 if feat_dim == 0 {
535 return Err(SslError::InvalidFeatureDim);
536 }
537 if config.n_clusters == 0 {
538 return Err(SslError::InvalidParameter {
539 name: "n_clusters".to_string(),
540 reason: "must be >= 1".to_string(),
541 });
542 }
543 if config.n_clusters > n_samples {
544 return Err(SslError::InvalidParameter {
545 name: "n_clusters".to_string(),
546 reason: format!(
547 "must be <= n_samples ({n_samples}), got {}",
548 config.n_clusters
549 ),
550 });
551 }
552 if features.len() != n_samples * feat_dim {
553 return Err(SslError::DimensionMismatch {
554 expected: n_samples * feat_dim,
555 got: features.len(),
556 });
557 }
558
559 let mut rng = LcgRng::new(config.seed);
560 let k = config.n_clusters;
561
562 let (work_features, work_dim) = if config.n_pca_components > 0
564 && config.n_pca_components < feat_dim
565 {
566 let whitened = pca_whiten(features, n_samples, feat_dim, config.n_pca_components, 1e-6)?;
567 let dim = config.n_pca_components;
568 (whitened, dim)
569 } else {
570 (features.to_vec(), feat_dim)
571 };
572
573 let init_indices = kmeans_pp_init(&work_features, n_samples, work_dim, k, &mut rng);
575 let mut centroids = vec![0.0_f64; k * work_dim];
576 for (c, &idx) in init_indices.iter().enumerate() {
577 centroids[c * work_dim..(c + 1) * work_dim]
578 .copy_from_slice(&work_features[idx * work_dim..(idx + 1) * work_dim]);
579 }
580
581 let mut labels = vec![0_usize; n_samples];
583 let mut n_iter = 0_usize;
584 let mut converged = false;
585 let mut final_n_reassignments = n_samples;
586
587 for iter in 0..config.kmeans_max_iter {
588 let (new_labels, _iter_inertia, n_changed) =
590 assign_step(&work_features, ¢roids, &labels, n_samples, work_dim, k);
591 final_n_reassignments = n_changed;
592 labels = new_labels;
593 n_iter = iter + 1;
594
595 let (new_centroids, mut counts) =
597 update_step(&work_features, &labels, n_samples, work_dim, k);
598 centroids = new_centroids;
599
600 if config.reassign_empty {
602 reassign_empty_clusters(
603 &mut centroids,
604 &mut counts,
605 &work_features,
606 &mut labels,
607 n_samples,
608 work_dim,
609 k,
610 &mut rng,
611 );
612 }
613
614 let frac_changed = n_changed as f64 / n_samples as f64;
616 if frac_changed <= config.kmeans_tol {
617 converged = true;
618 break;
619 }
620 }
621
622 let (final_labels, final_inertia, final_changed) =
624 assign_step(&work_features, ¢roids, &labels, n_samples, work_dim, k);
625 labels = final_labels;
626 if n_iter > 0 {
628 final_n_reassignments = final_changed;
629 }
630
631 let (_, final_counts) = update_step(&work_features, &labels, n_samples, work_dim, k);
632 let empty_clusters = final_counts.iter().filter(|&&c| c == 0).count();
633
634 Ok(DeepClusterResult {
635 labels,
636 centroids,
637 inertia: final_inertia,
638 n_iter,
639 converged,
640 n_reassignments: final_n_reassignments,
641 empty_clusters,
642 })
643}
644
645pub fn deeper_cluster(
657 features: &[f64],
658 n_samples: usize,
659 feat_dim: usize,
660 config: &DeeperClusterConfig,
661) -> SslResult<DeeperClusterResult> {
662 if config.cluster_scales.is_empty() {
663 return Err(SslError::InvalidParameter {
664 name: "cluster_scales".to_string(),
665 reason: "must contain at least one scale".to_string(),
666 });
667 }
668
669 let mut per_scale = Vec::with_capacity(config.cluster_scales.len());
670 let mut multi_labels = Vec::with_capacity(config.cluster_scales.len());
671
672 for (scale_idx, &n_clusters) in config.cluster_scales.iter().enumerate() {
673 let scale_seed = config
675 .base_config
676 .seed
677 .wrapping_add(scale_idx as u64 * 0x9e37_79b9_7f4a_7c15);
678
679 let scale_config = DeepClusterConfig {
680 n_clusters,
681 n_pca_components: config.base_config.n_pca_components,
682 kmeans_max_iter: config.base_config.kmeans_max_iter,
683 kmeans_tol: config.base_config.kmeans_tol,
684 reassign_empty: config.base_config.reassign_empty,
685 seed: scale_seed,
686 };
687
688 let result = deep_cluster(features, n_samples, feat_dim, &scale_config)?;
689 multi_labels.push(result.labels.clone());
690 per_scale.push(result);
691 }
692
693 Ok(DeeperClusterResult {
694 per_scale,
695 multi_labels,
696 })
697}
698
699pub fn deep_cluster_loss(
720 logits: &[f32],
721 pseudo_labels: &[usize],
722 n_samples: usize,
723 n_clusters: usize,
724) -> SslResult<f32> {
725 if n_samples == 0 {
726 return Err(SslError::EmptyInput);
727 }
728 if n_clusters < 2 {
729 return Err(SslError::NumPrototypesTooSmall);
730 }
731 if logits.len() != n_samples * n_clusters {
732 return Err(SslError::DimensionMismatch {
733 expected: n_samples * n_clusters,
734 got: logits.len(),
735 });
736 }
737 if pseudo_labels.len() != n_samples {
738 return Err(SslError::DimensionMismatch {
739 expected: n_samples,
740 got: pseudo_labels.len(),
741 });
742 }
743 for (i, &lbl) in pseudo_labels.iter().enumerate() {
744 if lbl >= n_clusters {
745 return Err(SslError::InvalidParameter {
746 name: format!("pseudo_labels[{i}]"),
747 reason: format!("label {lbl} >= n_clusters {n_clusters}"),
748 });
749 }
750 }
751
752 let mut total_loss = 0.0_f64;
753 for i in 0..n_samples {
754 let row = &logits[i * n_clusters..(i + 1) * n_clusters];
755 let max_v = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
757 let mut sum_exp = 0.0_f64;
758 let mut exps = Vec::with_capacity(n_clusters);
759 for &v in row {
760 let e = ((v - max_v) as f64).exp();
761 exps.push(e);
762 sum_exp += e;
763 }
764 let log_sum_exp = sum_exp.max(1e-300).ln();
765 let target_score = (row[pseudo_labels[i]] - max_v) as f64;
767 total_loss += log_sum_exp - target_score;
768 }
769
770 let loss = (total_loss / n_samples as f64) as f32;
771 if !loss.is_finite() {
772 return Err(SslError::NanEncountered {
773 location: "deep_cluster_loss",
774 });
775 }
776 Ok(loss)
777}
778
779#[cfg(test)]
782mod tests {
783 use super::*;
784
785 fn two_cluster_data(n_per_cluster: usize) -> Vec<f64> {
788 let mut data = Vec::with_capacity(2 * n_per_cluster * 2);
789 for i in 0..n_per_cluster {
790 let offset = (i as f64) * 0.01;
791 data.push(5.0 + offset);
792 data.push(0.0 + offset);
793 }
794 for i in 0..n_per_cluster {
795 let offset = (i as f64) * 0.01;
796 data.push(-5.0 - offset);
797 data.push(0.0 + offset);
798 }
799 data
800 }
801
802 #[test]
804 fn both_clusters_non_empty_on_separated_data() {
805 let n_per = 20_usize;
806 let n = 2 * n_per;
807 let d = 2_usize;
808 let data = two_cluster_data(n_per);
809 let config = DeepClusterConfig {
810 n_clusters: 2,
811 n_pca_components: 0, kmeans_max_iter: 100,
813 kmeans_tol: 1e-5,
814 reassign_empty: true,
815 seed: 7,
816 };
817 let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
818 let mut count = [0_usize; 2];
820 for &l in &result.labels {
821 count[l] += 1;
822 }
823 assert!(count[0] > 0, "cluster 0 should be non-empty");
824 assert!(count[1] > 0, "cluster 1 should be non-empty");
825 assert_eq!(count[0] + count[1], n);
826 }
827
828 #[test]
830 fn converges_before_max_iter_on_easy_data() {
831 let n_per = 30_usize;
832 let n = 2 * n_per;
833 let d = 2_usize;
834 let data = two_cluster_data(n_per);
835 let config = DeepClusterConfig {
836 n_clusters: 2,
837 n_pca_components: 0,
838 kmeans_max_iter: 200,
839 kmeans_tol: 1e-3,
840 reassign_empty: true,
841 seed: 13,
842 };
843 let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
844 assert!(
845 result.converged,
846 "should converge; n_iter = {}",
847 result.n_iter
848 );
849 assert!(result.n_iter < 200, "n_iter = {}", result.n_iter);
850 }
851
852 #[test]
854 fn labels_length_equals_n_samples() {
855 let n = 50_usize;
856 let d = 4_usize;
857 let features: Vec<f64> = (0..n * d).map(|i| (i as f64) * 0.01).collect();
858 let config = DeepClusterConfig {
859 n_clusters: 5,
860 n_pca_components: 0,
861 kmeans_max_iter: 20,
862 kmeans_tol: 1e-4,
863 reassign_empty: true,
864 seed: 17,
865 };
866 let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
867 assert_eq!(result.labels.len(), n);
868 }
869
870 #[test]
872 fn centroids_shape_correct() {
873 let n = 40_usize;
874 let d = 6_usize;
875 let k = 4_usize;
876 let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.17).sin()).collect();
877 let config = DeepClusterConfig {
878 n_clusters: k,
879 n_pca_components: 0,
880 kmeans_max_iter: 30,
881 kmeans_tol: 1e-4,
882 reassign_empty: true,
883 seed: 23,
884 };
885 let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
886 assert_eq!(result.centroids.len(), k * d);
887 }
888
889 #[test]
891 fn loss_finite_and_non_negative() {
892 let n = 8_usize;
893 let k = 4_usize;
894 let logits: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.1).collect();
895 let labels = vec![0_usize, 1, 2, 3, 0, 1, 2, 3];
896 let loss =
897 deep_cluster_loss(&logits, &labels, n, k).expect("deep_cluster_loss should succeed");
898 assert!(loss.is_finite(), "loss = {loss}");
899 assert!(loss >= 0.0, "loss = {loss}");
900 }
901
902 #[test]
904 fn uniform_logits_give_ln_k_loss() {
905 let n = 16_usize;
906 let k = 8_usize;
907 let logits = vec![0.0_f32; n * k]; let labels: Vec<usize> = (0..n).map(|i| i % k).collect();
909 let loss =
910 deep_cluster_loss(&logits, &labels, n, k).expect("deep_cluster_loss should succeed");
911 let expected = (k as f32).ln();
912 assert!(
913 (loss - expected).abs() < 1e-4,
914 "loss = {loss}, expected = {expected}"
915 );
916 }
917
918 #[test]
920 fn deeper_cluster_two_scales() {
921 let n = 60_usize;
922 let d = 4_usize;
923 let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.23).sin()).collect();
924 let base = DeepClusterConfig {
925 n_clusters: 2, n_pca_components: 0,
927 kmeans_max_iter: 20,
928 kmeans_tol: 1e-3,
929 reassign_empty: true,
930 seed: 31,
931 };
932 let config = DeeperClusterConfig {
933 cluster_scales: vec![2, 3],
934 base_config: base,
935 };
936 let result =
937 deeper_cluster(&features, n, d, &config).expect("deeper_cluster should succeed");
938 assert_eq!(result.per_scale.len(), 2);
939 assert_eq!(result.multi_labels.len(), 2);
940 assert_eq!(result.multi_labels[0].len(), n);
941 assert_eq!(result.multi_labels[1].len(), n);
942 for &lbl in &result.multi_labels[0] {
944 assert!(lbl < 2, "scale-0 label {lbl} out of range");
945 }
946 for &lbl in &result.multi_labels[1] {
947 assert!(lbl < 3, "scale-1 label {lbl} out of range");
948 }
949 }
950
951 #[test]
953 fn pca_whiten_output_unit_variance_columns() {
954 let n = 200_usize;
956 let d = 2_usize;
957 let mut features = Vec::with_capacity(n * d);
958 for i in 0..n {
959 let t = i as f64;
960 features.push(2.0 * (t * 0.031).sin()); features.push(1.0 * (t * 0.073).cos()); }
963 let n_comp = 2_usize;
964 let whitened =
965 pca_whiten(&features, n, d, n_comp, 1e-6).expect("pca_whiten should succeed");
966 assert_eq!(whitened.len(), n * n_comp);
967 for col in 0..n_comp {
969 let mean: f64 = whitened.iter().skip(col).step_by(n_comp).sum::<f64>() / n as f64;
970 let var: f64 = whitened
971 .iter()
972 .skip(col)
973 .step_by(n_comp)
974 .map(|&v| (v - mean) * (v - mean))
975 .sum::<f64>()
976 / (n as f64 - 1.0);
977 assert!(
978 var > 0.0 && var.is_finite(),
979 "col {col} variance = {var} should be finite and positive"
980 );
981 }
982 }
983
984 #[test]
986 fn empty_cluster_reassignment_does_not_crash() {
987 let n = 10_usize;
989 let d = 2_usize;
990 let features = vec![1.0_f64; n * d];
992 let config = DeepClusterConfig {
993 n_clusters: 5,
994 n_pca_components: 0,
995 kmeans_max_iter: 10,
996 kmeans_tol: 0.0, reassign_empty: true,
998 seed: 37,
999 };
1000 let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
1002 assert_eq!(result.labels.len(), n);
1003 }
1004
1005 #[test]
1007 fn error_on_more_clusters_than_samples() {
1008 let n = 5_usize;
1009 let d = 2_usize;
1010 let features = vec![1.0_f64; n * d];
1011 let config = DeepClusterConfig {
1012 n_clusters: 10, n_pca_components: 0,
1014 kmeans_max_iter: 10,
1015 kmeans_tol: 1e-4,
1016 reassign_empty: true,
1017 seed: 41,
1018 };
1019 assert!(deep_cluster(&features, n, d, &config).is_err());
1020 }
1021
1022 #[test]
1024 fn error_on_zero_clusters() {
1025 let result = DeepClusterConfig::new(0, 0, 10, 1e-4, true, 42);
1026 assert!(result.is_err(), "n_clusters=0 should return an error");
1027 }
1028
1029 #[test]
1031 fn inertia_non_negative_and_finite() {
1032 let n = 50_usize;
1033 let d = 3_usize;
1034 let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.11).sin()).collect();
1035 let config = DeepClusterConfig {
1036 n_clusters: 5,
1037 n_pca_components: 0,
1038 kmeans_max_iter: 50,
1039 kmeans_tol: 1e-4,
1040 reassign_empty: true,
1041 seed: 53,
1042 };
1043 let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
1044 assert!(result.inertia.is_finite(), "inertia = {}", result.inertia);
1045 assert!(result.inertia >= 0.0, "inertia = {}", result.inertia);
1046 }
1047
1048 #[test]
1050 fn converged_true_when_stable() {
1051 let n_per = 20_usize;
1052 let n = 2 * n_per;
1053 let d = 2_usize;
1054 let data = two_cluster_data(n_per);
1055 let config = DeepClusterConfig {
1056 n_clusters: 2,
1057 n_pca_components: 0,
1058 kmeans_max_iter: 500,
1059 kmeans_tol: 0.01, reassign_empty: true,
1061 seed: 61,
1062 };
1063 let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
1064 assert!(result.converged, "should have converged");
1065 }
1066
1067 #[test]
1069 fn loss_rejects_out_of_range_label() {
1070 let n = 4_usize;
1071 let k = 3_usize;
1072 let logits = vec![0.0_f32; n * k];
1073 let labels = vec![0_usize, 1, 2, 3]; assert!(deep_cluster_loss(&logits, &labels, n, k).is_err());
1075 }
1076
1077 #[test]
1079 fn pca_whiten_rejects_invalid_n_components() {
1080 let n = 10_usize;
1081 let d = 4_usize;
1082 let features = vec![1.0_f64; n * d];
1083 assert!(pca_whiten(&features, n, d, 0, 1e-6).is_err());
1085 assert!(pca_whiten(&features, n, d, d + 1, 1e-6).is_err());
1087 }
1088
1089 #[test]
1091 fn loss_rejects_single_cluster() {
1092 let logits = vec![1.0_f32; 4];
1093 let labels = vec![0_usize; 4];
1094 assert!(deep_cluster_loss(&logits, &labels, 4, 1).is_err());
1095 }
1096}