use crate::error::{SslError, SslResult};
fn l2_normalise_rows(z: &mut [f32], n: usize, d: usize) {
for i in 0..n {
let row = &mut z[i * d..(i + 1) * d];
let norm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 1e-12 {
for v in row.iter_mut() {
*v /= norm;
}
}
}
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn uniformity_loss(z: &[f32], n: usize, d: usize) -> SslResult<f32> {
if n < 2 || d == 0 {
return Err(SslError::EmptyInput);
}
if z.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z.len(),
});
}
let mut z_hat = z.to_vec();
l2_normalise_rows(&mut z_hat, n, d);
let mut kernel_sum = 0.0_f64;
let num_pairs = n * (n - 1) / 2;
for i in 0..n {
let zi = &z_hat[i * d..(i + 1) * d];
for j in (i + 1)..n {
let zj = &z_hat[j * d..(j + 1) * d];
let cos = dot(zi, zj);
let cos_clamped = cos.clamp(-1.0, 1.0);
let dist_sq = 2.0_f64 * (1.0 - cos_clamped as f64);
kernel_sum += (-2.0 * dist_sq).exp();
}
}
let mean_kernel = kernel_sum / num_pairs as f64;
Ok(mean_kernel.ln() as f32)
}
pub fn alignment_loss(z1: &[f32], z2: &[f32], n: usize, d: usize, alpha: f32) -> SslResult<f32> {
if n == 0 || d == 0 {
return Err(SslError::EmptyInput);
}
if z1.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z1.len(),
});
}
if z2.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z2.len(),
});
}
if !alpha.is_finite() || alpha <= 0.0 {
return Err(SslError::InvalidParameter {
name: "alpha".into(),
reason: "must be finite and > 0".into(),
});
}
let mut z1_hat = z1.to_vec();
let mut z2_hat = z2.to_vec();
l2_normalise_rows(&mut z1_hat, n, d);
l2_normalise_rows(&mut z2_hat, n, d);
let half_alpha = alpha / 2.0;
let mut total = 0.0_f64;
for i in 0..n {
let a = &z1_hat[i * d..(i + 1) * d];
let b = &z2_hat[i * d..(i + 1) * d];
let cos = dot(a, b).clamp(-1.0, 1.0);
let dist_sq = 2.0_f64 * (1.0 - cos as f64);
total += dist_sq.powf(half_alpha as f64);
}
Ok((total / n as f64) as f32)
}
pub fn effective_rank(z: &[f32], n: usize, d: usize) -> SslResult<f32> {
if n < 2 || d == 0 {
return Err(SslError::EmptyInput);
}
if z.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z.len(),
});
}
let mut z_hat = z.to_vec();
l2_normalise_rows(&mut z_hat, n, d);
let mut col_mean = vec![0.0_f64; d];
for i in 0..n {
for j in 0..d {
col_mean[j] += z_hat[i * d + j] as f64;
}
}
for m in col_mean.iter_mut() {
*m /= n as f64;
}
let mut col_var = vec![0.0_f64; d];
for i in 0..n {
for j in 0..d {
let diff = z_hat[i * d + j] as f64 - col_mean[j];
col_var[j] += diff * diff;
}
}
for v in col_var.iter_mut() {
*v /= n as f64;
}
let total_var: f64 = col_var.iter().sum();
if total_var < 1e-30 {
return Err(SslError::InvalidParameter {
name: "z".into(),
reason: "total column variance is zero; features appear to be all-zero".into(),
});
}
let mut entropy = 0.0_f64;
for &v in col_var.iter() {
if v > 0.0 {
let p = v / total_var;
entropy -= p * p.ln();
}
}
Ok(entropy.exp() as f32)
}
pub fn collapse_score(z1: &[f32], z2: &[f32], n: usize, d: usize) -> SslResult<f32> {
let u = uniformity_loss(z1, n, d)?;
let a = alignment_loss(z1, z2, n, d, 2.0)?;
Ok(a - u)
}
pub fn pairwise_cosine_stats(z: &[f32], n: usize, d: usize) -> SslResult<(f32, f32, f32)> {
if n < 2 || d == 0 {
return Err(SslError::EmptyInput);
}
if z.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z.len(),
});
}
let mut z_hat = z.to_vec();
l2_normalise_rows(&mut z_hat, n, d);
let num_pairs = n * (n - 1) / 2;
let mut cosines = Vec::with_capacity(num_pairs);
let mut max_cos = f32::NEG_INFINITY;
for i in 0..n {
let zi = &z_hat[i * d..(i + 1) * d];
for j in (i + 1)..n {
let zj = &z_hat[j * d..(j + 1) * d];
let c = dot(zi, zj).clamp(-1.0, 1.0);
if c > max_cos {
max_cos = c;
}
cosines.push(c);
}
}
let mean = cosines.iter().map(|&c| c as f64).sum::<f64>() / num_pairs as f64;
let var = cosines
.iter()
.map(|&c| {
let diff = c as f64 - mean;
diff * diff
})
.sum::<f64>()
/ num_pairs as f64;
let std = var.sqrt();
Ok((mean as f32, std as f32, max_cos))
}
#[cfg(test)]
mod tests {
use super::*;
fn basis(d: usize) -> Vec<f32> {
let mut v = vec![0.0_f32; d * d];
for i in 0..d {
v[i * d + i] = 1.0;
}
v
}
fn all_same(unit: &[f32], n: usize) -> Vec<f32> {
let d = unit.len();
let mut v = Vec::with_capacity(n * d);
for _ in 0..n {
v.extend_from_slice(unit);
}
v
}
#[test]
fn uniformity_perfectly_uniform_sphere() {
let z = vec![
1.0_f32, 0.0, -1.0, 0.0, 0.0, 1.0, 0.0, -1.0, ];
let l = uniformity_loss(&z, 4, 2).expect("uniformity_loss should succeed");
assert!(l.is_finite(), "l = {l}");
assert!(l < 0.0, "expected negative uniformity, got {l}");
}
#[test]
fn uniformity_all_same_point_high() {
let z = all_same(&[1.0_f32, 0.0, 0.0], 6);
let l = uniformity_loss(&z, 6, 3).expect("uniformity_loss should succeed");
assert!((l - 0.0).abs() < 1e-5, "expected ≈ 0, got {l}");
}
#[test]
fn uniformity_two_orthogonal_points() {
let z = vec![1.0_f32, 0.0, 0.0, 1.0];
let l = uniformity_loss(&z, 2, 2).expect("uniformity_loss should succeed");
assert!((l - (-4.0)).abs() < 1e-4, "expected -4, got {l}");
}
#[test]
fn alignment_identical_pairs_zero_loss() {
let z: Vec<f32> = (0..16).map(|i| (i as f32) * 0.3 + 0.1).collect();
let l = alignment_loss(&z, &z, 4, 4, 2.0).expect("alignment_loss should succeed");
assert!(l.abs() < 1e-5, "expected 0, got {l}");
}
#[test]
fn alignment_orthogonal_pairs_max_loss() {
let z1 = vec![1.0_f32, 0.0];
let z2 = vec![0.0_f32, 1.0];
let l = alignment_loss(&z1, &z2, 1, 2, 2.0).expect("alignment_loss should succeed");
assert!((l - 2.0).abs() < 1e-5, "expected 2, got {l}");
}
#[test]
fn alignment_n1_works() {
let z1 = vec![1.0_f32, 0.0, 0.0];
let z2 = vec![0.0_f32, 1.0, 0.0];
let l = alignment_loss(&z1, &z2, 1, 3, 2.0).expect("alignment_loss should succeed");
assert!(l.is_finite() && l >= 0.0, "l = {l}");
}
#[test]
fn alignment_alpha_one_finite() {
let z1 = vec![1.0_f32, 0.0];
let z2 = vec![0.0_f32, 1.0];
let l = alignment_loss(&z1, &z2, 1, 2, 1.0).expect("alignment_loss should succeed");
assert!((l - 2.0_f32.sqrt()).abs() < 1e-4, "l = {l}");
}
#[test]
fn alignment_invalid_alpha_returns_error() {
let z1 = vec![1.0_f32, 0.0];
let z2 = vec![0.0_f32, 1.0];
let r = alignment_loss(&z1, &z2, 1, 2, 0.0);
assert!(
matches!(r, Err(SslError::InvalidParameter { .. })),
"expected InvalidParameter, got {r:?}"
);
}
#[test]
fn effective_rank_uniform_full_rank() {
let d = 8_usize;
let z = basis(d);
let er = effective_rank(&z, d, d).expect("effective_rank should succeed");
assert!((er - d as f32).abs() < 0.5, "expected ≈ {d}, got {er}");
}
#[test]
fn effective_rank_rank1_collapsed() {
let d = 8_usize;
let n = 16_usize;
let mut z_collapsed = vec![0.0_f32; n * d];
for i in 0..n {
z_collapsed[i * d] = 1.0 + (i as f32) * 0.1; }
assert!(
matches!(
effective_rank(&z_collapsed, n, d),
Err(SslError::InvalidParameter { .. })
),
"expected InvalidParameter for completely collapsed input"
);
let mut z_near_rank1 = vec![0.0_f32; n * d];
for i in 0..n {
z_near_rank1[i * d] = 1.0;
z_near_rank1[i * d + 1] = (i as f32) * 0.001; }
let er = effective_rank(&z_near_rank1, n, d).expect("effective_rank should succeed");
assert!(er < 2.0, "expected eff_rank ≈ 1, got {er}");
assert!(er >= 1.0, "eff_rank must be >= 1, got {er}");
}
#[test]
fn effective_rank_in_range() {
let d = 4_usize;
let n = 8_usize;
let mut state: u64 = 0xdead_beef_cafe;
let z: Vec<f32> = (0..n * d)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
})
.collect();
let er = effective_rank(&z, n, d).expect("effective_rank should succeed");
assert!(er >= 1.0 - 1e-4, "eff_rank below 1: {er}");
assert!(er <= d as f32 + 1e-4, "eff_rank above d: {er}");
}
#[test]
fn pairwise_cosine_orthogonal_basis() {
let d = 4_usize;
let z = basis(d);
let (mean, std, max) =
pairwise_cosine_stats(&z, d, d).expect("pairwise_cosine_stats should succeed");
assert!(mean.abs() < 1e-5, "mean = {mean}");
assert!(std.abs() < 1e-5, "std = {std}");
assert!(max.abs() < 1e-5, "max = {max}");
}
#[test]
fn pairwise_cosine_same_direction() {
let n = 5_usize;
let unit = [1.0_f32, 0.0, 0.0];
let z = all_same(&unit, n);
let (mean, std, max) =
pairwise_cosine_stats(&z, n, 3).expect("pairwise_cosine_stats should succeed");
assert!((mean - 1.0).abs() < 1e-5, "mean = {mean}");
assert!(std.abs() < 1e-5, "std = {std}");
assert!((max - 1.0).abs() < 1e-5, "max = {max}");
}
#[test]
fn empty_input_returns_error() {
assert!(
matches!(uniformity_loss(&[], 0, 4), Err(SslError::EmptyInput)),
"uniformity n=0"
);
assert!(
matches!(
uniformity_loss(&[1.0, 0.0, 0.0, 0.0], 1, 4),
Err(SslError::EmptyInput)
),
"uniformity n=1"
);
assert!(
matches!(
alignment_loss(&[], &[], 0, 4, 2.0),
Err(SslError::EmptyInput)
),
"alignment n=0"
);
assert!(
matches!(effective_rank(&[], 0, 4), Err(SslError::EmptyInput)),
"eff_rank n=0"
);
assert!(
matches!(pairwise_cosine_stats(&[], 0, 4), Err(SslError::EmptyInput)),
"pairwise_cosine n=0"
);
}
#[test]
fn dimension_mismatch_returns_error() {
let z = vec![1.0_f32; 4];
assert!(
matches!(
uniformity_loss(&z, 2, 3),
Err(SslError::DimensionMismatch {
expected: 6,
got: 4
})
),
"uniformity mismatch"
);
assert!(
matches!(
effective_rank(&z, 2, 3),
Err(SslError::DimensionMismatch {
expected: 6,
got: 4
})
),
"eff_rank mismatch"
);
assert!(
matches!(
pairwise_cosine_stats(&z, 2, 3),
Err(SslError::DimensionMismatch {
expected: 6,
got: 4
})
),
"pairwise_cosine mismatch"
);
let z_ok = vec![1.0_f32; 6];
assert!(
matches!(
alignment_loss(&z_ok, &z, 2, 3, 2.0),
Err(SslError::DimensionMismatch {
expected: 6,
got: 4
})
),
"alignment z2 mismatch"
);
}
#[test]
fn collapse_score_identical_pairs_finite() {
let d = 4_usize;
let n = 4_usize;
let z = basis(d);
let score = collapse_score(&z, &z, n, d).expect("collapse_score should succeed");
assert!(score.is_finite(), "score = {score}");
assert!(
score >= 0.0,
"score should be >= 0 for identical pairs, got {score}"
);
}
}