use std::sync::Arc;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::inference::residual_factor::{ResidualFactorInput, StructuredResidualModel};
use crate::inference::structure_evidence::{ClaimKind, StructureLedger};
use crate::linalg::faer_ndarray::{FaerCholesky, FaerEigh};
use crate::solver::structure_search::{
CollapseAction, MoveBudget, MoveProposal, SearchLedger, SearchOutcome, StructureMove, search,
};
use crate::solver::{
AutoTopologyKind, TopologyAutoFitEvidence, TopologyAutoSelector, TopologyScoreScale,
select_topology_with_fit,
};
use crate::terms::latent::{LatentIdMode, LatentManifold};
use crate::terms::sae::atom_codes::SparseAtomCodes;
use crate::terms::sae::basis::{
CylinderHarmonicEvaluator, EuclideanPatchEvaluator, PeriodicHarmonicEvaluator,
SaeBasisSecondJet, SphereChartEvaluator, TorusHarmonicEvaluator,
};
use crate::terms::sae::manifold::{
SaeAtomBasisKind, SaeManifoldAtom, SaeManifoldRho, SaeManifoldTerm,
};
use crate::warm_start::Fingerprinter;
const ACTIVE_SUPPORT_REL_FLOOR: f64 = 0.5;
const ARD_DIVERGENCE_LOG_PRECISION: f64 = 12.0;
const FUSION_DEPENDENCE_FLOOR: f64 = 0.6;
const ABSORPTION_ASYMMETRY_FLOOR: f64 = 0.5;
#[derive(Clone, Copy, Debug)]
pub struct HarvestParams {
pub max_fusions: usize,
pub max_fissions: usize,
pub max_births: usize,
}
impl Default for HarvestParams {
fn default() -> Self {
Self {
max_fusions: 4,
max_fissions: 4,
max_births: 4,
}
}
}
pub fn sparse_codes_from_term(term: &SaeManifoldTerm) -> SparseAtomCodes {
let assignments = term.assignment.assignments();
let n = assignments.nrows();
let k = assignments.ncols();
let floor = if k == 0 {
0.0
} else {
ACTIVE_SUPPORT_REL_FLOOR / k as f64
};
let mut codes = SparseAtomCodes::empty(n, k);
for row in 0..n {
for atom in 0..k {
let mass = assignments[[row, atom]];
if mass > floor {
codes.row_mut(row).assign(atom, mass);
}
}
}
codes
}
fn per_atom_max_mass(term: &SaeManifoldTerm) -> Array1<f64> {
let assignments = term.assignment.assignments();
let k = assignments.ncols();
let mut out = Array1::<f64>::zeros(k);
for atom in 0..k {
let mut max = 0.0_f64;
for &m in assignments.column(atom).iter() {
if m > max {
max = m;
}
}
out[atom] = max;
}
out
}
fn per_atom_ard_divergence(rho: &SaeManifoldRho, atom: usize) -> f64 {
rho.log_ard
.get(atom)
.and_then(|axes| axes.iter().copied().reduce(f64::max))
.unwrap_or(f64::NEG_INFINITY)
}
fn post_move_structure_hash(term: &SaeManifoldTerm, mv: &StructureMove) -> u64 {
let mut fp = Fingerprinter::new();
fp.write_str("sae_structure_move");
match mv {
StructureMove::Birth { candidate } => {
fp.write_str("birth");
fp.write_usize(*candidate);
}
StructureMove::Death { atom } => {
fp.write_str("death");
fp.write_usize(*atom);
}
StructureMove::Fission { atom } => {
fp.write_str("fission");
fp.write_usize(*atom);
}
StructureMove::Fusion { a, b } => {
fp.write_str("fusion");
fp.write_usize((*a).min(*b));
fp.write_usize((*a).max(*b));
}
}
fp.write_usize(term.atoms.len());
for atom in &term.atoms {
fp.write_str(basis_kind_tag(&atom.basis_kind));
fp.write_usize(atom.latent_dim);
}
let digest = fp.finalize();
let bytes = digest.as_bytes();
u64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
])
}
fn basis_kind_tag(kind: &SaeAtomBasisKind) -> &str {
match kind {
SaeAtomBasisKind::Duchon => "duchon",
SaeAtomBasisKind::Periodic => "periodic",
SaeAtomBasisKind::Sphere => "sphere",
SaeAtomBasisKind::Torus => "torus",
SaeAtomBasisKind::Linear => "linear",
SaeAtomBasisKind::EuclideanPatch => "euclidean_patch",
SaeAtomBasisKind::Poincare => "poincare",
SaeAtomBasisKind::Cylinder => "cylinder",
SaeAtomBasisKind::Precomputed(_) => "precomputed",
}
}
fn proposal(term: &SaeManifoldTerm, mv: StructureMove, trigger: f64) -> MoveProposal {
let structure_hash = post_move_structure_hash(term, &mv);
let claim = match &mv {
StructureMove::Birth { candidate } => ClaimKind::AtomExists {
atom: term.k_atoms() + *candidate,
},
StructureMove::Death { atom } => ClaimKind::AtomExists { atom: *atom },
StructureMove::Fusion { a, b } => ClaimKind::BindingEdge { a: *a, b: *b },
StructureMove::Fission { atom } => ClaimKind::Custom {
label: format!("fission:{atom}"),
},
};
MoveProposal {
mv,
trigger,
structure_hash,
claim,
}
}
pub fn harvest_move_proposals(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
residuals: ArrayView2<'_, f64>,
params: &HarvestParams,
) -> Result<HarvestReport, String> {
let k = term.k_atoms();
let mut proposals: Vec<MoveProposal> = Vec::new();
let max_mass = per_atom_max_mass(term);
let terminal: std::collections::HashSet<usize> = term
.collapse_events()
.iter()
.filter(|e| matches!(e.action, CollapseAction::Terminal))
.map(|e| e.atom)
.collect();
for atom in 0..k {
let ard = per_atom_ard_divergence(rho, atom);
let diverged = ard >= ARD_DIVERGENCE_LOG_PRECISION;
let collapsed = terminal.contains(&atom);
if diverged || collapsed {
let trigger = if collapsed { f64::MAX / 2.0 } else { ard };
let trigger = trigger - max_mass[atom].min(1.0) * 1e-9;
proposals.push(proposal(term, StructureMove::Death { atom }, trigger));
}
}
let codes = sparse_codes_from_term(term);
let mut fusion_pairs: Vec<(usize, usize, f64)> = Vec::new();
for a in 0..k {
for b in (a + 1)..k {
let stats = codes.coactivation(a, b);
let dep = stats.dependence();
if dep >= FUSION_DEPENDENCE_FLOOR {
fusion_pairs.push((a, b, dep));
}
}
}
fusion_pairs.sort_by(|x, y| y.2.total_cmp(&x.2).then(x.0.cmp(&y.0)).then(x.1.cmp(&y.1)));
for &(a, b, dep) in fusion_pairs.iter().take(params.max_fusions) {
proposals.push(proposal(term, StructureMove::Fusion { a, b }, dep));
}
let mut fission_atoms: Vec<(usize, f64)> = Vec::new();
for a in 0..k {
for b in (a + 1)..k {
let stats = codes.coactivation(a, b);
let asym = stats.absorption_asymmetry();
if asym >= ABSORPTION_ASYMMETRY_FLOOR {
let parent = if stats.p_a_given_b >= stats.p_b_given_a {
a
} else {
b
};
let significance = (1.0 - asym).max(0.0);
fission_atoms.push((parent, significance));
}
}
}
fission_atoms.sort_by(|x, y| x.1.total_cmp(&y.1).then(x.0.cmp(&y.0)));
fission_atoms.dedup_by_key(|(atom, _)| *atom);
let fission_carve_skipped = !fission_atoms.is_empty();
for &(atom, significance) in fission_atoms.iter().take(params.max_fissions) {
proposals.push(proposal(
term,
StructureMove::Fission { atom },
significance,
));
}
let n = residuals.nrows();
let assignments = term.assignment.assignments();
let activity: Array1<f64> = (0..n).map(|r| assignments.row(r).sum()).collect();
let mut births_proposed = 0usize;
let mut birth_skipped_reason: Option<String> = None;
if params.max_births > 0 && n > 0 && residuals.ncols() > 0 {
let p = residuals.ncols();
let max_rank = params.max_births.min(p.saturating_sub(1));
match StructuredResidualModel::fit(ResidualFactorInput {
residuals,
activity: activity.view(),
max_factor_rank: max_rank,
}) {
Ok(model) => {
let factor = model.factor();
let r = model.factor_rank();
let mut dirs: Vec<(usize, f64)> = (0..r)
.map(|j| {
let mass = factor.column(j).iter().map(|v| v * v).sum::<f64>().sqrt();
(j, mass)
})
.collect();
dirs.sort_by(|x, y| y.1.total_cmp(&x.1).then(x.0.cmp(&y.0)));
for &(candidate, mass) in dirs.iter().take(params.max_births) {
proposals.push(proposal(term, StructureMove::Birth { candidate }, mass));
births_proposed += 1;
}
}
Err(e) => {
birth_skipped_reason = Some(e);
}
}
} else if params.max_births > 0 {
birth_skipped_reason =
Some("residuals empty or single-channel; no factor subspace to mine".to_string());
}
Ok(HarvestReport {
proposals,
fission_carve_skipped,
births_proposed,
birth_skipped_reason,
})
}
#[derive(Clone, Debug)]
pub struct HarvestReport {
pub proposals: Vec<MoveProposal>,
pub fission_carve_skipped: bool,
pub births_proposed: usize,
pub birth_skipped_reason: Option<String>,
}
pub fn apply_structure_move(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
mv: &StructureMove,
birth_decoders: &[Array2<f64>],
) -> Result<(SaeManifoldTerm, SaeManifoldRho), String> {
match mv {
StructureMove::Death { atom } => {
let mut child = term.clone();
demote_atom(&mut child, *atom)?;
Ok((child, rho.clone()))
}
StructureMove::Fusion { a, b } => {
let mut child = term.clone();
fold_atom_into(&mut child, *a, *b)?;
Ok((child, rho.clone()))
}
StructureMove::Fission { atom } => {
let (child, child_rho) = duplicate_atom(term, rho, *atom)?;
Ok((child, child_rho))
}
StructureMove::Birth { candidate } => {
let decoder = birth_decoders.get(*candidate).ok_or_else(|| {
format!(
"apply_structure_move: birth candidate {candidate} out of range \
({} residual-factor decoders)",
birth_decoders.len()
)
})?;
born_atom(term, rho, decoder.view())
}
}
}
const DEMOTE_LOGIT: f64 = -40.0;
fn demote_atom(term: &mut SaeManifoldTerm, atom: usize) -> Result<(), String> {
let k = term.k_atoms();
if atom >= k {
return Err(format!("demote_atom: atom {atom} out of range (K={k})"));
}
for row in 0..term.assignment.logits.nrows() {
term.assignment.logits[[row, atom]] = DEMOTE_LOGIT;
}
Ok(())
}
fn fold_atom_into(term: &mut SaeManifoldTerm, a: usize, b: usize) -> Result<(), String> {
let k = term.k_atoms();
if a >= k || b >= k {
return Err(format!(
"fold_atom_into: atoms ({a},{b}) out of range (K={k})"
));
}
if a == b {
return Err("fold_atom_into: cannot fuse an atom with itself".to_string());
}
for row in 0..term.assignment.logits.nrows() {
let la = term.assignment.logits[[row, a]];
let lb = term.assignment.logits[[row, b]];
term.assignment.logits[[row, a]] = la.max(lb);
}
demote_atom(term, b)?;
Ok(())
}
fn duplicate_atom(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
parent: usize,
) -> Result<(SaeManifoldTerm, SaeManifoldRho), String> {
let k = term.k_atoms();
if parent >= k {
return Err(format!(
"duplicate_atom: parent {parent} out of range (K={k})"
));
}
let mut atoms = term.atoms.clone();
let child_atom = term.atoms[parent].clone();
atoms.push(child_atom);
let n = term.assignment.logits.nrows();
let mut logits = Array2::<f64>::zeros((n, k + 1));
let split = std::f64::consts::LN_2;
for row in 0..n {
for col in 0..k {
let mut v = term.assignment.logits[[row, col]];
if col == parent {
v -= split;
}
logits[[row, col]] = v;
}
logits[[row, k]] = term.assignment.logits[[row, parent]] - split;
}
let mut coords = term.assignment.coords.clone();
coords.push(term.assignment.coords[parent].clone());
let assignment = crate::terms::sae::manifold::SaeAssignment::with_mode(
logits,
coords,
term.assignment.mode,
)?;
let child = SaeManifoldTerm::new(atoms, assignment)?;
let mut child_rho = rho.clone();
if parent < child_rho.log_ard.len() {
let inherited = child_rho.log_ard[parent].clone();
child_rho.log_ard.push(inherited);
} else {
child_rho.log_ard.push(Array1::<f64>::zeros(0));
}
Ok((child, child_rho))
}
const TOPOLOGY_FIT_RIDGE: f64 = 1e-6;
#[derive(Clone)]
struct TopologyRaceFit {
evaluator: Arc<dyn SaeBasisSecondJet>,
basis_kind: SaeAtomBasisKind,
manifold: LatentManifold,
latent_dim: usize,
coords: Array2<f64>,
phi: Array2<f64>,
jet: ndarray::Array3<f64>,
decoder: Array2<f64>,
penalty: Array2<f64>,
}
struct TopologyCandidateSpec {
kind: AutoTopologyKind,
basis_kind: SaeAtomBasisKind,
manifold: LatentManifold,
latent_dim: usize,
evaluator: Arc<dyn SaeBasisSecondJet>,
coords: Array2<f64>,
}
fn topology_candidates_for_dim(
coords: ArrayView2<'_, f64>,
d_k: usize,
) -> Result<Vec<TopologyCandidateSpec>, String> {
let n = coords.nrows();
let d_seed = coords.ncols();
if d_k == 0 {
return Ok(Vec::new());
}
let coords_d = |d: usize| -> Array2<f64> {
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
for col in 0..d {
let src = col.min(d_seed.saturating_sub(1));
out[[row, col]] = coords[[row, src]];
}
}
out
};
let mut specs: Vec<TopologyCandidateSpec> = Vec::new();
match d_k {
1 => {
let n_harmonics = (2 * d_k + 1).max(3) | 1; specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Circle,
basis_kind: SaeAtomBasisKind::Periodic,
manifold: LatentManifold::Circle { period: 1.0 },
latent_dim: 1,
evaluator: Arc::new(PeriodicHarmonicEvaluator::new(n_harmonics)?),
coords: coords_d(1),
});
specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Euclidean,
basis_kind: SaeAtomBasisKind::EuclideanPatch,
manifold: LatentManifold::Euclidean,
latent_dim: 1,
evaluator: Arc::new(EuclideanPatchEvaluator::new(1, 3)?),
coords: coords_d(1),
});
}
2 => {
specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Torus,
basis_kind: SaeAtomBasisKind::Torus,
manifold: LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
]),
latent_dim: 2,
evaluator: Arc::new(TorusHarmonicEvaluator::new(2, 2)?),
coords: coords_d(2),
});
specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Sphere,
basis_kind: SaeAtomBasisKind::Sphere,
manifold: LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]),
latent_dim: 2,
evaluator: Arc::new(SphereChartEvaluator),
coords: coords_d(2),
});
specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Euclidean,
basis_kind: SaeAtomBasisKind::EuclideanPatch,
manifold: LatentManifold::Euclidean,
latent_dim: 2,
evaluator: Arc::new(EuclideanPatchEvaluator::new(2, 2)?),
coords: coords_d(2),
});
specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Cylinder,
basis_kind: SaeAtomBasisKind::Cylinder,
manifold: LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Euclidean,
]),
latent_dim: 2,
evaluator: Arc::new(CylinderHarmonicEvaluator::new(2, 2)?),
coords: coords_d(2),
});
}
_ => {
specs.push(TopologyCandidateSpec {
kind: AutoTopologyKind::Euclidean,
basis_kind: SaeAtomBasisKind::EuclideanPatch,
manifold: LatentManifold::Euclidean,
latent_dim: d_k,
evaluator: Arc::new(EuclideanPatchEvaluator::new(d_k, 2)?),
coords: coords_d(d_k),
});
}
}
Ok(specs)
}
fn fit_topology_candidate(
spec: &TopologyCandidateSpec,
target: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<TopologyAutoFitEvidence<TopologyRaceFit>, String> {
let n = target.nrows();
let p = target.ncols();
let (phi, jet) = spec.evaluator.evaluate(spec.coords.view())?;
let m = phi.ncols();
if phi.nrows() != n {
return Err(format!(
"fit_topology_candidate: basis rows {} != target rows {n}",
phi.nrows()
));
}
if weights.len() != n {
return Err(format!(
"fit_topology_candidate: weights length {} != target rows {n}",
weights.len()
));
}
let mut gram = Array2::<f64>::zeros((m, m)); let mut rhs = Array2::<f64>::zeros((m, p)); let mut w_sum = 0.0_f64;
for row in 0..n {
let w = weights[row];
if !(w.is_finite() && w >= 0.0) {
return Err("fit_topology_candidate: weights must be finite and non-negative".into());
}
w_sum += w;
if w == 0.0 {
continue;
}
for a in 0..m {
let pa = phi[[row, a]];
let wpa = w * pa;
for b in a..m {
gram[[a, b]] += wpa * phi[[row, b]];
}
for out in 0..p {
rhs[[a, out]] += wpa * target[[row, out]];
}
}
}
for a in 0..m {
for b in (a + 1)..m {
gram[[b, a]] = gram[[a, b]];
}
}
if !(w_sum > 0.0 && w_sum.is_finite()) {
return Err("fit_topology_candidate: degenerate (zero-mass) birth target".into());
}
let second_jet = spec.evaluator.second_jet(spec.coords.view())?; let d = spec.latent_dim;
let mut s_raw = Array2::<f64>::zeros((m, m));
for row in 0..n {
for a in 0..d {
for c in 0..d {
for mu in 0..m {
let hmu = second_jet[[row, mu, a, c]];
if hmu == 0.0 {
continue;
}
for nu in mu..m {
s_raw[[mu, nu]] += hmu * second_jet[[row, nu, a, c]];
}
}
}
}
}
for mu in 0..m {
for nu in (mu + 1)..m {
s_raw[[nu, mu]] = s_raw[[mu, nu]];
}
}
let mut h = gram.clone();
for a in 0..m {
for b in 0..m {
h[[a, b]] += s_raw[[a, b]];
}
h[[a, a]] += TOPOLOGY_FIT_RIDGE;
}
let h_chol = h
.cholesky(Side::Lower)
.map_err(|e| format!("fit_topology_candidate: penalized Hessian Cholesky: {e:?}"))?;
let decoder = h_chol.solve_mat(&rhs);
let mut sse = 0.0_f64;
for row in 0..n {
let w = weights[row];
if w == 0.0 {
continue;
}
for out in 0..p {
let mut pred = 0.0_f64;
for a in 0..m {
pred += phi[[row, a]] * decoder[[a, out]];
}
let r = target[[row, out]] - pred;
sse += w * r * r;
}
}
let (h_evals, _h_evecs) = h
.eigh(Side::Lower)
.map_err(|e| format!("fit_topology_candidate: Hessian eigendecomposition: {e:?}"))?;
let mut log_det_h = 0.0_f64;
for &ev in &h_evals {
if !(ev > 0.0) {
return Err("fit_topology_candidate: penalized Hessian not positive definite".into());
}
log_det_h += ev.ln();
}
let raw_reml = 0.5 * sse + 0.5 * log_det_h;
let (s_evals, s_evecs) = s_raw
.eigh(Side::Lower)
.map_err(|e| format!("fit_topology_candidate: penalty eigendecomposition: {e:?}"))?;
let s_max = s_evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
let s_tol = 1e-9 * (1.0 + s_max);
let null_cols: Vec<usize> = s_evals
.iter()
.enumerate()
.filter(|&(_, &v)| v <= s_tol)
.map(|(i, _)| i)
.collect();
let null_dim = null_cols.len();
let null_space_logdet = if null_dim == 0 {
None
} else {
let mut h_null = Array2::<f64>::zeros((null_dim, null_dim));
for (ii, &ci) in null_cols.iter().enumerate() {
let mut hu = Array1::<f64>::zeros(m);
for a in 0..m {
let mut acc = 0.0_f64;
for b in 0..m {
acc += h[[a, b]] * s_evecs[[b, ci]];
}
hu[a] = acc;
}
for (jj, &cj) in null_cols.iter().enumerate() {
let mut acc = 0.0_f64;
for a in 0..m {
acc += s_evecs[[a, cj]] * hu[a];
}
h_null[[ii, jj]] = acc;
}
}
let (hn_evals, _) = h_null
.eigh(Side::Lower)
.map_err(|e| format!("fit_topology_candidate: null-space Hessian eigh: {e:?}"))?;
let mut ld = 0.0_f64;
for &ev in &hn_evals {
if !(ev > 0.0) {
return Err(
"fit_topology_candidate: null-space Hessian not positive definite".into(),
);
}
ld += ev.ln();
}
Some(ld)
};
let h_inv_gram = h_chol.solve_mat(&gram); let mut effective_dim = 0.0_f64;
for a in 0..m {
effective_dim += h_inv_gram[[a, a]];
}
if !(effective_dim.is_finite() && effective_dim > 0.0) {
effective_dim = 1.0;
}
if !raw_reml.is_finite() {
return Err("fit_topology_candidate: non-finite raw REML".into());
}
let penalty = s_raw.clone();
Ok(TopologyAutoFitEvidence {
topology_name: spec.kind.as_str().to_string(),
raw_reml,
null_dim: null_dim as f64,
null_space_logdet,
effective_dim,
n_obs: n,
fit_handle: TopologyRaceFit {
evaluator: spec.evaluator.clone(),
basis_kind: spec.basis_kind.clone(),
manifold: spec.manifold.clone(),
latent_dim: spec.latent_dim,
coords: spec.coords.clone(),
phi,
jet,
decoder,
penalty,
},
})
}
fn race_birth_topology(
coords: ArrayView2<'_, f64>,
target: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
d_k: usize,
) -> Result<Option<TopologyRaceFit>, String> {
let specs = topology_candidates_for_dim(coords, d_k)?;
if specs.is_empty() {
return Ok(None);
}
let selector = TopologyAutoSelector {
candidates: specs.iter().map(|s| s.kind).collect(),
score_scale: TopologyScoreScale::PerEffectiveDim,
latent: None,
};
let mut by_kind: std::collections::HashMap<AutoTopologyKind, &TopologyCandidateSpec> =
std::collections::HashMap::with_capacity(specs.len() + 1);
for spec in &specs {
by_kind.insert(spec.kind, spec);
}
if !by_kind.contains_key(&AutoTopologyKind::ConstantCurvature) {
if let Some(sphere) = specs.iter().find(|s| s.kind == AutoTopologyKind::Sphere) {
by_kind.insert(AutoTopologyKind::ConstantCurvature, sphere);
} else if let Some(euclid) = specs.iter().find(|s| s.kind == AutoTopologyKind::Euclidean) {
by_kind.insert(AutoTopologyKind::ConstantCurvature, euclid);
}
}
let ranked = select_topology_with_fit(&selector, |kind| {
let spec = by_kind.get(&kind).ok_or_else(|| {
format!(
"race_birth_topology: no realized candidate for fused topology {:?}",
kind.as_str()
)
})?;
fit_topology_candidate(spec, target, weights)
})?;
let winner = ranked
.winner()
.ok_or_else(|| "race_birth_topology: empty ranking".to_string())?;
Ok(Some(winner.fit_handle.clone()))
}
const BIRTH_SEED_LOGIT: f64 = -4.0;
fn born_atom(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
factor_dir: ArrayView2<'_, f64>,
) -> Result<(SaeManifoldTerm, SaeManifoldRho), String> {
let k = term.k_atoms();
if term.atoms.is_empty() {
return Err(
"born_atom: cannot birth from an empty dictionary (no template atom to seed the \
coordinate block / basis from)"
.to_string(),
);
}
let template = &term.atoms[0];
let m = template.basis_size();
let p = term.output_dim();
if factor_dir.dim() != (m, p) {
return Err(format!(
"born_atom: residual-factor decoder must be ({m}, {p}); got {:?}",
factor_dir.dim()
));
}
let mut atoms = term.atoms.clone();
let template_coords = term.assignment.coords[0].as_matrix();
let birth_target = template.basis_values.dot(&factor_dir); let weights = Array1::<f64>::ones(birth_target.nrows());
let raced = race_birth_topology(
template_coords.view(),
birth_target.view(),
weights.view(),
template.latent_dim,
)?;
let (born, born_coord_block) = match raced {
Some(fit) => {
let mut atom = SaeManifoldAtom::new(
format!("atom_born_{k}"),
fit.basis_kind.clone(),
fit.latent_dim,
fit.phi.clone(),
fit.jet.clone(),
fit.decoder.clone(),
fit.penalty.clone(),
)?
.with_basis_second_jet(fit.evaluator.clone());
atom.refresh_intrinsic_smooth_penalty();
let coord_block = crate::terms::latent::LatentCoordValues::from_matrix_with_manifold(
fit.coords.view(),
LatentIdMode::None,
fit.manifold.clone(),
);
(atom, coord_block)
}
None => {
let mut atom = template.clone();
atom.decoder_coefficients = factor_dir.to_owned();
atom.refresh_intrinsic_smooth_penalty();
(atom, term.assignment.coords[0].clone())
}
};
atoms.push(born);
let n = term.assignment.logits.nrows();
let mut logits = Array2::<f64>::zeros((n, k + 1));
for row in 0..n {
for col in 0..k {
logits[[row, col]] = term.assignment.logits[[row, col]];
}
logits[[row, k]] = BIRTH_SEED_LOGIT;
}
let mut coords = term.assignment.coords.clone();
coords.push(born_coord_block);
let assignment = crate::terms::sae::manifold::SaeAssignment::with_mode(
logits,
coords,
term.assignment.mode,
)?;
let child = SaeManifoldTerm::new(atoms, assignment)?;
let mut child_rho = rho.clone();
let inherited = child_rho
.log_ard
.first()
.cloned()
.unwrap_or_else(|| Array1::<f64>::zeros(0));
child_rho.log_ard.push(inherited);
Ok((child, child_rho))
}
#[derive(Clone, Debug)]
pub struct RowBlockShard {
pub target: std::sync::Arc<Array2<f64>>,
pub rows: Vec<usize>,
}
#[derive(Clone, Debug)]
pub struct EstimationEvalSplit {
pub estimation_rows: Vec<usize>,
pub shards: Vec<RowBlockShard>,
}
const ESTIMATION_FRACTION: f64 = 0.6;
pub fn estimation_eval_split(target: ArrayView2<'_, f64>, n_shards: usize) -> EstimationEvalSplit {
let n = target.nrows();
if n == 0 {
return EstimationEvalSplit {
estimation_rows: Vec::new(),
shards: Vec::new(),
};
}
let shared = std::sync::Arc::new(target.to_owned());
let n_est =
((n as f64 * ESTIMATION_FRACTION).round() as usize).clamp(1, n.saturating_sub(1).max(1));
let estimation_rows: Vec<usize> = (0..n_est).collect();
let eval_rows: Vec<usize> = (n_est..n).collect();
let n_eval = eval_rows.len();
let n_shards = n_shards.min(n_eval).max(usize::from(n_eval > 0));
let mut shards = Vec::new();
if n_eval > 0 && n_shards > 0 {
let base = n_eval / n_shards;
let rem = n_eval % n_shards;
let mut cursor = 0usize;
for s in 0..n_shards {
let len = base + usize::from(s < rem);
let rows: Vec<usize> = eval_rows[cursor..cursor + len].to_vec();
shards.push(RowBlockShard {
target: shared.clone(),
rows,
});
cursor += len;
}
}
EstimationEvalSplit {
estimation_rows,
shards,
}
}
pub struct StructureSearchResult {
pub term: SaeManifoldTerm,
pub rho: SaeManifoldRho,
pub rounds: Vec<SearchLedger>,
}
impl StructureSearchResult {
#[must_use]
pub fn structure_changed(&self) -> bool {
use crate::solver::structure_search::MoveVerdict;
self.rounds.iter().any(|round| {
round.moves.iter().any(|record| {
matches!(
record.verdict,
MoveVerdict::Accepted { .. } | MoveVerdict::Demoted { .. }
)
})
})
}
}
#[derive(Clone, Copy, Debug)]
pub struct RoundDriverConfig {
pub n_shards: usize,
pub budget: MoveBudget,
pub max_rounds: usize,
pub harvest_params: HarvestParams,
}
pub fn run_structure_search_rounds(
mut term: SaeManifoldTerm,
mut rho: SaeManifoldRho,
target: ArrayView2<'_, f64>,
config: RoundDriverConfig,
ledger: &mut StructureLedger,
mut candidate_fit: impl FnMut(
SaeManifoldTerm,
SaeManifoldRho,
&[usize],
) -> (SaeManifoldTerm, SaeManifoldRho),
mut finalize_round: impl FnMut(
SaeManifoldTerm,
SaeManifoldRho,
&[usize],
) -> (SaeManifoldTerm, SaeManifoldRho),
) -> Result<StructureSearchResult, String> {
let RoundDriverConfig {
n_shards,
budget,
max_rounds,
harvest_params,
} = config;
let split = estimation_eval_split(target, n_shards);
let mut rounds: Vec<SearchLedger> = Vec::new();
for _ in 0..max_rounds {
let fitted = term.try_fitted()?;
let residuals = &target.to_owned() - &fitted;
let report = harvest_move_proposals(&term, &rho, residuals.view(), &harvest_params)?;
let birth_decoders = build_birth_decoders(&term, residuals.view(), &harvest_params)?;
if report.proposals.is_empty() || split.shards.is_empty() {
rounds.push(SearchLedger {
alpha: budget.alpha,
moves: Vec::new(),
collapse_events: term.collapse_events().to_vec(),
});
break;
}
type State = (SaeManifoldTerm, SaeManifoldRho);
let collapse_events = term.collapse_events().to_vec();
let decoders = birth_decoders;
let estimation_rows = split.estimation_rows.clone();
let outcome: SearchOutcome<State> = search(
(term, rho),
report.proposals,
&split.shards,
&budget,
ledger,
|state: &State, mv: &StructureMove| {
let (cand_term, cand_rho) =
apply_structure_move(&state.0, &state.1, mv, &decoders)?;
Ok(candidate_fit(cand_term, cand_rho, &estimation_rows))
},
|state: &State, shard: &RowBlockShard| eval_log_lik(&state.0, shard),
|state: &State, shard: &RowBlockShard| eval_log_lik(&state.0, shard),
|state: State, _: &RowBlockShard| state,
)?;
let (next_term, next_rho) = outcome.state;
let mut round_ledger = outcome.ledger;
round_ledger.collapse_events = collapse_events;
let applied = round_ledger.moves.iter().any(|m| {
matches!(
m.verdict,
crate::solver::structure_search::MoveVerdict::Accepted { .. }
| crate::solver::structure_search::MoveVerdict::Demoted { .. }
)
});
rounds.push(round_ledger);
if applied {
let (polished_term, polished_rho) =
finalize_round(next_term, next_rho, &split.estimation_rows);
term = polished_term;
rho = polished_rho;
} else {
term = next_term;
rho = next_rho;
break;
}
}
Ok(StructureSearchResult { term, rho, rounds })
}
fn build_birth_decoders(
term: &SaeManifoldTerm,
residuals: ArrayView2<'_, f64>,
params: &HarvestParams,
) -> Result<Vec<Array2<f64>>, String> {
let n = residuals.nrows();
let p = residuals.ncols();
if params.max_births == 0 || n == 0 || p == 0 {
return Ok(Vec::new());
}
let assignments = term.assignment.assignments();
let activity: Array1<f64> = (0..n).map(|r| assignments.row(r).sum()).collect();
let max_rank = params.max_births.min(p.saturating_sub(1));
let model = match StructuredResidualModel::fit(ResidualFactorInput {
residuals,
activity: activity.view(),
max_factor_rank: max_rank,
}) {
Ok(m) => m,
Err(_) => return Ok(Vec::new()),
};
let factor = model.factor();
let r = factor.ncols();
let m = term.atoms[0].basis_size();
let mut decoders = Vec::with_capacity(r);
for j in 0..r {
let mut decoder = Array2::<f64>::zeros((m, p));
for out in 0..p {
decoder[[0, out]] = factor[[out, j]];
}
decoders.push(decoder);
}
Ok(decoders)
}
fn eval_log_lik(term: &SaeManifoldTerm, shard: &RowBlockShard) -> f64 {
let fitted = match term.try_fitted() {
Ok(f) => f,
Err(_) => return f64::NEG_INFINITY,
};
let n_full = fitted.nrows();
let p = fitted.ncols();
if p != shard.target.ncols() || n_full != shard.target.nrows() {
return f64::NEG_INFINITY;
}
let mut sse = 0.0_f64;
let mut count = 0usize;
for &row in &shard.rows {
if row >= n_full {
continue;
}
for out in 0..p {
let d = fitted[[row, out]] - shard.target[[row, out]];
sse_accumulate(&mut sse, d);
}
count += p;
}
if count == 0 {
return f64::NEG_INFINITY;
}
let reconstruction = -0.5 * sse;
let gate_evidence = gate_block_log_evidence(term, shard);
reconstruction + gate_evidence
}
fn gate_block_log_evidence(term: &SaeManifoldTerm, shard: &RowBlockShard) -> f64 {
use crate::inference::pg_gate_evidence::{GateBlock, pg_gate_evidence};
let logits = &term.assignment.logits;
let n_full = logits.nrows();
let k = logits.ncols();
if k == 0 {
return 0.0;
}
let rows: Vec<usize> = shard.rows.iter().copied().filter(|&r| r < n_full).collect();
let m = rows.len();
if m == 0 {
return 0.0;
}
let design = Array2::<f64>::ones((m, 1));
let b = Array1::<f64>::ones(m);
let penalty = Array2::<f64>::eye(1);
let mut total = 0.0_f64;
for atom in 0..k {
let mut psi = Array1::<f64>::zeros(m);
let mut y = Array1::<f64>::zeros(m);
for (i, &row) in rows.iter().enumerate() {
let logit = logits[[row, atom]];
if !logit.is_finite() {
return 0.0;
}
psi[i] = logit;
y[i] = if logit > 0.0 { 1.0 } else { 0.0 };
}
let block = GateBlock {
design: design.view(),
y: y.view(),
b: b.view(),
offset: None,
psi_hat: Some(psi.view()),
penalty: Some(penalty.view()),
hess_rest: None,
h_rest: None,
};
match pg_gate_evidence(&block) {
Ok(ev) => total -= ev.neg_log_evidence,
Err(_) => return 0.0,
}
}
total
}
#[inline]
fn sse_accumulate(sse: &mut f64, d: f64) {
*sse += d * d;
}
#[derive(Clone, Copy, Debug)]
pub struct ProductionRefitParams {
pub inner_max_iter: usize,
pub scoring_inner_max_iter: usize,
pub learning_rate: f64,
pub ridge_ext_coord: f64,
pub ridge_beta: f64,
}
pub fn run_production_structure_search(
term: SaeManifoldTerm,
rho: SaeManifoldRho,
target: ArrayView2<'_, f64>,
config: RoundDriverConfig,
refit_params: ProductionRefitParams,
ledger: &mut StructureLedger,
) -> Result<StructureSearchResult, String> {
let n = target.nrows();
let refit_at = |full_target: ArrayView2<'_, f64>,
mut cand_term: SaeManifoldTerm,
mut cand_rho: SaeManifoldRho,
estimation_rows: &[usize],
inner_max_iter: usize|
-> (SaeManifoldTerm, SaeManifoldRho) {
const HELD_OUT_WEIGHT: f64 = 1e-12;
let mut weights = vec![HELD_OUT_WEIGHT; n];
for &r in estimation_rows {
if r < n {
weights[r] = 1.0;
}
}
if cand_term.set_row_loss_weights(weights).is_err() {
return (cand_term, cand_rho);
}
if cand_term
.run_joint_fit_arrow_schur(
full_target,
&mut cand_rho,
None,
inner_max_iter,
refit_params.learning_rate,
refit_params.ridge_ext_coord,
refit_params.ridge_beta,
)
.is_err()
{
return (cand_term, cand_rho);
}
(cand_term, cand_rho)
};
let scoring_iters = refit_params.scoring_inner_max_iter;
let full_iters = refit_params.inner_max_iter;
let full_target_score = target.to_owned();
let full_target_polish = target.to_owned();
run_structure_search_rounds(
term,
rho,
target,
config,
ledger,
move |cand_term, cand_rho, estimation_rows| {
refit_at(
full_target_score.view(),
cand_term,
cand_rho,
estimation_rows,
scoring_iters,
)
},
move |adopted_term, adopted_rho, estimation_rows| {
refit_at(
full_target_polish.view(),
adopted_term,
adopted_rho,
estimation_rows,
full_iters,
)
},
)
}
pub fn rounds_to_json(rounds: &[SearchLedger]) -> Result<String, String> {
serde_json::to_string(rounds)
.map_err(|e| format!("rounds_to_json: serialize search ledger: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::structure_search::{CollapseAction, CollapseEvent};
use crate::terms::latent::LatentManifold;
use crate::terms::sae::manifold::{
AssignmentMode, PeriodicHarmonicEvaluator, SaeAssignment, SaeAtomBasisKind,
SaeBasisEvaluator, SaeManifoldAtom,
};
use ndarray::Array2;
use std::sync::Arc;
const ON: f64 = 6.0;
const OFF: f64 = -6.0;
fn planted_term(active: &[Vec<bool>]) -> (SaeManifoldTerm, SaeManifoldRho) {
let n = active.len();
let k = active[0].len();
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mut atoms = Vec::with_capacity(k);
let mut coord_blocks = Vec::with_capacity(k);
for atom_idx in 0..k {
let mut decoder = Array2::<f64>::zeros((3, p));
decoder[[1, atom_idx % p]] = 1.0;
decoder[[2, (atom_idx + 1) % p]] = 1.0;
let atom = SaeManifoldAtom::new(
format!("atom_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
atoms.push(atom);
coord_blocks.push(coords.clone());
}
let mut logits = Array2::<f64>::zeros((n, k));
for (row, atom_active) in active.iter().enumerate() {
for (atom, &on) in atom_active.iter().enumerate() {
logits[[row, atom]] = if on { ON } else { OFF };
}
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coord_blocks,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
(term, rho)
}
fn residuals_of(term: &SaeManifoldTerm) -> Array2<f64> {
let fitted = term.try_fitted().unwrap();
-&fitted
}
#[test]
fn structure_changed_is_true_only_when_a_move_lands() {
use crate::solver::structure_search::{MoveRecord, MoveVerdict};
fn ledger_with(verdicts: Vec<MoveVerdict>) -> SearchLedger {
SearchLedger {
alpha: 0.05,
moves: verdicts
.into_iter()
.enumerate()
.map(|(i, verdict)| MoveRecord {
mv: StructureMove::Death { atom: i },
trigger: 0.0,
structure_hash: i as u64,
claim: ClaimKind::AtomExists { atom: i },
verdict,
})
.collect(),
collapse_events: Vec::new(),
}
}
let (term0, rho0) = planted_term(&[vec![true], vec![true]]);
let empty = StructureSearchResult {
term: term0.clone(),
rho: rho0.clone(),
rounds: Vec::new(),
};
assert!(
!empty.structure_changed(),
"no rounds ⇒ the term/rho are the pre-search fit ⇒ structure_changed() must be false"
);
let no_landed = StructureSearchResult {
term: term0.clone(),
rho: rho0.clone(),
rounds: vec![ledger_with(vec![
MoveVerdict::Contested { log_e: -1.0 },
MoveVerdict::Vetoed { log_e: -2.0 },
])],
};
assert!(
!no_landed.structure_changed(),
"all-contested/vetoed rounds leave the model unchanged ⇒ structure_changed() must be false"
);
let accepted = StructureSearchResult {
term: term0.clone(),
rho: rho0.clone(),
rounds: vec![ledger_with(vec![
MoveVerdict::Contested { log_e: -1.0 },
MoveVerdict::Accepted { log_e: 3.0 },
])],
};
assert!(
accepted.structure_changed(),
"a landed Accepted move mutates term/rho ⇒ structure_changed() must be true (recompute bands)"
);
let demoted = StructureSearchResult {
term: term0.clone(),
rho: rho0.clone(),
rounds: vec![ledger_with(vec![MoveVerdict::Demoted { log_e: -1.0 }])],
};
assert!(
demoted.structure_changed(),
"a landed Demoted death folds an atom to ~0 routing ⇒ structure_changed() must be true"
);
}
#[test]
fn residual_bearing_fit_harvests_birth_proposal() {
let n = 40usize;
let active: Vec<Vec<bool>> = (0..n).map(|_| vec![true]).collect();
let (term, rho) = planted_term(&active);
let p = term.output_dim();
let mut residuals = Array2::<f64>::zeros((n, p));
let u = [0.6_f64, -0.4, 0.5, -0.3];
for row in 0..n {
let amp = 1.0 + (row as f64) / (n as f64);
for c in 0..p {
residuals[[row, c]] = amp * u[c % u.len()];
}
}
let params = HarvestParams {
max_fusions: 0,
max_fissions: 0,
max_births: 2,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let births: usize = report
.proposals
.iter()
.filter(|p| matches!(p.mv, StructureMove::Birth { .. }))
.count();
assert!(
births >= 1,
"a residual-bearing fit with births enabled must harvest at least \
one birth proposal (so K can be discovered); got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
assert!(
report.births_proposed >= 1,
"births_proposed must count the harvested births; got {}",
report.births_proposed
);
assert!(
report.birth_skipped_reason.is_none(),
"the birth channel must run (no skip) on a non-degenerate residual; got {:?}",
report.birth_skipped_reason
);
}
#[test]
fn fully_reconstructed_null_harvests_no_birth() {
let n = 40usize;
let active: Vec<Vec<bool>> = (0..n).map(|_| vec![true]).collect();
let (term, rho) = planted_term(&active);
let p = term.output_dim();
let zero_residual = Array2::<f64>::zeros((n, p));
let params = HarvestParams {
max_fusions: 0,
max_fissions: 0,
max_births: 2,
};
let report = harvest_move_proposals(&term, &rho, zero_residual.view(), ¶ms).unwrap();
let births: usize = report
.proposals
.iter()
.filter(|p| matches!(p.mv, StructureMove::Birth { .. }))
.count();
assert_eq!(
births, 0,
"a fully-reconstructed (zero-residual) null must harvest no birth \
proposal; got {births} births"
);
}
#[test]
fn planted_shatter_harvests_fusion_not_fission() {
let n = 30usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| {
let dup = row % 3 == 0;
vec![dup, dup, row % 2 == 0]
})
.collect();
let (term, rho) = planted_term(&active);
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 4,
max_fissions: 4,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let has_fusion_01 = report.proposals.iter().any(|p| {
matches!(p.mv, StructureMove::Fusion { a, b } if (a, b) == (0, 1) || (a, b) == (1, 0))
});
assert!(
has_fusion_01,
"shattered duplicate pair (0,1) must yield a fusion proposal; got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
let has_fission = report
.proposals
.iter()
.any(|p| matches!(p.mv, StructureMove::Fission { .. }));
assert!(
!has_fission,
"symmetric duplicate supports must not trigger an absorption fission audit"
);
}
#[test]
fn planted_absorption_harvests_fission_audit_with_loud_carve_skip() {
let n = 40usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| {
let child = row % 4 == 0;
let parent = row % 2 == 0 || row % 4 == 1;
vec![parent, child, row % 5 == 0]
})
.collect();
let (term, rho) = planted_term(&active);
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 4,
max_fissions: 4,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let fissioned_parent = report
.proposals
.iter()
.any(|p| matches!(p.mv, StructureMove::Fission { atom: 0 }));
assert!(
fissioned_parent,
"nested-support parent (atom 0) must be flagged for a fission audit; got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
assert!(
report.fission_carve_skipped,
"the #993 within-atom carve is unwired; the skip must be recorded, not silent"
);
}
#[test]
fn independent_atoms_harvest_no_fusion() {
let n = 60usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| vec![row % 2 == 0, row % 3 == 0, row % 5 == 0])
.collect();
let (term, rho) = planted_term(&active);
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 4,
max_fissions: 4,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let has_fusion = report
.proposals
.iter()
.any(|p| matches!(p.mv, StructureMove::Fusion { .. }));
assert!(
!has_fusion,
"independent atom supports must not produce fusion proposals; got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
}
#[test]
fn diverged_ard_and_terminal_collapse_harvest_deaths() {
let n = 20usize;
let active: Vec<Vec<bool>> = (0..n).map(|row| vec![true, row % 2 == 0, false]).collect();
let (mut term, mut rho) = planted_term(&active);
rho.log_ard[2] = Array1::from_elem(1, ARD_DIVERGENCE_LOG_PRECISION + 5.0);
term.record_collapse_event(CollapseEvent {
iteration: 3,
atom: 1,
max_active_mass: 1e-6,
floor: 1e-3,
action: CollapseAction::Terminal,
});
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 0,
max_fissions: 0,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let death_atoms: Vec<usize> = report
.proposals
.iter()
.filter_map(|p| match p.mv {
StructureMove::Death { atom } => Some(atom),
_ => None,
})
.collect();
assert!(
death_atoms.contains(&2),
"diverged ARD on atom 2 must yield a death proposal; got {death_atoms:?}"
);
assert!(
death_atoms.contains(&1),
"terminal collapse on atom 1 must yield a death proposal; got {death_atoms:?}"
);
}
#[test]
fn apply_move_restructures_warm() {
let n = 12usize;
let active: Vec<Vec<bool>> = (0..n).map(|row| vec![true, row % 2 == 0]).collect();
let (term, rho) = planted_term(&active);
let k0 = term.k_atoms();
let (fissioned, fissioned_rho) =
apply_structure_move(&term, &rho, &StructureMove::Fission { atom: 0 }, &[]).unwrap();
assert_eq!(fissioned.k_atoms(), k0 + 1);
assert_eq!(fissioned_rho.log_ard.len(), k0 + 1);
let (fused, _) =
apply_structure_move(&term, &rho, &StructureMove::Fusion { a: 0, b: 1 }, &[]).unwrap();
assert_eq!(fused.k_atoms(), k0);
let fused_assign = fused.assignment.assignments();
assert!(
fused_assign.column(1).iter().all(|&m| m < 1e-6),
"fused-away atom 1 must route to ~0 mass"
);
let (dead, _) =
apply_structure_move(&term, &rho, &StructureMove::Death { atom: 1 }, &[]).unwrap();
assert_eq!(dead.k_atoms(), k0);
let dead_assign = dead.assignment.assignments();
assert!(dead_assign.column(1).iter().all(|&m| m < 1e-6));
let p = term.output_dim();
let m = term.atoms[0].basis_size();
let mut decoder = Array2::<f64>::zeros((m, p));
decoder[[0, 0]] = 0.7;
let (born, born_rho) = apply_structure_move(
&term,
&rho,
&StructureMove::Birth { candidate: 0 },
&[decoder],
)
.unwrap();
assert_eq!(born.k_atoms(), k0 + 1);
assert_eq!(born_rho.log_ard.len(), k0 + 1);
assert_eq!(born.atoms[k0].decoder_coefficients[[0, 0]], 0.7);
}
#[test]
fn round_driver_ledger_is_byte_deterministic() {
let n = 24usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| {
let dup = row % 3 == 0;
vec![dup, dup, row % 2 == 0]
})
.collect();
let run = || {
let (term, rho) = planted_term(&active);
let target = Array2::<f64>::zeros((n, term.output_dim()));
let mut ledger = crate::inference::structure_evidence::StructureLedger::new();
let budget = MoveBudget {
max_moves: 4,
alpha: 0.05,
};
let params = HarvestParams {
max_fusions: 4,
max_fissions: 0,
max_births: 0,
};
let config = RoundDriverConfig {
n_shards: 3,
budget,
max_rounds: 2,
harvest_params: params,
};
run_structure_search_rounds(
term,
rho,
target.view(),
config,
&mut ledger,
|t, r, _| (t, r),
|t, r, _| (t, r),
)
.unwrap()
};
let a = run();
let b = run();
let sa = serde_json::to_string(&a.rounds).unwrap();
let sb = serde_json::to_string(&b.rounds).unwrap();
assert_eq!(
sa, sb,
"identical inputs must produce a byte-identical ledger"
);
assert_eq!(a.term.k_atoms(), b.term.k_atoms());
}
#[test]
fn scoring_iter_cap_preserves_moves_and_adopted_fit() {
let n = 40usize;
let active: Vec<Vec<bool>> = (0..n).map(|_| vec![true]).collect();
let p = 4usize;
let u = [0.6_f64, -0.4, 0.5, -0.3];
let mut target = Array2::<f64>::zeros((n, p));
for row in 0..n {
let amp = 1.0 + (row as f64) / (n as f64);
for c in 0..p {
target[[row, c]] = amp * u[c % u.len()];
}
}
let config = RoundDriverConfig {
n_shards: 4,
budget: MoveBudget {
max_moves: 4,
alpha: 0.05,
},
max_rounds: 2,
harvest_params: HarvestParams {
max_fusions: 2,
max_fissions: 2,
max_births: 2,
},
};
let full_iters = 24usize;
let run = |scoring_inner_max_iter: usize| {
let (term, rho) = planted_term(&active);
let mut ledger = StructureLedger::new();
let refit_params = ProductionRefitParams {
inner_max_iter: full_iters,
scoring_inner_max_iter,
learning_rate: 1.0,
ridge_ext_coord: 1e-6,
ridge_beta: 1e-6,
};
let result = run_production_structure_search(
term,
rho,
target.view(),
config,
refit_params,
&mut ledger,
)
.unwrap();
let fitted = result.term.try_fitted().unwrap();
(result, fitted)
};
let (reference, ref_fitted) = run(full_iters);
let (capped, cap_fitted) = run(4);
let round_moves = |rounds: &[SearchLedger]| -> String {
serde_json::to_string(&rounds.iter().map(|r| &r.moves).collect::<Vec<_>>()).unwrap()
};
assert_eq!(
round_moves(&reference.rounds),
round_moves(&capped.rounds),
"scoring-iteration cap changed the accepted-move trajectory — the e-gate \
decisions are NOT cap-invariant (the #1026 economy is unsound)"
);
assert_eq!(
reference.term.k_atoms(),
capped.term.k_atoms(),
"scoring cap changed the discovered dictionary size"
);
assert_eq!(ref_fitted.dim(), cap_fitted.dim());
let mut max_abs = 0.0_f64;
for (a, b) in ref_fitted.iter().zip(cap_fitted.iter()) {
max_abs = max_abs.max((a - b).abs());
}
assert!(
max_abs < 1e-6,
"capped-scoring adopted fit diverged from the full-iter reference by \
{max_abs:.3e} (> 1e-6); the polish did not reach the same optimum"
);
}
#[test]
fn estimation_eval_split_is_disjoint() {
let target = Array2::<f64>::zeros((20, 3));
let split = estimation_eval_split(target.view(), 4);
assert!(!split.estimation_rows.is_empty());
assert!(!split.shards.is_empty());
let est: std::collections::HashSet<usize> = split.estimation_rows.iter().copied().collect();
for shard in &split.shards {
for &row in &shard.rows {
assert!(
!est.contains(&row),
"eval shard row {row} must not be in the estimation set"
);
}
}
}
#[test]
fn birth_topology_race_assigns_circle_vs_line_by_evidence() {
use std::f64::consts::TAU;
let n = 80usize;
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let p = 4usize;
let mut circle_target = Array2::<f64>::zeros((n, p));
for row in 0..n {
let t = coords[[row, 0]];
circle_target[[row, 0]] = (TAU * t).cos();
circle_target[[row, 1]] = (TAU * t).sin();
}
let mut line_target = Array2::<f64>::zeros((n, p));
let u = [0.7_f64, -0.4, 0.5, -0.2];
for row in 0..n {
let t = coords[[row, 0]];
for c in 0..p {
line_target[[row, c]] = t * u[c];
}
}
let weights = Array1::<f64>::ones(n);
let circle_fit =
race_birth_topology(coords.view(), circle_target.view(), weights.view(), 1)
.expect("circle race runs")
.expect("circle race has a realizable candidate");
let line_fit = race_birth_topology(coords.view(), line_target.view(), weights.view(), 1)
.expect("line race runs")
.expect("line race has a realizable candidate");
assert_eq!(
circle_fit.basis_kind,
SaeAtomBasisKind::Periodic,
"a circular birth residual must win the circle (Periodic) topology"
);
assert_eq!(
line_fit.basis_kind,
SaeAtomBasisKind::EuclideanPatch,
"a straight birth residual must win the line (EuclideanPatch) topology"
);
assert_ne!(
circle_fit.basis_kind, line_fit.basis_kind,
"the discovery must assign DIFFERENT topologies to the circle and line \
atoms (evidence-chosen, not inherited)"
);
}
#[test]
fn birth_topology_race_d2_includes_and_selects_cylinder() {
use std::f64::consts::TAU;
let n = 120usize;
let coords = Array2::<f64>::from_shape_fn((n, 2), |(row, axis)| {
if axis == 0 {
(row as f64 / n as f64) * 2.0
} else {
(row as f64 / n as f64) * 3.0 - 1.5
}
});
let specs = topology_candidates_for_dim(coords.view(), 2).expect("d=2 candidates build");
let has_cylinder = specs
.iter()
.any(|s| s.basis_kind == SaeAtomBasisKind::Cylinder);
assert!(
has_cylinder,
"the d=2 topology-race candidate set MUST include the Cylinder kind; got {:?}",
specs.iter().map(|s| &s.basis_kind).collect::<Vec<_>>()
);
let has_torus = specs
.iter()
.any(|s| s.basis_kind == SaeAtomBasisKind::Torus);
let has_sphere = specs
.iter()
.any(|s| s.basis_kind == SaeAtomBasisKind::Sphere);
let has_patch = specs
.iter()
.any(|s| s.basis_kind == SaeAtomBasisKind::EuclideanPatch);
assert!(
has_torus && has_sphere && has_patch,
"the d=2 race must be COMPLETE (torus + sphere + euclidean + cylinder)"
);
let p = 4usize;
let mut cyl_target = Array2::<f64>::zeros((n, p));
for row in 0..n {
let phase = coords[[row, 0]];
let mag = coords[[row, 1]];
cyl_target[[row, 0]] = (TAU * phase).cos();
cyl_target[[row, 1]] = (TAU * phase).sin();
cyl_target[[row, 2]] = mag;
}
let weights = Array1::<f64>::ones(n);
let cyl_fit = race_birth_topology(coords.view(), cyl_target.view(), weights.view(), 2)
.expect("cylinder race runs")
.expect("cylinder race has a realizable candidate");
assert_eq!(
cyl_fit.basis_kind,
SaeAtomBasisKind::Cylinder,
"a cylindrical birth residual (periodic along one axis, linear along the \
other) must win the Cylinder topology by evidence; got {:?}",
cyl_fit.basis_kind
);
}
#[test]
fn born_atom_reports_finite_uncertainty_band() {
let n = 48usize;
let active: Vec<Vec<bool>> = (0..n).map(|_| vec![true]).collect();
let (term, rho) = planted_term(&active);
let k_seed = term.k_atoms();
let p = term.output_dim();
let m = term.atoms[0].basis_size();
let mut decoder = Array2::<f64>::zeros((m, p));
decoder[[1, 0]] = 0.9;
decoder[[2, 1]] = -0.6;
let (mut born, born_rho) = apply_structure_move(
&term,
&rho,
&StructureMove::Birth { candidate: 0 },
&[decoder],
)
.expect("birth applies");
assert_eq!(born.k_atoms(), k_seed + 1, "the birth grows K by one");
let target = born.try_fitted().expect("born term reconstructs");
let dispersion = 1.0e-2_f64;
born.set_atom_inner_fits(target.view(), &born_rho, dispersion)
.expect("inner fits build");
let mut unc = born.shape_uncertainty_without_decoder_covariance(dispersion);
unc.atoms.truncate(k_seed);
assert_eq!(
unc.atoms.len(),
k_seed,
"seed-K Schur band omits the born atom"
);
born.complete_born_atom_shape_bands(&mut unc)
.expect("born-atom band completes");
assert_eq!(
unc.atoms.len(),
born.k_atoms(),
"completion must grow the band list to the post-search atom count"
);
let born_band = &unc.atoms[k_seed];
assert!(
born_band.band_sd.nrows() > 0 && born_band.band_sd.ncols() == p,
"the born atom's band must be shaped (G>0, p)"
);
let mut any_positive = false;
for &sd in born_band.band_sd.iter() {
assert!(
sd.is_finite() && sd >= 0.0,
"born-atom band sd must be finite and non-negative; got {sd}"
);
if sd > 0.0 {
any_positive = true;
}
}
assert!(
any_positive,
"a born atom with a non-degenerate inner Hessian must report a strictly \
positive uncertainty somewhere (a finite band, never all-zero / missing)"
);
}
#[test]
fn production_gate_consumes_corrected_pg_normalizer() {
let n = 32usize;
let null_active: Vec<Vec<bool>> = (0..n).map(|_| vec![true, true]).collect();
let cand_active: Vec<Vec<bool>> = (0..n).map(|_| vec![true, true, true]).collect();
let (null_term, _) = planted_term(&null_active);
let (cand_term, _) = planted_term(&cand_active);
assert_eq!(null_term.k_atoms(), 2);
assert_eq!(cand_term.k_atoms(), 3, "candidate grows K by one atom");
let p = null_term.output_dim();
let target = Arc::new(Array2::<f64>::zeros((n, p)));
let shard = RowBlockShard {
target: target.clone(),
rows: (0..n).collect(),
};
let null_gate = gate_block_log_evidence(&null_term, &shard);
let cand_gate = gate_block_log_evidence(&cand_term, &shard);
assert!(
null_gate.is_finite() && cand_gate.is_finite(),
"gate-block evidence must be finite on a well-posed gate block"
);
let log_2pi = (2.0 * std::f64::consts::PI).ln();
let gate_delta = cand_gate - null_gate;
let per_atom_no_norm = |term: &SaeManifoldTerm| -> f64 {
let dg = term.k_atoms() as f64; gate_block_log_evidence(term, &shard) + 0.5 * dg * log_2pi
};
let no_norm_delta = per_atom_no_norm(&cand_term) - per_atom_no_norm(&null_term);
let normalizer_in_delta = gate_delta - no_norm_delta;
assert!(
(normalizer_in_delta + 0.5 * log_2pi).abs() < 1e-9,
"the gate-block normalizer in the K→K+1 difference must be the \
corrected −½·log(2π) Occam penalty, got {normalizer_in_delta} \
(buggy +½·log(2π) = {})",
0.5 * log_2pi
);
let full = eval_log_lik(&cand_term, &shard);
let recon_only = {
let fitted = cand_term.try_fitted().unwrap();
let mut sse = 0.0;
for &row in &shard.rows {
for out in 0..p {
let d = fitted[[row, out]] - shard.target[[row, out]];
sse += d * d;
}
}
-0.5 * sse
};
assert!(
(full - (recon_only + cand_gate)).abs() < 1e-9,
"the live per-shard likelihood must equal reconstruction + the \
PG gate-block evidence (so the corrected normalizer reaches the gate)"
);
}
}