Skip to main content

gam_identifiability/
kernel.rs

1//! Numeric kernels for identifiability-theorem diagnostics.
2//!
3//! The kernels return scalar facts for iVAE auxiliary richness, decoder
4//! Jacobian sparsity, and manifold-SAE anchor coverage. Rust, Python, and CLI
5//! layers turn those facts into user-facing reports.
6
7use ndarray::{Array2, ArrayView2, Axis};
8
9/// Maximum sweeps for the cyclic-by-largest-pivot Jacobi eigensolver.
10///
11/// Jacobi converges quadratically once off-diagonals are small, and the
12/// matrices here are tiny (< 64×64 identifiability normal-equation blocks),
13/// so a converged solve needs only a handful of sweeps. 200 is a generous
14/// safety cap that the `JACOBI_OFFDIAG_TOL` break almost always reaches
15/// first; it only bounds pathological non-converging inputs.
16const JACOBI_MAX_SWEEPS: usize = 200;
17
18/// Off-diagonal magnitude below which the Jacobi sweep is considered
19/// converged. `1e-14` is two orders above f64 unit roundoff, tight enough
20/// that residual off-diagonal mass cannot perturb the rank/pseudo-inverse
21/// decisions these diagnostics make.
22const JACOBI_OFFDIAG_TOL: f64 = 1.0e-14;
23
24/// Maximum distinct values per aux column for it to count as "discrete".
25///
26/// An integer-valued column with at most this many levels is treated as a
27/// categorical/discrete covariate (the regime the iVAE auxiliary-richness
28/// theorem is stated for); above it the column is treated as continuous.
29const AUX_DISCRETE_MAX_LEVELS: usize = 64;
30
31/// Absolute gap below which two aux values count as the same distinct level.
32/// Integer-valued aux data dedups exactly; this only guards float dust from
33/// the `round()` check above.
34const AUX_LEVEL_DEDUP_TOL: f64 = 1.0e-12;
35
36/// Scalar facts about the auxiliary covariate / latent pair feeding an iVAE.
37#[derive(Debug, Clone)]
38pub struct AuxRichnessMetrics {
39    /// `true` iff every entry of the aux matrix is finite.
40    pub aux_observed: bool,
41    /// Number of non-finite entries in the aux matrix.
42    pub n_nonfinite_aux: usize,
43    /// Aux dimension (column count).
44    pub aux_dim: usize,
45    /// Latent dimension (column count of `latents`).
46    pub latent_dim: usize,
47    /// Row count `N`.
48    pub n_rows: usize,
49    /// Column indices (sorted, ascending) that are constant across rows.
50    pub constant_columns: Vec<usize>,
51    /// `true` iff aux is integer-valued and every column has <= 64 unique values.
52    pub aux_is_discrete: bool,
53    /// Joint distinct-row count of aux (only computed when `aux_is_discrete`).
54    pub n_distinct_levels: usize,
55    /// Empirical rank of the least-squares Jacobian `B = (Aᵀ A)^{-1} Aᵀ Z`.
56    /// `usize::MAX` sentinel if the rank could not be estimated (e.g. too few rows).
57    pub jacobian_rank: usize,
58    /// True iff we had enough rows + finite data to estimate the Jacobian rank.
59    pub jacobian_rank_estimated: bool,
60}
61
62/// Compute the iVAE auxiliary-richness numeric facts.
63///
64/// `aux` is `(N, aux_dim)`; `latents` is `(N, latent_dim)`. The empirical
65/// Jacobian is the linear-regression slope ``B`` of ``Z ~ A`` (centred). For
66/// a non-linear iVAE encoder this is a first-order surrogate; a deficient
67/// rank here forecloses identifiability regardless of nonlinear postproc.
68pub fn aux_richness_metrics(aux: ArrayView2<f64>, latents: ArrayView2<f64>) -> AuxRichnessMetrics {
69    let (n, aux_dim) = aux.dim();
70    let (n_z, latent_dim) = latents.dim();
71    assert_eq!(n, n_z, "aux and latents must share row count");
72
73    // 1. Finiteness.
74    let mut n_nonfinite_aux: usize = 0;
75    for &v in aux.iter() {
76        if !v.is_finite() {
77            n_nonfinite_aux += 1;
78        }
79    }
80    let aux_observed = n_nonfinite_aux == 0;
81
82    // 2. Constant columns. Skip non-finite columns entirely (they will be
83    //    flagged by `aux_observed=false`).
84    let mut constant_columns: Vec<usize> = Vec::new();
85    if aux_observed && n >= 1 {
86        for j in 0..aux_dim {
87            let col = aux.column(j);
88            // sample std (population formula — exact zero iff constant).
89            let mean: f64 = col.sum() / n as f64;
90            let mut var = 0.0_f64;
91            for &v in col.iter() {
92                let d = v - mean;
93                var += d * d;
94            }
95            var /= n as f64;
96            if var <= 1.0e-24 {
97                constant_columns.push(j);
98            }
99        }
100    }
101
102    // 3. Discreteness + distinct level count.
103    let (aux_is_discrete, n_distinct_levels) = if aux_observed && n >= 1 {
104        let mut discrete = true;
105        for &v in aux.iter() {
106            if (v - v.round()).abs() > 0.0 {
107                discrete = false;
108                break;
109            }
110        }
111        if discrete {
112            for j in 0..aux_dim {
113                let col = aux.column(j);
114                let mut sorted: Vec<f64> = col.iter().copied().collect();
115                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
116                sorted.dedup_by(|a, b| (*a - *b).abs() < AUX_LEVEL_DEDUP_TOL);
117                if sorted.len() > AUX_DISCRETE_MAX_LEVELS {
118                    discrete = false;
119                    break;
120                }
121            }
122        }
123        if discrete {
124            // Joint distinct rows.
125            let mut keys: Vec<Vec<i64>> = Vec::with_capacity(n);
126            for i in 0..n {
127                let mut row = Vec::with_capacity(aux_dim);
128                for j in 0..aux_dim {
129                    row.push(aux[[i, j]].round() as i64);
130                }
131                keys.push(row);
132            }
133            keys.sort();
134            keys.dedup();
135            (true, keys.len())
136        } else {
137            (false, 0)
138        }
139    } else {
140        (false, 0)
141    };
142
143    // 4. Empirical Jacobian rank.
144    let need_rows = aux_dim.max(latent_dim) + 1;
145    let mut jacobian_rank_estimated = false;
146    let mut jacobian_rank: usize = usize::MAX;
147    let z_finite = latents.iter().all(|v| v.is_finite());
148    if aux_observed && z_finite && n >= need_rows && aux_dim >= 1 && latent_dim >= 1 {
149        // Centre A and Z.
150        let mut a_c = aux.to_owned();
151        let mut z_c = latents.to_owned();
152        let a_mean = a_c.mean_axis(Axis(0)).unwrap();
153        let z_mean = z_c.mean_axis(Axis(0)).unwrap();
154        for mut row in a_c.rows_mut() {
155            row -= &a_mean;
156        }
157        for mut row in z_c.rows_mut() {
158            row -= &z_mean;
159        }
160        // Solve B = (Aᵀ A)^{+} Aᵀ Z via SVD on (Aᵀ A) — small (aux_dim x aux_dim).
161        let ata = a_c.t().dot(&a_c);
162        let atz = a_c.t().dot(&z_c);
163        let b_hat = pinv_solve(ata.view(), atz.view());
164        jacobian_rank = matrix_rank(b_hat.view(), 1.0e-8);
165        jacobian_rank_estimated = true;
166    }
167
168    AuxRichnessMetrics {
169        aux_observed,
170        n_nonfinite_aux,
171        aux_dim,
172        latent_dim,
173        n_rows: n,
174        constant_columns,
175        aux_is_discrete,
176        n_distinct_levels,
177        jacobian_rank,
178        jacobian_rank_estimated,
179    }
180}
181
182/// Moore-Penrose pseudo-inverse times rhs via SVD. Stable for the small
183/// `(aux_dim x aux_dim)` normal-equation matrices encountered here. Tolerance
184/// is `1e-12 * max_singular_value`.
185fn pinv_solve(a: ArrayView2<f64>, b: ArrayView2<f64>) -> Array2<f64> {
186    let (m, n) = a.dim();
187    assert_eq!(m, n, "pinv_solve expects a square normal-equation matrix");
188    // Symmetric eigen-decomposition via Jacobi (matrices are small, < 64x64
189    // in any realistic identifiability check — Jacobi is robust and avoids
190    // pulling in a heavier dependency for this code path).
191    let (eigvals, eigvecs) = jacobi_symmetric_eigen(a);
192    let max_abs = eigvals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
193    let tol = 1.0e-12 * max_abs.max(1.0);
194    // Build A^+ = V diag(1/λ_i if |λ_i|>tol else 0) Vᵀ.
195    let k = eigvals.len();
196    let mut inv_diag = vec![0.0_f64; k];
197    for i in 0..k {
198        if eigvals[i].abs() > tol {
199            inv_diag[i] = 1.0 / eigvals[i];
200        }
201    }
202    // A^+ b  =  V D Vᵀ b  where D = diag(inv_diag).
203    let vtb = eigvecs.t().dot(&b);
204    let mut dvtb = vtb.clone();
205    for i in 0..k {
206        let scale = inv_diag[i];
207        for j in 0..dvtb.ncols() {
208            dvtb[[i, j]] *= scale;
209        }
210    }
211    eigvecs.dot(&dvtb)
212}
213
214/// Jacobi rotation eigen-decomposition for small symmetric matrices.
215/// Returns `(eigenvalues, eigenvectors)` with `A = V diag(λ) Vᵀ`.
216fn jacobi_symmetric_eigen(a: ArrayView2<f64>) -> (Vec<f64>, Array2<f64>) {
217    let n = a.nrows();
218    assert_eq!(n, a.ncols());
219    let mut m = a.to_owned();
220    let mut v = Array2::<f64>::eye(n);
221    for _ in 0..JACOBI_MAX_SWEEPS {
222        // Find largest off-diagonal.
223        let mut p = 0usize;
224        let mut q = 1usize;
225        let mut max_off = 0.0_f64;
226        for i in 0..n {
227            for j in (i + 1)..n {
228                let av = m[[i, j]].abs();
229                if av > max_off {
230                    max_off = av;
231                    p = i;
232                    q = j;
233                }
234            }
235        }
236        if max_off < JACOBI_OFFDIAG_TOL {
237            break;
238        }
239        let app = m[[p, p]];
240        let aqq = m[[q, q]];
241        let apq = m[[p, q]];
242        let theta = 0.5 * (aqq - app) / apq;
243        let t = if theta >= 0.0 {
244            1.0 / (theta + (1.0 + theta * theta).sqrt())
245        } else {
246            1.0 / (theta - (1.0 + theta * theta).sqrt())
247        };
248        let c = 1.0 / (1.0 + t * t).sqrt();
249        let s = t * c;
250        // Update M.
251        let new_pp = app - t * apq;
252        let new_qq = aqq + t * apq;
253        m[[p, p]] = new_pp;
254        m[[q, q]] = new_qq;
255        m[[p, q]] = 0.0;
256        m[[q, p]] = 0.0;
257        for i in 0..n {
258            if i != p && i != q {
259                let aip = m[[i, p]];
260                let aiq = m[[i, q]];
261                m[[i, p]] = c * aip - s * aiq;
262                m[[p, i]] = m[[i, p]];
263                m[[i, q]] = s * aip + c * aiq;
264                m[[q, i]] = m[[i, q]];
265            }
266        }
267        // Update V.
268        for i in 0..n {
269            let vip = v[[i, p]];
270            let viq = v[[i, q]];
271            v[[i, p]] = c * vip - s * viq;
272            v[[i, q]] = s * vip + c * viq;
273        }
274    }
275    let eigvals: Vec<f64> = (0..n).map(|i| m[[i, i]]).collect();
276    (eigvals, v)
277}
278
279/// Numeric rank of `m` via its singular values (computed as
280/// `sqrt(eig(MᵀM))`). `tol` is absolute; entries with singular value
281/// `<= tol` are considered zero.
282fn matrix_rank(m: ArrayView2<f64>, tol: f64) -> usize {
283    let gram = m.t().dot(&m);
284    let (eigvals, _) = jacobi_symmetric_eigen(gram.view());
285    let mut rank = 0usize;
286    for &lam in eigvals.iter() {
287        if lam.max(0.0).sqrt() > tol {
288            rank += 1;
289        }
290    }
291    rank
292}
293
294/// Scalar facts about decoder Jacobian sparsity.
295#[derive(Debug, Clone)]
296pub struct JacobianSparsityMetrics {
297    /// `(N_samples, P, latent_dim)` shape elements.
298    pub n_samples: usize,
299    pub p_features: usize,
300    pub latent_dim: usize,
301    /// Fraction of entries with `|J| < zero_threshold * max|J|`, averaged
302    /// across samples.
303    pub mean_sparsity: f64,
304    /// Maximum absolute entry of the Jacobian stack.
305    pub max_abs: f64,
306    /// Per-sample numeric column rank (each entry in `[0, latent_dim]`).
307    pub ranks: Vec<usize>,
308}
309
310/// Compute mean sparsity and per-sample rank of a stack of Jacobians.
311///
312/// `jacobians` is `(N_samples, P, latent_dim)`, flattened to a `(N*P, latent_dim)`
313/// row-major view. `n_samples` is the leading axis size.
314pub fn jacobian_sparsity_metrics(
315    jacobians_flat: ArrayView2<f64>,
316    n_samples: usize,
317    zero_threshold: f64,
318) -> JacobianSparsityMetrics {
319    let (np_rows, latent_dim) = jacobians_flat.dim();
320    assert!(np_rows % n_samples == 0, "rows not divisible by n_samples");
321    let p_features = np_rows / n_samples;
322
323    // Max abs.
324    let mut max_abs = 0.0_f64;
325    for &v in jacobians_flat.iter() {
326        let a = v.abs();
327        if a > max_abs {
328            max_abs = a;
329        }
330    }
331    let cutoff = zero_threshold * max_abs;
332
333    let mut total_near_zero: usize = 0;
334    let total_entries = np_rows * latent_dim;
335    if max_abs > 0.0 {
336        for &v in jacobians_flat.iter() {
337            if v.abs() < cutoff {
338                total_near_zero += 1;
339            }
340        }
341    } else {
342        // All zero Jacobian: maximally "sparse" but degenerate; caller flags this.
343        total_near_zero = total_entries;
344    }
345    let mean_sparsity = if total_entries > 0 {
346        total_near_zero as f64 / total_entries as f64
347    } else {
348        0.0
349    };
350
351    // Per-sample rank.
352    let mut ranks = Vec::with_capacity(n_samples);
353    for s in 0..n_samples {
354        let start = s * p_features;
355        let end = start + p_features;
356        let view = jacobians_flat.slice(ndarray::s![start..end, ..]);
357        // Use `cutoff` (absolute) as the rank tolerance: an entry below it is
358        // considered zero, which matches the sparsity decision.
359        ranks.push(matrix_rank(view, cutoff.max(1.0e-300)));
360    }
361
362    JacobianSparsityMetrics {
363        n_samples,
364        p_features,
365        latent_dim,
366        mean_sparsity,
367        max_abs,
368        ranks,
369    }
370}
371
372/// Scalar facts about the per-atom anchor structure of an assignment matrix.
373#[derive(Debug, Clone)]
374pub struct AnchorConsistencyMetrics {
375    /// `N` (row count of the assignment matrix).
376    pub n_rows: usize,
377    /// `K` (column count = atom count).
378    pub n_atoms: usize,
379    /// Total number of anchor rows
380    /// (rows with `max|A|/sum|A| >= anchor_dominance`).
381    pub n_anchors: usize,
382    /// Per-atom anchor count `(length K)`: for each anchor row, the
383    /// dominant atom is tallied.
384    pub anchors_per_atom: Vec<usize>,
385}
386
387/// Compute anchor counts from an assignment matrix.
388///
389/// `assignments` is `(N, K)`. A row is an anchor when its maximum-magnitude
390/// entry contributes at least `anchor_dominance ∈ (0, 1]` of the row's L1
391/// mass. Zero-mass rows are *not* anchors.
392pub fn anchor_consistency_metrics(
393    assignments: ArrayView2<f64>,
394    anchor_dominance: f64,
395) -> AnchorConsistencyMetrics {
396    let (n, k) = assignments.dim();
397    let mut anchors_per_atom = vec![0_usize; k];
398    let mut n_anchors = 0_usize;
399    for i in 0..n {
400        let row = assignments.row(i);
401        let mut mass = 0.0_f64;
402        let mut max_val = 0.0_f64;
403        let mut max_j = 0_usize;
404        for j in 0..k {
405            let a = row[j].abs();
406            mass += a;
407            if a > max_val {
408                max_val = a;
409                max_j = j;
410            }
411        }
412        if mass > 0.0 && max_val / mass >= anchor_dominance {
413            n_anchors += 1;
414            anchors_per_atom[max_j] += 1;
415        }
416    }
417    AnchorConsistencyMetrics {
418        n_rows: n,
419        n_atoms: k,
420        n_anchors,
421        anchors_per_atom,
422    }
423}
424
425/// Stack a list of per-atom decoder blocks (each shape `(basis_size_k, P)`)
426/// column-wise into a single Jacobian of shape `(P, sum_k basis_size_k)`.
427/// Used by the Python diagnostics dispatcher to feed
428/// [`jacobian_sparsity_metrics`] from a `ManifoldSAE.decoder_blocks` payload
429/// without doing the concatenation in Python.
430pub fn concat_decoder_blocks(blocks: &[ArrayView2<f64>]) -> Result<Array2<f64>, String> {
431    if blocks.is_empty() {
432        return Err("concat_decoder_blocks: empty block list".into());
433    }
434    let p = blocks[0].ncols();
435    for (i, b) in blocks.iter().enumerate() {
436        if b.ncols() != p {
437            return Err(format!(
438                "concat_decoder_blocks: block {} has {} cols, expected {}",
439                i,
440                b.ncols(),
441                p
442            ));
443        }
444    }
445    let total_k: usize = blocks.iter().map(|b| b.nrows()).sum();
446    let mut out = Array2::<f64>::zeros((p, total_k));
447    let mut col = 0_usize;
448    for b in blocks {
449        // Block has shape (basis_size, P); transpose into columns of out.
450        for k in 0..b.nrows() {
451            for row in 0..p {
452                out[[row, col]] = b[[k, row]];
453            }
454            col += 1;
455        }
456    }
457    Ok(out)
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use ndarray::array;
464
465    #[test]
466    fn aux_richness_passes_on_rich_2d_aux() {
467        let aux = array![
468            [0.0, 0.0],
469            [0.0, 1.0],
470            [1.0, 0.0],
471            [1.0, 1.0],
472            [2.0, 0.0],
473            [2.0, 1.0],
474            [0.0, 2.0],
475            [1.0, 2.0],
476            [2.0, 2.0],
477        ];
478        let lat = array![
479            [0.10, 0.05],
480            [0.02, 1.01],
481            [1.05, 0.04],
482            [1.01, 1.02],
483            [2.03, 0.07],
484            [2.04, 1.01],
485            [0.05, 2.02],
486            [1.02, 2.01],
487            [2.01, 2.05],
488        ];
489        let m = aux_richness_metrics(aux.view(), lat.view());
490        assert!(m.aux_observed);
491        assert_eq!(m.aux_dim, 2);
492        assert_eq!(m.latent_dim, 2);
493        assert!(m.constant_columns.is_empty());
494        assert!(m.aux_is_discrete);
495        assert!(m.n_distinct_levels >= 3);
496        assert!(m.jacobian_rank_estimated);
497        assert_eq!(m.jacobian_rank, 2);
498    }
499
500    #[test]
501    fn aux_richness_flags_constant_aux() {
502        let aux = Array2::<f64>::zeros((20, 1));
503        let mut lat = Array2::<f64>::zeros((20, 2));
504        for i in 0..20 {
505            lat[[i, 0]] = i as f64;
506            lat[[i, 1]] = (i as f64).cos();
507        }
508        let m = aux_richness_metrics(aux.view(), lat.view());
509        assert_eq!(m.aux_dim, 1);
510        assert_eq!(m.latent_dim, 2);
511        assert_eq!(m.constant_columns, vec![0_usize]);
512    }
513
514    #[test]
515    fn aux_richness_flags_nonfinite_aux() {
516        let mut aux = Array2::<f64>::zeros((10, 1));
517        aux[[3, 0]] = f64::NAN;
518        let lat = Array2::<f64>::zeros((10, 1));
519        let m = aux_richness_metrics(aux.view(), lat.view());
520        assert!(!m.aux_observed);
521        assert_eq!(m.n_nonfinite_aux, 1);
522    }
523
524    #[test]
525    fn jacobian_sparsity_passes_on_diagonal() {
526        // P=4, K=3, n_samples=1; mostly zero.
527        let j = array![
528            [1.0_f64, 0.0, 0.0],
529            [0.0, 1.0, 0.0],
530            [0.0, 0.0, 1.0],
531            [0.0, 0.0, 0.0]
532        ];
533        let m = jacobian_sparsity_metrics(j.view(), 1, 1.0e-3);
534        assert_eq!(m.p_features, 4);
535        assert_eq!(m.latent_dim, 3);
536        assert!(m.mean_sparsity > 0.5);
537        assert_eq!(m.ranks, vec![3_usize]);
538    }
539
540    #[test]
541    fn jacobian_sparsity_dense_has_low_sparsity() {
542        let mut j = Array2::<f64>::zeros((4, 3));
543        for i in 0..4 {
544            for k in 0..3 {
545                j[[i, k]] = 1.0 + 0.1 * (i + k) as f64;
546            }
547        }
548        let m = jacobian_sparsity_metrics(j.view(), 1, 1.0e-3);
549        assert!(m.mean_sparsity < 0.1);
550    }
551
552    #[test]
553    fn anchor_consistency_three_clusters() {
554        let mut a = Array2::<f64>::from_elem((9, 3), 0.01);
555        for i in 0..3 {
556            a[[i, 0]] = 1.0;
557        }
558        for i in 3..6 {
559            a[[i, 1]] = 1.0;
560        }
561        for i in 6..9 {
562            a[[i, 2]] = 1.0;
563        }
564        let m = anchor_consistency_metrics(a.view(), 0.95);
565        assert_eq!(m.n_atoms, 3);
566        assert_eq!(m.n_anchors, 9);
567        assert_eq!(m.anchors_per_atom, vec![3, 3, 3]);
568    }
569
570    #[test]
571    fn anchor_consistency_uniform_has_zero_anchors() {
572        let a = Array2::<f64>::from_elem((10, 4), 0.25);
573        let m = anchor_consistency_metrics(a.view(), 0.95);
574        assert_eq!(m.n_anchors, 0);
575        assert_eq!(m.anchors_per_atom, vec![0, 0, 0, 0]);
576    }
577}