1use std::collections::BTreeMap;
7
8use ndarray::{Array2, ArrayView2, Axis};
9use serde::{Deserialize, Serialize};
10
11use gam_linalg::faer_ndarray::FaerSvd;
12
13pub const DEFAULT_IVAE_AUX_VAR_FLOOR: f64 = 1.0e-9;
16
17pub const DEFAULT_IVAE_AUX_RANK_RTOL: f64 = 1.0e-8;
20
21pub const DEFAULT_IVAE_MIN_ENCODER_LAYERS: i64 = 2;
25
26pub const DEFAULT_MECH_SPARSITY_FRACTION: f64 = 0.50;
29
30pub const DEFAULT_MECH_SPARSITY_ZERO_TOL: f64 = 1.0e-3;
33
34pub const DEFAULT_RANDPROJ_VAR_CEILING: f64 = 1.0e6;
37
38pub const DEFAULT_RANDPROJ_VAR_WARN: f64 = 1.0e3;
41
42#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
45pub struct Thresholds {
46 pub ivae_aux_var_floor: f64,
47 pub ivae_aux_rank_rtol: f64,
48 pub ivae_min_encoder_layers: i64,
49 pub mech_sparsity_fraction: f64,
50 pub mech_sparsity_zero_tol: f64,
51 pub randproj_var_warn: f64,
52 pub randproj_var_ceiling: f64,
53}
54
55impl Default for Thresholds {
56 fn default() -> Self {
57 Self {
58 ivae_aux_var_floor: DEFAULT_IVAE_AUX_VAR_FLOOR,
59 ivae_aux_rank_rtol: DEFAULT_IVAE_AUX_RANK_RTOL,
60 ivae_min_encoder_layers: DEFAULT_IVAE_MIN_ENCODER_LAYERS,
61 mech_sparsity_fraction: DEFAULT_MECH_SPARSITY_FRACTION,
62 mech_sparsity_zero_tol: DEFAULT_MECH_SPARSITY_ZERO_TOL,
63 randproj_var_warn: DEFAULT_RANDPROJ_VAR_WARN,
64 randproj_var_ceiling: DEFAULT_RANDPROJ_VAR_CEILING,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TheoremResult {
72 pub theorem_name: String,
73 pub status: TheoremStatus,
74 pub reason: String,
75 pub metric: BTreeMap<String, f64>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79#[serde(rename_all = "lowercase")]
80pub enum TheoremStatus {
81 Pass,
82 Warn,
83 Fail,
84}
85
86impl TheoremStatus {
87 fn rank(&self) -> u8 {
88 match self {
89 TheoremStatus::Pass => 0,
90 TheoremStatus::Warn => 1,
91 TheoremStatus::Fail => 2,
92 }
93 }
94 fn worse(self, other: TheoremStatus) -> TheoremStatus {
95 if other.rank() > self.rank() {
96 other
97 } else {
98 self
99 }
100 }
101}
102
103#[derive(Debug, Clone, Default, Deserialize, Serialize)]
106pub struct FitSummary {
107 pub aux: Option<Vec<Vec<f64>>>,
110 pub n_supervised: Option<i64>,
112 pub n_free: Option<i64>,
114 pub decoder: Option<Vec<Vec<f64>>>,
116 pub encoder_depth: Option<i64>,
118 pub mech_sparsity_weight: Option<f64>,
120 pub activations: Option<Vec<Vec<f64>>>,
122 pub ground_truth_dim: Option<i64>,
124 #[serde(default)]
126 pub thresholds: Option<Thresholds>,
127}
128
129fn rows_to_array(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
130 if rows.is_empty() {
131 return Ok(Array2::<f64>::zeros((0, 0)));
132 }
133 let ncols = rows[0].len();
134 for (i, row) in rows.iter().enumerate() {
135 if row.len() != ncols {
136 return Err(format!(
137 "ragged matrix: row 0 has {ncols} cols but row {i} has {} cols",
138 row.len()
139 ));
140 }
141 }
142 let nrows = rows.len();
143 let mut flat = Vec::with_capacity(nrows * ncols);
144 for row in rows {
145 flat.extend_from_slice(row);
146 }
147 Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| e.to_string())
148}
149
150fn column_std(mat: ArrayView2<f64>) -> Vec<f64> {
151 let n = mat.nrows() as f64;
152 if n <= 0.0 {
153 return vec![0.0; mat.ncols()];
154 }
155 let mut out = Vec::with_capacity(mat.ncols());
156 for col in mat.axis_iter(Axis(1)) {
157 let mean = col.sum() / n;
158 let mut var = 0.0_f64;
159 for v in col.iter() {
160 let d = v - mean;
161 var += d * d;
162 }
163 out.push((var / n).sqrt());
164 }
165 out
166}
167
168fn column_var(mat: ArrayView2<f64>) -> Vec<f64> {
169 column_std(mat).into_iter().map(|s| s * s).collect()
170}
171
172fn matrix_rank(mat: ArrayView2<f64>, rtol: f64) -> Result<usize, String> {
175 if mat.nrows() == 0 || mat.ncols() == 0 {
176 return Ok(0);
177 }
178 let owned = mat.to_owned();
179 let (_u, sigma, _vt) = owned.svd(false, false).map_err(|e| format!("{e:?}"))?;
180 if sigma.is_empty() {
181 return Ok(0);
182 }
183 let smax = sigma.iter().cloned().fold(0.0_f64, f64::max);
184 if smax <= 0.0 {
185 return Ok(0);
186 }
187 let cutoff = smax * rtol;
188 Ok(sigma.iter().filter(|s| **s > cutoff).count())
189}
190
191pub fn check_ivae(summary: &FitSummary, thr: &Thresholds) -> TheoremResult {
193 let mut metric: BTreeMap<String, f64> = BTreeMap::new();
194 let mut issues: Vec<String> = Vec::new();
195 let mut status = TheoremStatus::Pass;
196
197 let aux_rows = match summary.aux.as_ref() {
198 Some(a) => a,
199 None => {
200 return TheoremResult {
201 theorem_name: "iVAE".to_string(),
202 status: TheoremStatus::Warn,
203 reason: "iVAE check skipped: no aux provided in fit summary.".to_string(),
204 metric,
205 };
206 }
207 };
208 let n_supervised = match summary.n_supervised {
209 Some(v) => v,
210 None => {
211 return TheoremResult {
212 theorem_name: "iVAE".to_string(),
213 status: TheoremStatus::Warn,
214 reason: "iVAE check skipped: n_supervised missing.".to_string(),
215 metric,
216 };
217 }
218 };
219
220 let aux = match rows_to_array(aux_rows) {
221 Ok(a) => a,
222 Err(e) => {
223 return TheoremResult {
224 theorem_name: "iVAE".to_string(),
225 status: TheoremStatus::Fail,
226 reason: format!("aux is malformed: {e}"),
227 metric,
228 };
229 }
230 };
231
232 let stds = column_std(aux.view());
233 let min_std = stds.iter().cloned().fold(f64::INFINITY, f64::min);
234 metric.insert(
235 "aux_min_std".to_string(),
236 if stds.is_empty() { 0.0 } else { min_std },
237 );
238 if stds.is_empty() || stds.iter().any(|s| *s <= thr.ivae_aux_var_floor) {
239 let zeros: Vec<usize> = stds
240 .iter()
241 .enumerate()
242 .filter(|(_, s)| **s <= thr.ivae_aux_var_floor)
243 .map(|(i, _)| i)
244 .collect();
245 issues.push(format!(
246 "iVAE identifiability requires auxiliary covariate variation; \
247 aux axes {zeros:?} are constant across observations (min std \
248 {min_std:.3e} <= {:.0e}); Khemakhem 2107.10098 Thm. 1 \
249 conditioning rank is zero.",
250 thr.ivae_aux_var_floor,
251 ));
252 status = status.worse(TheoremStatus::Fail);
253 }
254
255 let rank = match matrix_rank(aux.view(), thr.ivae_aux_rank_rtol) {
256 Ok(r) => r,
257 Err(e) => {
258 return TheoremResult {
259 theorem_name: "iVAE".to_string(),
260 status: TheoremStatus::Fail,
261 reason: format!("aux SVD failed: {e}"),
262 metric,
263 };
264 }
265 };
266 metric.insert("aux_column_rank".to_string(), rank as f64);
267 metric.insert("n_supervised".to_string(), n_supervised as f64);
268 if (rank as i64) < n_supervised {
269 issues.push(format!(
270 "aux column rank {rank} < n_supervised={n_supervised}: \
271 Khemakhem 2107.10098 §3 parametric-richness fails."
272 ));
273 status = status.worse(TheoremStatus::Fail);
274 }
275
276 match summary.encoder_depth {
277 None => {
278 issues.push(
279 "encoder depth unknown — cannot verify the >=2-layer \
280 requirement of Khemakhem 2107.10098 §3."
281 .to_string(),
282 );
283 status = status.worse(TheoremStatus::Warn);
284 }
285 Some(depth) => {
286 metric.insert("encoder_depth".to_string(), depth as f64);
287 if depth < 1 {
288 issues.push(format!("encoder depth {depth} < 1; no encoder is present."));
289 status = status.worse(TheoremStatus::Fail);
290 } else if depth == 1 {
291 issues.push(
292 "encoder depth == 1 (bare linear); Khemakhem 2107.10098 \
293 §3 requires non-linear encoder. Identifiability voided."
294 .to_string(),
295 );
296 status = status.worse(TheoremStatus::Fail);
297 } else if depth < thr.ivae_min_encoder_layers {
298 issues.push(format!(
299 "encoder depth {depth} < canonical min={}: \
300 Khemakhem 2107.10098 §3 universal-approximation \
301 argument is weakened.",
302 thr.ivae_min_encoder_layers,
303 ));
304 status = status.worse(TheoremStatus::Warn);
305 }
306 }
307 }
308
309 let reason = if matches!(status, TheoremStatus::Pass) {
310 "all Khemakhem 2107.10098 Thm. 1 preconditions hold".to_string()
311 } else {
312 issues.join(" | ")
313 };
314 TheoremResult {
315 theorem_name: "iVAE".to_string(),
316 status,
317 reason,
318 metric,
319 }
320}
321
322pub fn check_mechanism_sparsity(summary: &FitSummary, thr: &Thresholds) -> TheoremResult {
324 let mut metric: BTreeMap<String, f64> = BTreeMap::new();
325 let mut issues: Vec<String> = Vec::new();
326 let mut status = TheoremStatus::Pass;
327
328 let decoder_rows = match summary.decoder.as_ref() {
329 Some(d) => d,
330 None => {
331 return TheoremResult {
332 theorem_name: "MechanismSparsity".to_string(),
333 status: TheoremStatus::Warn,
334 reason: "MechanismSparsity skipped: no decoder in fit summary.".to_string(),
335 metric,
336 };
337 }
338 };
339 let n_sup = summary.n_supervised.unwrap_or(0);
340 let n_free = match summary.n_free {
341 Some(v) => v,
342 None => {
343 return TheoremResult {
344 theorem_name: "MechanismSparsity".to_string(),
345 status: TheoremStatus::Warn,
346 reason: "MechanismSparsity skipped: n_free missing.".to_string(),
347 metric,
348 };
349 }
350 };
351
352 let decoder = match rows_to_array(decoder_rows) {
353 Ok(d) => d,
354 Err(e) => {
355 return TheoremResult {
356 theorem_name: "MechanismSparsity".to_string(),
357 status: TheoremStatus::Fail,
358 reason: format!("decoder is malformed: {e}"),
359 metric,
360 };
361 }
362 };
363
364 let total_cols = decoder.ncols() as i64;
365 if n_sup + n_free > total_cols || n_sup < 0 || n_free < 0 {
366 return TheoremResult {
367 theorem_name: "MechanismSparsity".to_string(),
368 status: TheoremStatus::Fail,
369 reason: format!(
370 "decoder has {total_cols} columns but n_supervised + n_free \
371 = {} + {}.",
372 n_sup, n_free,
373 ),
374 metric,
375 };
376 }
377 let free_cols = decoder.slice(ndarray::s![
378 ..,
379 (n_sup as usize)..((n_sup + n_free) as usize)
380 ]);
381 metric.insert(
382 "free_block_shape_rows".to_string(),
383 free_cols.nrows() as f64,
384 );
385 metric.insert(
386 "free_block_shape_cols".to_string(),
387 free_cols.ncols() as f64,
388 );
389
390 let mut col_max = vec![0.0_f64; free_cols.ncols()];
392 for col_idx in 0..free_cols.ncols() {
393 let col = free_cols.column(col_idx);
394 col_max[col_idx] = col.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
395 }
396 let mut zeros: u64 = 0;
397 let mut total: u64 = 0;
398 for col_idx in 0..free_cols.ncols() {
399 let safe_max = if col_max[col_idx] > 0.0 {
400 col_max[col_idx]
401 } else {
402 1.0
403 };
404 for row_idx in 0..free_cols.nrows() {
405 let rel = free_cols[[row_idx, col_idx]].abs() / safe_max;
406 if rel <= thr.mech_sparsity_zero_tol {
407 zeros += 1;
408 }
409 total += 1;
410 }
411 }
412 let zero_fraction = if total == 0 {
413 0.0
414 } else {
415 zeros as f64 / total as f64
416 };
417 metric.insert("decoder_zero_fraction".to_string(), zero_fraction);
418
419 let rank = match matrix_rank(free_cols.view(), 1.0e-8) {
420 Ok(r) => r,
421 Err(e) => {
422 return TheoremResult {
423 theorem_name: "MechanismSparsity".to_string(),
424 status: TheoremStatus::Fail,
425 reason: format!("decoder SVD failed: {e}"),
426 metric,
427 };
428 }
429 };
430 metric.insert("decoder_free_rank".to_string(), rank as f64);
431 if (rank as i64) < n_free {
432 issues.push(format!(
433 "decoder Jacobian on the free block has rank {rank} < \
434 n_free={n_free}; Lachapelle 2401.04890 Thm. requires full \
435 rank on the free latents."
436 ));
437 status = status.worse(TheoremStatus::Fail);
438 }
439
440 match summary.mech_sparsity_weight {
441 None => {
442 issues.push(
443 "mech sparsity weight unknown — cannot confirm L1 prox \
444 was active."
445 .to_string(),
446 );
447 status = status.worse(TheoremStatus::Warn);
448 }
449 Some(w) => {
450 metric.insert("mech_sparsity_weight".to_string(), w);
451 if !(w > 0.0) {
452 issues.push(format!(
453 "mech sparsity weight = {w} is not strictly positive; \
454 Lachapelle 2401.04890 identification voided."
455 ));
456 status = status.worse(TheoremStatus::Fail);
457 }
458 }
459 }
460
461 if zero_fraction < thr.mech_sparsity_fraction {
462 issues.push(format!(
463 "decoder zero-fraction {zero_fraction:.3} < {:.2} threshold \
464 from Lachapelle 2401.04890 §2.4: L1 prox has not reached \
465 equilibrium, identification weakened.",
466 thr.mech_sparsity_fraction,
467 ));
468 status = status.worse(TheoremStatus::Warn);
469 }
470
471 let state_dim = n_sup + n_free;
472 if let Some(gt) = summary.ground_truth_dim {
473 metric.insert("state_dim".to_string(), state_dim as f64);
474 metric.insert("ground_truth_dim".to_string(), gt as f64);
475 if state_dim < gt {
476 issues.push(format!(
477 "state_dim={state_dim} < ground_truth_dim={gt}: Lachapelle \
478 2401.04890 requires at least as many latents as the data \
479 generating process."
480 ));
481 status = status.worse(TheoremStatus::Fail);
482 }
483 }
484
485 let reason = if matches!(status, TheoremStatus::Pass) {
486 "all Lachapelle 2401.04890 preconditions hold".to_string()
487 } else {
488 issues.join(" | ")
489 };
490 TheoremResult {
491 theorem_name: "MechanismSparsity".to_string(),
492 status,
493 reason,
494 metric,
495 }
496}
497
498pub fn check_random_projection(summary: &FitSummary, thr: &Thresholds) -> TheoremResult {
500 let mut metric: BTreeMap<String, f64> = BTreeMap::new();
501
502 let act_rows = match summary.activations.as_ref() {
503 Some(a) => a,
504 None => {
505 return TheoremResult {
506 theorem_name: "RandomProjection".to_string(),
507 status: TheoremStatus::Warn,
508 reason: "RandomProjection skipped: no activations provided.".to_string(),
509 metric,
510 };
511 }
512 };
513 let act = match rows_to_array(act_rows) {
514 Ok(a) => a,
515 Err(e) => {
516 return TheoremResult {
517 theorem_name: "RandomProjection".to_string(),
518 status: TheoremStatus::Fail,
519 reason: format!("activations malformed: {e}"),
520 metric,
521 };
522 }
523 };
524 if act.nrows() == 0 || act.ncols() == 0 {
525 return TheoremResult {
526 theorem_name: "RandomProjection".to_string(),
527 status: TheoremStatus::Fail,
528 reason: "activations are empty.".to_string(),
529 metric,
530 };
531 }
532 let variances = column_var(act.view());
533 let var_max = variances.iter().cloned().fold(0.0_f64, f64::max);
534 let var_min = variances.iter().cloned().fold(f64::INFINITY, f64::min);
535 metric.insert("activation_var_max".to_string(), var_max);
536 metric.insert("activation_var_min".to_string(), var_min);
537 if variances.iter().any(|v| !v.is_finite()) {
538 return TheoremResult {
539 theorem_name: "RandomProjection".to_string(),
540 status: TheoremStatus::Fail,
541 reason: "activations contain non-finite variance; Khemakhem App. A.3 \
542 requires bounded variance."
543 .to_string(),
544 metric,
545 };
546 }
547 if var_max > thr.randproj_var_ceiling {
548 return TheoremResult {
549 theorem_name: "RandomProjection".to_string(),
550 status: TheoremStatus::Fail,
551 reason: format!(
552 "max activation variance {var_max:.3e} > ceiling \
553 {:.3e}; encoder is unbounded.",
554 thr.randproj_var_ceiling,
555 ),
556 metric,
557 };
558 }
559 if var_max > thr.randproj_var_warn {
560 return TheoremResult {
561 theorem_name: "RandomProjection".to_string(),
562 status: TheoremStatus::Warn,
563 reason: format!(
564 "max activation variance {var_max:.3e} > warn-floor \
565 {:.3e}; encoder is large but not yet unbounded.",
566 thr.randproj_var_warn,
567 ),
568 metric,
569 };
570 }
571 TheoremResult {
572 theorem_name: "RandomProjection".to_string(),
573 status: TheoremStatus::Pass,
574 reason: "encoder activation variance is bounded.".to_string(),
575 metric,
576 }
577}
578
579pub fn identifiability_check(summary: &FitSummary) -> Vec<TheoremResult> {
581 let thr = summary.thresholds.unwrap_or_default();
582 vec![
583 check_ivae(summary, &thr),
584 check_mechanism_sparsity(summary, &thr),
585 check_random_projection(summary, &thr),
586 ]
587}
588
589pub fn identifiability_check_json(input: &str) -> Result<String, String> {
593 let summary: FitSummary =
594 serde_json::from_str(input).map_err(|e| format!("invalid FitSummary JSON: {e}"))?;
595 let report = identifiability_check(&summary);
596 serde_json::to_string(&report).map_err(|e| format!("serialise: {e}"))
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 fn passing_ivae_summary() -> FitSummary {
604 FitSummary {
605 aux: Some(vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]]),
606 n_supervised: Some(1),
607 n_free: Some(0),
608 encoder_depth: Some(3),
609 mech_sparsity_weight: Some(1.0),
610 decoder: Some(vec![vec![1.0]]),
611 activations: Some(vec![vec![0.1], vec![0.2], vec![0.3], vec![0.4]]),
612 ground_truth_dim: None,
613 thresholds: None,
614 }
615 }
616
617 #[test]
622 fn theorem_status_worse_is_monotone() {
623 use TheoremStatus::{Fail, Pass, Warn};
624 assert_eq!(Pass.worse(Pass), Pass);
625 assert_eq!(Pass.worse(Warn), Warn);
626 assert_eq!(Pass.worse(Fail), Fail);
627 assert_eq!(Warn.worse(Pass), Warn);
628 assert_eq!(Warn.worse(Fail), Fail);
629 assert_eq!(Fail.worse(Pass), Fail);
630 assert_eq!(Fail.worse(Warn), Fail);
631 }
632
633 #[test]
638 fn constant_aux_fails_ivae() {
639 let summary = FitSummary {
640 aux: Some(vec![vec![1.0]; 32]),
641 n_supervised: Some(1),
642 n_free: Some(2),
643 encoder_depth: Some(3),
644 mech_sparsity_weight: Some(1.0),
645 decoder: Some(vec![vec![1.0, 0.5, 0.0, 0.0, 0.0]; 12]),
646 activations: Some(vec![vec![0.0; 3]; 32]),
647 ground_truth_dim: None,
648 thresholds: None,
649 };
650 let report = identifiability_check(&summary);
651 let ivae = report.iter().find(|t| t.theorem_name == "iVAE").unwrap();
652 assert_eq!(ivae.status, TheoremStatus::Fail);
653 assert!(ivae.reason.to_lowercase().contains("constant"));
654 assert_eq!(
655 ivae.metric.get("aux_min_std").copied().unwrap_or(f64::NAN),
656 0.0
657 );
658 }
659
660 #[test]
661 fn linear_encoder_depth_one_fails_ivae() {
662 let mut summary = passing_ivae_summary();
663 summary.encoder_depth = Some(1);
664 let thr = Thresholds::default();
665 let result = check_ivae(&summary, &thr);
666 assert_eq!(result.status, TheoremStatus::Fail);
667 assert!(result.reason.contains("linear"), "reason: {}", result.reason);
668 }
669
670 #[test]
671 fn missing_aux_warns_ivae() {
672 let mut summary = passing_ivae_summary();
673 summary.aux = None;
674 let thr = Thresholds::default();
675 let result = check_ivae(&summary, &thr);
676 assert_eq!(result.status, TheoremStatus::Warn);
677 }
678
679 #[test]
680 fn varying_aux_with_deep_encoder_passes_ivae() {
681 let thr = Thresholds::default();
682 let result = check_ivae(&passing_ivae_summary(), &thr);
683 assert_eq!(result.status, TheoremStatus::Pass, "reason: {}", result.reason);
684 }
685
686 #[test]
691 fn missing_decoder_warns_mechanism_sparsity() {
692 let mut summary = passing_ivae_summary();
693 summary.decoder = None;
694 let thr = Thresholds::default();
695 let result = check_mechanism_sparsity(&summary, &thr);
696 assert_eq!(result.status, TheoremStatus::Warn);
697 }
698
699 #[test]
700 fn mechanism_sparsity_passes_with_sparse_decoder() {
701 let summary = FitSummary {
704 n_supervised: Some(0),
705 n_free: Some(1),
706 decoder: Some(vec![
707 vec![1.0],
708 vec![0.0],
709 vec![0.0],
710 vec![0.0],
711 ]),
712 mech_sparsity_weight: Some(1.0),
713 aux: Some(vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]]),
714 encoder_depth: Some(3),
715 activations: Some(vec![vec![0.1], vec![0.2], vec![0.3], vec![0.4]]),
716 ground_truth_dim: None,
717 thresholds: None,
718 };
719 let thr = Thresholds::default();
720 let result = check_mechanism_sparsity(&summary, &thr);
721 assert_eq!(result.status, TheoremStatus::Pass, "reason: {}", result.reason);
722 }
723
724 #[test]
725 fn zero_mech_sparsity_weight_fails() {
726 let summary = FitSummary {
727 n_supervised: Some(0),
728 n_free: Some(1),
729 decoder: Some(vec![vec![1.0], vec![0.0], vec![0.0], vec![0.0]]),
730 mech_sparsity_weight: Some(0.0),
731 aux: None,
732 encoder_depth: Some(3),
733 activations: None,
734 ground_truth_dim: None,
735 thresholds: None,
736 };
737 let thr = Thresholds::default();
738 let result = check_mechanism_sparsity(&summary, &thr);
739 assert_eq!(result.status, TheoremStatus::Fail);
740 assert!(result.reason.contains("not strictly positive"), "reason: {}", result.reason);
741 }
742
743 #[test]
748 fn low_variance_activations_pass_random_projection() {
749 let summary = FitSummary {
750 activations: Some(vec![
751 vec![0.1, 0.2],
752 vec![0.15, 0.25],
753 vec![0.12, 0.22],
754 ]),
755 ..FitSummary::default()
756 };
757 let thr = Thresholds::default();
758 let result = check_random_projection(&summary, &thr);
759 assert_eq!(result.status, TheoremStatus::Pass, "reason: {}", result.reason);
760 }
761
762 #[test]
763 fn very_high_variance_activations_fail_random_projection() {
764 let summary = FitSummary {
766 activations: Some(vec![
767 vec![0.0],
768 vec![1_000_000.0],
769 ]),
770 ..FitSummary::default()
771 };
772 let thr = Thresholds::default();
773 let result = check_random_projection(&summary, &thr);
774 assert_eq!(result.status, TheoremStatus::Fail);
775 assert!(result.reason.contains("unbounded"), "reason: {}", result.reason);
776 }
777
778 #[test]
779 fn missing_activations_warn_random_projection() {
780 let summary = FitSummary { activations: None, ..FitSummary::default() };
781 let thr = Thresholds::default();
782 let result = check_random_projection(&summary, &thr);
783 assert_eq!(result.status, TheoremStatus::Warn);
784 }
785
786 #[test]
787 fn json_roundtrip() {
788 let summary = FitSummary {
789 aux: Some(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]),
790 n_supervised: Some(2),
791 n_free: Some(1),
792 encoder_depth: Some(3),
793 mech_sparsity_weight: Some(1.0),
794 decoder: Some(vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]]),
795 activations: Some(vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]),
796 ground_truth_dim: None,
797 thresholds: None,
798 };
799 let json = serde_json::to_string(&summary).unwrap();
800 let out = identifiability_check_json(&json).unwrap();
801 let parsed: Vec<TheoremResult> = serde_json::from_str(&out).unwrap();
802 assert_eq!(parsed.len(), 3);
803 }
804
805 #[test]
808 fn rows_to_array_empty_returns_0x0() {
809 let a = rows_to_array(&[]).unwrap();
810 assert_eq!(a.dim(), (0, 0));
811 }
812
813 #[test]
814 fn rows_to_array_rectangular_shape_and_values() {
815 let rows = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
816 let a = rows_to_array(&rows).unwrap();
817 assert_eq!(a.dim(), (2, 3));
818 assert_eq!(a[[0, 0]], 1.0);
819 assert_eq!(a[[0, 2]], 3.0);
820 assert_eq!(a[[1, 1]], 5.0);
821 }
822
823 #[test]
824 fn rows_to_array_ragged_returns_err() {
825 let rows = vec![vec![1.0, 2.0], vec![3.0]];
826 assert!(rows_to_array(&rows).is_err());
827 }
828
829 #[test]
830 fn rows_to_array_ragged_error_mentions_row_indices() {
831 let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0]];
832 let err = rows_to_array(&rows).unwrap_err();
833 assert!(err.contains('2'), "error should mention row 2, got: {err}");
834 }
835
836 #[test]
839 fn column_std_constant_column_is_zero() {
840 use ndarray::array;
841 let m = array![[3.0_f64], [3.0], [3.0]];
842 let std = column_std(m.view());
843 assert_eq!(std.len(), 1);
844 assert!(std[0].abs() < 1e-14, "constant column std should be 0, got {}", std[0]);
845 }
846
847 #[test]
848 fn column_std_known_value() {
849 use ndarray::array;
850 let m = array![[0.0_f64], [2.0]];
852 let std = column_std(m.view());
853 assert!((std[0] - 1.0).abs() < 1e-14, "expected std=1.0, got {}", std[0]);
854 }
855
856 #[test]
857 fn column_var_equals_std_squared() {
858 use ndarray::array;
859 let m = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
860 let std = column_std(m.view());
861 let var = column_var(m.view());
862 assert_eq!(std.len(), var.len());
863 for (s, v) in std.iter().zip(var.iter()) {
864 assert!((v - s * s).abs() < 1e-14, "var={v} should equal std²={}", s*s);
865 }
866 }
867
868 #[test]
869 fn column_std_empty_rows_returns_zeros() {
870 use ndarray::Array2;
871 let m: Array2<f64> = Array2::zeros((0, 3));
872 let std = column_std(m.view());
873 assert_eq!(std, vec![0.0, 0.0, 0.0]);
874 }
875
876 #[test]
877 fn column_std_two_columns_independently() {
878 use ndarray::array;
879 let m = array![[0.0_f64, 1.0], [2.0, 1.0]];
881 let std = column_std(m.view());
882 assert!((std[0] - 1.0).abs() < 1e-14, "col0 std={}", std[0]);
883 assert!(std[1].abs() < 1e-14, "col1 std={}", std[1]);
884 }
885}