Skip to main content

oxicuda_ssl/clustering/
deep_cluster.rs

1//! DeepCluster — Caron et al. 2018 — Unsupervised Learning of Visual Features
2//! by Clustering with Convolutions.
3//!
4//! This module implements the CPU-side components of DeepCluster and its
5//! hierarchical extension DeeperCluster:
6//!
7//! 1. **PCA whitening** — removes dominant directions from feature space.
8//! 2. **k-means clustering** with k-means++ initialisation and empty-cluster
9//!    reassignment, operating on (optionally whitened) L2-normalised features.
10//! 3. **Pseudo-label assignment** — cluster id per sample for cross-entropy
11//!    supervision.
12//! 4. **DeeperCluster** — multi-scale hierarchical clustering, yielding one set
13//!    of pseudo-labels per scale.
14//!
15//! Reference: Caron et al., *Deep Clustering for Unsupervised Learning of
16//! Visual Features*, ECCV 2018.
17
18use crate::{
19    error::{SslError, SslResult},
20    handle::LcgRng,
21};
22
23// ─── Configuration ────────────────────────────────────────────────────────────
24
25/// Configuration for the DeepCluster CPU-side pipeline.
26#[derive(Debug, Clone)]
27pub struct DeepClusterConfig {
28    /// Number of clusters `k` for k-means. Default: 1000.
29    pub n_clusters: usize,
30    /// Number of PCA components for whitening; 0 = skip whitening. Default: 256.
31    pub n_pca_components: usize,
32    /// Maximum k-means iterations. Default: 100.
33    pub kmeans_max_iter: usize,
34    /// Convergence tolerance: fraction of reassigned points. Default: 1e-4.
35    pub kmeans_tol: f64,
36    /// Whether to reassign empty clusters to avoid degenerate solutions. Default: true.
37    pub reassign_empty: bool,
38    /// Seed for the deterministic LCG RNG. Default: 42.
39    pub seed: u64,
40}
41
42impl Default for DeepClusterConfig {
43    fn default() -> Self {
44        Self {
45            n_clusters: 1000,
46            n_pca_components: 256,
47            kmeans_max_iter: 100,
48            kmeans_tol: 1e-4,
49            reassign_empty: true,
50            seed: 42,
51        }
52    }
53}
54
55impl DeepClusterConfig {
56    /// Construct a validated DeepClusterConfig.
57    ///
58    /// # Errors
59    /// - [`SslError::InvalidParameter`] when `n_clusters == 0` or `kmeans_max_iter == 0`.
60    pub fn new(
61        n_clusters: usize,
62        n_pca_components: usize,
63        kmeans_max_iter: usize,
64        kmeans_tol: f64,
65        reassign_empty: bool,
66        seed: u64,
67    ) -> SslResult<Self> {
68        if n_clusters == 0 {
69            return Err(SslError::InvalidParameter {
70                name: "n_clusters".to_string(),
71                reason: "must be >= 1".to_string(),
72            });
73        }
74        if kmeans_max_iter == 0 {
75            return Err(SslError::InvalidParameter {
76                name: "kmeans_max_iter".to_string(),
77                reason: "must be >= 1".to_string(),
78            });
79        }
80        Ok(Self {
81            n_clusters,
82            n_pca_components,
83            kmeans_max_iter,
84            kmeans_tol,
85            reassign_empty,
86            seed,
87        })
88    }
89}
90
91/// Output of one DeepCluster run.
92#[derive(Debug, Clone)]
93pub struct DeepClusterResult {
94    /// Cluster assignment per sample; length = `n_samples`.
95    pub labels: Vec<usize>,
96    /// Centroid matrix `[k × d]` row-major.
97    pub centroids: Vec<f64>,
98    /// Sum of squared distances to the assigned centroids (inertia).
99    pub inertia: f64,
100    /// Actual number of k-means iterations performed.
101    pub n_iter: usize,
102    /// Whether the algorithm converged before `kmeans_max_iter`.
103    pub converged: bool,
104    /// Number of reassignments in the final iteration.
105    pub n_reassignments: usize,
106    /// Number of clusters with zero assigned samples after the final iteration.
107    pub empty_clusters: usize,
108}
109
110// ─── DeeperCluster configuration and result ───────────────────────────────────
111
112/// Configuration for DeeperCluster (multi-scale hierarchical clustering).
113#[derive(Debug, Clone)]
114pub struct DeeperClusterConfig {
115    /// Cluster counts per scale, e.g. `[100, 1000, 10000]`.
116    pub cluster_scales: Vec<usize>,
117    /// Base DeepCluster config shared across all scales (except `n_clusters`).
118    pub base_config: DeepClusterConfig,
119}
120
121impl Default for DeeperClusterConfig {
122    fn default() -> Self {
123        Self {
124            cluster_scales: vec![100, 1000],
125            base_config: DeepClusterConfig::default(),
126        }
127    }
128}
129
130/// Output of a DeeperCluster run.
131#[derive(Debug, Clone)]
132pub struct DeeperClusterResult {
133    /// One [`DeepClusterResult`] per scale in `cluster_scales`.
134    pub per_scale: Vec<DeepClusterResult>,
135    /// `[n_scales][n_samples]` pseudo-labels — one label list per scale.
136    pub multi_labels: Vec<Vec<usize>>,
137}
138
139// ─── PCA whitening ────────────────────────────────────────────────────────────
140
141/// Compute the centred covariance matrix of `X ∈ ℝ^{n × d}`.
142/// Returns `[d × d]` upper-triangle-filled symmetric matrix.
143fn compute_covariance(x_centered: &[f64], n: usize, d: usize) -> Vec<f64> {
144    let mut cov = vec![0.0_f64; d * d];
145    let inv_n = 1.0 / (n as f64 - 1.0).max(1.0);
146    for row in 0..n {
147        let xi = &x_centered[row * d..(row + 1) * d];
148        for i in 0..d {
149            for j in i..d {
150                cov[i * d + j] += xi[i] * xi[j] * inv_n;
151            }
152        }
153    }
154    // Mirror upper triangle to lower
155    for i in 0..d {
156        for j in 0..i {
157            cov[i * d + j] = cov[j * d + i];
158        }
159    }
160    cov
161}
162
163/// Compute Av in-place (matrix–vector product) for symmetric `[d × d]` matrix.
164#[inline]
165fn matvec(a: &[f64], v: &[f64], out: &mut [f64], d: usize) {
166    for i in 0..d {
167        let mut acc = 0.0_f64;
168        for j in 0..d {
169            acc += a[i * d + j] * v[j];
170        }
171        out[i] = acc;
172    }
173}
174
175/// L2-normalise a mutable slice in-place; returns the norm.
176fn l2_normalize_inplace(v: &mut [f64]) -> f64 {
177    let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
178    if norm > 1e-12 {
179        for x in v.iter_mut() {
180            *x /= norm;
181        }
182    }
183    norm
184}
185
186/// L2 norm of a slice.
187#[inline]
188fn l2_norm(v: &[f64]) -> f64 {
189    v.iter().map(|x| x * x).sum::<f64>().sqrt()
190}
191
192/// Power iteration to find the dominant eigenvector of a symmetric matrix.
193/// Initialises with `init_vec` and runs `n_iter` steps.
194fn power_iteration(cov: &[f64], d: usize, init_vec: &[f64], n_iter: usize) -> (f64, Vec<f64>) {
195    let mut v = init_vec.to_vec();
196    l2_normalize_inplace(&mut v);
197    let mut av = vec![0.0_f64; d];
198    let mut eigenvalue = 0.0_f64;
199    for _ in 0..n_iter {
200        matvec(cov, &v, &mut av, d);
201        eigenvalue = av.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
202        let norm = l2_norm(&av);
203        if norm < 1e-14 {
204            break;
205        }
206        for i in 0..d {
207            v[i] = av[i] / norm;
208        }
209    }
210    (eigenvalue, v)
211}
212
213/// Deflate the covariance matrix: `cov -= λ v vᵀ`.
214fn deflate(cov: &mut [f64], eigenvalue: f64, eigenvec: &[f64], d: usize) {
215    for i in 0..d {
216        for j in 0..d {
217            cov[i * d + j] -= eigenvalue * eigenvec[i] * eigenvec[j];
218        }
219    }
220}
221
222/// PCA whitening: project `X` onto the top `n_components` principal directions,
223/// then divide each projected dimension by `sqrt(eigenvalue + eps)`.
224///
225/// # Arguments
226/// * `features`    — `[n_samples × feat_dim]` row-major input features.
227/// * `n_samples`   — number of data points.
228/// * `feat_dim`    — feature dimensionality.
229/// * `n_components`— number of principal components to retain.
230/// * `eps`         — regularisation added to eigenvalues before sqrt (prevents /0).
231///
232/// # Returns
233/// `[n_samples × n_components]` whitened and projected feature matrix.
234///
235/// # Errors
236/// - [`SslError::EmptyInput`] if `n_samples == 0`.
237/// - [`SslError::InvalidFeatureDim`] if `feat_dim == 0`.
238/// - [`SslError::InvalidParameter`] if `n_components == 0` or `> feat_dim`.
239/// - [`SslError::DimensionMismatch`] if `features.len() != n_samples * feat_dim`.
240pub fn pca_whiten(
241    features: &[f64],
242    n_samples: usize,
243    feat_dim: usize,
244    n_components: usize,
245    eps: f64,
246) -> SslResult<Vec<f64>> {
247    if n_samples == 0 {
248        return Err(SslError::EmptyInput);
249    }
250    if feat_dim == 0 {
251        return Err(SslError::InvalidFeatureDim);
252    }
253    if n_components == 0 || n_components > feat_dim {
254        return Err(SslError::InvalidParameter {
255            name: "n_components".to_string(),
256            reason: format!("must be in [1, feat_dim={feat_dim}]"),
257        });
258    }
259    if features.len() != n_samples * feat_dim {
260        return Err(SslError::DimensionMismatch {
261            expected: n_samples * feat_dim,
262            got: features.len(),
263        });
264    }
265
266    // Center the data.
267    let mut mean = vec![0.0_f64; feat_dim];
268    for i in 0..n_samples {
269        for j in 0..feat_dim {
270            mean[j] += features[i * feat_dim + j];
271        }
272    }
273    let inv_n = 1.0 / n_samples as f64;
274    for m in mean.iter_mut() {
275        *m *= inv_n;
276    }
277    let mut x_centered = features.to_vec();
278    for i in 0..n_samples {
279        for j in 0..feat_dim {
280            x_centered[i * feat_dim + j] -= mean[j];
281        }
282    }
283
284    // Covariance matrix.
285    let mut cov = compute_covariance(&x_centered, n_samples, feat_dim);
286
287    // Deflated power iteration: extract top-n_components eigenpairs.
288    let power_iter_steps = 30_usize.max(n_components * 2);
289    let mut eigenvecs: Vec<Vec<f64>> = Vec::with_capacity(n_components);
290    let mut eigenvalues: Vec<f64> = Vec::with_capacity(n_components);
291
292    // Initialise first eigenvector from a deterministic vector to avoid RNG.
293    let mut init = vec![0.0_f64; feat_dim];
294    for (i, v) in init.iter_mut().enumerate() {
295        *v = ((i as f64 + 1.0) * 0.618_033_988).fract() * 2.0 - 1.0;
296    }
297
298    for k in 0..n_components {
299        // Perturb init slightly per component.
300        let perturb = (k as f64 + 1.0) * 0.01;
301        let mut v_init: Vec<f64> = init
302            .iter()
303            .enumerate()
304            .map(|(i, &v)| v + perturb * ((i as f64 + k as f64 * 17.0).sin()))
305            .collect();
306        // Orthogonalise against previously found eigenvectors (classical Gram-Schmidt).
307        for ev in &eigenvecs {
308            let dot: f64 = v_init.iter().zip(ev.iter()).map(|(a, b)| a * b).sum();
309            for (vi, ei) in v_init.iter_mut().zip(ev.iter()) {
310                *vi -= dot * ei;
311            }
312        }
313        l2_normalize_inplace(&mut v_init);
314        let (lambda, eigvec) = power_iteration(&cov, feat_dim, &v_init, power_iter_steps);
315        let lambda_pos = lambda.max(0.0);
316        deflate(&mut cov, lambda, &eigvec, feat_dim);
317        eigenvecs.push(eigvec);
318        eigenvalues.push(lambda_pos);
319    }
320
321    // Project X_centered onto eigenvecs and whiten.
322    // eigenvecs[k] is a d-dimensional row; projection: z[i, k] = dot(x_centered[i], ev[k])
323    let mut out = vec![0.0_f64; n_samples * n_components];
324    for i in 0..n_samples {
325        let xi = &x_centered[i * feat_dim..(i + 1) * feat_dim];
326        for k in 0..n_components {
327            let dot: f64 = xi.iter().zip(eigenvecs[k].iter()).map(|(a, b)| a * b).sum();
328            out[i * n_components + k] = dot / (eigenvalues[k] + eps).sqrt();
329        }
330    }
331    Ok(out)
332}
333
334// ─── k-means++ initialisation ─────────────────────────────────────────────────
335
336/// D² sampling: returns indices of the initial `k` centroids.
337fn kmeans_pp_init(
338    features: &[f64],
339    n_samples: usize,
340    d: usize,
341    k: usize,
342    rng: &mut LcgRng,
343) -> Vec<usize> {
344    let mut chosen = Vec::with_capacity(k);
345    // First centroid: uniform random.
346    chosen.push(rng.next_usize(n_samples));
347
348    let mut min_sq_dists = vec![f64::MAX; n_samples];
349
350    for c_idx in 1..k {
351        // Update min distances to nearest chosen centroid.
352        let last = chosen[c_idx - 1];
353        let c_row = &features[last * d..(last + 1) * d];
354        for i in 0..n_samples {
355            let xi = &features[i * d..(i + 1) * d];
356            let sq_dist = sq_dist_slices(xi, c_row);
357            if sq_dist < min_sq_dists[i] {
358                min_sq_dists[i] = sq_dist;
359            }
360        }
361        // Weighted random selection proportional to D².
362        let total: f64 = min_sq_dists.iter().sum();
363        if total <= 0.0 {
364            // All points on top of the already-chosen centroids: fall back to random.
365            chosen.push(rng.next_usize(n_samples));
366            continue;
367        }
368        let threshold = rng.next_f32() as f64 * total;
369        let mut cumsum = 0.0_f64;
370        let mut selected = n_samples - 1;
371        for (i, &dist) in min_sq_dists.iter().enumerate() {
372            cumsum += dist;
373            if cumsum >= threshold {
374                selected = i;
375                break;
376            }
377        }
378        chosen.push(selected);
379    }
380    chosen
381}
382
383// ─── k-means internals ────────────────────────────────────────────────────────
384
385/// Squared Euclidean distance between two equal-length slices.
386#[inline]
387fn sq_dist_slices(a: &[f64], b: &[f64]) -> f64 {
388    a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
389}
390
391/// Assign each sample to its nearest centroid.
392/// Returns `(labels, inertia, n_changed)`.
393fn assign_step(
394    features: &[f64],
395    centroids: &[f64],
396    labels: &[usize],
397    n_samples: usize,
398    d: usize,
399    k: usize,
400) -> (Vec<usize>, f64, usize) {
401    let mut new_labels = vec![0_usize; n_samples];
402    let mut inertia = 0.0_f64;
403    let mut n_changed = 0_usize;
404    for i in 0..n_samples {
405        let xi = &features[i * d..(i + 1) * d];
406        let mut best_dist = f64::MAX;
407        let mut best_c = 0_usize;
408        for c in 0..k {
409            let dist = sq_dist_slices(xi, &centroids[c * d..(c + 1) * d]);
410            if dist < best_dist {
411                best_dist = dist;
412                best_c = c;
413            }
414        }
415        new_labels[i] = best_c;
416        inertia += best_dist;
417        if best_c != labels[i] {
418            n_changed += 1;
419        }
420    }
421    (new_labels, inertia, n_changed)
422}
423
424/// Update centroid positions as the mean of assigned samples.
425/// Returns the new centroid matrix and the count per cluster.
426fn update_step(
427    features: &[f64],
428    labels: &[usize],
429    n_samples: usize,
430    d: usize,
431    k: usize,
432) -> (Vec<f64>, Vec<usize>) {
433    let mut centroids = vec![0.0_f64; k * d];
434    let mut counts = vec![0_usize; k];
435    for i in 0..n_samples {
436        let c = labels[i];
437        counts[c] += 1;
438        let xi = &features[i * d..(i + 1) * d];
439        for j in 0..d {
440            centroids[c * d + j] += xi[j];
441        }
442    }
443    for c in 0..k {
444        if counts[c] > 0 {
445            let inv = 1.0 / counts[c] as f64;
446            for j in 0..d {
447                centroids[c * d + j] *= inv;
448            }
449        }
450    }
451    (centroids, counts)
452}
453
454/// Find the index of the largest cluster (by sample count).
455fn largest_cluster(counts: &[usize]) -> usize {
456    counts
457        .iter()
458        .enumerate()
459        .max_by_key(|&(_, &c)| c)
460        .map(|(i, _)| i)
461        .unwrap_or(0)
462}
463
464/// Reassign empty clusters: place an empty cluster's centroid at a random member
465/// of the largest cluster, perturbed slightly, then update the largest cluster
466/// centroid.
467fn reassign_empty_clusters(
468    centroids: &mut [f64],
469    counts: &mut [usize],
470    features: &[f64],
471    labels: &mut [usize],
472    n_samples: usize,
473    d: usize,
474    k: usize,
475    rng: &mut LcgRng,
476) {
477    for c in 0..k {
478        if counts[c] == 0 {
479            let src = largest_cluster(counts);
480            // Pick a random sample from the source cluster.
481            let members: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == src).collect();
482            if members.is_empty() {
483                continue;
484            }
485            let rand_idx = members[rng.next_usize(members.len())];
486            // Perturb the picked sample slightly by ±1e-6 in first dimension.
487            let src_row = &features[rand_idx * d..(rand_idx + 1) * d];
488            for j in 0..d {
489                // Small perturbation alternating sign per dimension.
490                let perturb = 1e-6 * if j % 2 == 0 { 1.0 } else { -1.0 };
491                centroids[c * d + j] = src_row[j] + perturb;
492            }
493            // Also nudge the source centroid slightly.
494            for j in 0..d {
495                let perturb = 1e-6 * if j % 2 == 0 { -1.0 } else { 1.0 };
496                centroids[src * d + j] = features[rand_idx * d + j] + perturb;
497            }
498            counts[c] = 0; // Will be picked up at next assignment.
499        }
500    }
501}
502
503// ─── Public API ───────────────────────────────────────────────────────────────
504
505/// Run k-means clustering (DeepCluster pipeline) on pre-normalised features.
506///
507/// If `config.n_pca_components > 0`, the input features are first whitened via
508/// [`pca_whiten`] before clustering (using a small `eps = 1e-6` for numerical
509/// stability). The returned centroids are in the PCA-whitened space when PCA is
510/// applied.
511///
512/// # Arguments
513/// * `features`  — `[n_samples × feat_dim]` row-major, ideally L2-normalised.
514/// * `n_samples` — number of data points.
515/// * `feat_dim`  — feature dimensionality.
516/// * `config`    — DeepCluster parameters.
517///
518/// # Errors
519/// - [`SslError::EmptyInput`] if `n_samples == 0`.
520/// - [`SslError::InvalidFeatureDim`] if `feat_dim == 0`.
521/// - [`SslError::InvalidParameter`] if `n_clusters == 0` or
522///   `n_clusters > n_samples`.
523/// - [`SslError::DimensionMismatch`] on length mismatch.
524pub fn deep_cluster(
525    features: &[f64],
526    n_samples: usize,
527    feat_dim: usize,
528    config: &DeepClusterConfig,
529) -> SslResult<DeepClusterResult> {
530    // ── Validation ────────────────────────────────────────────────────────────
531    if n_samples == 0 {
532        return Err(SslError::EmptyInput);
533    }
534    if feat_dim == 0 {
535        return Err(SslError::InvalidFeatureDim);
536    }
537    if config.n_clusters == 0 {
538        return Err(SslError::InvalidParameter {
539            name: "n_clusters".to_string(),
540            reason: "must be >= 1".to_string(),
541        });
542    }
543    if config.n_clusters > n_samples {
544        return Err(SslError::InvalidParameter {
545            name: "n_clusters".to_string(),
546            reason: format!(
547                "must be <= n_samples ({n_samples}), got {}",
548                config.n_clusters
549            ),
550        });
551    }
552    if features.len() != n_samples * feat_dim {
553        return Err(SslError::DimensionMismatch {
554            expected: n_samples * feat_dim,
555            got: features.len(),
556        });
557    }
558
559    let mut rng = LcgRng::new(config.seed);
560    let k = config.n_clusters;
561
562    // ── Optional PCA whitening ────────────────────────────────────────────────
563    let (work_features, work_dim) = if config.n_pca_components > 0
564        && config.n_pca_components < feat_dim
565    {
566        let whitened = pca_whiten(features, n_samples, feat_dim, config.n_pca_components, 1e-6)?;
567        let dim = config.n_pca_components;
568        (whitened, dim)
569    } else {
570        (features.to_vec(), feat_dim)
571    };
572
573    // ── k-means++ initialisation ──────────────────────────────────────────────
574    let init_indices = kmeans_pp_init(&work_features, n_samples, work_dim, k, &mut rng);
575    let mut centroids = vec![0.0_f64; k * work_dim];
576    for (c, &idx) in init_indices.iter().enumerate() {
577        centroids[c * work_dim..(c + 1) * work_dim]
578            .copy_from_slice(&work_features[idx * work_dim..(idx + 1) * work_dim]);
579    }
580
581    // ── k-means iterations ────────────────────────────────────────────────────
582    let mut labels = vec![0_usize; n_samples];
583    let mut n_iter = 0_usize;
584    let mut converged = false;
585    let mut final_n_reassignments = n_samples;
586
587    for iter in 0..config.kmeans_max_iter {
588        // Assignment step.
589        let (new_labels, _iter_inertia, n_changed) =
590            assign_step(&work_features, &centroids, &labels, n_samples, work_dim, k);
591        final_n_reassignments = n_changed;
592        labels = new_labels;
593        n_iter = iter + 1;
594
595        // Update step.
596        let (new_centroids, mut counts) =
597            update_step(&work_features, &labels, n_samples, work_dim, k);
598        centroids = new_centroids;
599
600        // Empty-cluster reassignment.
601        if config.reassign_empty {
602            reassign_empty_clusters(
603                &mut centroids,
604                &mut counts,
605                &work_features,
606                &mut labels,
607                n_samples,
608                work_dim,
609                k,
610                &mut rng,
611            );
612        }
613
614        // Convergence check.
615        let frac_changed = n_changed as f64 / n_samples as f64;
616        if frac_changed <= config.kmeans_tol {
617            converged = true;
618            break;
619        }
620    }
621
622    // Final assignment to recompute accurate inertia and empty-cluster count.
623    let (final_labels, final_inertia, final_changed) =
624        assign_step(&work_features, &centroids, &labels, n_samples, work_dim, k);
625    labels = final_labels;
626    // Only update reassignment count for the last pass if we ran at least one iter.
627    if n_iter > 0 {
628        final_n_reassignments = final_changed;
629    }
630
631    let (_, final_counts) = update_step(&work_features, &labels, n_samples, work_dim, k);
632    let empty_clusters = final_counts.iter().filter(|&&c| c == 0).count();
633
634    Ok(DeepClusterResult {
635        labels,
636        centroids,
637        inertia: final_inertia,
638        n_iter,
639        converged,
640        n_reassignments: final_n_reassignments,
641        empty_clusters,
642    })
643}
644
645/// Run DeeperCluster — hierarchical multi-scale clustering.
646///
647/// Applies [`deep_cluster`] independently at each scale in
648/// `config.cluster_scales`, collecting one set of pseudo-labels per scale.
649/// Each scale uses `base_config` except with `n_clusters` overridden to the
650/// scale value. A unique per-scale seed is derived from `base_config.seed`.
651///
652/// # Errors
653/// Propagates all errors from [`deep_cluster`].
654/// - Additionally returns [`SslError::InvalidParameter`] if `cluster_scales` is
655///   empty.
656pub fn deeper_cluster(
657    features: &[f64],
658    n_samples: usize,
659    feat_dim: usize,
660    config: &DeeperClusterConfig,
661) -> SslResult<DeeperClusterResult> {
662    if config.cluster_scales.is_empty() {
663        return Err(SslError::InvalidParameter {
664            name: "cluster_scales".to_string(),
665            reason: "must contain at least one scale".to_string(),
666        });
667    }
668
669    let mut per_scale = Vec::with_capacity(config.cluster_scales.len());
670    let mut multi_labels = Vec::with_capacity(config.cluster_scales.len());
671
672    for (scale_idx, &n_clusters) in config.cluster_scales.iter().enumerate() {
673        // Derive a unique seed per scale by mixing base seed with scale index.
674        let scale_seed = config
675            .base_config
676            .seed
677            .wrapping_add(scale_idx as u64 * 0x9e37_79b9_7f4a_7c15);
678
679        let scale_config = DeepClusterConfig {
680            n_clusters,
681            n_pca_components: config.base_config.n_pca_components,
682            kmeans_max_iter: config.base_config.kmeans_max_iter,
683            kmeans_tol: config.base_config.kmeans_tol,
684            reassign_empty: config.base_config.reassign_empty,
685            seed: scale_seed,
686        };
687
688        let result = deep_cluster(features, n_samples, feat_dim, &scale_config)?;
689        multi_labels.push(result.labels.clone());
690        per_scale.push(result);
691    }
692
693    Ok(DeeperClusterResult {
694        per_scale,
695        multi_labels,
696    })
697}
698
699// ─── Loss functions ───────────────────────────────────────────────────────────
700
701/// Compute the DeepCluster cross-entropy loss.
702///
703/// The classifier outputs `logits ∈ ℝ^{n × n_clusters}` (unnormalised) and the
704/// pseudo-labels are the cluster assignments from [`deep_cluster`].
705/// Loss = `(1/n) Σ_i −log softmax(logits[i])[pseudo_labels[i]]`.
706///
707/// # Arguments
708/// * `logits`       — `[n_samples × n_clusters]` row-major unnormalised scores.
709/// * `pseudo_labels`— cluster assignment per sample (output of `deep_cluster`).
710/// * `n_samples`    — number of data points.
711/// * `n_clusters`   — number of cluster classes.
712///
713/// # Errors
714/// - [`SslError::EmptyInput`] if `n_samples == 0`.
715/// - [`SslError::NumPrototypesTooSmall`] if `n_clusters < 2`.
716/// - [`SslError::DimensionMismatch`] on length mismatch.
717/// - [`SslError::InvalidParameter`] if any pseudo-label ≥ `n_clusters`.
718/// - [`SslError::NanEncountered`] if the loss is non-finite.
719pub fn deep_cluster_loss(
720    logits: &[f32],
721    pseudo_labels: &[usize],
722    n_samples: usize,
723    n_clusters: usize,
724) -> SslResult<f32> {
725    if n_samples == 0 {
726        return Err(SslError::EmptyInput);
727    }
728    if n_clusters < 2 {
729        return Err(SslError::NumPrototypesTooSmall);
730    }
731    if logits.len() != n_samples * n_clusters {
732        return Err(SslError::DimensionMismatch {
733            expected: n_samples * n_clusters,
734            got: logits.len(),
735        });
736    }
737    if pseudo_labels.len() != n_samples {
738        return Err(SslError::DimensionMismatch {
739            expected: n_samples,
740            got: pseudo_labels.len(),
741        });
742    }
743    for (i, &lbl) in pseudo_labels.iter().enumerate() {
744        if lbl >= n_clusters {
745            return Err(SslError::InvalidParameter {
746                name: format!("pseudo_labels[{i}]"),
747                reason: format!("label {lbl} >= n_clusters {n_clusters}"),
748            });
749        }
750    }
751
752    let mut total_loss = 0.0_f64;
753    for i in 0..n_samples {
754        let row = &logits[i * n_clusters..(i + 1) * n_clusters];
755        // Numerically stable softmax.
756        let max_v = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
757        let mut sum_exp = 0.0_f64;
758        let mut exps = Vec::with_capacity(n_clusters);
759        for &v in row {
760            let e = ((v - max_v) as f64).exp();
761            exps.push(e);
762            sum_exp += e;
763        }
764        let log_sum_exp = sum_exp.max(1e-300).ln();
765        // CE = -(logit[label] - max_v) + log_sum_exp
766        let target_score = (row[pseudo_labels[i]] - max_v) as f64;
767        total_loss += log_sum_exp - target_score;
768    }
769
770    let loss = (total_loss / n_samples as f64) as f32;
771    if !loss.is_finite() {
772        return Err(SslError::NanEncountered {
773            location: "deep_cluster_loss",
774        });
775    }
776    Ok(loss)
777}
778
779// ─── Tests ────────────────────────────────────────────────────────────────────
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784
785    /// Build two clearly separated 2D clusters.
786    /// Cluster 0: n points near (+5, 0); Cluster 1: n points near (-5, 0).
787    fn two_cluster_data(n_per_cluster: usize) -> Vec<f64> {
788        let mut data = Vec::with_capacity(2 * n_per_cluster * 2);
789        for i in 0..n_per_cluster {
790            let offset = (i as f64) * 0.01;
791            data.push(5.0 + offset);
792            data.push(0.0 + offset);
793        }
794        for i in 0..n_per_cluster {
795            let offset = (i as f64) * 0.01;
796            data.push(-5.0 - offset);
797            data.push(0.0 + offset);
798        }
799        data
800    }
801
802    // ── Test 1: k=2 both clusters non-empty ──────────────────────────────────
803    #[test]
804    fn both_clusters_non_empty_on_separated_data() {
805        let n_per = 20_usize;
806        let n = 2 * n_per;
807        let d = 2_usize;
808        let data = two_cluster_data(n_per);
809        let config = DeepClusterConfig {
810            n_clusters: 2,
811            n_pca_components: 0, // skip PCA to keep test simple
812            kmeans_max_iter: 100,
813            kmeans_tol: 1e-5,
814            reassign_empty: true,
815            seed: 7,
816        };
817        let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
818        // Count points per label
819        let mut count = [0_usize; 2];
820        for &l in &result.labels {
821            count[l] += 1;
822        }
823        assert!(count[0] > 0, "cluster 0 should be non-empty");
824        assert!(count[1] > 0, "cluster 1 should be non-empty");
825        assert_eq!(count[0] + count[1], n);
826    }
827
828    // ── Test 2: convergence on easy data ─────────────────────────────────────
829    #[test]
830    fn converges_before_max_iter_on_easy_data() {
831        let n_per = 30_usize;
832        let n = 2 * n_per;
833        let d = 2_usize;
834        let data = two_cluster_data(n_per);
835        let config = DeepClusterConfig {
836            n_clusters: 2,
837            n_pca_components: 0,
838            kmeans_max_iter: 200,
839            kmeans_tol: 1e-3,
840            reassign_empty: true,
841            seed: 13,
842        };
843        let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
844        assert!(
845            result.converged,
846            "should converge; n_iter = {}",
847            result.n_iter
848        );
849        assert!(result.n_iter < 200, "n_iter = {}", result.n_iter);
850    }
851
852    // ── Test 3: labels length == n_samples ───────────────────────────────────
853    #[test]
854    fn labels_length_equals_n_samples() {
855        let n = 50_usize;
856        let d = 4_usize;
857        let features: Vec<f64> = (0..n * d).map(|i| (i as f64) * 0.01).collect();
858        let config = DeepClusterConfig {
859            n_clusters: 5,
860            n_pca_components: 0,
861            kmeans_max_iter: 20,
862            kmeans_tol: 1e-4,
863            reassign_empty: true,
864            seed: 17,
865        };
866        let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
867        assert_eq!(result.labels.len(), n);
868    }
869
870    // ── Test 4: centroids shape == [k * d] ───────────────────────────────────
871    #[test]
872    fn centroids_shape_correct() {
873        let n = 40_usize;
874        let d = 6_usize;
875        let k = 4_usize;
876        let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.17).sin()).collect();
877        let config = DeepClusterConfig {
878            n_clusters: k,
879            n_pca_components: 0,
880            kmeans_max_iter: 30,
881            kmeans_tol: 1e-4,
882            reassign_empty: true,
883            seed: 23,
884        };
885        let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
886        assert_eq!(result.centroids.len(), k * d);
887    }
888
889    // ── Test 5: deep_cluster_loss finite and non-negative ────────────────────
890    #[test]
891    fn loss_finite_and_non_negative() {
892        let n = 8_usize;
893        let k = 4_usize;
894        let logits: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.1).collect();
895        let labels = vec![0_usize, 1, 2, 3, 0, 1, 2, 3];
896        let loss =
897            deep_cluster_loss(&logits, &labels, n, k).expect("deep_cluster_loss should succeed");
898        assert!(loss.is_finite(), "loss = {loss}");
899        assert!(loss >= 0.0, "loss = {loss}");
900    }
901
902    // ── Test 6: uniform logits → loss ≈ ln(k) ────────────────────────────────
903    #[test]
904    fn uniform_logits_give_ln_k_loss() {
905        let n = 16_usize;
906        let k = 8_usize;
907        let logits = vec![0.0_f32; n * k]; // all equal → softmax = 1/k
908        let labels: Vec<usize> = (0..n).map(|i| i % k).collect();
909        let loss =
910            deep_cluster_loss(&logits, &labels, n, k).expect("deep_cluster_loss should succeed");
911        let expected = (k as f32).ln();
912        assert!(
913            (loss - expected).abs() < 1e-4,
914            "loss = {loss}, expected = {expected}"
915        );
916    }
917
918    // ── Test 7: DeeperCluster with 2 scales returns 2 results ────────────────
919    #[test]
920    fn deeper_cluster_two_scales() {
921        let n = 60_usize;
922        let d = 4_usize;
923        let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.23).sin()).collect();
924        let base = DeepClusterConfig {
925            n_clusters: 2, // will be overridden per scale
926            n_pca_components: 0,
927            kmeans_max_iter: 20,
928            kmeans_tol: 1e-3,
929            reassign_empty: true,
930            seed: 31,
931        };
932        let config = DeeperClusterConfig {
933            cluster_scales: vec![2, 3],
934            base_config: base,
935        };
936        let result =
937            deeper_cluster(&features, n, d, &config).expect("deeper_cluster should succeed");
938        assert_eq!(result.per_scale.len(), 2);
939        assert_eq!(result.multi_labels.len(), 2);
940        assert_eq!(result.multi_labels[0].len(), n);
941        assert_eq!(result.multi_labels[1].len(), n);
942        // Cluster counts should match requested scales.
943        for &lbl in &result.multi_labels[0] {
944            assert!(lbl < 2, "scale-0 label {lbl} out of range");
945        }
946        for &lbl in &result.multi_labels[1] {
947            assert!(lbl < 3, "scale-1 label {lbl} out of range");
948        }
949    }
950
951    // ── Test 8: pca_whiten output is approximately whitened ──────────────────
952    #[test]
953    fn pca_whiten_output_unit_variance_columns() {
954        // Create 2D data with variance = [4, 1] (axis-aligned).
955        let n = 200_usize;
956        let d = 2_usize;
957        let mut features = Vec::with_capacity(n * d);
958        for i in 0..n {
959            let t = i as f64;
960            features.push(2.0 * (t * 0.031).sin()); // σ≈√2 in x
961            features.push(1.0 * (t * 0.073).cos()); // σ≈1/√2 in y
962        }
963        let n_comp = 2_usize;
964        let whitened =
965            pca_whiten(&features, n, d, n_comp, 1e-6).expect("pca_whiten should succeed");
966        assert_eq!(whitened.len(), n * n_comp);
967        // Check each column has roughly unit variance.
968        for col in 0..n_comp {
969            let mean: f64 = whitened.iter().skip(col).step_by(n_comp).sum::<f64>() / n as f64;
970            let var: f64 = whitened
971                .iter()
972                .skip(col)
973                .step_by(n_comp)
974                .map(|&v| (v - mean) * (v - mean))
975                .sum::<f64>()
976                / (n as f64 - 1.0);
977            assert!(
978                var > 0.0 && var.is_finite(),
979                "col {col} variance = {var} should be finite and positive"
980            );
981        }
982    }
983
984    // ── Test 9: empty cluster reassignment doesn't crash ─────────────────────
985    #[test]
986    fn empty_cluster_reassignment_does_not_crash() {
987        // Duplicate data — guaranteed empty clusters with many k.
988        let n = 10_usize;
989        let d = 2_usize;
990        // All points at same location → lots of empty clusters.
991        let features = vec![1.0_f64; n * d];
992        let config = DeepClusterConfig {
993            n_clusters: 5,
994            n_pca_components: 0,
995            kmeans_max_iter: 10,
996            kmeans_tol: 0.0, // always run max_iter
997            reassign_empty: true,
998            seed: 37,
999        };
1000        // Should complete without panic.
1001        let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
1002        assert_eq!(result.labels.len(), n);
1003    }
1004
1005    // ── Test 10: n_clusters > n_samples → error ───────────────────────────────
1006    #[test]
1007    fn error_on_more_clusters_than_samples() {
1008        let n = 5_usize;
1009        let d = 2_usize;
1010        let features = vec![1.0_f64; n * d];
1011        let config = DeepClusterConfig {
1012            n_clusters: 10, // > n
1013            n_pca_components: 0,
1014            kmeans_max_iter: 10,
1015            kmeans_tol: 1e-4,
1016            reassign_empty: true,
1017            seed: 41,
1018        };
1019        assert!(deep_cluster(&features, n, d, &config).is_err());
1020    }
1021
1022    // ── Test 11: n_clusters = 0 → error ──────────────────────────────────────
1023    #[test]
1024    fn error_on_zero_clusters() {
1025        let result = DeepClusterConfig::new(0, 0, 10, 1e-4, true, 42);
1026        assert!(result.is_err(), "n_clusters=0 should return an error");
1027    }
1028
1029    // ── Test 12: inertia non-negative and finite ──────────────────────────────
1030    #[test]
1031    fn inertia_non_negative_and_finite() {
1032        let n = 50_usize;
1033        let d = 3_usize;
1034        let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.11).sin()).collect();
1035        let config = DeepClusterConfig {
1036            n_clusters: 5,
1037            n_pca_components: 0,
1038            kmeans_max_iter: 50,
1039            kmeans_tol: 1e-4,
1040            reassign_empty: true,
1041            seed: 53,
1042        };
1043        let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
1044        assert!(result.inertia.is_finite(), "inertia = {}", result.inertia);
1045        assert!(result.inertia >= 0.0, "inertia = {}", result.inertia);
1046    }
1047
1048    // ── Test 13: converged=true when data is already clustered ────────────────
1049    #[test]
1050    fn converged_true_when_stable() {
1051        let n_per = 20_usize;
1052        let n = 2 * n_per;
1053        let d = 2_usize;
1054        let data = two_cluster_data(n_per);
1055        let config = DeepClusterConfig {
1056            n_clusters: 2,
1057            n_pca_components: 0,
1058            kmeans_max_iter: 500,
1059            kmeans_tol: 0.01, // 1% tolerance — well-separated data converges easily
1060            reassign_empty: true,
1061            seed: 61,
1062        };
1063        let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
1064        assert!(result.converged, "should have converged");
1065    }
1066
1067    // ── Test 14: loss rejects invalid label ───────────────────────────────────
1068    #[test]
1069    fn loss_rejects_out_of_range_label() {
1070        let n = 4_usize;
1071        let k = 3_usize;
1072        let logits = vec![0.0_f32; n * k];
1073        let labels = vec![0_usize, 1, 2, 3]; // 3 >= k=3 → invalid
1074        assert!(deep_cluster_loss(&logits, &labels, n, k).is_err());
1075    }
1076
1077    // ── Test 15: pca_whiten rejects bad n_components ──────────────────────────
1078    #[test]
1079    fn pca_whiten_rejects_invalid_n_components() {
1080        let n = 10_usize;
1081        let d = 4_usize;
1082        let features = vec![1.0_f64; n * d];
1083        // n_components == 0
1084        assert!(pca_whiten(&features, n, d, 0, 1e-6).is_err());
1085        // n_components > feat_dim
1086        assert!(pca_whiten(&features, n, d, d + 1, 1e-6).is_err());
1087    }
1088
1089    // ── Test 16: deep_cluster_loss with k=1 → error ───────────────────────────
1090    #[test]
1091    fn loss_rejects_single_cluster() {
1092        let logits = vec![1.0_f32; 4];
1093        let labels = vec![0_usize; 4];
1094        assert!(deep_cluster_loss(&logits, &labels, 4, 1).is_err());
1095    }
1096}