use crate::inference::row_metric::{MetricProvenance, RowMetric};
use crate::inference::structure_evidence::{StructureCertificate, StructureLedger};
use crate::linalg::faer_ndarray::{
FaerEigh, FaerQr, FaerSvd, default_rrqr_rank_alpha, rrqr_with_permutation,
};
use faer::Side;
use ndarray::{Array1, Array2, Array3, Array4, ArrayView1, ArrayView2, s};
#[derive(Debug, Clone)]
pub struct MechanismSparsityJacobian {
pub weight: f64,
pub epsilon: f64,
}
impl MechanismSparsityJacobian {
pub fn new(weight: f64, epsilon: f64) -> Result<Self, String> {
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"MechanismSparsityJacobian: weight must be finite and >0, got {weight}"
));
}
if !(epsilon.is_finite() && epsilon > 0.0) {
return Err(format!(
"MechanismSparsityJacobian: epsilon must be finite and >0, got {epsilon}"
));
}
Ok(Self { weight, epsilon })
}
pub fn value_and_grad(&self, w: ArrayView2<f64>) -> (f64, Array2<f64>) {
let (d, k) = w.dim();
let eps2 = self.epsilon * self.epsilon;
let mut grad = Array2::<f64>::zeros((d, k));
let mut value = 0.0;
for col in 0..k {
let mut sq = 0.0;
for row in 0..d {
sq += w[[row, col]] * w[[row, col]];
}
let denom = (sq + eps2).sqrt();
value += denom - self.epsilon;
let factor = self.weight / denom;
for row in 0..d {
grad[[row, col]] = factor * w[[row, col]];
}
}
(self.weight * value, grad)
}
pub fn hessian_diag(&self, w: ArrayView2<f64>) -> Array2<f64> {
let (d, k) = w.dim();
let eps2 = self.epsilon * self.epsilon;
let mut out = Array2::<f64>::zeros((d, k));
for col in 0..k {
let mut sq = 0.0;
for row in 0..d {
sq += w[[row, col]] * w[[row, col]];
}
let denom = (sq + eps2).sqrt();
let inv = 1.0 / denom;
let inv3 = inv * inv * inv;
for row in 0..d {
out[[row, col]] = self.weight * (inv - w[[row, col]] * w[[row, col]] * inv3);
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct ConditionalPriorIvae {
pub mean: Array2<f64>,
pub scale: Array2<f64>,
pub weight: f64,
}
impl ConditionalPriorIvae {
pub fn new(mean: Array2<f64>, scale: Array2<f64>, weight: f64) -> Result<Self, String> {
if mean.dim() != scale.dim() {
return Err(format!(
"ConditionalPriorIvae: mean shape {:?} != scale shape {:?}",
mean.dim(),
scale.dim()
));
}
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"ConditionalPriorIvae: weight must be finite and >0, got {weight}"
));
}
for &v in scale.iter() {
if !(v.is_finite() && v > 0.0) {
return Err(format!(
"ConditionalPriorIvae: every scale must be finite and >0, got {v}"
));
}
}
for &v in mean.iter() {
if !v.is_finite() {
return Err("ConditionalPriorIvae: mean contains non-finite entry".to_string());
}
}
let (n_rows, latent_dim) = mean.dim();
let needed_rows = 2 * latent_dim + 1;
if n_rows < needed_rows {
return Err(format!(
"ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
precondition violated: need at least 2k+1 = {needed_rows} distinct \
auxiliary states for latent_dim k = {latent_dim}, got n_rows = {n_rows}"
));
}
let signature = {
let mut s = Array2::<f64>::zeros((n_rows, 2 * latent_dim));
for r in 0..n_rows {
for c in 0..latent_dim {
s[[r, c]] = mean[[r, c]];
s[[r, latent_dim + c]] = scale[[r, c]].ln();
}
}
s
};
let first = signature.row(0).to_owned();
let all_identical = signature
.outer_iter()
.all(|row| row.iter().zip(first.iter()).all(|(a, b)| a == b));
if all_identical {
return Err(format!(
"ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
precondition violated: all {n_rows} rows of the stacked auxiliary \
signature [μ ‖ log σ] are identical, so the conditional prior is the \
trivial unconditional N(μ, σ²) — provably non-identifiable (no \
auxiliary information)"
));
}
let (_u, sv, _vt) = signature
.svd(false, false)
.map_err(|e| format!("ConditionalPriorIvae: SVD of auxiliary signature failed: {e}"))?;
let max_sv = sv.iter().cloned().fold(0.0_f64, f64::max);
let tol = max_sv * (n_rows.max(2 * latent_dim) as f64) * f64::EPSILON;
let numerical_rank = sv.iter().filter(|&&s| s > tol).count();
let required = 2 * latent_dim;
if numerical_rank < required {
return Err(format!(
"ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
precondition violated: stacked auxiliary signature [μ ‖ log σ] has \
numerical rank {numerical_rank} < 2·latent_dim = {required} \
(tolerance {tol:.3e}); the family `p(t|u)` does not span a \
2k-dimensional set of natural parameters"
));
}
Ok(Self {
mean,
scale,
weight,
})
}
pub fn value_and_grad(&self, t: ArrayView2<f64>) -> (f64, Array2<f64>) {
assert_eq!(
t.dim(),
self.mean.dim(),
"ConditionalPriorIvae: t/mean shape mismatch"
);
let (n, d) = t.dim();
let log_2pi = (2.0 * std::f64::consts::PI).ln();
let mut grad = Array2::<f64>::zeros((n, d));
let mut value = 0.0;
for row in 0..n {
for col in 0..d {
let mu = self.mean[[row, col]];
let sigma = self.scale[[row, col]];
let z = (t[[row, col]] - mu) / sigma;
value += 0.5 * (z * z + 2.0 * sigma.ln() + log_2pi);
grad[[row, col]] = self.weight * z / sigma;
}
}
(self.weight * value, grad)
}
pub fn value(&self, t: ArrayView2<f64>) -> f64 {
self.value_and_grad(t).0
}
}
pub fn piecewise_linear_eval(
u: ArrayView1<f64>,
coeffs: ArrayView2<f64>,
u_min: f64,
u_max: f64,
) -> Array2<f64> {
let (k, d) = coeffs.dim();
assert!(k >= 2, "piecewise_linear_eval: need ≥2 centres");
let n = u.len();
let mut out = Array2::<f64>::zeros((n, d));
let step = (u_max - u_min) / (k - 1) as f64;
for (row, &val) in u.iter().enumerate() {
let pos = ((val - u_min) / step).clamp(0.0, (k - 1) as f64);
let lo = (pos.floor() as usize).min(k - 2);
let hi = lo + 1;
let frac = pos - lo as f64;
for col in 0..d {
out[[row, col]] = coeffs[[lo, col]] * (1.0 - frac) + coeffs[[hi, col]] * frac;
}
}
out
}
#[derive(Debug, Clone)]
pub struct WeightSearchResult {
pub best_i: usize,
pub best_j: usize,
pub best_lam1: f64,
pub best_lam2: f64,
pub best_evidence: f64,
pub evidence_grid: Array2<f64>,
}
pub fn identifiable_factor_select_weights(
rss_grid: ArrayView2<'_, f64>,
penalty_grid: ArrayView2<'_, f64>,
lam1_grid: ArrayView1<'_, f64>,
lam2_grid: ArrayView1<'_, f64>,
n_obs: usize,
) -> Result<WeightSearchResult, String> {
let (g1, g2) = rss_grid.dim();
if penalty_grid.dim() != (g1, g2) {
return Err(format!(
"identifiable_factor_select_weights: penalty_grid shape {:?} \
must match rss_grid shape ({}, {})",
penalty_grid.dim(),
g1,
g2
));
}
if lam1_grid.len() != g1 {
return Err(format!(
"identifiable_factor_select_weights: lam1_grid len {} must \
equal rss_grid rows {}",
lam1_grid.len(),
g1
));
}
if lam2_grid.len() != g2 {
return Err(format!(
"identifiable_factor_select_weights: lam2_grid len {} must \
equal rss_grid cols {}",
lam2_grid.len(),
g2
));
}
if g1 == 0 || g2 == 0 {
return Err("identifiable_factor_select_weights: grids must be non-empty".to_string());
}
if n_obs == 0 {
return Err("identifiable_factor_select_weights: n_obs must be > 0".to_string());
}
for v in rss_grid.iter() {
if !v.is_finite() || *v < 0.0 {
return Err(format!(
"identifiable_factor_select_weights: rss_grid contains non-finite or \
negative value {v}"
));
}
}
for v in penalty_grid.iter() {
if !v.is_finite() {
return Err(format!(
"identifiable_factor_select_weights: penalty_grid contains non-finite value {v}"
));
}
}
for v in lam1_grid.iter().chain(lam2_grid.iter()) {
if !v.is_finite() || *v <= 0.0 {
return Err(format!(
"identifiable_factor_select_weights: λ grids must contain finite positive \
values, got {v}"
));
}
}
let n = n_obs as f64;
let rss_floor = 1.0e-300_f64;
let mut evidence_grid = Array2::<f64>::zeros((g1, g2));
let mut best: Option<(usize, usize, f64)> = None;
for i in 0..g1 {
for j in 0..g2 {
let rss = rss_grid[[i, j]];
let pen = penalty_grid[[i, j]];
let mean_sq = (rss / n).max(rss_floor);
let ev = -0.5 * n * mean_sq.ln() - 0.5 * pen;
evidence_grid[[i, j]] = ev;
let better = match best {
None => true,
Some((bi, bj, bev)) => {
if ev > bev {
true
} else if ev == bev {
let cur_sum = i + j;
let best_sum = bi + bj;
if cur_sum < best_sum {
true
} else if cur_sum == best_sum && i < bi {
true
} else {
cur_sum == best_sum && i == bi && j < bj
}
} else {
false
}
}
};
if better {
best = Some((i, j, ev));
}
}
}
let (best_i, best_j, best_evidence) = best.ok_or_else(|| {
"identifiable_factor_select_weights: empty search (this is a bug)".to_string()
})?;
Ok(WeightSearchResult {
best_i,
best_j,
best_lam1: lam1_grid[best_i],
best_lam2: lam2_grid[best_j],
best_evidence,
evidence_grid,
})
}
pub fn thin_svd_scores(x: ArrayView2<f64>, k: usize) -> Result<Array2<f64>, String> {
let (n, p) = x.dim();
if k == 0 {
return Ok(Array2::<f64>::zeros((n, 0)));
}
if k > n.min(p) {
return Err(format!(
"thin_svd_scores: requested {k} components but min(n={n}, p={p}) limits to {}",
n.min(p)
));
}
let mut mean_row = Array1::<f64>::zeros(p);
for row in 0..n {
for col in 0..p {
mean_row[col] += x[[row, col]];
}
}
if n > 0 {
let inv_n = 1.0 / (n as f64);
for col in 0..p {
mean_row[col] *= inv_n;
}
}
let mut xc = Array2::<f64>::zeros((n, p));
for row in 0..n {
for col in 0..p {
xc[[row, col]] = x[[row, col]] - mean_row[col];
}
}
let (u_opt, sigma, _vt_opt) = xc
.svd(true, false)
.map_err(|e| format!("thin_svd_scores: SVD failed: {e}"))?;
let u = u_opt.ok_or_else(|| "thin_svd_scores: SVD did not return U".to_string())?;
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
for col in 0..k {
out[[row, col]] = u[[row, col]] * sigma[col];
}
}
Ok(out)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PartialSupervisionSupMethod {
Procrustes,
Anchor,
SoftL2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PartialSupervisionFreeConstraint {
OrthogonalToSup,
None,
}
#[derive(Debug, Clone)]
pub struct PartialSupervisionResult {
pub t_supervised: Array2<f64>,
pub t_free: Array2<f64>,
pub alignment_score: f64,
pub selected_weight: Option<f64>,
pub map_r: Option<Array2<f64>>,
pub map_a: Option<Array2<f64>>,
pub map_b: Option<Array1<f64>>,
}
pub fn partial_supervision_solve(
t_sup: ArrayView2<f64>,
aux: ArrayView2<f64>,
t_free: ArrayView2<f64>,
method: PartialSupervisionSupMethod,
anchor_idx: &[usize],
free_constraint: PartialSupervisionFreeConstraint,
) -> Result<PartialSupervisionResult, String> {
let (n, d_sup) = t_sup.dim();
if aux.dim() != (n, d_sup) {
return Err(format!(
"partial_supervision_solve: aux shape {:?} must equal t_sup shape ({}, {})",
aux.dim(),
n,
d_sup
));
}
if t_free.nrows() != n {
return Err(format!(
"partial_supervision_solve: t_free has {} rows, expected {}",
t_free.nrows(),
n
));
}
let aux_norm_sq: f64 = aux.iter().map(|x| x * x).sum();
if !(aux_norm_sq.is_finite() && aux_norm_sq > 0.0) {
return Err(
"partial_supervision_solve: aux has zero or non-finite Frobenius norm".to_string(),
);
}
let mut t_sup_aligned = Array2::<f64>::zeros((n, d_sup));
let mut map_r: Option<Array2<f64>> = None;
let mut map_a: Option<Array2<f64>> = None;
let mut map_b: Option<Array1<f64>> = None;
let mut selected_weight: Option<f64> = None;
match method {
PartialSupervisionSupMethod::Procrustes => {
let m = t_sup.t().dot(&aux);
let (u_opt, _sigma, vt_opt) = m
.svd(true, true)
.map_err(|e| format!("partial_supervision_solve: Procrustes SVD failed: {e}"))?;
let u = u_opt
.ok_or_else(|| "partial_supervision_solve: SVD did not return U".to_string())?;
let vt = vt_opt
.ok_or_else(|| "partial_supervision_solve: SVD did not return Vᵀ".to_string())?;
let r = u.dot(&vt);
t_sup_aligned = t_sup.dot(&r);
map_r = Some(r);
}
PartialSupervisionSupMethod::Anchor => {
if anchor_idx.is_empty() {
return Err(
"partial_supervision_solve: anchor method requires anchor_idx with at \
least one row"
.to_string(),
);
}
for &idx in anchor_idx {
if idx >= n {
return Err(format!(
"partial_supervision_solve: anchor index {idx} out of bounds (n={n})"
));
}
}
let m_rows = anchor_idx.len();
let mut design = Array2::<f64>::zeros((m_rows, d_sup + 1));
let mut targets = Array2::<f64>::zeros((m_rows, d_sup));
for (row_out, &row_in) in anchor_idx.iter().enumerate() {
for c in 0..d_sup {
design[[row_out, c]] = t_sup[[row_in, c]];
targets[[row_out, c]] = aux[[row_in, c]];
}
design[[row_out, d_sup]] = 1.0;
}
let (u_opt, sigma, vt_opt) = design
.svd(true, true)
.map_err(|e| format!("partial_supervision_solve: Anchor SVD failed: {e}"))?;
let u = u_opt
.ok_or_else(|| "partial_supervision_solve: anchor SVD lacked U".to_string())?;
let vt = vt_opt
.ok_or_else(|| "partial_supervision_solve: anchor SVD lacked Vᵀ".to_string())?;
let leading = sigma.iter().cloned().fold(0.0_f64, f64::max);
let cutoff = leading * f64::EPSILON * (m_rows.max(d_sup + 1) as f64);
let rank = sigma.len();
let ut_targets = u.t().dot(&targets);
let mut scaled = Array2::<f64>::zeros((rank, d_sup));
for r in 0..rank {
let s = sigma[r];
if s > cutoff {
let inv = 1.0 / s;
for c in 0..d_sup {
scaled[[r, c]] = inv * ut_targets[[r, c]];
}
}
}
let coef = vt.t().dot(&scaled);
let a = coef.slice(s![..d_sup, ..]).to_owned();
let b_vec = coef.slice(s![d_sup, ..]).to_owned();
for row in 0..n {
for c in 0..d_sup {
let mut acc = b_vec[c];
for k in 0..d_sup {
acc += t_sup[[row, k]] * a[[k, c]];
}
t_sup_aligned[[row, c]] = acc;
}
}
map_a = Some(a);
map_b = Some(b_vec);
}
PartialSupervisionSupMethod::SoftL2 => {
let g = t_sup.t().dot(&t_sup);
let (eigvals, eigvecs) = g
.eigh(Side::Lower)
.map_err(|e| format!("partial_supervision_solve: eigh on Gram failed: {e}"))?;
let rhs = t_sup.t().dot(&aux);
let ut_aux = eigvecs.t().dot(&rhs);
let m_row: Array1<f64> = Array1::from_vec(
(0..d_sup)
.map(|r| (0..d_sup).map(|c| ut_aux[[r, c]] * ut_aux[[r, c]]).sum())
.collect(),
);
let lam_max = eigvals.iter().cloned().fold(0.0_f64, f64::max);
let floor = (lam_max * 1.0e-10).max(1.0e-12);
let top = (lam_max * 1.0e3).max(floor * 1.0e6);
let grid_n: usize = 64;
let log_floor = floor.ln();
let log_top = top.ln();
let mut best_score = f64::INFINITY;
let mut best_lam = floor;
for k in 0..grid_n {
let frac = if grid_n == 1 {
0.0
} else {
(k as f64) / ((grid_n - 1) as f64)
};
let lam = (log_floor + frac * (log_top - log_floor)).exp();
let mut shrunk = 0.0_f64; let mut logdet = 0.0_f64; for r in 0..d_sup {
let g = eigvals[r].max(0.0);
shrunk += m_row[r] / (g + lam);
logdet += (1.0 + g / lam).ln();
}
let s = aux_norm_sq - shrunk;
if !(s.is_finite() && s > 0.0) {
continue;
}
let score = (n as f64) * s.ln() + logdet;
if score < best_score {
best_score = score;
best_lam = lam;
}
}
if !best_score.is_finite() {
return Err(
"partial_supervision_solve: REML grid did not find a finite-score weight"
.to_string(),
);
}
let denom: Array1<f64> = eigvals.mapv(|v| v + best_lam);
let mut a_eig = Array2::<f64>::zeros((d_sup, d_sup));
for r in 0..d_sup {
for c in 0..d_sup {
a_eig[[r, c]] = ut_aux[[r, c]] / denom[r];
}
}
let best_a = eigvecs.dot(&a_eig);
t_sup_aligned = t_sup.dot(&best_a);
map_a = Some(best_a);
selected_weight = Some(best_lam);
}
}
let mut sq_resid = 0.0_f64;
for row in 0..n {
for c in 0..d_sup {
let r = t_sup_aligned[[row, c]] - aux[[row, c]];
sq_resid += r * r;
}
}
let alignment_score = 1.0 - sq_resid / aux_norm_sq;
let t_free_out = match free_constraint {
PartialSupervisionFreeConstraint::None => t_free.to_owned(),
PartialSupervisionFreeConstraint::OrthogonalToSup => {
if t_sup_aligned.ncols() == 0 || t_free.ncols() == 0 {
t_free.to_owned()
} else {
let qr_pair = t_sup_aligned
.qr()
.map_err(|e| format!("partial_supervision_solve: QR on T_sup failed: {e}"))?;
let q = qr_pair.0;
let qt_free = q.t().dot(&t_free);
let proj = q.dot(&qt_free);
let mut out = t_free.to_owned();
out -= &proj;
out
}
}
};
Ok(PartialSupervisionResult {
t_supervised: t_sup_aligned,
t_free: t_free_out,
alignment_score,
selected_weight,
map_r,
map_a,
map_b,
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AtomTopology {
Circle,
Sphere,
Torus { latent_dim: usize },
EuclideanPatch { latent_dim: usize },
}
impl AtomTopology {
fn latent_dim(&self) -> usize {
match self {
AtomTopology::Circle => 1,
AtomTopology::Sphere => 2,
AtomTopology::Torus { latent_dim } => *latent_dim,
AtomTopology::EuclideanPatch { latent_dim } => *latent_dim,
}
}
}
#[derive(Debug, Clone)]
pub struct FittedAtom {
pub name: String,
pub topology: AtomTopology,
pub frame: Array2<f64>,
pub ard_variances: Option<Array1<f64>>,
pub lowering_error: f64,
}
pub struct FittedSaeManifold {
pub atoms: Vec<FittedAtom>,
pub jacobian_rows: Vec<Vec<f64>>,
pub isometry_penalty_root: Array2<f64>,
pub metric: RowMetric,
}
impl FittedSaeManifold {
pub fn param_dim(&self) -> usize {
self.atoms.iter().map(|a| a.frame.len()).sum()
}
fn atom_offset(&self, k: usize) -> usize {
self.atoms[..k].iter().map(|a| a.frame.len()).sum()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GeneratorFamily {
IsomAtom,
EqualArdRotation,
FrameRotation,
AtomPermutation,
}
impl GeneratorFamily {
fn label(self) -> &'static str {
match self {
GeneratorFamily::IsomAtom => "Isom(M_k)",
GeneratorFamily::EqualArdRotation => "equal-ARD rotation",
GeneratorFamily::FrameRotation => "frame rotation O(output_dim)",
GeneratorFamily::AtomPermutation => "Sym(F) atom permutation",
}
}
}
pub const GENERATOR_FLAT_ENERGY_TOL: f64 = 1.0e-3;
#[derive(Debug, Clone)]
pub struct GeneratorVerdict {
pub family: GeneratorFamily,
pub description: String,
pub unpinned: bool,
pub generator_norm: f64,
pub pinned_energy_fraction: f64,
pub lowering_error_scale: f64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FrameInnerRotationGauge {
pub per_atom_ranks: Vec<usize>,
pub dim: usize,
}
impl FrameInnerRotationGauge {
pub fn from_ranks(per_atom_ranks: Vec<usize>) -> Self {
let dim = frame_inner_rotation_dim(&per_atom_ranks);
Self {
per_atom_ranks,
dim,
}
}
}
pub fn frame_inner_rotation_dim(ranks: &[usize]) -> usize {
ranks.iter().map(|&r| r * r.saturating_sub(1) / 2).sum()
}
#[derive(Debug, Clone)]
pub struct ResidualGaugeReport {
pub metric_provenance: MetricProvenance,
pub generators: Vec<GeneratorVerdict>,
pub pinning_rank: usize,
pub residual_gauge_dim: usize,
pub diffeomorphism_unpinned: bool,
pub sym_f_trivial_under_output_fisher: Option<bool>,
pub frame_inner_rotation: Option<FrameInnerRotationGauge>,
pub summary: String,
}
impl ResidualGaugeReport {
pub fn group_signature(&self) -> String {
let base = group_signature_of(&self.generators, self.diffeomorphism_unpinned);
match &self.frame_inner_rotation {
Some(gauge) if gauge.dim > 0 => format!(
"{base} ⊕ frame-inner ∏O(r_k)×{} [dim {}, canonical-fixed]",
gauge.per_atom_ranks.len(),
gauge.dim
),
_ => base,
}
}
pub fn with_frame_inner_rotation(mut self, ranks: Vec<usize>) -> Self {
let gauge = FrameInnerRotationGauge::from_ranks(ranks);
if gauge.dim > 0 {
self.summary.push_str(&format!(
"; frame inner-rotation gauge ∏O(r_k) of dim {} enumerated \
(exact reparameterization, fixed by the canonical orientation gauge)",
gauge.dim
));
}
self.frame_inner_rotation = Some(gauge);
self
}
}
fn group_signature_of(generators: &[GeneratorVerdict], diffeomorphism_unpinned: bool) -> String {
let mut counts: std::collections::BTreeMap<&'static str, usize> =
std::collections::BTreeMap::new();
for g in generators {
if g.unpinned {
*counts.entry(g.family.label()).or_insert(0) += 1;
}
}
let body = if counts.is_empty() {
"{e} [fully pinned: rigid up to nothing]".to_string()
} else {
counts
.iter()
.map(|(name, mult)| format!("{name}×{mult}"))
.collect::<Vec<_>>()
.join(" ⊕ ")
};
if diffeomorphism_unpinned {
format!("Diff(M) ⊇ {{ {body} }} [diffeomorphism-unpinned: isometry pin inactive]")
} else {
body
}
}
fn atom_isometry_generators(atom: &FittedAtom) -> Vec<(Array1<f64>, String)> {
let (p, d) = atom.frame.dim();
if d != atom.topology.latent_dim() {
return Vec::new();
}
let mut out: Vec<(Array1<f64>, String)> = Vec::new();
match &atom.topology {
AtomTopology::Circle => {
if d >= 1 {
let mut g = Array1::<f64>::zeros(p * d);
for i in 0..p {
g[i * d] = atom.frame[[i, 0]];
}
out.push((g, format!("{}: S¹ U(1) phase shift", atom.name)));
}
}
AtomTopology::Sphere | AtomTopology::EuclideanPatch { .. } | AtomTopology::Torus { .. } => {
for a in 0..d {
for b in (a + 1)..d {
let mut g = Array1::<f64>::zeros(p * d);
for i in 0..p {
g[i * d + a] = -atom.frame[[i, b]];
g[i * d + b] = atom.frame[[i, a]];
}
out.push((
g,
format!(
"{}: {} rotation axes ({a},{b})",
atom.name,
match &atom.topology {
AtomTopology::Sphere => "S² so(3)",
AtomTopology::Torus { .. } => "Tᵈ frame",
_ => "patch so(d)",
}
),
));
}
}
if let AtomTopology::Torus { .. } = atom.topology {
for a in 0..d {
let mut g = Array1::<f64>::zeros(p * d);
for i in 0..p {
g[i * d + a] = atom.frame[[i, a]];
}
out.push((g, format!("{}: Tᵈ circle shift axis {a}", atom.name)));
}
}
}
}
out
}
fn equal_ard_rotation_generators(atom: &FittedAtom) -> Vec<(Array1<f64>, String)> {
let mut out: Vec<(Array1<f64>, String)> = Vec::new();
let (p, d) = atom.frame.dim();
let Some(ard) = atom.ard_variances.as_ref() else {
return out;
};
if ard.len() != d {
return out;
}
const ARD_EQUAL_REL_TOL: f64 = 1.0e-9;
for a in 0..d {
for b in (a + 1)..d {
let va = ard[a];
let vb = ard[b];
let scale = va.abs().max(vb.abs()).max(f64::MIN_POSITIVE);
if (va - vb).abs() <= ARD_EQUAL_REL_TOL * scale {
let mut g = Array1::<f64>::zeros(p * d);
for i in 0..p {
g[i * d + a] = -atom.frame[[i, b]];
g[i * d + b] = atom.frame[[i, a]];
}
out.push((
g,
format!("{}: equal-ARD rotation axes ({a},{b})", atom.name),
));
}
}
}
out
}
fn frame_rotation_generators(model: &FittedSaeManifold) -> Vec<(Array1<f64>, String)> {
let mut out: Vec<(Array1<f64>, String)> = Vec::new();
let p = model
.atoms
.iter()
.map(|a| a.frame.nrows())
.max()
.unwrap_or(0);
let param_dim = model.param_dim();
for oi in 0..p {
for oj in (oi + 1)..p {
let mut g = Array1::<f64>::zeros(param_dim);
for (k, atom) in model.atoms.iter().enumerate() {
let (ap, ad) = atom.frame.dim();
if oi >= ap || oj >= ap {
continue;
}
let base = model.atom_offset(k);
for c in 0..ad {
g[base + oi * ad + c] = -atom.frame[[oj, c]];
g[base + oj * ad + c] = atom.frame[[oi, c]];
}
}
out.push((g, format!("output-frame rotation axes ({oi},{oj})")));
}
}
out
}
fn embed_local_generator(offset: usize, local: &Array1<f64>, param_dim: usize) -> Array1<f64> {
let mut g = Array1::<f64>::zeros(param_dim);
g.slice_mut(s![offset..offset + local.len()]).assign(local);
g
}
fn atom_permutation_generators(
model: &FittedSaeManifold,
) -> Vec<(Array1<f64>, String, usize, usize)> {
let mut out: Vec<(Array1<f64>, String, usize, usize)> = Vec::new();
let param_dim = model.param_dim();
for ka in 0..model.atoms.len() {
for kb in (ka + 1)..model.atoms.len() {
let a = &model.atoms[ka];
let b = &model.atoms[kb];
if a.topology != b.topology || a.frame.dim() != b.frame.dim() {
continue;
}
let (ap, ad) = a.frame.dim();
let base_a = model.atom_offset(ka);
let base_b = model.atom_offset(kb);
let mut g = Array1::<f64>::zeros(param_dim);
for i in 0..ap {
for c in 0..ad {
let diff = b.frame[[i, c]] - a.frame[[i, c]];
g[base_a + i * ad + c] = diff;
g[base_b + i * ad + c] = -diff;
}
}
out.push((g, format!("atom-exchange {} ↔ {}", a.name, b.name), ka, kb));
}
}
out
}
#[derive(Debug, Clone)]
pub struct AtomParameterView {
pub basis_values: Array2<f64>,
pub basis_jacobian: Array3<f64>,
pub decoder: Array2<f64>,
pub coords: Array2<f64>,
pub activations: Array1<f64>,
pub basis_second_jet: Option<Array4<f64>>,
}
pub struct OrbitPenaltyOperator {
#[allow(clippy::type_complexity)]
pub apply: Box<dyn Fn(ArrayView2<f64>, ArrayView2<f64>) -> Array1<f64> + Send + Sync>,
pub stiffness_sq: f64,
}
pub fn isometry_orbit_penalty_operator(
view: &AtomParameterView,
weight: f64,
) -> Option<OrbitPenaltyOperator> {
let second = view.basis_second_jet.as_ref()?.clone();
let (n, m) = view.basis_values.dim();
let d = view.coords.ncols();
let p = view.decoder.ncols();
if second.dim() != (n, m, d, d) || view.basis_jacobian.dim() != (n, m, d) {
return None;
}
if !(weight.is_finite() && weight > 0.0) {
return None;
}
let sqrt_w = weight.sqrt();
let jac = view.basis_jacobian.clone();
let decoder = view.decoder.clone();
let mut j_base = Array3::<f64>::zeros((n, p, d));
for row in 0..n {
for i in 0..p {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += jac[[row, mm, c]] * decoder[[mm, i]];
}
j_base[[row, i, c]] = acc;
}
}
}
let mut max_curv_sq = 0.0_f64;
for row in 0..n {
let mut hn = vec![0.0_f64; p * d * d];
for i in 0..p {
for c in 0..d {
for e in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += second[[row, mm, c, e]] * decoder[[mm, i]];
}
hn[(i * d + c) * d + e] = acc;
}
}
}
for e in 0..d {
let mut frob = 0.0_f64;
for a in 0..d {
for b in 0..d {
let mut g = 0.0;
for i in 0..p {
g += hn[(i * d + a) * d + e] * j_base[[row, i, b]];
g += j_base[[row, i, a]] * hn[(i * d + b) * d + e];
}
frob += g * g;
}
}
max_curv_sq = max_curv_sq.max(frob);
}
}
let stiffness_sq = (weight * max_curv_sq).max(f64::MIN_POSITIVE);
let apply = move |delta_b: ArrayView2<f64>, delta_t: ArrayView2<f64>| -> Array1<f64> {
let mut image = Array1::<f64>::zeros(n * d * d);
let valid_b = delta_b.dim() == (m, p);
let valid_t = delta_t.dim() == (n, d);
if !valid_t {
return image;
}
for row in 0..n {
let mut dj = vec![0.0_f64; p * d];
for i in 0..p {
for c in 0..d {
let mut acc = 0.0;
if valid_b {
for mm in 0..m {
acc += jac[[row, mm, c]] * delta_b[[mm, i]];
}
}
for e in 0..d {
let dte = delta_t[[row, e]];
if dte == 0.0 {
continue;
}
for mm in 0..m {
acc += second[[row, mm, c, e]] * dte * decoder[[mm, i]];
}
}
dj[i * d + c] = acc;
}
}
for a in 0..d {
for b in 0..d {
let mut dg = 0.0;
for i in 0..p {
dg += dj[i * d + a] * j_base[[row, i, b]];
dg += j_base[[row, i, a]] * dj[i * d + b];
}
image[(row * d + a) * d + b] = sqrt_w * dg;
}
}
}
image
};
Some(OrbitPenaltyOperator {
apply: Box::new(apply),
stiffness_sq,
})
}
fn exact_orbit_fields(
atom: &FittedAtom,
view: &AtomParameterView,
) -> Vec<(GeneratorFamily, Array2<f64>, String)> {
let n = view.coords.nrows();
let d = view.coords.ncols();
let mut out: Vec<(GeneratorFamily, Array2<f64>, String)> = Vec::new();
let rotation_field = |a: usize, b: usize| -> Array2<f64> {
let mut dt = Array2::<f64>::zeros((n, d));
for row in 0..n {
dt[[row, a]] = -view.coords[[row, b]];
dt[[row, b]] = view.coords[[row, a]];
}
dt
};
match &atom.topology {
AtomTopology::Circle => {
out.push((
GeneratorFamily::IsomAtom,
Array2::<f64>::ones((n, 1)),
format!("{}: S¹ U(1) phase shift [exact orbit]", atom.name),
));
}
AtomTopology::Torus { .. } => {
for ax in 0..d {
let mut dt = Array2::<f64>::zeros((n, d));
dt.column_mut(ax).fill(1.0);
out.push((
GeneratorFamily::IsomAtom,
dt,
format!("{}: Tᵈ circle shift axis {ax} [exact orbit]", atom.name),
));
}
}
AtomTopology::EuclideanPatch { .. } => {
for a in 0..d {
for b in (a + 1)..d {
out.push((
GeneratorFamily::IsomAtom,
rotation_field(a, b),
format!(
"{}: patch so(d) rotation axes ({a},{b}) [exact orbit]",
atom.name
),
));
}
}
}
AtomTopology::Sphere => {}
}
if !matches!(atom.topology, AtomTopology::Circle | AtomTopology::Sphere) {
if let Some(ard) = atom.ard_variances.as_ref() {
if ard.len() == d {
const ARD_EQUAL_REL_TOL: f64 = 1.0e-9;
for a in 0..d {
for b in (a + 1)..d {
let scale = ard[a].abs().max(ard[b].abs()).max(f64::MIN_POSITIVE);
if (ard[a] - ard[b]).abs() <= ARD_EQUAL_REL_TOL * scale {
out.push((
GeneratorFamily::EqualArdRotation,
rotation_field(a, b),
format!(
"{}: equal-ARD rotation axes ({a},{b}) [exact orbit]",
atom.name
),
));
}
}
}
}
}
}
out
}
fn exact_orbit_verdicts(
atom: &FittedAtom,
view: &AtomParameterView,
penalty: Option<&OrbitPenaltyOperator>,
) -> Result<Vec<GeneratorVerdict>, String> {
let (n, m) = view.basis_values.dim();
let d = view.coords.ncols();
let p = view.decoder.ncols();
if view.basis_jacobian.dim() != (n, m, d) {
return Err(format!(
"exact_orbit_verdicts({}): basis_jacobian shape {:?} must be ({n}, {m}, {d})",
atom.name,
view.basis_jacobian.dim()
));
}
if view.decoder.nrows() != m {
return Err(format!(
"exact_orbit_verdicts({}): decoder has {} rows but basis has {m} columns",
atom.name,
view.decoder.nrows()
));
}
if view.coords.nrows() != n || view.activations.len() != n {
return Err(format!(
"exact_orbit_verdicts({}): coords/activations rows must match basis rows {n}",
atom.name
));
}
let fields = exact_orbit_fields(atom, view);
if fields.is_empty() {
return Ok(Vec::new());
}
let mut design = Array2::<f64>::zeros((n, m));
for row in 0..n {
let a = view.activations[row];
for c in 0..m {
design[[row, c]] = a * view.basis_values[[row, c]];
}
}
let (u_opt, sigma, vt_opt) = design
.svd(true, true)
.map_err(|e| format!("exact_orbit_verdicts({}): SVD of D failed: {e}", atom.name))?;
let u_svd =
u_opt.ok_or_else(|| format!("exact_orbit_verdicts({}): SVD lacked U", atom.name))?;
let vt = vt_opt.ok_or_else(|| format!("exact_orbit_verdicts({}): SVD lacked Vᵀ", atom.name))?;
let smax = sigma.iter().cloned().fold(0.0_f64, f64::max);
let cutoff = smax * f64::EPSILON * (n.max(m) as f64);
let mut out: Vec<GeneratorVerdict> = Vec::with_capacity(fields.len());
for (family, dt, description) in fields {
let mut u_mot = Array2::<f64>::zeros((n, p));
for row in 0..n {
let a = view.activations[row];
if !(a != 0.0) {
continue;
}
for ax in 0..d {
let step = dt[[row, ax]];
if step == 0.0 {
continue;
}
for bm in 0..m {
let dphi = view.basis_jacobian[[row, bm, ax]];
if dphi == 0.0 {
continue;
}
let w = a * step * dphi;
for j in 0..p {
u_mot[[row, j]] += w * view.decoder[[bm, j]];
}
}
}
}
let raw: f64 = u_mot.iter().map(|v| v * v).sum();
if raw <= f64::MIN_POSITIVE {
out.push(GeneratorVerdict {
family,
description,
unpinned: false,
generator_norm: 0.0,
pinned_energy_fraction: 1.0,
lowering_error_scale: 0.0,
});
continue;
}
let coeffs = u_svd.t().dot(&u_mot);
let mut kept_sq = 0.0_f64;
let mut scaled = Array2::<f64>::zeros((sigma.len(), p));
for r in 0..sigma.len() {
if sigma[r] > cutoff {
let inv = 1.0 / sigma[r];
for j in 0..p {
kept_sq += coeffs[[r, j]] * coeffs[[r, j]];
scaled[[r, j]] = -inv * coeffs[[r, j]];
}
}
}
let resid_sq = (raw - kept_sq).max(0.0);
let data_fraction = (resid_sq / raw).clamp(0.0, 1.0);
let penalty_fraction = match penalty {
Some(op) if op.stiffness_sq > f64::MIN_POSITIVE => {
let delta_b = vt.t().dot(&scaled); let image = (op.apply)(delta_b.view(), dt.view());
let cost: f64 = image.iter().map(|v| v * v).sum();
(cost / op.stiffness_sq).clamp(0.0, 1.0)
}
_ => 0.0,
};
let pinned_energy_fraction = data_fraction.max(penalty_fraction);
out.push(GeneratorVerdict {
family,
description,
unpinned: pinned_energy_fraction <= GENERATOR_FLAT_ENERGY_TOL,
generator_norm: raw.sqrt(),
pinned_energy_fraction,
lowering_error_scale: 0.0,
});
}
Ok(out)
}
fn stacked_curvature_root(model: &FittedSaeManifold) -> Result<Array2<f64>, String> {
let param_dim = model.param_dim();
if param_dim == 0 {
return Ok(Array2::<f64>::zeros((0, 0)));
}
let p = model.metric.p_out();
let mut stacked_rows: Vec<Array1<f64>> = Vec::new();
for (n, j_flat) in model.jacobian_rows.iter().enumerate() {
if j_flat.len() != p * param_dim {
return Err(format!(
"stacked_curvature_root: jacobian_rows[{n}] has len {} but expected p*param_dim = {}*{} = {}",
j_flat.len(),
p,
param_dim,
p * param_dim
));
}
let mut cols_whitened: Vec<Vec<f64>> = Vec::with_capacity(param_dim);
for c in 0..param_dim {
let mut col = vec![0.0_f64; p];
for i in 0..p {
col[i] = j_flat[i * param_dim + c];
}
cols_whitened.push(model.metric.whiten_residual_row(n, ArrayView1::from(&col)));
}
let whit_len = cols_whitened.first().map_or(0, |c| c.len());
for r in 0..whit_len {
let mut row = Array1::<f64>::zeros(param_dim);
for (c, col) in cols_whitened.iter().enumerate() {
row[c] = col[r];
}
stacked_rows.push(row);
}
}
if model.isometry_penalty_root.ncols() != 0 {
if model.isometry_penalty_root.ncols() != param_dim {
return Err(format!(
"stacked_curvature_root: isometry_penalty_root has {} cols but param_dim = {param_dim}",
model.isometry_penalty_root.ncols()
));
}
for r in 0..model.isometry_penalty_root.nrows() {
stacked_rows.push(model.isometry_penalty_root.row(r).to_owned());
}
}
if stacked_rows.is_empty() {
return Ok(Array2::<f64>::zeros((0, param_dim)));
}
let m = stacked_rows.len();
let mut r_mat = Array2::<f64>::zeros((m, param_dim));
for (i, row) in stacked_rows.iter().enumerate() {
r_mat.row_mut(i).assign(row);
}
Ok(r_mat)
}
enum CurvatureReduction {
Root {
pinning_rank: usize,
sigma_max_sq: f64,
root: Array2<f64>,
},
Gram {
pinning_rank: usize,
sigma_max_sq: f64,
gram: Array2<f64>,
},
}
impl CurvatureReduction {
fn from_model(model: &FittedSaeManifold) -> Result<Self, String> {
let root = stacked_curvature_root(model)?;
if root.nrows() == 0 {
return Ok(Self::Root {
pinning_rank: 0,
sigma_max_sq: 0.0,
root,
});
}
let r_t = root.t().to_owned();
let rrqr = rrqr_with_permutation(&r_t, default_rrqr_rank_alpha())
.map_err(|e| format!("residual_gauge: RRQR on Rᵀ failed: {e:?}"))?;
let (_u, sv, _vt) = root
.svd(false, false)
.map_err(|e| format!("residual_gauge: SVD of curvature root failed: {e}"))?;
let smax = sv.iter().cloned().fold(0.0_f64, f64::max);
Ok(Self::Root {
pinning_rank: rrqr.rank,
sigma_max_sq: smax * smax,
root,
})
}
fn from_gram(gram: Array2<f64>, root_rows: usize, param_dim: usize) -> Result<Self, String> {
if gram.nrows() != param_dim || gram.ncols() != param_dim {
return Err(format!(
"residual_gauge: curvature gram has shape ({}, {}) but param_dim = {param_dim}",
gram.nrows(),
gram.ncols()
));
}
if param_dim == 0 || root_rows == 0 {
return Ok(Self::Gram {
pinning_rank: 0,
sigma_max_sq: 0.0,
gram,
});
}
let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
format!("residual_gauge: eigendecomposition of curvature gram failed: {e}")
})?;
let sigma_max_sq = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
let sigma_max = sigma_max_sq.sqrt();
let rank_tol = default_rrqr_rank_alpha()
* f64::EPSILON
* (root_rows.max(param_dim).max(1) as f64)
* sigma_max.max(1.0);
let lambda_tol = rank_tol * rank_tol;
let pinning_rank = evals
.iter()
.filter(|&&lambda| lambda.max(0.0) > lambda_tol)
.count();
Ok(Self::Gram {
pinning_rank,
sigma_max_sq,
gram,
})
}
fn pinning_rank(&self) -> usize {
match self {
Self::Root { pinning_rank, .. } | Self::Gram { pinning_rank, .. } => *pinning_rank,
}
}
fn sigma_max_sq(&self) -> f64 {
match self {
Self::Root { sigma_max_sq, .. } | Self::Gram { sigma_max_sq, .. } => *sigma_max_sq,
}
}
fn unit_generator_energy(&self, unit: &Array1<f64>) -> f64 {
match self {
Self::Root { root, .. } => {
let r_xi = root.dot(unit);
r_xi.iter().map(|c| c * c).sum::<f64>()
}
Self::Gram { gram, .. } => {
let h_xi = gram.dot(unit);
unit.dot(&h_xi).max(0.0)
}
}
}
}
pub fn residual_gauge(model: &FittedSaeManifold) -> Result<ResidualGaugeReport, String> {
residual_gauge_inner(model, None, None)
}
pub fn residual_gauge_exact(
model: &FittedSaeManifold,
views: &[Option<AtomParameterView>],
penalty_ops: &[Option<OrbitPenaltyOperator>],
) -> Result<ResidualGaugeReport, String> {
let exact = residual_gauge_exact_inputs(model, views, penalty_ops)?;
residual_gauge_inner(model, Some(exact), None)
}
pub fn residual_gauge_exact_from_curvature_gram(
model: &FittedSaeManifold,
views: &[Option<AtomParameterView>],
penalty_ops: &[Option<OrbitPenaltyOperator>],
curvature_gram: Array2<f64>,
root_rows: usize,
) -> Result<ResidualGaugeReport, String> {
let param_dim = model.param_dim();
let curvature = CurvatureReduction::from_gram(curvature_gram, root_rows, param_dim)?;
let exact = residual_gauge_exact_inputs(model, views, penalty_ops)?;
residual_gauge_inner(model, Some(exact), Some(curvature))
}
fn residual_gauge_exact_inputs(
model: &FittedSaeManifold,
views: &[Option<AtomParameterView>],
penalty_ops: &[Option<OrbitPenaltyOperator>],
) -> Result<(Vec<bool>, Vec<GeneratorVerdict>), String> {
if views.len() != model.atoms.len() || penalty_ops.len() != model.atoms.len() {
return Err(format!(
"residual_gauge_exact: views ({}) and penalty_ops ({}) must align with atoms ({})",
views.len(),
penalty_ops.len(),
model.atoms.len()
));
}
let mut mask = vec![false; model.atoms.len()];
let mut exact_verdicts: Vec<GeneratorVerdict> = Vec::new();
for (k, (atom, view)) in model.atoms.iter().zip(views.iter()).enumerate() {
let Some(view) = view else { continue };
if matches!(atom.topology, AtomTopology::Sphere) {
continue;
}
exact_verdicts.extend(exact_orbit_verdicts(atom, view, penalty_ops[k].as_ref())?);
mask[k] = true;
}
Ok((mask, exact_verdicts))
}
fn residual_gauge_inner(
model: &FittedSaeManifold,
exact: Option<(Vec<bool>, Vec<GeneratorVerdict>)>,
precomputed_curvature: Option<CurvatureReduction>,
) -> Result<ResidualGaugeReport, String> {
let metric_provenance = model.metric.provenance();
let param_dim = model.param_dim();
let (exact_mask, exact_verdicts) = match exact {
Some((mask, verdicts)) => (Some(mask), verdicts),
None => (None, Vec::new()),
};
let scale_of = |k: usize| -> f64 { model.atoms[k].lowering_error.clamp(0.0, 1.0) };
let global_scale = (0..model.atoms.len()).map(scale_of).fold(0.0_f64, f64::max);
let mut gens: Vec<(GeneratorFamily, Array1<f64>, String, f64)> = Vec::new();
for (k, atom) in model.atoms.iter().enumerate() {
if exact_mask.as_ref().is_some_and(|mask| mask[k]) {
continue;
}
let base = model.atom_offset(k);
for (g, desc) in atom_isometry_generators(atom) {
gens.push((
GeneratorFamily::IsomAtom,
embed_local_generator(base, &g, param_dim),
desc,
scale_of(k),
));
}
for (g, desc) in equal_ard_rotation_generators(atom) {
gens.push((
GeneratorFamily::EqualArdRotation,
embed_local_generator(base, &g, param_dim),
desc,
scale_of(k),
));
}
}
for (g, desc) in frame_rotation_generators(model) {
gens.push((GeneratorFamily::FrameRotation, g, desc, global_scale));
}
for (g, desc, ka, kb) in atom_permutation_generators(model) {
gens.push((
GeneratorFamily::AtomPermutation,
g,
desc,
scale_of(ka).max(scale_of(kb)),
));
}
let curvature = match precomputed_curvature {
Some(curvature) => curvature,
None => CurvatureReduction::from_model(model)?,
};
let pinning_rank = curvature.pinning_rank();
let sigma_max_sq = curvature.sigma_max_sq();
let diffeomorphism_unpinned = model.isometry_penalty_root.nrows() == 0;
let mut verdicts: Vec<GeneratorVerdict> = Vec::with_capacity(gens.len());
for (family, g, description, lowering_error_scale) in &gens {
let norm = g.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm <= f64::MIN_POSITIVE {
verdicts.push(GeneratorVerdict {
family: *family,
description: description.clone(),
unpinned: false,
generator_norm: 0.0,
pinned_energy_fraction: 1.0,
lowering_error_scale: *lowering_error_scale,
});
continue;
}
let pinned_energy_fraction = if sigma_max_sq <= f64::MIN_POSITIVE {
0.0
} else {
let unit = g.mapv(|v| v / norm);
(curvature.unit_generator_energy(&unit) / sigma_max_sq).clamp(0.0, 1.0)
};
let tolerance = GENERATOR_FLAT_ENERGY_TOL.max(*lowering_error_scale);
let unpinned = pinned_energy_fraction <= tolerance;
verdicts.push(GeneratorVerdict {
family: *family,
description: description.clone(),
unpinned,
generator_norm: norm,
pinned_energy_fraction,
lowering_error_scale: *lowering_error_scale,
});
}
verdicts.extend(exact_verdicts);
let residual_gauge_dim = verdicts.iter().filter(|v| v.unpinned).count();
let sym_f_trivial_under_output_fisher = if matches!(
metric_provenance,
MetricProvenance::OutputFisher { .. } | MetricProvenance::OutputFisherDownstream { .. }
) {
let any_perm_unpinned = verdicts
.iter()
.any(|v| v.family == GeneratorFamily::AtomPermutation && v.unpinned);
Some(!any_perm_unpinned)
} else {
None
};
let summary = format!(
"residual gauge certificate (computed in metric {metric_provenance:?}): \
pinning rank {pinning_rank}, {residual_gauge_dim} unpinned residual gauge \
generator(s) of {} enumerated; group = {}{}{}",
verdicts.len(),
group_signature_of(&verdicts, diffeomorphism_unpinned),
match sym_f_trivial_under_output_fisher {
Some(true) => "; Sym(F) trivially pinned under OutputFisher",
Some(false) => "; ⚠ Sym(F) NON-trivial under OutputFisher (certificate violation)",
None => "",
},
if diffeomorphism_unpinned {
"; ⚠ isometry pin inactive"
} else {
""
},
);
Ok(ResidualGaugeReport {
metric_provenance,
generators: verdicts,
pinning_rank,
residual_gauge_dim,
diffeomorphism_unpinned,
sym_f_trivial_under_output_fisher,
frame_inner_rotation: None,
summary,
})
}
#[derive(Debug, Clone)]
pub struct DictionaryReport {
pub gauge: ResidualGaugeReport,
pub structure: StructureCertificate,
}
pub fn dictionary_report(
model: &FittedSaeManifold,
ledger: &StructureLedger,
alpha: f64,
) -> Result<DictionaryReport, String> {
Ok(DictionaryReport {
gauge: residual_gauge(model)?,
structure: ledger.certify(alpha),
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, array};
#[test]
fn mechanism_sparsity_jacobian_value_matches_closed_form() {
let w = array![[3.0_f64, 0.0], [4.0, 0.0]]; let pen = MechanismSparsityJacobian::new(1.0, 1.0e-8).unwrap();
let (v, _g) = pen.value_and_grad(w.view());
assert!((v - 5.0).abs() < 1e-6, "value {v} expected ≈5");
}
#[test]
fn mechanism_sparsity_jacobian_grad_matches_finite_diff() {
let w = array![[0.5_f64, -1.2, 0.3], [1.1, 0.4, -0.7]];
let pen = MechanismSparsityJacobian::new(2.5, 1.0e-6).unwrap();
let (_, g) = pen.value_and_grad(w.view());
let h = 1.0e-5;
for i in 0..w.nrows() {
for j in 0..w.ncols() {
let mut wp = w.clone();
let mut wm = w.clone();
wp[[i, j]] += h;
wm[[i, j]] -= h;
let (vp, _) = pen.value_and_grad(wp.view());
let (vm, _) = pen.value_and_grad(wm.view());
let fd = (vp - vm) / (2.0 * h);
assert!(
(g[[i, j]] - fd).abs() < 1e-4,
"grad[{i},{j}] = {} vs fd {}",
g[[i, j]],
fd
);
}
}
}
#[test]
fn mechanism_sparsity_jacobian_rejects_bad_input() {
assert!(MechanismSparsityJacobian::new(-1.0, 1e-6).is_err());
assert!(MechanismSparsityJacobian::new(1.0, 0.0).is_err());
}
#[test]
fn frame_inner_rotation_dim_is_sum_of_so_r_dims() {
assert_eq!(frame_inner_rotation_dim(&[]), 0);
assert_eq!(frame_inner_rotation_dim(&[1]), 0);
assert_eq!(frame_inner_rotation_dim(&[2]), 1);
assert_eq!(frame_inner_rotation_dim(&[4]), 6);
assert_eq!(frame_inner_rotation_dim(&[1, 4, 8]), 0 + 6 + 28);
assert_eq!(
FrameInnerRotationGauge::from_ranks(vec![3, 3]).dim,
6,
"two rank-3 frames carry 2·3 inner-rotation dims"
);
}
#[test]
fn frame_inner_rotation_attaches_to_the_certificate_without_verdict_change() {
let base = ResidualGaugeReport {
metric_provenance: MetricProvenance::Euclidean,
generators: Vec::new(),
pinning_rank: 5,
residual_gauge_dim: 0,
diffeomorphism_unpinned: false,
sym_f_trivial_under_output_fisher: None,
frame_inner_rotation: None,
summary: "base".to_string(),
};
let sig_before = base.group_signature();
let report = base.with_frame_inner_rotation(vec![1, 4, 8]);
assert_eq!(
report.frame_inner_rotation,
Some(FrameInnerRotationGauge {
per_atom_ranks: vec![1, 4, 8],
dim: 34,
})
);
assert_eq!(report.residual_gauge_dim, 0);
assert!(report.generators.is_empty());
let sig_after = report.group_signature();
assert_ne!(sig_before, sig_after);
assert!(sig_after.contains("frame-inner"), "got: {sig_after}");
assert!(sig_after.contains("dim 34"), "got: {sig_after}");
assert!(sig_after.contains("canonical-fixed"), "got: {sig_after}");
assert!(report.summary.contains("inner-rotation gauge"));
let trivial = ResidualGaugeReport {
metric_provenance: MetricProvenance::Euclidean,
generators: Vec::new(),
pinning_rank: 0,
residual_gauge_dim: 0,
diffeomorphism_unpinned: false,
sym_f_trivial_under_output_fisher: None,
frame_inner_rotation: None,
summary: "base".to_string(),
};
let sig_trivial_before = trivial.group_signature();
let trivial = trivial.with_frame_inner_rotation(vec![1, 1, 1]);
assert_eq!(
trivial.frame_inner_rotation.as_ref().map(|g| g.dim),
Some(0)
);
assert_eq!(trivial.group_signature(), sig_trivial_before);
assert_eq!(trivial.summary, "base");
}
fn ivae_precondition_pair(n: usize, d: usize) -> (Array2<f64>, Array2<f64>) {
assert!(n >= 2 * d + 1, "need at least 2d+1 rows");
let mut mean = Array2::<f64>::zeros((n, d));
let mut scale = Array2::<f64>::from_elem((n, d), 1.0);
for r in 0..n {
let t = r as f64 / (n as f64 - 1.0);
for c in 0..d {
let omega = (c + 1) as f64;
mean[[r, c]] = (std::f64::consts::PI * omega * t).sin();
scale[[r, c]] = (0.4 * (std::f64::consts::PI * omega * t).cos()).exp();
}
}
(mean, scale)
}
#[test]
fn conditional_prior_ivae_zero_mean_unit_scale_matches_standard_gaussian() {
let n = 7;
let d = 3;
let (mean, scale) = ivae_precondition_pair(n, d);
let t = mean.clone();
let log_norm: f64 = scale.iter().map(|s| s.ln()).sum();
let pen = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap();
let (v, g) = pen.value_and_grad(t.view());
let expected = log_norm + 0.5 * (n * d) as f64 * (2.0 * std::f64::consts::PI).ln();
assert!(
(v - expected).abs() < 1e-9,
"value {v} vs expected {expected}"
);
for &gv in g.iter() {
assert!(gv.abs() < 1e-12);
}
}
#[test]
fn conditional_prior_ivae_grad_matches_finite_diff() {
let (mean, scale) = ivae_precondition_pair(5, 2);
let mut t = mean.clone();
for r in 0..5 {
t[[r, 0]] += 0.4;
t[[r, 1]] -= 0.3;
}
let pen = ConditionalPriorIvae::new(mean, scale, 1.7).unwrap();
let (_, g) = pen.value_and_grad(t.view());
let h = 1.0e-5;
for i in 0..t.nrows() {
for j in 0..t.ncols() {
let mut tp = t.clone();
let mut tm = t.clone();
tp[[i, j]] += h;
tm[[i, j]] -= h;
let vp = pen.value(tp.view());
let vm = pen.value(tm.view());
let fd = (vp - vm) / (2.0 * h);
assert!((g[[i, j]] - fd).abs() < 1e-5);
}
}
}
#[test]
fn conditional_prior_ivae_rejects_nonpositive_scale() {
let mean = Array2::<f64>::zeros((2, 2));
let mut scale = Array2::<f64>::ones((2, 2));
scale[[0, 0]] = -0.1;
assert!(ConditionalPriorIvae::new(mean, scale, 1.0).is_err());
}
#[test]
fn conditional_prior_ivae_accepts_when_signature_full_rank() {
let (mean, scale) = ivae_precondition_pair(7, 3);
let result = ConditionalPriorIvae::new(mean, scale, 1.0);
assert!(
result.is_ok(),
"full-rank signature should satisfy Khemakhem Theorem 1, got {:?}",
result.err(),
);
}
#[test]
fn conditional_prior_ivae_rejects_trivial_constant_prior() {
let n = 9;
let d = 3;
let mean = Array2::<f64>::from_elem((n, d), 0.25);
let scale = Array2::<f64>::from_elem((n, d), 1.5);
let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
assert!(
err.contains("trivial unconditional") && err.contains("Khemakhem"),
"unexpected error: {err}"
);
}
#[test]
fn conditional_prior_ivae_rejects_too_few_auxiliary_states() {
let (full_mean, full_scale) = ivae_precondition_pair(7, 3);
let mean = full_mean.slice(s![..4, ..]).to_owned();
let scale = full_scale.slice(s![..4, ..]).to_owned();
let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
assert!(
err.contains("2k+1") && err.contains("Khemakhem"),
"unexpected error: {err}"
);
}
#[test]
fn conditional_prior_ivae_rejects_rank_deficient_signature() {
let n = 9;
let d = 3;
let mut mean = Array2::<f64>::zeros((n, d));
let mut scale = Array2::<f64>::from_elem((n, d), 1.0);
for r in 0..n {
let v = ((r as f64) * 0.5).sin();
mean[[r, 0]] = v;
scale[[r, 0]] = v.exp(); }
let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
assert!(
err.contains("numerical rank") && err.contains("Khemakhem"),
"unexpected error: {err}"
);
}
#[test]
fn piecewise_linear_eval_endpoints_and_midpoint() {
let coeffs = array![[0.0_f64, 10.0], [1.0, 20.0], [2.0, 30.0]];
let u = Array1::from(vec![0.0, 0.5, 1.0]);
let out = piecewise_linear_eval(u.view(), coeffs.view(), 0.0, 1.0);
assert!((out[[0, 0]] - 0.0).abs() < 1e-12);
assert!((out[[1, 0]] - 1.0).abs() < 1e-12);
assert!((out[[2, 0]] - 2.0).abs() < 1e-12);
assert!((out[[1, 1]] - 20.0).abs() < 1e-12);
}
#[test]
fn select_weights_picks_max_evidence() {
let rss = array![[10.0, 9.0, 9.5], [8.0, 4.0, 5.0], [9.0, 6.0, 7.0]];
let pen = Array2::<f64>::zeros((3, 3));
let l1 = Array1::from(vec![0.1, 1.0, 10.0]);
let l2 = Array1::from(vec![0.1, 1.0, 10.0]);
let res =
identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 80)
.unwrap();
assert_eq!((res.best_i, res.best_j), (1, 1));
assert!((res.best_lam1 - 1.0).abs() < 1e-12);
assert!((res.best_lam2 - 1.0).abs() < 1e-12);
assert!(res.best_evidence.is_finite());
}
#[test]
fn select_weights_breaks_ties_by_smallest_log_weight_sum() {
let rss = Array2::<f64>::from_elem((2, 2), 4.0);
let pen = Array2::<f64>::from_elem((2, 2), 1.0);
let l1 = Array1::from(vec![0.1, 10.0]);
let l2 = Array1::from(vec![0.1, 10.0]);
let res =
identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 8)
.unwrap();
assert_eq!((res.best_i, res.best_j), (0, 0));
}
#[test]
fn select_weights_rejects_shape_mismatch() {
let rss = Array2::<f64>::zeros((2, 3));
let pen = Array2::<f64>::zeros((2, 2));
let l1 = Array1::from(vec![1.0, 1.0]);
let l2 = Array1::from(vec![1.0, 1.0, 1.0]);
let err =
identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 8)
.unwrap_err();
assert!(err.contains("penalty_grid"));
}
#[test]
fn partial_supervision_procrustes_recovers_rotation_and_orthogonalizes_free() {
let aux = array![
[1.0_f64, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 1.0, 0.0],
[-1.0, 1.0, 2.0],
];
let q = array![[0.0_f64, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]];
let t_sup = aux.dot(&q.t());
let t_free = array![
[1.5_f64, 0.0],
[0.0, 1.0],
[-1.0, 2.0],
[0.3, -0.7],
[2.0, 1.0],
];
let result = partial_supervision_solve(
t_sup.view(),
aux.view(),
t_free.view(),
PartialSupervisionSupMethod::Procrustes,
&[],
PartialSupervisionFreeConstraint::OrthogonalToSup,
)
.expect("procrustes solve should succeed");
for r in 0..aux.nrows() {
for c in 0..aux.ncols() {
assert!(
(result.t_supervised[[r, c]] - aux[[r, c]]).abs() < 1.0e-10,
"sup[{r},{c}] = {} vs aux {}",
result.t_supervised[[r, c]],
aux[[r, c]]
);
}
}
let cross = result.t_free.t().dot(&result.t_supervised);
let frob: f64 = cross.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(frob < 1.0e-8, "cross frobenius = {frob}");
assert!(result.alignment_score > 1.0 - 1.0e-10);
assert!(result.map_r.is_some());
}
#[test]
fn partial_supervision_anchor_pins_exact_anchors_when_full_rank() {
let aux = array![[1.0_f64, 2.0], [-1.0, 0.5], [3.0, -2.0], [0.7, 1.2],];
let t_sup = array![[0.5_f64, 1.0], [-0.5, 0.25], [1.5, -1.0], [0.35, 0.6],];
let t_free = Array2::<f64>::zeros((4, 1));
let result = partial_supervision_solve(
t_sup.view(),
aux.view(),
t_free.view(),
PartialSupervisionSupMethod::Anchor,
&[0, 1, 2],
PartialSupervisionFreeConstraint::None,
)
.expect("anchor solve should succeed");
for &row in &[0, 1, 2] {
for c in 0..2 {
assert!(
(result.t_supervised[[row, c]] - aux[[row, c]]).abs() < 1.0e-9,
"anchor row {row} col {c} not pinned: {} vs {}",
result.t_supervised[[row, c]],
aux[[row, c]]
);
}
}
assert!(result.map_a.is_some() && result.map_b.is_some());
}
#[test]
fn partial_supervision_softl2_selects_a_finite_weight() {
let aux = array![
[1.0_f64, 0.0],
[0.0, 1.0],
[1.0, 1.0],
[-1.0, 1.0],
[0.5, -0.5],
];
let t_sup = array![
[1.0_f64, 0.1],
[0.1, 1.0],
[1.0, 1.0],
[-1.0, 1.0],
[0.5, -0.5],
];
let t_free = array![[0.5_f64], [0.5], [0.5], [0.5], [0.5]];
let result = partial_supervision_solve(
t_sup.view(),
aux.view(),
t_free.view(),
PartialSupervisionSupMethod::SoftL2,
&[],
PartialSupervisionFreeConstraint::OrthogonalToSup,
)
.expect("soft_l2 solve should succeed");
let lam = result.selected_weight.unwrap();
assert!(lam.is_finite() && lam > 0.0, "lam={lam}");
assert!(result.map_a.is_some());
}
}