Skip to main content

gam_identifiability/
precondition.rs

1//! Runnable identifiability-theorem precondition checks.
2//!
3//! Python, the CLI, and other bindings consume the same `Vec<TheoremResult>`
4//! from this module. Thresholds are explicit call inputs via [`Thresholds`].
5
6use std::collections::BTreeMap;
7
8use ndarray::{Array2, ArrayView2, Axis};
9use serde::{Deserialize, Serialize};
10
11use gam_linalg::faer_ndarray::FaerSvd;
12
13/// Below this std the aux column is "constant" (Khemakhem 2107.10098 Thm. 1
14/// — a constant column carries zero conditioning information).
15pub const DEFAULT_IVAE_AUX_VAR_FLOOR: f64 = 1.0e-9;
16
17/// Tolerance used by the truncated-SVD rank routine for the aux column-rank
18/// check (Khemakhem 2107.10098 §3 parametric-richness assumption).
19pub const DEFAULT_IVAE_AUX_RANK_RTOL: f64 = 1.0e-8;
20
21/// Khemakhem 2107.10098 §3: encoder must be "non-trivially nonlinear" — bare
22/// linear (1 affine layer) does not satisfy the universal-approximation
23/// argument that pushes identifiability through the encoder.
24pub const DEFAULT_IVAE_MIN_ENCODER_LAYERS: i64 = 2;
25
26/// Lachapelle 2401.04890 §2.4: at L1 equilibrium >=50% of the decoder
27/// Jacobian entries on the free block are near zero.
28pub const DEFAULT_MECH_SPARSITY_FRACTION: f64 = 0.50;
29
30/// Relative threshold for "near-zero" decoder entry — mirrors the paper's
31/// column-relative thresholding.
32pub const DEFAULT_MECH_SPARSITY_ZERO_TOL: f64 = 1.0e-3;
33
34/// Khemakhem App. A.3: encoder activation variance must be bounded. We treat
35/// activation variances above this ceiling as a hard fail.
36pub const DEFAULT_RANDPROJ_VAR_CEILING: f64 = 1.0e6;
37
38/// Variances above this floor (but below the ceiling) downgrade the random
39/// projection check to a warn — encoder is large but not yet unbounded.
40pub const DEFAULT_RANDPROJ_VAR_WARN: f64 = 1.0e3;
41
42/// Tunable thresholds — every field has a paper-backed default and can be
43/// overridden per call (constructor kwargs in Python, struct literal here).
44#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
45pub struct Thresholds {
46    pub ivae_aux_var_floor: f64,
47    pub ivae_aux_rank_rtol: f64,
48    pub ivae_min_encoder_layers: i64,
49    pub mech_sparsity_fraction: f64,
50    pub mech_sparsity_zero_tol: f64,
51    pub randproj_var_warn: f64,
52    pub randproj_var_ceiling: f64,
53}
54
55impl Default for Thresholds {
56    fn default() -> Self {
57        Self {
58            ivae_aux_var_floor: DEFAULT_IVAE_AUX_VAR_FLOOR,
59            ivae_aux_rank_rtol: DEFAULT_IVAE_AUX_RANK_RTOL,
60            ivae_min_encoder_layers: DEFAULT_IVAE_MIN_ENCODER_LAYERS,
61            mech_sparsity_fraction: DEFAULT_MECH_SPARSITY_FRACTION,
62            mech_sparsity_zero_tol: DEFAULT_MECH_SPARSITY_ZERO_TOL,
63            randproj_var_warn: DEFAULT_RANDPROJ_VAR_WARN,
64            randproj_var_ceiling: DEFAULT_RANDPROJ_VAR_CEILING,
65        }
66    }
67}
68
69/// Outcome of a single per-theorem precondition check.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TheoremResult {
72    pub theorem_name: String,
73    pub status: TheoremStatus,
74    pub reason: String,
75    pub metric: BTreeMap<String, f64>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79#[serde(rename_all = "lowercase")]
80pub enum TheoremStatus {
81    Pass,
82    Warn,
83    Fail,
84}
85
86impl TheoremStatus {
87    fn rank(&self) -> u8 {
88        match self {
89            TheoremStatus::Pass => 0,
90            TheoremStatus::Warn => 1,
91            TheoremStatus::Fail => 2,
92        }
93    }
94    fn worse(self, other: TheoremStatus) -> TheoremStatus {
95        if other.rank() > self.rank() {
96            other
97        } else {
98            self
99        }
100    }
101}
102
103/// Caller-supplied summary of the fit. All numerical evidence is in here —
104/// the Rust check needs no Python objects.
105#[derive(Debug, Clone, Default, Deserialize, Serialize)]
106pub struct FitSummary {
107    /// Auxiliary covariates of shape `(n_obs, n_supervised)`. Row-major. If
108    /// `None`, the iVAE check downgrades to a warn-with-skip.
109    pub aux: Option<Vec<Vec<f64>>>,
110    /// Declared supervised latent dim.
111    pub n_supervised: Option<i64>,
112    /// Declared free latent dim.
113    pub n_free: Option<i64>,
114    /// Decoder of shape `(n_features, n_supervised + n_free)`.
115    pub decoder: Option<Vec<Vec<f64>>>,
116    /// Number of affine (Linear) layers in the encoder.
117    pub encoder_depth: Option<i64>,
118    /// Sparsity penalty weight used at fit time.
119    pub mech_sparsity_weight: Option<f64>,
120    /// Latent samples / encoder activations of shape `(n_obs, latent_dim)`.
121    pub activations: Option<Vec<Vec<f64>>>,
122    /// Ground-truth latent dim (e.g. from a simulator). Optional.
123    pub ground_truth_dim: Option<i64>,
124    /// Threshold overrides (defaults to paper-cited values when missing).
125    #[serde(default)]
126    pub thresholds: Option<Thresholds>,
127}
128
129fn rows_to_array(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
130    if rows.is_empty() {
131        return Ok(Array2::<f64>::zeros((0, 0)));
132    }
133    let ncols = rows[0].len();
134    for (i, row) in rows.iter().enumerate() {
135        if row.len() != ncols {
136            return Err(format!(
137                "ragged matrix: row 0 has {ncols} cols but row {i} has {} cols",
138                row.len()
139            ));
140        }
141    }
142    let nrows = rows.len();
143    let mut flat = Vec::with_capacity(nrows * ncols);
144    for row in rows {
145        flat.extend_from_slice(row);
146    }
147    Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| e.to_string())
148}
149
150fn column_std(mat: ArrayView2<f64>) -> Vec<f64> {
151    let n = mat.nrows() as f64;
152    if n <= 0.0 {
153        return vec![0.0; mat.ncols()];
154    }
155    let mut out = Vec::with_capacity(mat.ncols());
156    for col in mat.axis_iter(Axis(1)) {
157        let mean = col.sum() / n;
158        let mut var = 0.0_f64;
159        for v in col.iter() {
160            let d = v - mean;
161            var += d * d;
162        }
163        out.push((var / n).sqrt());
164    }
165    out
166}
167
168fn column_var(mat: ArrayView2<f64>) -> Vec<f64> {
169    column_std(mat).into_iter().map(|s| s * s).collect()
170}
171
172/// Tolerance-based rank via faer SVD: count singular values larger than
173/// `rtol * max_singular_value`.
174fn matrix_rank(mat: ArrayView2<f64>, rtol: f64) -> Result<usize, String> {
175    if mat.nrows() == 0 || mat.ncols() == 0 {
176        return Ok(0);
177    }
178    let owned = mat.to_owned();
179    let (_u, sigma, _vt) = owned.svd(false, false).map_err(|e| format!("{e:?}"))?;
180    if sigma.is_empty() {
181        return Ok(0);
182    }
183    let smax = sigma.iter().cloned().fold(0.0_f64, f64::max);
184    if smax <= 0.0 {
185        return Ok(0);
186    }
187    let cutoff = smax * rtol;
188    Ok(sigma.iter().filter(|s| **s > cutoff).count())
189}
190
191/// Khemakhem 2107.10098 Theorem 1 preconditions.
192pub fn check_ivae(summary: &FitSummary, thr: &Thresholds) -> TheoremResult {
193    let mut metric: BTreeMap<String, f64> = BTreeMap::new();
194    let mut issues: Vec<String> = Vec::new();
195    let mut status = TheoremStatus::Pass;
196
197    let aux_rows = match summary.aux.as_ref() {
198        Some(a) => a,
199        None => {
200            return TheoremResult {
201                theorem_name: "iVAE".to_string(),
202                status: TheoremStatus::Warn,
203                reason: "iVAE check skipped: no aux provided in fit summary.".to_string(),
204                metric,
205            };
206        }
207    };
208    let n_supervised = match summary.n_supervised {
209        Some(v) => v,
210        None => {
211            return TheoremResult {
212                theorem_name: "iVAE".to_string(),
213                status: TheoremStatus::Warn,
214                reason: "iVAE check skipped: n_supervised missing.".to_string(),
215                metric,
216            };
217        }
218    };
219
220    let aux = match rows_to_array(aux_rows) {
221        Ok(a) => a,
222        Err(e) => {
223            return TheoremResult {
224                theorem_name: "iVAE".to_string(),
225                status: TheoremStatus::Fail,
226                reason: format!("aux is malformed: {e}"),
227                metric,
228            };
229        }
230    };
231
232    let stds = column_std(aux.view());
233    let min_std = stds.iter().cloned().fold(f64::INFINITY, f64::min);
234    metric.insert(
235        "aux_min_std".to_string(),
236        if stds.is_empty() { 0.0 } else { min_std },
237    );
238    if stds.is_empty() || stds.iter().any(|s| *s <= thr.ivae_aux_var_floor) {
239        let zeros: Vec<usize> = stds
240            .iter()
241            .enumerate()
242            .filter(|(_, s)| **s <= thr.ivae_aux_var_floor)
243            .map(|(i, _)| i)
244            .collect();
245        issues.push(format!(
246            "iVAE identifiability requires auxiliary covariate variation; \
247             aux axes {zeros:?} are constant across observations (min std \
248             {min_std:.3e} <= {:.0e}); Khemakhem 2107.10098 Thm. 1 \
249             conditioning rank is zero.",
250            thr.ivae_aux_var_floor,
251        ));
252        status = status.worse(TheoremStatus::Fail);
253    }
254
255    let rank = match matrix_rank(aux.view(), thr.ivae_aux_rank_rtol) {
256        Ok(r) => r,
257        Err(e) => {
258            return TheoremResult {
259                theorem_name: "iVAE".to_string(),
260                status: TheoremStatus::Fail,
261                reason: format!("aux SVD failed: {e}"),
262                metric,
263            };
264        }
265    };
266    metric.insert("aux_column_rank".to_string(), rank as f64);
267    metric.insert("n_supervised".to_string(), n_supervised as f64);
268    if (rank as i64) < n_supervised {
269        issues.push(format!(
270            "aux column rank {rank} < n_supervised={n_supervised}: \
271             Khemakhem 2107.10098 §3 parametric-richness fails."
272        ));
273        status = status.worse(TheoremStatus::Fail);
274    }
275
276    match summary.encoder_depth {
277        None => {
278            issues.push(
279                "encoder depth unknown — cannot verify the >=2-layer \
280                 requirement of Khemakhem 2107.10098 §3."
281                    .to_string(),
282            );
283            status = status.worse(TheoremStatus::Warn);
284        }
285        Some(depth) => {
286            metric.insert("encoder_depth".to_string(), depth as f64);
287            if depth < 1 {
288                issues.push(format!("encoder depth {depth} < 1; no encoder is present."));
289                status = status.worse(TheoremStatus::Fail);
290            } else if depth == 1 {
291                issues.push(
292                    "encoder depth == 1 (bare linear); Khemakhem 2107.10098 \
293                     §3 requires non-linear encoder. Identifiability voided."
294                        .to_string(),
295                );
296                status = status.worse(TheoremStatus::Fail);
297            } else if depth < thr.ivae_min_encoder_layers {
298                issues.push(format!(
299                    "encoder depth {depth} < canonical min={}: \
300                     Khemakhem 2107.10098 §3 universal-approximation \
301                     argument is weakened.",
302                    thr.ivae_min_encoder_layers,
303                ));
304                status = status.worse(TheoremStatus::Warn);
305            }
306        }
307    }
308
309    let reason = if matches!(status, TheoremStatus::Pass) {
310        "all Khemakhem 2107.10098 Thm. 1 preconditions hold".to_string()
311    } else {
312        issues.join(" | ")
313    };
314    TheoremResult {
315        theorem_name: "iVAE".to_string(),
316        status,
317        reason,
318        metric,
319    }
320}
321
322/// Lachapelle 2401.04890 Theorem preconditions.
323pub fn check_mechanism_sparsity(summary: &FitSummary, thr: &Thresholds) -> TheoremResult {
324    let mut metric: BTreeMap<String, f64> = BTreeMap::new();
325    let mut issues: Vec<String> = Vec::new();
326    let mut status = TheoremStatus::Pass;
327
328    let decoder_rows = match summary.decoder.as_ref() {
329        Some(d) => d,
330        None => {
331            return TheoremResult {
332                theorem_name: "MechanismSparsity".to_string(),
333                status: TheoremStatus::Warn,
334                reason: "MechanismSparsity skipped: no decoder in fit summary.".to_string(),
335                metric,
336            };
337        }
338    };
339    let n_sup = summary.n_supervised.unwrap_or(0);
340    let n_free = match summary.n_free {
341        Some(v) => v,
342        None => {
343            return TheoremResult {
344                theorem_name: "MechanismSparsity".to_string(),
345                status: TheoremStatus::Warn,
346                reason: "MechanismSparsity skipped: n_free missing.".to_string(),
347                metric,
348            };
349        }
350    };
351
352    let decoder = match rows_to_array(decoder_rows) {
353        Ok(d) => d,
354        Err(e) => {
355            return TheoremResult {
356                theorem_name: "MechanismSparsity".to_string(),
357                status: TheoremStatus::Fail,
358                reason: format!("decoder is malformed: {e}"),
359                metric,
360            };
361        }
362    };
363
364    let total_cols = decoder.ncols() as i64;
365    if n_sup + n_free > total_cols || n_sup < 0 || n_free < 0 {
366        return TheoremResult {
367            theorem_name: "MechanismSparsity".to_string(),
368            status: TheoremStatus::Fail,
369            reason: format!(
370                "decoder has {total_cols} columns but n_supervised + n_free \
371                 = {} + {}.",
372                n_sup, n_free,
373            ),
374            metric,
375        };
376    }
377    let free_cols = decoder.slice(ndarray::s![
378        ..,
379        (n_sup as usize)..((n_sup + n_free) as usize)
380    ]);
381    metric.insert(
382        "free_block_shape_rows".to_string(),
383        free_cols.nrows() as f64,
384    );
385    metric.insert(
386        "free_block_shape_cols".to_string(),
387        free_cols.ncols() as f64,
388    );
389
390    // Column-relative thresholded zero-fraction.
391    let mut col_max = vec![0.0_f64; free_cols.ncols()];
392    for col_idx in 0..free_cols.ncols() {
393        let col = free_cols.column(col_idx);
394        col_max[col_idx] = col.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
395    }
396    let mut zeros: u64 = 0;
397    let mut total: u64 = 0;
398    for col_idx in 0..free_cols.ncols() {
399        let safe_max = if col_max[col_idx] > 0.0 {
400            col_max[col_idx]
401        } else {
402            1.0
403        };
404        for row_idx in 0..free_cols.nrows() {
405            let rel = free_cols[[row_idx, col_idx]].abs() / safe_max;
406            if rel <= thr.mech_sparsity_zero_tol {
407                zeros += 1;
408            }
409            total += 1;
410        }
411    }
412    let zero_fraction = if total == 0 {
413        0.0
414    } else {
415        zeros as f64 / total as f64
416    };
417    metric.insert("decoder_zero_fraction".to_string(), zero_fraction);
418
419    let rank = match matrix_rank(free_cols.view(), 1.0e-8) {
420        Ok(r) => r,
421        Err(e) => {
422            return TheoremResult {
423                theorem_name: "MechanismSparsity".to_string(),
424                status: TheoremStatus::Fail,
425                reason: format!("decoder SVD failed: {e}"),
426                metric,
427            };
428        }
429    };
430    metric.insert("decoder_free_rank".to_string(), rank as f64);
431    if (rank as i64) < n_free {
432        issues.push(format!(
433            "decoder Jacobian on the free block has rank {rank} < \
434             n_free={n_free}; Lachapelle 2401.04890 Thm. requires full \
435             rank on the free latents."
436        ));
437        status = status.worse(TheoremStatus::Fail);
438    }
439
440    match summary.mech_sparsity_weight {
441        None => {
442            issues.push(
443                "mech sparsity weight unknown — cannot confirm L1 prox \
444                 was active."
445                    .to_string(),
446            );
447            status = status.worse(TheoremStatus::Warn);
448        }
449        Some(w) => {
450            metric.insert("mech_sparsity_weight".to_string(), w);
451            if !(w > 0.0) {
452                issues.push(format!(
453                    "mech sparsity weight = {w} is not strictly positive; \
454                     Lachapelle 2401.04890 identification voided."
455                ));
456                status = status.worse(TheoremStatus::Fail);
457            }
458        }
459    }
460
461    if zero_fraction < thr.mech_sparsity_fraction {
462        issues.push(format!(
463            "decoder zero-fraction {zero_fraction:.3} < {:.2} threshold \
464             from Lachapelle 2401.04890 §2.4: L1 prox has not reached \
465             equilibrium, identification weakened.",
466            thr.mech_sparsity_fraction,
467        ));
468        status = status.worse(TheoremStatus::Warn);
469    }
470
471    let state_dim = n_sup + n_free;
472    if let Some(gt) = summary.ground_truth_dim {
473        metric.insert("state_dim".to_string(), state_dim as f64);
474        metric.insert("ground_truth_dim".to_string(), gt as f64);
475        if state_dim < gt {
476            issues.push(format!(
477                "state_dim={state_dim} < ground_truth_dim={gt}: Lachapelle \
478                 2401.04890 requires at least as many latents as the data \
479                 generating process."
480            ));
481            status = status.worse(TheoremStatus::Fail);
482        }
483    }
484
485    let reason = if matches!(status, TheoremStatus::Pass) {
486        "all Lachapelle 2401.04890 preconditions hold".to_string()
487    } else {
488        issues.join(" | ")
489    };
490    TheoremResult {
491        theorem_name: "MechanismSparsity".to_string(),
492        status,
493        reason,
494        metric,
495    }
496}
497
498/// Random-projection identifiability precondition (Khemakhem App. A.3).
499pub fn check_random_projection(summary: &FitSummary, thr: &Thresholds) -> TheoremResult {
500    let mut metric: BTreeMap<String, f64> = BTreeMap::new();
501
502    let act_rows = match summary.activations.as_ref() {
503        Some(a) => a,
504        None => {
505            return TheoremResult {
506                theorem_name: "RandomProjection".to_string(),
507                status: TheoremStatus::Warn,
508                reason: "RandomProjection skipped: no activations provided.".to_string(),
509                metric,
510            };
511        }
512    };
513    let act = match rows_to_array(act_rows) {
514        Ok(a) => a,
515        Err(e) => {
516            return TheoremResult {
517                theorem_name: "RandomProjection".to_string(),
518                status: TheoremStatus::Fail,
519                reason: format!("activations malformed: {e}"),
520                metric,
521            };
522        }
523    };
524    if act.nrows() == 0 || act.ncols() == 0 {
525        return TheoremResult {
526            theorem_name: "RandomProjection".to_string(),
527            status: TheoremStatus::Fail,
528            reason: "activations are empty.".to_string(),
529            metric,
530        };
531    }
532    let variances = column_var(act.view());
533    let var_max = variances.iter().cloned().fold(0.0_f64, f64::max);
534    let var_min = variances.iter().cloned().fold(f64::INFINITY, f64::min);
535    metric.insert("activation_var_max".to_string(), var_max);
536    metric.insert("activation_var_min".to_string(), var_min);
537    if variances.iter().any(|v| !v.is_finite()) {
538        return TheoremResult {
539            theorem_name: "RandomProjection".to_string(),
540            status: TheoremStatus::Fail,
541            reason: "activations contain non-finite variance; Khemakhem App. A.3 \
542                 requires bounded variance."
543                .to_string(),
544            metric,
545        };
546    }
547    if var_max > thr.randproj_var_ceiling {
548        return TheoremResult {
549            theorem_name: "RandomProjection".to_string(),
550            status: TheoremStatus::Fail,
551            reason: format!(
552                "max activation variance {var_max:.3e} > ceiling \
553                 {:.3e}; encoder is unbounded.",
554                thr.randproj_var_ceiling,
555            ),
556            metric,
557        };
558    }
559    if var_max > thr.randproj_var_warn {
560        return TheoremResult {
561            theorem_name: "RandomProjection".to_string(),
562            status: TheoremStatus::Warn,
563            reason: format!(
564                "max activation variance {var_max:.3e} > warn-floor \
565                 {:.3e}; encoder is large but not yet unbounded.",
566                thr.randproj_var_warn,
567            ),
568            metric,
569        };
570    }
571    TheoremResult {
572        theorem_name: "RandomProjection".to_string(),
573        status: TheoremStatus::Pass,
574        reason: "encoder activation variance is bounded.".to_string(),
575        metric,
576    }
577}
578
579/// Run every applicable identifiability theorem check.
580pub fn identifiability_check(summary: &FitSummary) -> Vec<TheoremResult> {
581    let thr = summary.thresholds.unwrap_or_default();
582    vec![
583        check_ivae(summary, &thr),
584        check_mechanism_sparsity(summary, &thr),
585        check_random_projection(summary, &thr),
586    ]
587}
588
589/// JSON adaptor: caller serializes a `FitSummary`, gets back a JSON array of
590/// `TheoremResult`. The single FFI surface — Python, the CLI, and any
591/// future binding all consume this.
592pub fn identifiability_check_json(input: &str) -> Result<String, String> {
593    let summary: FitSummary =
594        serde_json::from_str(input).map_err(|e| format!("invalid FitSummary JSON: {e}"))?;
595    let report = identifiability_check(&summary);
596    serde_json::to_string(&report).map_err(|e| format!("serialise: {e}"))
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    fn passing_ivae_summary() -> FitSummary {
604        FitSummary {
605            aux: Some(vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]]),
606            n_supervised: Some(1),
607            n_free: Some(0),
608            encoder_depth: Some(3),
609            mech_sparsity_weight: Some(1.0),
610            decoder: Some(vec![vec![1.0]]),
611            activations: Some(vec![vec![0.1], vec![0.2], vec![0.3], vec![0.4]]),
612            ground_truth_dim: None,
613            thresholds: None,
614        }
615    }
616
617    // -----------------------------------------------------------------------
618    // TheoremStatus ordering
619    // -----------------------------------------------------------------------
620
621    #[test]
622    fn theorem_status_worse_is_monotone() {
623        use TheoremStatus::{Fail, Pass, Warn};
624        assert_eq!(Pass.worse(Pass), Pass);
625        assert_eq!(Pass.worse(Warn), Warn);
626        assert_eq!(Pass.worse(Fail), Fail);
627        assert_eq!(Warn.worse(Pass), Warn);
628        assert_eq!(Warn.worse(Fail), Fail);
629        assert_eq!(Fail.worse(Pass), Fail);
630        assert_eq!(Fail.worse(Warn), Fail);
631    }
632
633    // -----------------------------------------------------------------------
634    // check_ivae
635    // -----------------------------------------------------------------------
636
637    #[test]
638    fn constant_aux_fails_ivae() {
639        let summary = FitSummary {
640            aux: Some(vec![vec![1.0]; 32]),
641            n_supervised: Some(1),
642            n_free: Some(2),
643            encoder_depth: Some(3),
644            mech_sparsity_weight: Some(1.0),
645            decoder: Some(vec![vec![1.0, 0.5, 0.0, 0.0, 0.0]; 12]),
646            activations: Some(vec![vec![0.0; 3]; 32]),
647            ground_truth_dim: None,
648            thresholds: None,
649        };
650        let report = identifiability_check(&summary);
651        let ivae = report.iter().find(|t| t.theorem_name == "iVAE").unwrap();
652        assert_eq!(ivae.status, TheoremStatus::Fail);
653        assert!(ivae.reason.to_lowercase().contains("constant"));
654        assert_eq!(
655            ivae.metric.get("aux_min_std").copied().unwrap_or(f64::NAN),
656            0.0
657        );
658    }
659
660    #[test]
661    fn linear_encoder_depth_one_fails_ivae() {
662        let mut summary = passing_ivae_summary();
663        summary.encoder_depth = Some(1);
664        let thr = Thresholds::default();
665        let result = check_ivae(&summary, &thr);
666        assert_eq!(result.status, TheoremStatus::Fail);
667        assert!(result.reason.contains("linear"), "reason: {}", result.reason);
668    }
669
670    #[test]
671    fn missing_aux_warns_ivae() {
672        let mut summary = passing_ivae_summary();
673        summary.aux = None;
674        let thr = Thresholds::default();
675        let result = check_ivae(&summary, &thr);
676        assert_eq!(result.status, TheoremStatus::Warn);
677    }
678
679    #[test]
680    fn varying_aux_with_deep_encoder_passes_ivae() {
681        let thr = Thresholds::default();
682        let result = check_ivae(&passing_ivae_summary(), &thr);
683        assert_eq!(result.status, TheoremStatus::Pass, "reason: {}", result.reason);
684    }
685
686    // -----------------------------------------------------------------------
687    // check_mechanism_sparsity
688    // -----------------------------------------------------------------------
689
690    #[test]
691    fn missing_decoder_warns_mechanism_sparsity() {
692        let mut summary = passing_ivae_summary();
693        summary.decoder = None;
694        let thr = Thresholds::default();
695        let result = check_mechanism_sparsity(&summary, &thr);
696        assert_eq!(result.status, TheoremStatus::Warn);
697    }
698
699    #[test]
700    fn mechanism_sparsity_passes_with_sparse_decoder() {
701        // Decoder: 4 rows x 1 col free block; 3 out of 4 entries are zero
702        // (zero fraction = 0.75 > 0.5 threshold). Full rank (1).
703        let summary = FitSummary {
704            n_supervised: Some(0),
705            n_free: Some(1),
706            decoder: Some(vec![
707                vec![1.0],
708                vec![0.0],
709                vec![0.0],
710                vec![0.0],
711            ]),
712            mech_sparsity_weight: Some(1.0),
713            aux: Some(vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]]),
714            encoder_depth: Some(3),
715            activations: Some(vec![vec![0.1], vec![0.2], vec![0.3], vec![0.4]]),
716            ground_truth_dim: None,
717            thresholds: None,
718        };
719        let thr = Thresholds::default();
720        let result = check_mechanism_sparsity(&summary, &thr);
721        assert_eq!(result.status, TheoremStatus::Pass, "reason: {}", result.reason);
722    }
723
724    #[test]
725    fn zero_mech_sparsity_weight_fails() {
726        let summary = FitSummary {
727            n_supervised: Some(0),
728            n_free: Some(1),
729            decoder: Some(vec![vec![1.0], vec![0.0], vec![0.0], vec![0.0]]),
730            mech_sparsity_weight: Some(0.0),
731            aux: None,
732            encoder_depth: Some(3),
733            activations: None,
734            ground_truth_dim: None,
735            thresholds: None,
736        };
737        let thr = Thresholds::default();
738        let result = check_mechanism_sparsity(&summary, &thr);
739        assert_eq!(result.status, TheoremStatus::Fail);
740        assert!(result.reason.contains("not strictly positive"), "reason: {}", result.reason);
741    }
742
743    // -----------------------------------------------------------------------
744    // check_random_projection
745    // -----------------------------------------------------------------------
746
747    #[test]
748    fn low_variance_activations_pass_random_projection() {
749        let summary = FitSummary {
750            activations: Some(vec![
751                vec![0.1, 0.2],
752                vec![0.15, 0.25],
753                vec![0.12, 0.22],
754            ]),
755            ..FitSummary::default()
756        };
757        let thr = Thresholds::default();
758        let result = check_random_projection(&summary, &thr);
759        assert_eq!(result.status, TheoremStatus::Pass, "reason: {}", result.reason);
760    }
761
762    #[test]
763    fn very_high_variance_activations_fail_random_projection() {
764        // variance ≈ (1e4)^2 > DEFAULT_RANDPROJ_VAR_CEILING = 1e6
765        let summary = FitSummary {
766            activations: Some(vec![
767                vec![0.0],
768                vec![1_000_000.0],
769            ]),
770            ..FitSummary::default()
771        };
772        let thr = Thresholds::default();
773        let result = check_random_projection(&summary, &thr);
774        assert_eq!(result.status, TheoremStatus::Fail);
775        assert!(result.reason.contains("unbounded"), "reason: {}", result.reason);
776    }
777
778    #[test]
779    fn missing_activations_warn_random_projection() {
780        let summary = FitSummary { activations: None, ..FitSummary::default() };
781        let thr = Thresholds::default();
782        let result = check_random_projection(&summary, &thr);
783        assert_eq!(result.status, TheoremStatus::Warn);
784    }
785
786    #[test]
787    fn json_roundtrip() {
788        let summary = FitSummary {
789            aux: Some(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]),
790            n_supervised: Some(2),
791            n_free: Some(1),
792            encoder_depth: Some(3),
793            mech_sparsity_weight: Some(1.0),
794            decoder: Some(vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]]),
795            activations: Some(vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]),
796            ground_truth_dim: None,
797            thresholds: None,
798        };
799        let json = serde_json::to_string(&summary).unwrap();
800        let out = identifiability_check_json(&json).unwrap();
801        let parsed: Vec<TheoremResult> = serde_json::from_str(&out).unwrap();
802        assert_eq!(parsed.len(), 3);
803    }
804
805    // ── rows_to_array ──────────────────────────────────────────────────────────
806
807    #[test]
808    fn rows_to_array_empty_returns_0x0() {
809        let a = rows_to_array(&[]).unwrap();
810        assert_eq!(a.dim(), (0, 0));
811    }
812
813    #[test]
814    fn rows_to_array_rectangular_shape_and_values() {
815        let rows = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
816        let a = rows_to_array(&rows).unwrap();
817        assert_eq!(a.dim(), (2, 3));
818        assert_eq!(a[[0, 0]], 1.0);
819        assert_eq!(a[[0, 2]], 3.0);
820        assert_eq!(a[[1, 1]], 5.0);
821    }
822
823    #[test]
824    fn rows_to_array_ragged_returns_err() {
825        let rows = vec![vec![1.0, 2.0], vec![3.0]];
826        assert!(rows_to_array(&rows).is_err());
827    }
828
829    #[test]
830    fn rows_to_array_ragged_error_mentions_row_indices() {
831        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0]];
832        let err = rows_to_array(&rows).unwrap_err();
833        assert!(err.contains('2'), "error should mention row 2, got: {err}");
834    }
835
836    // ── column_std / column_var ────────────────────────────────────────────────
837
838    #[test]
839    fn column_std_constant_column_is_zero() {
840        use ndarray::array;
841        let m = array![[3.0_f64], [3.0], [3.0]];
842        let std = column_std(m.view());
843        assert_eq!(std.len(), 1);
844        assert!(std[0].abs() < 1e-14, "constant column std should be 0, got {}", std[0]);
845    }
846
847    #[test]
848    fn column_std_known_value() {
849        use ndarray::array;
850        // Column [0, 2]: mean=1, deviations [-1,1], pop var=1, std=1
851        let m = array![[0.0_f64], [2.0]];
852        let std = column_std(m.view());
853        assert!((std[0] - 1.0).abs() < 1e-14, "expected std=1.0, got {}", std[0]);
854    }
855
856    #[test]
857    fn column_var_equals_std_squared() {
858        use ndarray::array;
859        let m = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
860        let std = column_std(m.view());
861        let var = column_var(m.view());
862        assert_eq!(std.len(), var.len());
863        for (s, v) in std.iter().zip(var.iter()) {
864            assert!((v - s * s).abs() < 1e-14, "var={v} should equal std²={}", s*s);
865        }
866    }
867
868    #[test]
869    fn column_std_empty_rows_returns_zeros() {
870        use ndarray::Array2;
871        let m: Array2<f64> = Array2::zeros((0, 3));
872        let std = column_std(m.view());
873        assert_eq!(std, vec![0.0, 0.0, 0.0]);
874    }
875
876    #[test]
877    fn column_std_two_columns_independently() {
878        use ndarray::array;
879        // col0: [0,2] → std=1; col1: [1,1] → std=0
880        let m = array![[0.0_f64, 1.0], [2.0, 1.0]];
881        let std = column_std(m.view());
882        assert!((std[0] - 1.0).abs() < 1e-14, "col0 std={}", std[0]);
883        assert!(std[1].abs() < 1e-14, "col1 std={}", std[1]);
884    }
885}