use crate::{
error::{SslError, SslResult},
handle::LcgRng,
};
#[derive(Debug, Clone)]
pub struct DeepClusterConfig {
pub n_clusters: usize,
pub n_pca_components: usize,
pub kmeans_max_iter: usize,
pub kmeans_tol: f64,
pub reassign_empty: bool,
pub seed: u64,
}
impl Default for DeepClusterConfig {
fn default() -> Self {
Self {
n_clusters: 1000,
n_pca_components: 256,
kmeans_max_iter: 100,
kmeans_tol: 1e-4,
reassign_empty: true,
seed: 42,
}
}
}
impl DeepClusterConfig {
pub fn new(
n_clusters: usize,
n_pca_components: usize,
kmeans_max_iter: usize,
kmeans_tol: f64,
reassign_empty: bool,
seed: u64,
) -> SslResult<Self> {
if n_clusters == 0 {
return Err(SslError::InvalidParameter {
name: "n_clusters".to_string(),
reason: "must be >= 1".to_string(),
});
}
if kmeans_max_iter == 0 {
return Err(SslError::InvalidParameter {
name: "kmeans_max_iter".to_string(),
reason: "must be >= 1".to_string(),
});
}
Ok(Self {
n_clusters,
n_pca_components,
kmeans_max_iter,
kmeans_tol,
reassign_empty,
seed,
})
}
}
#[derive(Debug, Clone)]
pub struct DeepClusterResult {
pub labels: Vec<usize>,
pub centroids: Vec<f64>,
pub inertia: f64,
pub n_iter: usize,
pub converged: bool,
pub n_reassignments: usize,
pub empty_clusters: usize,
}
#[derive(Debug, Clone)]
pub struct DeeperClusterConfig {
pub cluster_scales: Vec<usize>,
pub base_config: DeepClusterConfig,
}
impl Default for DeeperClusterConfig {
fn default() -> Self {
Self {
cluster_scales: vec![100, 1000],
base_config: DeepClusterConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct DeeperClusterResult {
pub per_scale: Vec<DeepClusterResult>,
pub multi_labels: Vec<Vec<usize>>,
}
fn compute_covariance(x_centered: &[f64], n: usize, d: usize) -> Vec<f64> {
let mut cov = vec![0.0_f64; d * d];
let inv_n = 1.0 / (n as f64 - 1.0).max(1.0);
for row in 0..n {
let xi = &x_centered[row * d..(row + 1) * d];
for i in 0..d {
for j in i..d {
cov[i * d + j] += xi[i] * xi[j] * inv_n;
}
}
}
for i in 0..d {
for j in 0..i {
cov[i * d + j] = cov[j * d + i];
}
}
cov
}
#[inline]
fn matvec(a: &[f64], v: &[f64], out: &mut [f64], d: usize) {
for i in 0..d {
let mut acc = 0.0_f64;
for j in 0..d {
acc += a[i * d + j] * v[j];
}
out[i] = acc;
}
}
fn l2_normalize_inplace(v: &mut [f64]) -> f64 {
let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-12 {
for x in v.iter_mut() {
*x /= norm;
}
}
norm
}
#[inline]
fn l2_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn power_iteration(cov: &[f64], d: usize, init_vec: &[f64], n_iter: usize) -> (f64, Vec<f64>) {
let mut v = init_vec.to_vec();
l2_normalize_inplace(&mut v);
let mut av = vec![0.0_f64; d];
let mut eigenvalue = 0.0_f64;
for _ in 0..n_iter {
matvec(cov, &v, &mut av, d);
eigenvalue = av.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
let norm = l2_norm(&av);
if norm < 1e-14 {
break;
}
for i in 0..d {
v[i] = av[i] / norm;
}
}
(eigenvalue, v)
}
fn deflate(cov: &mut [f64], eigenvalue: f64, eigenvec: &[f64], d: usize) {
for i in 0..d {
for j in 0..d {
cov[i * d + j] -= eigenvalue * eigenvec[i] * eigenvec[j];
}
}
}
pub fn pca_whiten(
features: &[f64],
n_samples: usize,
feat_dim: usize,
n_components: usize,
eps: f64,
) -> SslResult<Vec<f64>> {
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
if feat_dim == 0 {
return Err(SslError::InvalidFeatureDim);
}
if n_components == 0 || n_components > feat_dim {
return Err(SslError::InvalidParameter {
name: "n_components".to_string(),
reason: format!("must be in [1, feat_dim={feat_dim}]"),
});
}
if features.len() != n_samples * feat_dim {
return Err(SslError::DimensionMismatch {
expected: n_samples * feat_dim,
got: features.len(),
});
}
let mut mean = vec![0.0_f64; feat_dim];
for i in 0..n_samples {
for j in 0..feat_dim {
mean[j] += features[i * feat_dim + j];
}
}
let inv_n = 1.0 / n_samples as f64;
for m in mean.iter_mut() {
*m *= inv_n;
}
let mut x_centered = features.to_vec();
for i in 0..n_samples {
for j in 0..feat_dim {
x_centered[i * feat_dim + j] -= mean[j];
}
}
let mut cov = compute_covariance(&x_centered, n_samples, feat_dim);
let power_iter_steps = 30_usize.max(n_components * 2);
let mut eigenvecs: Vec<Vec<f64>> = Vec::with_capacity(n_components);
let mut eigenvalues: Vec<f64> = Vec::with_capacity(n_components);
let mut init = vec![0.0_f64; feat_dim];
for (i, v) in init.iter_mut().enumerate() {
*v = ((i as f64 + 1.0) * 0.618_033_988).fract() * 2.0 - 1.0;
}
for k in 0..n_components {
let perturb = (k as f64 + 1.0) * 0.01;
let mut v_init: Vec<f64> = init
.iter()
.enumerate()
.map(|(i, &v)| v + perturb * ((i as f64 + k as f64 * 17.0).sin()))
.collect();
for ev in &eigenvecs {
let dot: f64 = v_init.iter().zip(ev.iter()).map(|(a, b)| a * b).sum();
for (vi, ei) in v_init.iter_mut().zip(ev.iter()) {
*vi -= dot * ei;
}
}
l2_normalize_inplace(&mut v_init);
let (lambda, eigvec) = power_iteration(&cov, feat_dim, &v_init, power_iter_steps);
let lambda_pos = lambda.max(0.0);
deflate(&mut cov, lambda, &eigvec, feat_dim);
eigenvecs.push(eigvec);
eigenvalues.push(lambda_pos);
}
let mut out = vec![0.0_f64; n_samples * n_components];
for i in 0..n_samples {
let xi = &x_centered[i * feat_dim..(i + 1) * feat_dim];
for k in 0..n_components {
let dot: f64 = xi.iter().zip(eigenvecs[k].iter()).map(|(a, b)| a * b).sum();
out[i * n_components + k] = dot / (eigenvalues[k] + eps).sqrt();
}
}
Ok(out)
}
fn kmeans_pp_init(
features: &[f64],
n_samples: usize,
d: usize,
k: usize,
rng: &mut LcgRng,
) -> Vec<usize> {
let mut chosen = Vec::with_capacity(k);
chosen.push(rng.next_usize(n_samples));
let mut min_sq_dists = vec![f64::MAX; n_samples];
for c_idx in 1..k {
let last = chosen[c_idx - 1];
let c_row = &features[last * d..(last + 1) * d];
for i in 0..n_samples {
let xi = &features[i * d..(i + 1) * d];
let sq_dist = sq_dist_slices(xi, c_row);
if sq_dist < min_sq_dists[i] {
min_sq_dists[i] = sq_dist;
}
}
let total: f64 = min_sq_dists.iter().sum();
if total <= 0.0 {
chosen.push(rng.next_usize(n_samples));
continue;
}
let threshold = rng.next_f32() as f64 * total;
let mut cumsum = 0.0_f64;
let mut selected = n_samples - 1;
for (i, &dist) in min_sq_dists.iter().enumerate() {
cumsum += dist;
if cumsum >= threshold {
selected = i;
break;
}
}
chosen.push(selected);
}
chosen
}
#[inline]
fn sq_dist_slices(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
}
fn assign_step(
features: &[f64],
centroids: &[f64],
labels: &[usize],
n_samples: usize,
d: usize,
k: usize,
) -> (Vec<usize>, f64, usize) {
let mut new_labels = vec![0_usize; n_samples];
let mut inertia = 0.0_f64;
let mut n_changed = 0_usize;
for i in 0..n_samples {
let xi = &features[i * d..(i + 1) * d];
let mut best_dist = f64::MAX;
let mut best_c = 0_usize;
for c in 0..k {
let dist = sq_dist_slices(xi, ¢roids[c * d..(c + 1) * d]);
if dist < best_dist {
best_dist = dist;
best_c = c;
}
}
new_labels[i] = best_c;
inertia += best_dist;
if best_c != labels[i] {
n_changed += 1;
}
}
(new_labels, inertia, n_changed)
}
fn update_step(
features: &[f64],
labels: &[usize],
n_samples: usize,
d: usize,
k: usize,
) -> (Vec<f64>, Vec<usize>) {
let mut centroids = vec![0.0_f64; k * d];
let mut counts = vec![0_usize; k];
for i in 0..n_samples {
let c = labels[i];
counts[c] += 1;
let xi = &features[i * d..(i + 1) * d];
for j in 0..d {
centroids[c * d + j] += xi[j];
}
}
for c in 0..k {
if counts[c] > 0 {
let inv = 1.0 / counts[c] as f64;
for j in 0..d {
centroids[c * d + j] *= inv;
}
}
}
(centroids, counts)
}
fn largest_cluster(counts: &[usize]) -> usize {
counts
.iter()
.enumerate()
.max_by_key(|&(_, &c)| c)
.map(|(i, _)| i)
.unwrap_or(0)
}
fn reassign_empty_clusters(
centroids: &mut [f64],
counts: &mut [usize],
features: &[f64],
labels: &mut [usize],
n_samples: usize,
d: usize,
k: usize,
rng: &mut LcgRng,
) {
for c in 0..k {
if counts[c] == 0 {
let src = largest_cluster(counts);
let members: Vec<usize> = (0..n_samples).filter(|&i| labels[i] == src).collect();
if members.is_empty() {
continue;
}
let rand_idx = members[rng.next_usize(members.len())];
let src_row = &features[rand_idx * d..(rand_idx + 1) * d];
for j in 0..d {
let perturb = 1e-6 * if j % 2 == 0 { 1.0 } else { -1.0 };
centroids[c * d + j] = src_row[j] + perturb;
}
for j in 0..d {
let perturb = 1e-6 * if j % 2 == 0 { -1.0 } else { 1.0 };
centroids[src * d + j] = features[rand_idx * d + j] + perturb;
}
counts[c] = 0; }
}
}
pub fn deep_cluster(
features: &[f64],
n_samples: usize,
feat_dim: usize,
config: &DeepClusterConfig,
) -> SslResult<DeepClusterResult> {
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
if feat_dim == 0 {
return Err(SslError::InvalidFeatureDim);
}
if config.n_clusters == 0 {
return Err(SslError::InvalidParameter {
name: "n_clusters".to_string(),
reason: "must be >= 1".to_string(),
});
}
if config.n_clusters > n_samples {
return Err(SslError::InvalidParameter {
name: "n_clusters".to_string(),
reason: format!(
"must be <= n_samples ({n_samples}), got {}",
config.n_clusters
),
});
}
if features.len() != n_samples * feat_dim {
return Err(SslError::DimensionMismatch {
expected: n_samples * feat_dim,
got: features.len(),
});
}
let mut rng = LcgRng::new(config.seed);
let k = config.n_clusters;
let (work_features, work_dim) = if config.n_pca_components > 0
&& config.n_pca_components < feat_dim
{
let whitened = pca_whiten(features, n_samples, feat_dim, config.n_pca_components, 1e-6)?;
let dim = config.n_pca_components;
(whitened, dim)
} else {
(features.to_vec(), feat_dim)
};
let init_indices = kmeans_pp_init(&work_features, n_samples, work_dim, k, &mut rng);
let mut centroids = vec![0.0_f64; k * work_dim];
for (c, &idx) in init_indices.iter().enumerate() {
centroids[c * work_dim..(c + 1) * work_dim]
.copy_from_slice(&work_features[idx * work_dim..(idx + 1) * work_dim]);
}
let mut labels = vec![0_usize; n_samples];
let mut n_iter = 0_usize;
let mut converged = false;
let mut final_n_reassignments = n_samples;
for iter in 0..config.kmeans_max_iter {
let (new_labels, _iter_inertia, n_changed) =
assign_step(&work_features, ¢roids, &labels, n_samples, work_dim, k);
final_n_reassignments = n_changed;
labels = new_labels;
n_iter = iter + 1;
let (new_centroids, mut counts) =
update_step(&work_features, &labels, n_samples, work_dim, k);
centroids = new_centroids;
if config.reassign_empty {
reassign_empty_clusters(
&mut centroids,
&mut counts,
&work_features,
&mut labels,
n_samples,
work_dim,
k,
&mut rng,
);
}
let frac_changed = n_changed as f64 / n_samples as f64;
if frac_changed <= config.kmeans_tol {
converged = true;
break;
}
}
let (final_labels, final_inertia, final_changed) =
assign_step(&work_features, ¢roids, &labels, n_samples, work_dim, k);
labels = final_labels;
if n_iter > 0 {
final_n_reassignments = final_changed;
}
let (_, final_counts) = update_step(&work_features, &labels, n_samples, work_dim, k);
let empty_clusters = final_counts.iter().filter(|&&c| c == 0).count();
Ok(DeepClusterResult {
labels,
centroids,
inertia: final_inertia,
n_iter,
converged,
n_reassignments: final_n_reassignments,
empty_clusters,
})
}
pub fn deeper_cluster(
features: &[f64],
n_samples: usize,
feat_dim: usize,
config: &DeeperClusterConfig,
) -> SslResult<DeeperClusterResult> {
if config.cluster_scales.is_empty() {
return Err(SslError::InvalidParameter {
name: "cluster_scales".to_string(),
reason: "must contain at least one scale".to_string(),
});
}
let mut per_scale = Vec::with_capacity(config.cluster_scales.len());
let mut multi_labels = Vec::with_capacity(config.cluster_scales.len());
for (scale_idx, &n_clusters) in config.cluster_scales.iter().enumerate() {
let scale_seed = config
.base_config
.seed
.wrapping_add(scale_idx as u64 * 0x9e37_79b9_7f4a_7c15);
let scale_config = DeepClusterConfig {
n_clusters,
n_pca_components: config.base_config.n_pca_components,
kmeans_max_iter: config.base_config.kmeans_max_iter,
kmeans_tol: config.base_config.kmeans_tol,
reassign_empty: config.base_config.reassign_empty,
seed: scale_seed,
};
let result = deep_cluster(features, n_samples, feat_dim, &scale_config)?;
multi_labels.push(result.labels.clone());
per_scale.push(result);
}
Ok(DeeperClusterResult {
per_scale,
multi_labels,
})
}
pub fn deep_cluster_loss(
logits: &[f32],
pseudo_labels: &[usize],
n_samples: usize,
n_clusters: usize,
) -> SslResult<f32> {
if n_samples == 0 {
return Err(SslError::EmptyInput);
}
if n_clusters < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
if logits.len() != n_samples * n_clusters {
return Err(SslError::DimensionMismatch {
expected: n_samples * n_clusters,
got: logits.len(),
});
}
if pseudo_labels.len() != n_samples {
return Err(SslError::DimensionMismatch {
expected: n_samples,
got: pseudo_labels.len(),
});
}
for (i, &lbl) in pseudo_labels.iter().enumerate() {
if lbl >= n_clusters {
return Err(SslError::InvalidParameter {
name: format!("pseudo_labels[{i}]"),
reason: format!("label {lbl} >= n_clusters {n_clusters}"),
});
}
}
let mut total_loss = 0.0_f64;
for i in 0..n_samples {
let row = &logits[i * n_clusters..(i + 1) * n_clusters];
let max_v = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum_exp = 0.0_f64;
let mut exps = Vec::with_capacity(n_clusters);
for &v in row {
let e = ((v - max_v) as f64).exp();
exps.push(e);
sum_exp += e;
}
let log_sum_exp = sum_exp.max(1e-300).ln();
let target_score = (row[pseudo_labels[i]] - max_v) as f64;
total_loss += log_sum_exp - target_score;
}
let loss = (total_loss / n_samples as f64) as f32;
if !loss.is_finite() {
return Err(SslError::NanEncountered {
location: "deep_cluster_loss",
});
}
Ok(loss)
}
#[cfg(test)]
mod tests {
use super::*;
fn two_cluster_data(n_per_cluster: usize) -> Vec<f64> {
let mut data = Vec::with_capacity(2 * n_per_cluster * 2);
for i in 0..n_per_cluster {
let offset = (i as f64) * 0.01;
data.push(5.0 + offset);
data.push(0.0 + offset);
}
for i in 0..n_per_cluster {
let offset = (i as f64) * 0.01;
data.push(-5.0 - offset);
data.push(0.0 + offset);
}
data
}
#[test]
fn both_clusters_non_empty_on_separated_data() {
let n_per = 20_usize;
let n = 2 * n_per;
let d = 2_usize;
let data = two_cluster_data(n_per);
let config = DeepClusterConfig {
n_clusters: 2,
n_pca_components: 0, kmeans_max_iter: 100,
kmeans_tol: 1e-5,
reassign_empty: true,
seed: 7,
};
let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
let mut count = [0_usize; 2];
for &l in &result.labels {
count[l] += 1;
}
assert!(count[0] > 0, "cluster 0 should be non-empty");
assert!(count[1] > 0, "cluster 1 should be non-empty");
assert_eq!(count[0] + count[1], n);
}
#[test]
fn converges_before_max_iter_on_easy_data() {
let n_per = 30_usize;
let n = 2 * n_per;
let d = 2_usize;
let data = two_cluster_data(n_per);
let config = DeepClusterConfig {
n_clusters: 2,
n_pca_components: 0,
kmeans_max_iter: 200,
kmeans_tol: 1e-3,
reassign_empty: true,
seed: 13,
};
let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
assert!(
result.converged,
"should converge; n_iter = {}",
result.n_iter
);
assert!(result.n_iter < 200, "n_iter = {}", result.n_iter);
}
#[test]
fn labels_length_equals_n_samples() {
let n = 50_usize;
let d = 4_usize;
let features: Vec<f64> = (0..n * d).map(|i| (i as f64) * 0.01).collect();
let config = DeepClusterConfig {
n_clusters: 5,
n_pca_components: 0,
kmeans_max_iter: 20,
kmeans_tol: 1e-4,
reassign_empty: true,
seed: 17,
};
let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
assert_eq!(result.labels.len(), n);
}
#[test]
fn centroids_shape_correct() {
let n = 40_usize;
let d = 6_usize;
let k = 4_usize;
let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.17).sin()).collect();
let config = DeepClusterConfig {
n_clusters: k,
n_pca_components: 0,
kmeans_max_iter: 30,
kmeans_tol: 1e-4,
reassign_empty: true,
seed: 23,
};
let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
assert_eq!(result.centroids.len(), k * d);
}
#[test]
fn loss_finite_and_non_negative() {
let n = 8_usize;
let k = 4_usize;
let logits: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.1).collect();
let labels = vec![0_usize, 1, 2, 3, 0, 1, 2, 3];
let loss =
deep_cluster_loss(&logits, &labels, n, k).expect("deep_cluster_loss should succeed");
assert!(loss.is_finite(), "loss = {loss}");
assert!(loss >= 0.0, "loss = {loss}");
}
#[test]
fn uniform_logits_give_ln_k_loss() {
let n = 16_usize;
let k = 8_usize;
let logits = vec![0.0_f32; n * k]; let labels: Vec<usize> = (0..n).map(|i| i % k).collect();
let loss =
deep_cluster_loss(&logits, &labels, n, k).expect("deep_cluster_loss should succeed");
let expected = (k as f32).ln();
assert!(
(loss - expected).abs() < 1e-4,
"loss = {loss}, expected = {expected}"
);
}
#[test]
fn deeper_cluster_two_scales() {
let n = 60_usize;
let d = 4_usize;
let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.23).sin()).collect();
let base = DeepClusterConfig {
n_clusters: 2, n_pca_components: 0,
kmeans_max_iter: 20,
kmeans_tol: 1e-3,
reassign_empty: true,
seed: 31,
};
let config = DeeperClusterConfig {
cluster_scales: vec![2, 3],
base_config: base,
};
let result =
deeper_cluster(&features, n, d, &config).expect("deeper_cluster should succeed");
assert_eq!(result.per_scale.len(), 2);
assert_eq!(result.multi_labels.len(), 2);
assert_eq!(result.multi_labels[0].len(), n);
assert_eq!(result.multi_labels[1].len(), n);
for &lbl in &result.multi_labels[0] {
assert!(lbl < 2, "scale-0 label {lbl} out of range");
}
for &lbl in &result.multi_labels[1] {
assert!(lbl < 3, "scale-1 label {lbl} out of range");
}
}
#[test]
fn pca_whiten_output_unit_variance_columns() {
let n = 200_usize;
let d = 2_usize;
let mut features = Vec::with_capacity(n * d);
for i in 0..n {
let t = i as f64;
features.push(2.0 * (t * 0.031).sin()); features.push(1.0 * (t * 0.073).cos()); }
let n_comp = 2_usize;
let whitened =
pca_whiten(&features, n, d, n_comp, 1e-6).expect("pca_whiten should succeed");
assert_eq!(whitened.len(), n * n_comp);
for col in 0..n_comp {
let mean: f64 = whitened.iter().skip(col).step_by(n_comp).sum::<f64>() / n as f64;
let var: f64 = whitened
.iter()
.skip(col)
.step_by(n_comp)
.map(|&v| (v - mean) * (v - mean))
.sum::<f64>()
/ (n as f64 - 1.0);
assert!(
var > 0.0 && var.is_finite(),
"col {col} variance = {var} should be finite and positive"
);
}
}
#[test]
fn empty_cluster_reassignment_does_not_crash() {
let n = 10_usize;
let d = 2_usize;
let features = vec![1.0_f64; n * d];
let config = DeepClusterConfig {
n_clusters: 5,
n_pca_components: 0,
kmeans_max_iter: 10,
kmeans_tol: 0.0, reassign_empty: true,
seed: 37,
};
let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
assert_eq!(result.labels.len(), n);
}
#[test]
fn error_on_more_clusters_than_samples() {
let n = 5_usize;
let d = 2_usize;
let features = vec![1.0_f64; n * d];
let config = DeepClusterConfig {
n_clusters: 10, n_pca_components: 0,
kmeans_max_iter: 10,
kmeans_tol: 1e-4,
reassign_empty: true,
seed: 41,
};
assert!(deep_cluster(&features, n, d, &config).is_err());
}
#[test]
fn error_on_zero_clusters() {
let result = DeepClusterConfig::new(0, 0, 10, 1e-4, true, 42);
assert!(result.is_err(), "n_clusters=0 should return an error");
}
#[test]
fn inertia_non_negative_and_finite() {
let n = 50_usize;
let d = 3_usize;
let features: Vec<f64> = (0..n * d).map(|i| ((i as f64) * 0.11).sin()).collect();
let config = DeepClusterConfig {
n_clusters: 5,
n_pca_components: 0,
kmeans_max_iter: 50,
kmeans_tol: 1e-4,
reassign_empty: true,
seed: 53,
};
let result = deep_cluster(&features, n, d, &config).expect("deep_cluster should succeed");
assert!(result.inertia.is_finite(), "inertia = {}", result.inertia);
assert!(result.inertia >= 0.0, "inertia = {}", result.inertia);
}
#[test]
fn converged_true_when_stable() {
let n_per = 20_usize;
let n = 2 * n_per;
let d = 2_usize;
let data = two_cluster_data(n_per);
let config = DeepClusterConfig {
n_clusters: 2,
n_pca_components: 0,
kmeans_max_iter: 500,
kmeans_tol: 0.01, reassign_empty: true,
seed: 61,
};
let result = deep_cluster(&data, n, d, &config).expect("deep_cluster should succeed");
assert!(result.converged, "should have converged");
}
#[test]
fn loss_rejects_out_of_range_label() {
let n = 4_usize;
let k = 3_usize;
let logits = vec![0.0_f32; n * k];
let labels = vec![0_usize, 1, 2, 3]; assert!(deep_cluster_loss(&logits, &labels, n, k).is_err());
}
#[test]
fn pca_whiten_rejects_invalid_n_components() {
let n = 10_usize;
let d = 4_usize;
let features = vec![1.0_f64; n * d];
assert!(pca_whiten(&features, n, d, 0, 1e-6).is_err());
assert!(pca_whiten(&features, n, d, d + 1, 1e-6).is_err());
}
#[test]
fn loss_rejects_single_cluster() {
let logits = vec![1.0_f32; 4];
let labels = vec![0_usize; 4];
assert!(deep_cluster_loss(&logits, &labels, 4, 1).is_err());
}
}