use super::*;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum GlobalOptimalityVerdict {
CertifiedGlobal { margin: f64 },
Uncertified { margin: f64 },
}
impl GlobalOptimalityVerdict {
pub fn margin(&self) -> f64 {
match self {
Self::CertifiedGlobal { margin } | Self::Uncertified { margin } => *margin,
}
}
pub fn is_certified(&self) -> bool {
matches!(self, Self::CertifiedGlobal { .. })
}
}
pub const SAE_CERT_CURVATURE_CONSTANT: f64 = 1.0;
pub const SAE_CERT_INCOHERENCE_BUDGET: f64 = 0.125;
pub fn curved_dictionary_global_optimality_verdict(
mu_hat: f64,
kappa_max: f64,
activity_floor: f64,
snr_proxy: f64,
k_atoms: usize,
) -> GlobalOptimalityVerdict {
if !mu_hat.is_finite()
|| !kappa_max.is_finite()
|| !activity_floor.is_finite()
|| !snr_proxy.is_finite()
|| k_atoms == 0
{
return GlobalOptimalityVerdict::Uncertified {
margin: f64::NEG_INFINITY,
};
}
let curvature_factor = 1.0 - SAE_CERT_CURVATURE_CONSTANT * kappa_max.max(0.0);
let snr_factor = 1.0 - 1.0 / snr_proxy;
if curvature_factor <= 0.0 || snr_factor <= 0.0 {
return GlobalOptimalityVerdict::Uncertified {
margin: f64::NEG_INFINITY,
};
}
let a = activity_floor.max(0.0);
let budget =
SAE_CERT_INCOHERENCE_BUDGET * a * a * snr_factor * curvature_factor / k_atoms as f64;
let margin = budget - mu_hat;
if margin > 0.0 {
GlobalOptimalityVerdict::CertifiedGlobal { margin }
} else {
GlobalOptimalityVerdict::Uncertified { margin }
}
}
#[derive(Clone, Debug)]
pub struct CertificateInputs {
pub mu_hat: f64,
pub per_atom_kappa_hat: Vec<f64>,
pub per_atom_mean_activity: Vec<f64>,
pub per_atom_peak_activity: Vec<f64>,
pub mean_activity_floor: f64,
pub peak_activity_floor: f64,
pub snr_proxy: f64,
pub dispersion: f64,
pub global_optimality: GlobalOptimalityVerdict,
pub note: String,
}
#[derive(Clone, Debug)]
pub struct SaeManifoldFitDiagnostics {
pub atom_two_lens: crate::inference::atom_lens::AtomTwoLensReport,
pub residual_gauge: crate::identifiability::ResidualGaugeReport,
pub incoherence_report: Option<CertificateInputs>,
pub atom_inference: Vec<crate::identifiability::AtomInferenceReport>,
}
#[derive(Clone, Debug)]
pub struct SaeTrustDiagnostics {
pub atom_trust: Vec<f64>,
pub atoms: Vec<SaeAtomTrustDiagnostics>,
}
#[derive(Clone, Debug)]
pub struct SaeAtomTrustDiagnostics {
pub trust_score: f64,
pub sigma_min_tangent: f64,
pub sigma_max_tangent: f64,
pub tangent_condition_score: f64,
pub coverage: f64,
pub activation_frequency: f64,
pub untyped: bool,
pub active_token_count: usize,
}
pub fn dictionary_incoherence_report(term: &SaeManifoldTerm) -> Result<CertificateInputs, String> {
let dispersion = term.certificate_dispersion.ok_or_else(|| {
"dictionary_incoherence_report: fitted reconstruction dispersion is unavailable".to_string()
})?;
dictionary_incoherence_report_with_dispersion(term, dispersion)
}
pub fn dictionary_incoherence_report_with_dispersion(
term: &SaeManifoldTerm,
dispersion: f64,
) -> Result<CertificateInputs, String> {
if !dispersion.is_finite() || dispersion <= 0.0 {
return Err(format!(
"dictionary_incoherence_report: dispersion must be finite and positive, got {dispersion}"
));
}
let mu_hat = dictionary_frame_incoherence(term)?;
let per_atom_kappa_hat = term
.atoms
.iter()
.enumerate()
.map(|(atom_idx, _)| atom_curvature_bound(term, atom_idx))
.collect::<Result<Vec<_>, _>>()?;
let assignments = term.assignment.assignments();
let n = assignments.nrows();
let k_atoms = assignments.ncols();
let mut per_atom_mean_activity = Vec::with_capacity(k_atoms);
let mut per_atom_peak_activity = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut sum = 0.0_f64;
let mut peak = 0.0_f64;
for row in 0..n {
let value = assignments[[row, atom_idx]];
sum += value;
peak = peak.max(value);
}
per_atom_mean_activity.push(if n > 0 { sum / n as f64 } else { 0.0 });
per_atom_peak_activity.push(peak);
}
let mean_activity_floor = per_atom_mean_activity
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
let peak_activity_floor = per_atom_peak_activity
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
let fitted = term.fitted();
let signal_power = if fitted.is_empty() {
0.0
} else {
fitted.iter().map(|v| v * v).sum::<f64>() / fitted.len() as f64
};
let mean_activity_floor = if mean_activity_floor.is_finite() {
mean_activity_floor
} else {
0.0
};
let peak_activity_floor = if peak_activity_floor.is_finite() {
peak_activity_floor
} else {
0.0
};
let snr_proxy = signal_power / dispersion;
let kappa_max = per_atom_kappa_hat.iter().copied().fold(0.0_f64, f64::max);
let global_optimality = curved_dictionary_global_optimality_verdict(
mu_hat,
kappa_max,
peak_activity_floor,
snr_proxy,
k_atoms,
);
let note = match global_optimality {
GlobalOptimalityVerdict::CertifiedGlobal { margin } => format!(
"global optimality CERTIFIED up to the residual gauge group \
(margin {margin:.3e}); μ̂={mu_hat:.3e}, κ̂_max={kappa_max:.3e}, \
a_floor={peak_activity_floor:.3e}, SNR={snr_proxy:.3e}"
),
GlobalOptimalityVerdict::Uncertified { margin } => format!(
"global optimality UNCERTIFIED (margin {margin:.3e}; cannot decide — \
multistart/homotopy genuinely needed); μ̂={mu_hat:.3e}, \
κ̂_max={kappa_max:.3e}, a_floor={peak_activity_floor:.3e}, \
SNR={snr_proxy:.3e}"
),
};
Ok(CertificateInputs {
mu_hat,
per_atom_kappa_hat,
per_atom_mean_activity,
per_atom_peak_activity,
mean_activity_floor,
peak_activity_floor,
snr_proxy,
dispersion,
global_optimality,
note,
})
}
pub(crate) fn dictionary_frame_incoherence(term: &SaeManifoldTerm) -> Result<f64, String> {
let frames = (0..term.k_atoms())
.map(|atom_idx| certificate_output_frame(term, atom_idx))
.collect::<Result<Vec<_>, _>>()?;
let mut mu = 0.0_f64;
for j in 0..frames.len() {
for k in (j + 1)..frames.len() {
if frames[j].ncols() == 0 || frames[k].ncols() == 0 {
continue;
}
let overlap = fast_atb(&frames[j], &frames[k]);
let (_u, s, _vt) = overlap.svd(false, false).map_err(|e| {
format!("dictionary_frame_incoherence: SVD failed for atom pair ({j}, {k}): {e}")
})?;
let pair = s.iter().copied().fold(0.0_f64, f64::max);
mu = mu.max(pair);
}
}
Ok(mu)
}
pub(crate) fn certificate_output_frame(
term: &SaeManifoldTerm,
atom_idx: usize,
) -> Result<Array2<f64>, String> {
let atom = &term.atoms[atom_idx];
if atom.decoder_frame.is_some() {
return Ok(term.frame_output_matrix(atom_idx));
}
let p = atom.output_dim();
let (_u, s, vt_opt) = atom
.decoder_coefficients
.svd(false, true)
.map_err(|e| format!("certificate_output_frame: SVD failed for atom {atom_idx}: {e}"))?;
let max_sv = s.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok(Array2::<f64>::zeros((p, 0)));
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
let rank = s.iter().filter(|&&value| value > tol).count();
let vt = vt_opt.ok_or_else(|| {
format!("certificate_output_frame: SVD returned no right factor for atom {atom_idx}")
})?;
let rank = rank.min(vt.nrows());
let mut frame = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
Ok(frame)
}
pub(crate) fn atom_curvature_bound(term: &SaeManifoldTerm, atom_idx: usize) -> Result<f64, String> {
let atom = &term.atoms[atom_idx];
let coords = term.assignment.coords[atom_idx].as_matrix();
let second = atom
.basis_evaluator
.as_ref()
.and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
.ok_or_else(|| {
format!(
"atom_curvature_bound: atom {atom_idx} has no analytic second jet; cannot compute kappa_hat"
)
})?
.map_err(|e| format!("atom_curvature_bound: atom {atom_idx} second jet failed: {e}"))?;
atom_curvature_bound_with_decoder(
atom,
atom_idx,
second.view(),
atom.decoder_coefficients.view(),
)
}
pub(crate) fn atom_curvature_bound_with_decoder(
atom: &SaeManifoldAtom,
atom_idx: usize,
second: ArrayView4<'_, f64>,
decoder: ArrayView2<'_, f64>,
) -> Result<f64, String> {
let n = atom.n_obs();
let m = atom.basis_size();
let d = atom.latent_dim;
let p = atom.output_dim();
if second.dim() != (n, m, d, d) {
return Err(format!(
"atom_curvature_bound: atom {atom_idx} second jet shape {:?} must be ({n}, {m}, {d}, {d})",
second.dim()
));
}
if decoder.dim() != (m, p) {
return Err(format!(
"atom_curvature_bound: atom {atom_idx} decoder shape {:?} must be ({m}, {p})",
decoder.dim()
));
}
let mut max_kappa = 0.0_f64;
let mut tangent = Array2::<f64>::zeros((p, d));
let mut second_vec = vec![0.0_f64; p];
for row in 0..n {
tangent.fill(0.0);
for basis_col in 0..m {
for axis in 0..d {
let dphi = atom.basis_jacobian[[row, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
tangent[[out, axis]] += dphi * decoder[[basis_col, out]];
}
}
}
let tangent_rank = tangent_frame_rank(tangent.view())?;
let tangent_scale = tangent_rank.0;
let q = tangent_rank.1;
for axis_a in 0..d {
for axis_b in 0..d {
second_vec.fill(0.0);
for basis_col in 0..m {
let h = second[[row, basis_col, axis_a, axis_b]];
if h == 0.0 {
continue;
}
for out in 0..p {
second_vec[out] += h * decoder[[basis_col, out]];
}
}
let perp_norm = projected_perp_norm(&second_vec, q.view());
if tangent_scale > 0.0 {
max_kappa = max_kappa.max(perp_norm / tangent_scale);
} else if perp_norm > 0.0 {
return Ok(f64::INFINITY);
}
}
}
}
Ok(max_kappa)
}
pub(crate) fn tangent_frame_rank(
tangent: ArrayView2<'_, f64>,
) -> Result<(f64, Array2<f64>), String> {
let p = tangent.nrows();
let d = tangent.ncols();
if p == 0 || d == 0 {
return Ok((0.0, Array2::<f64>::zeros((p, 0))));
}
let (u_opt, s, _vt) = tangent
.to_owned()
.svd(true, false)
.map_err(|e| format!("tangent_frame_rank: SVD failed: {e}"))?;
let max_sv = s.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok((0.0, Array2::<f64>::zeros((p, 0))));
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
let rank = s.iter().filter(|&&value| value > tol).count();
let min_positive = s
.iter()
.copied()
.filter(|value| *value > tol)
.fold(f64::INFINITY, f64::min);
let u = u_opt.ok_or_else(|| "tangent_frame_rank: SVD returned no U".to_string())?;
let rank = rank.min(u.ncols());
let mut q = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
q[[row, col]] = u[[row, col]];
}
}
Ok((min_positive * min_positive, q))
}
pub(crate) fn projected_perp_norm(vector: &[f64], tangent_frame: ArrayView2<'_, f64>) -> f64 {
let mut residual = vector.to_vec();
for axis in 0..tangent_frame.ncols() {
let mut coeff = 0.0_f64;
for out in 0..tangent_frame.nrows() {
coeff += tangent_frame[[out, axis]] * vector[out];
}
if coeff == 0.0 {
continue;
}
for out in 0..tangent_frame.nrows() {
residual[out] -= coeff * tangent_frame[[out, axis]];
}
}
residual.iter().map(|v| v * v).sum::<f64>().sqrt()
}