use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
#[derive(Debug, Clone)]
pub struct AuxRichnessMetrics {
pub aux_observed: bool,
pub n_nonfinite_aux: usize,
pub aux_dim: usize,
pub latent_dim: usize,
pub n_rows: usize,
pub constant_columns: Vec<usize>,
pub aux_is_discrete: bool,
pub n_distinct_levels: usize,
pub jacobian_rank: usize,
pub jacobian_rank_estimated: bool,
}
pub fn aux_richness_metrics(aux: ArrayView2<f64>, latents: ArrayView2<f64>) -> AuxRichnessMetrics {
let (n, aux_dim) = aux.dim();
let (n_z, latent_dim) = latents.dim();
assert_eq!(n, n_z, "aux and latents must share row count");
let mut n_nonfinite_aux: usize = 0;
for &v in aux.iter() {
if !v.is_finite() {
n_nonfinite_aux += 1;
}
}
let aux_observed = n_nonfinite_aux == 0;
let mut constant_columns: Vec<usize> = Vec::new();
if aux_observed && n >= 1 {
for j in 0..aux_dim {
let col = aux.column(j);
let mean: f64 = col.sum() / n as f64;
let mut var = 0.0_f64;
for &v in col.iter() {
let d = v - mean;
var += d * d;
}
var /= n as f64;
if var <= 1.0e-24 {
constant_columns.push(j);
}
}
}
let (aux_is_discrete, n_distinct_levels) = if aux_observed && n >= 1 {
let mut discrete = true;
for &v in aux.iter() {
if (v - v.round()).abs() > 0.0 {
discrete = false;
break;
}
}
if discrete {
for j in 0..aux_dim {
let col = aux.column(j);
let mut sorted: Vec<f64> = col.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
sorted.dedup_by(|a, b| (*a - *b).abs() < 1.0e-12);
if sorted.len() > 64 {
discrete = false;
break;
}
}
}
if discrete {
let mut keys: Vec<Vec<i64>> = Vec::with_capacity(n);
for i in 0..n {
let mut row = Vec::with_capacity(aux_dim);
for j in 0..aux_dim {
row.push(aux[[i, j]].round() as i64);
}
keys.push(row);
}
keys.sort();
keys.dedup();
(true, keys.len())
} else {
(false, 0)
}
} else {
(false, 0)
};
let need_rows = aux_dim.max(latent_dim) + 1;
let mut jacobian_rank_estimated = false;
let mut jacobian_rank: usize = usize::MAX;
let z_finite = latents.iter().all(|v| v.is_finite());
if aux_observed && z_finite && n >= need_rows && aux_dim >= 1 && latent_dim >= 1 {
let mut a_c = aux.to_owned();
let mut z_c = latents.to_owned();
let a_mean = a_c.mean_axis(Axis(0)).unwrap();
let z_mean = z_c.mean_axis(Axis(0)).unwrap();
for mut row in a_c.rows_mut() {
row -= &a_mean;
}
for mut row in z_c.rows_mut() {
row -= &z_mean;
}
let ata = a_c.t().dot(&a_c);
let atz = a_c.t().dot(&z_c);
let b_hat = pinv_solve(ata.view(), atz.view());
jacobian_rank = matrix_rank(b_hat.view(), 1.0e-8);
jacobian_rank_estimated = true;
}
AuxRichnessMetrics {
aux_observed,
n_nonfinite_aux,
aux_dim,
latent_dim,
n_rows: n,
constant_columns,
aux_is_discrete,
n_distinct_levels,
jacobian_rank,
jacobian_rank_estimated,
}
}
fn pinv_solve(a: ArrayView2<f64>, b: ArrayView2<f64>) -> Array2<f64> {
let (m, n) = a.dim();
assert_eq!(m, n, "pinv_solve expects a square normal-equation matrix");
let (eigvals, eigvecs) = jacobi_symmetric_eigen(a);
let max_abs = eigvals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let tol = 1.0e-12 * max_abs.max(1.0);
let k = eigvals.len();
let mut inv_diag = vec![0.0_f64; k];
for i in 0..k {
if eigvals[i].abs() > tol {
inv_diag[i] = 1.0 / eigvals[i];
}
}
let vtb = eigvecs.t().dot(&b);
let mut dvtb = vtb.clone();
for i in 0..k {
let scale = inv_diag[i];
for j in 0..dvtb.ncols() {
dvtb[[i, j]] *= scale;
}
}
eigvecs.dot(&dvtb)
}
fn jacobi_symmetric_eigen(a: ArrayView2<f64>) -> (Vec<f64>, Array2<f64>) {
let n = a.nrows();
assert_eq!(n, a.ncols());
let mut m = a.to_owned();
let mut v = Array2::<f64>::eye(n);
let max_iter = 200;
let tol = 1.0e-14;
for _ in 0..max_iter {
let mut p = 0usize;
let mut q = 1usize;
let mut max_off = 0.0_f64;
for i in 0..n {
for j in (i + 1)..n {
let av = m[[i, j]].abs();
if av > max_off {
max_off = av;
p = i;
q = j;
}
}
}
if max_off < tol {
break;
}
let app = m[[p, p]];
let aqq = m[[q, q]];
let apq = m[[p, q]];
let theta = 0.5 * (aqq - app) / apq;
let t = if theta >= 0.0 {
1.0 / (theta + (1.0 + theta * theta).sqrt())
} else {
1.0 / (theta - (1.0 + theta * theta).sqrt())
};
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
let new_pp = app - t * apq;
let new_qq = aqq + t * apq;
m[[p, p]] = new_pp;
m[[q, q]] = new_qq;
m[[p, q]] = 0.0;
m[[q, p]] = 0.0;
for i in 0..n {
if i != p && i != q {
let aip = m[[i, p]];
let aiq = m[[i, q]];
m[[i, p]] = c * aip - s * aiq;
m[[p, i]] = m[[i, p]];
m[[i, q]] = s * aip + c * aiq;
m[[q, i]] = m[[i, q]];
}
}
for i in 0..n {
let vip = v[[i, p]];
let viq = v[[i, q]];
v[[i, p]] = c * vip - s * viq;
v[[i, q]] = s * vip + c * viq;
}
}
let eigvals: Vec<f64> = (0..n).map(|i| m[[i, i]]).collect();
(eigvals, v)
}
fn matrix_rank(m: ArrayView2<f64>, tol: f64) -> usize {
let gram = m.t().dot(&m);
let (eigvals, _) = jacobi_symmetric_eigen(gram.view());
let mut rank = 0usize;
for &lam in eigvals.iter() {
if lam.max(0.0).sqrt() > tol {
rank += 1;
}
}
rank
}
#[derive(Debug, Clone)]
pub struct JacobianSparsityMetrics {
pub n_samples: usize,
pub p_features: usize,
pub latent_dim: usize,
pub mean_sparsity: f64,
pub max_abs: f64,
pub ranks: Vec<usize>,
}
pub fn jacobian_sparsity_metrics(
jacobians_flat: ArrayView2<f64>,
n_samples: usize,
zero_threshold: f64,
) -> JacobianSparsityMetrics {
let (np_rows, latent_dim) = jacobians_flat.dim();
assert!(np_rows % n_samples == 0, "rows not divisible by n_samples");
let p_features = np_rows / n_samples;
let mut max_abs = 0.0_f64;
for &v in jacobians_flat.iter() {
let a = v.abs();
if a > max_abs {
max_abs = a;
}
}
let cutoff = zero_threshold * max_abs;
let mut total_near_zero: usize = 0;
let total_entries = np_rows * latent_dim;
if max_abs > 0.0 {
for &v in jacobians_flat.iter() {
if v.abs() < cutoff {
total_near_zero += 1;
}
}
} else {
total_near_zero = total_entries;
}
let mean_sparsity = if total_entries > 0 {
total_near_zero as f64 / total_entries as f64
} else {
0.0
};
let mut ranks = Vec::with_capacity(n_samples);
for s in 0..n_samples {
let start = s * p_features;
let end = start + p_features;
let view = jacobians_flat.slice(ndarray::s![start..end, ..]);
ranks.push(matrix_rank(view, cutoff.max(1.0e-300)));
}
JacobianSparsityMetrics {
n_samples,
p_features,
latent_dim,
mean_sparsity,
max_abs,
ranks,
}
}
#[derive(Debug, Clone)]
pub struct AnchorConsistencyMetrics {
pub n_rows: usize,
pub n_atoms: usize,
pub n_anchors: usize,
pub anchors_per_atom: Vec<usize>,
}
pub fn anchor_consistency_metrics(
assignments: ArrayView2<f64>,
anchor_dominance: f64,
) -> AnchorConsistencyMetrics {
let (n, k) = assignments.dim();
let mut anchors_per_atom = vec![0_usize; k];
let mut n_anchors = 0_usize;
for i in 0..n {
let row = assignments.row(i);
let mut mass = 0.0_f64;
let mut max_val = 0.0_f64;
let mut max_j = 0_usize;
for j in 0..k {
let a = row[j].abs();
mass += a;
if a > max_val {
max_val = a;
max_j = j;
}
}
if mass > 0.0 && max_val / mass >= anchor_dominance {
n_anchors += 1;
anchors_per_atom[max_j] += 1;
}
}
AnchorConsistencyMetrics {
n_rows: n,
n_atoms: k,
n_anchors,
anchors_per_atom,
}
}
pub fn concat_decoder_blocks(blocks: &[ArrayView2<f64>]) -> Result<Array2<f64>, String> {
if blocks.is_empty() {
return Err("concat_decoder_blocks: empty block list".into());
}
let p = blocks[0].ncols();
for (i, b) in blocks.iter().enumerate() {
if b.ncols() != p {
return Err(format!(
"concat_decoder_blocks: block {} has {} cols, expected {}",
i,
b.ncols(),
p
));
}
}
let total_k: usize = blocks.iter().map(|b| b.nrows()).sum();
let mut out = Array2::<f64>::zeros((p, total_k));
let mut col = 0_usize;
for b in blocks {
for k in 0..b.nrows() {
for row in 0..p {
out[[row, col]] = b[[k, row]];
}
col += 1;
}
}
Ok(out)
}
pub fn jacobian_view_from_linear(decoder_pk: ArrayView2<f64>) -> Array2<f64> {
decoder_pk.to_owned()
}
pub fn count_nonfinite_1d(v: ArrayView1<f64>) -> usize {
v.iter().filter(|x| !x.is_finite()).count()
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn aux_richness_passes_on_rich_2d_aux() {
let aux = array![
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[2.0, 0.0],
[2.0, 1.0],
[0.0, 2.0],
[1.0, 2.0],
[2.0, 2.0],
];
let lat = array![
[0.10, 0.05],
[0.02, 1.01],
[1.05, 0.04],
[1.01, 1.02],
[2.03, 0.07],
[2.04, 1.01],
[0.05, 2.02],
[1.02, 2.01],
[2.01, 2.05],
];
let m = aux_richness_metrics(aux.view(), lat.view());
assert!(m.aux_observed);
assert_eq!(m.aux_dim, 2);
assert_eq!(m.latent_dim, 2);
assert!(m.constant_columns.is_empty());
assert!(m.aux_is_discrete);
assert!(m.n_distinct_levels >= 3);
assert!(m.jacobian_rank_estimated);
assert_eq!(m.jacobian_rank, 2);
}
#[test]
fn aux_richness_flags_constant_aux() {
let aux = Array2::<f64>::zeros((20, 1));
let mut lat = Array2::<f64>::zeros((20, 2));
for i in 0..20 {
lat[[i, 0]] = i as f64;
lat[[i, 1]] = (i as f64).cos();
}
let m = aux_richness_metrics(aux.view(), lat.view());
assert_eq!(m.aux_dim, 1);
assert_eq!(m.latent_dim, 2);
assert_eq!(m.constant_columns, vec![0_usize]);
}
#[test]
fn aux_richness_flags_nonfinite_aux() {
let mut aux = Array2::<f64>::zeros((10, 1));
aux[[3, 0]] = f64::NAN;
let lat = Array2::<f64>::zeros((10, 1));
let m = aux_richness_metrics(aux.view(), lat.view());
assert!(!m.aux_observed);
assert_eq!(m.n_nonfinite_aux, 1);
}
#[test]
fn jacobian_sparsity_passes_on_diagonal() {
let j = array![
[1.0_f64, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0]
];
let m = jacobian_sparsity_metrics(j.view(), 1, 1.0e-3);
assert_eq!(m.p_features, 4);
assert_eq!(m.latent_dim, 3);
assert!(m.mean_sparsity > 0.5);
assert_eq!(m.ranks, vec![3_usize]);
}
#[test]
fn jacobian_sparsity_dense_has_low_sparsity() {
let mut j = Array2::<f64>::zeros((4, 3));
for i in 0..4 {
for k in 0..3 {
j[[i, k]] = 1.0 + 0.1 * (i + k) as f64;
}
}
let m = jacobian_sparsity_metrics(j.view(), 1, 1.0e-3);
assert!(m.mean_sparsity < 0.1);
}
#[test]
fn anchor_consistency_three_clusters() {
let mut a = Array2::<f64>::from_elem((9, 3), 0.01);
for i in 0..3 {
a[[i, 0]] = 1.0;
}
for i in 3..6 {
a[[i, 1]] = 1.0;
}
for i in 6..9 {
a[[i, 2]] = 1.0;
}
let m = anchor_consistency_metrics(a.view(), 0.95);
assert_eq!(m.n_atoms, 3);
assert_eq!(m.n_anchors, 9);
assert_eq!(m.anchors_per_atom, vec![3, 3, 3]);
}
#[test]
fn anchor_consistency_uniform_has_zero_anchors() {
let a = Array2::<f64>::from_elem((10, 4), 0.25);
let m = anchor_consistency_metrics(a.view(), 0.95);
assert_eq!(m.n_anchors, 0);
assert_eq!(m.anchors_per_atom, vec![0, 0, 0, 0]);
}
}