use ndarray::{Array1, Array2, ArrayView2};
use crate::cache::Fingerprinter;
use crate::inference::residual_factor::{ResidualFactorInput, StructuredResidualModel};
use crate::inference::structure_evidence::{ClaimKind, StructureLedger};
use crate::solver::structure_search::{
CollapseAction, MoveBudget, MoveProposal, SearchLedger, SearchOutcome, StructureMove, search,
};
use crate::terms::atom_codes::SparseAtomCodes;
use crate::terms::sae_manifold::{SaeAtomBasisKind, SaeManifoldRho, SaeManifoldTerm};
const ACTIVE_SUPPORT_REL_FLOOR: f64 = 0.5;
const ARD_DIVERGENCE_LOG_PRECISION: f64 = 12.0;
const FUSION_DEPENDENCE_FLOOR: f64 = 0.6;
const ABSORPTION_ASYMMETRY_FLOOR: f64 = 0.5;
#[derive(Clone, Copy, Debug)]
pub struct HarvestParams {
pub max_fusions: usize,
pub max_fissions: usize,
pub max_births: usize,
}
impl Default for HarvestParams {
fn default() -> Self {
Self {
max_fusions: 4,
max_fissions: 4,
max_births: 4,
}
}
}
pub fn sparse_codes_from_term(term: &SaeManifoldTerm) -> SparseAtomCodes {
let assignments = term.assignment.assignments();
let n = assignments.nrows();
let k = assignments.ncols();
let floor = if k == 0 {
0.0
} else {
ACTIVE_SUPPORT_REL_FLOOR / k as f64
};
let mut codes = SparseAtomCodes::empty(n, k);
for row in 0..n {
for atom in 0..k {
let mass = assignments[[row, atom]];
if mass > floor {
codes.row_mut(row).assign(atom, mass);
}
}
}
codes
}
fn per_atom_max_mass(term: &SaeManifoldTerm) -> Array1<f64> {
let assignments = term.assignment.assignments();
let k = assignments.ncols();
let mut out = Array1::<f64>::zeros(k);
for atom in 0..k {
let mut max = 0.0_f64;
for &m in assignments.column(atom).iter() {
if m > max {
max = m;
}
}
out[atom] = max;
}
out
}
fn per_atom_ard_divergence(rho: &SaeManifoldRho, atom: usize) -> f64 {
rho.log_ard
.get(atom)
.and_then(|axes| axes.iter().copied().reduce(f64::max))
.unwrap_or(f64::NEG_INFINITY)
}
fn post_move_structure_hash(term: &SaeManifoldTerm, mv: &StructureMove) -> u64 {
let mut fp = Fingerprinter::new();
fp.write_str("sae_structure_move");
match mv {
StructureMove::Birth { candidate } => {
fp.write_str("birth");
fp.write_usize(*candidate);
}
StructureMove::Death { atom } => {
fp.write_str("death");
fp.write_usize(*atom);
}
StructureMove::Fission { atom } => {
fp.write_str("fission");
fp.write_usize(*atom);
}
StructureMove::Fusion { a, b } => {
fp.write_str("fusion");
fp.write_usize((*a).min(*b));
fp.write_usize((*a).max(*b));
}
}
fp.write_usize(term.atoms.len());
for atom in &term.atoms {
fp.write_str(basis_kind_tag(&atom.basis_kind));
fp.write_usize(atom.latent_dim);
}
let digest = fp.finalize();
let bytes = digest.as_bytes();
u64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
])
}
fn basis_kind_tag(kind: &SaeAtomBasisKind) -> &str {
match kind {
SaeAtomBasisKind::Duchon => "duchon",
SaeAtomBasisKind::Periodic => "periodic",
SaeAtomBasisKind::Sphere => "sphere",
SaeAtomBasisKind::Torus => "torus",
SaeAtomBasisKind::EuclideanPatch => "euclidean_patch",
SaeAtomBasisKind::Precomputed(_) => "precomputed",
}
}
fn proposal(term: &SaeManifoldTerm, mv: StructureMove, trigger: f64) -> MoveProposal {
let structure_hash = post_move_structure_hash(term, &mv);
let claim = match &mv {
StructureMove::Birth { candidate } => ClaimKind::AtomExists {
atom: term.k_atoms() + *candidate,
},
StructureMove::Death { atom } => ClaimKind::AtomExists { atom: *atom },
StructureMove::Fusion { a, b } => ClaimKind::BindingEdge { a: *a, b: *b },
StructureMove::Fission { atom } => ClaimKind::Custom {
label: format!("fission:{atom}"),
},
};
MoveProposal {
mv,
trigger,
structure_hash,
claim,
}
}
pub fn harvest_move_proposals(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
residuals: ArrayView2<'_, f64>,
params: &HarvestParams,
) -> Result<HarvestReport, String> {
let k = term.k_atoms();
let mut proposals: Vec<MoveProposal> = Vec::new();
let max_mass = per_atom_max_mass(term);
let terminal: std::collections::HashSet<usize> = term
.collapse_events()
.iter()
.filter(|e| matches!(e.action, CollapseAction::Terminal))
.map(|e| e.atom)
.collect();
for atom in 0..k {
let ard = per_atom_ard_divergence(rho, atom);
let diverged = ard >= ARD_DIVERGENCE_LOG_PRECISION;
let collapsed = terminal.contains(&atom);
if diverged || collapsed {
let trigger = if collapsed { f64::MAX / 2.0 } else { ard };
let trigger = trigger - max_mass[atom].min(1.0) * 1e-9;
proposals.push(proposal(term, StructureMove::Death { atom }, trigger));
}
}
let codes = sparse_codes_from_term(term);
let mut fusion_pairs: Vec<(usize, usize, f64)> = Vec::new();
for a in 0..k {
for b in (a + 1)..k {
let stats = codes.coactivation(a, b);
let dep = stats.dependence();
if dep >= FUSION_DEPENDENCE_FLOOR {
fusion_pairs.push((a, b, dep));
}
}
}
fusion_pairs.sort_by(|x, y| y.2.total_cmp(&x.2).then(x.0.cmp(&y.0)).then(x.1.cmp(&y.1)));
for &(a, b, dep) in fusion_pairs.iter().take(params.max_fusions) {
proposals.push(proposal(term, StructureMove::Fusion { a, b }, dep));
}
let mut fission_atoms: Vec<(usize, f64)> = Vec::new();
for a in 0..k {
for b in (a + 1)..k {
let stats = codes.coactivation(a, b);
let asym = stats.absorption_asymmetry();
if asym >= ABSORPTION_ASYMMETRY_FLOOR {
let parent = if stats.p_a_given_b >= stats.p_b_given_a {
a
} else {
b
};
let significance = (1.0 - asym).max(0.0);
fission_atoms.push((parent, significance));
}
}
}
fission_atoms.sort_by(|x, y| x.1.total_cmp(&y.1).then(x.0.cmp(&y.0)));
fission_atoms.dedup_by_key(|(atom, _)| *atom);
let fission_carve_skipped = !fission_atoms.is_empty();
for &(atom, significance) in fission_atoms.iter().take(params.max_fissions) {
proposals.push(proposal(
term,
StructureMove::Fission { atom },
significance,
));
}
let n = residuals.nrows();
let assignments = term.assignment.assignments();
let activity: Array1<f64> = (0..n).map(|r| assignments.row(r).sum()).collect();
let mut births_proposed = 0usize;
let mut birth_skipped_reason: Option<String> = None;
if params.max_births > 0 && n > 0 && residuals.ncols() > 0 {
let p = residuals.ncols();
let max_rank = params.max_births.min(p.saturating_sub(1));
match StructuredResidualModel::fit(ResidualFactorInput {
residuals,
activity: activity.view(),
max_factor_rank: max_rank,
}) {
Ok(model) => {
let factor = model.factor();
let r = model.factor_rank();
let mut dirs: Vec<(usize, f64)> = (0..r)
.map(|j| {
let mass = factor.column(j).iter().map(|v| v * v).sum::<f64>().sqrt();
(j, mass)
})
.collect();
dirs.sort_by(|x, y| y.1.total_cmp(&x.1).then(x.0.cmp(&y.0)));
for &(candidate, mass) in dirs.iter().take(params.max_births) {
proposals.push(proposal(term, StructureMove::Birth { candidate }, mass));
births_proposed += 1;
}
}
Err(e) => {
birth_skipped_reason = Some(e);
}
}
} else if params.max_births > 0 {
birth_skipped_reason =
Some("residuals empty or single-channel; no factor subspace to mine".to_string());
}
Ok(HarvestReport {
proposals,
fission_carve_skipped,
births_proposed,
birth_skipped_reason,
})
}
#[derive(Clone, Debug)]
pub struct HarvestReport {
pub proposals: Vec<MoveProposal>,
pub fission_carve_skipped: bool,
pub births_proposed: usize,
pub birth_skipped_reason: Option<String>,
}
pub fn apply_structure_move(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
mv: &StructureMove,
birth_decoders: &[Array2<f64>],
) -> Result<(SaeManifoldTerm, SaeManifoldRho), String> {
match mv {
StructureMove::Death { atom } => {
let mut child = term.clone();
demote_atom(&mut child, *atom)?;
Ok((child, rho.clone()))
}
StructureMove::Fusion { a, b } => {
let mut child = term.clone();
fold_atom_into(&mut child, *a, *b)?;
Ok((child, rho.clone()))
}
StructureMove::Fission { atom } => {
let (child, child_rho) = duplicate_atom(term, rho, *atom)?;
Ok((child, child_rho))
}
StructureMove::Birth { candidate } => {
let decoder = birth_decoders.get(*candidate).ok_or_else(|| {
format!(
"apply_structure_move: birth candidate {candidate} out of range \
({} residual-factor decoders)",
birth_decoders.len()
)
})?;
born_atom(term, rho, decoder.view())
}
}
}
const DEMOTE_LOGIT: f64 = -40.0;
fn demote_atom(term: &mut SaeManifoldTerm, atom: usize) -> Result<(), String> {
let k = term.k_atoms();
if atom >= k {
return Err(format!("demote_atom: atom {atom} out of range (K={k})"));
}
for row in 0..term.assignment.logits.nrows() {
term.assignment.logits[[row, atom]] = DEMOTE_LOGIT;
}
Ok(())
}
fn fold_atom_into(term: &mut SaeManifoldTerm, a: usize, b: usize) -> Result<(), String> {
let k = term.k_atoms();
if a >= k || b >= k {
return Err(format!(
"fold_atom_into: atoms ({a},{b}) out of range (K={k})"
));
}
if a == b {
return Err("fold_atom_into: cannot fuse an atom with itself".to_string());
}
for row in 0..term.assignment.logits.nrows() {
let la = term.assignment.logits[[row, a]];
let lb = term.assignment.logits[[row, b]];
term.assignment.logits[[row, a]] = la.max(lb);
}
demote_atom(term, b)?;
Ok(())
}
fn duplicate_atom(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
parent: usize,
) -> Result<(SaeManifoldTerm, SaeManifoldRho), String> {
let k = term.k_atoms();
if parent >= k {
return Err(format!(
"duplicate_atom: parent {parent} out of range (K={k})"
));
}
let mut atoms = term.atoms.clone();
let child_atom = term.atoms[parent].clone();
atoms.push(child_atom);
let n = term.assignment.logits.nrows();
let mut logits = Array2::<f64>::zeros((n, k + 1));
let split = std::f64::consts::LN_2;
for row in 0..n {
for col in 0..k {
let mut v = term.assignment.logits[[row, col]];
if col == parent {
v -= split;
}
logits[[row, col]] = v;
}
logits[[row, k]] = term.assignment.logits[[row, parent]] - split;
}
let mut coords = term.assignment.coords.clone();
coords.push(term.assignment.coords[parent].clone());
let assignment =
crate::terms::sae_manifold::SaeAssignment::with_mode(logits, coords, term.assignment.mode)?;
let child = SaeManifoldTerm::new(atoms, assignment)?;
let mut child_rho = rho.clone();
if parent < child_rho.log_ard.len() {
let inherited = child_rho.log_ard[parent].clone();
child_rho.log_ard.push(inherited);
} else {
child_rho.log_ard.push(Array1::<f64>::zeros(0));
}
Ok((child, child_rho))
}
const BIRTH_SEED_LOGIT: f64 = -4.0;
fn born_atom(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
factor_dir: ArrayView2<'_, f64>,
) -> Result<(SaeManifoldTerm, SaeManifoldRho), String> {
let k = term.k_atoms();
let template = &term.atoms[0];
let m = template.basis_size();
let p = term.output_dim();
if factor_dir.dim() != (m, p) {
return Err(format!(
"born_atom: residual-factor decoder must be ({m}, {p}); got {:?}",
factor_dir.dim()
));
}
let mut atoms = term.atoms.clone();
let mut born = template.clone();
born.decoder_coefficients = factor_dir.to_owned();
born.refresh_intrinsic_smooth_penalty();
atoms.push(born);
let n = term.assignment.logits.nrows();
let mut logits = Array2::<f64>::zeros((n, k + 1));
for row in 0..n {
for col in 0..k {
logits[[row, col]] = term.assignment.logits[[row, col]];
}
logits[[row, k]] = BIRTH_SEED_LOGIT;
}
let mut coords = term.assignment.coords.clone();
coords.push(term.assignment.coords[0].clone());
let assignment =
crate::terms::sae_manifold::SaeAssignment::with_mode(logits, coords, term.assignment.mode)?;
let child = SaeManifoldTerm::new(atoms, assignment)?;
let mut child_rho = rho.clone();
let inherited = child_rho
.log_ard
.first()
.cloned()
.unwrap_or_else(|| Array1::<f64>::zeros(0));
child_rho.log_ard.push(inherited);
Ok((child, child_rho))
}
#[derive(Clone, Debug)]
pub struct RowBlockShard {
pub target: std::sync::Arc<Array2<f64>>,
pub rows: Vec<usize>,
}
#[derive(Clone, Debug)]
pub struct EstimationEvalSplit {
pub estimation_rows: Vec<usize>,
pub shards: Vec<RowBlockShard>,
}
const ESTIMATION_FRACTION: f64 = 0.6;
pub fn estimation_eval_split(target: ArrayView2<'_, f64>, n_shards: usize) -> EstimationEvalSplit {
let n = target.nrows();
if n == 0 {
return EstimationEvalSplit {
estimation_rows: Vec::new(),
shards: Vec::new(),
};
}
let shared = std::sync::Arc::new(target.to_owned());
let n_est =
((n as f64 * ESTIMATION_FRACTION).round() as usize).clamp(1, n.saturating_sub(1).max(1));
let estimation_rows: Vec<usize> = (0..n_est).collect();
let eval_rows: Vec<usize> = (n_est..n).collect();
let n_eval = eval_rows.len();
let n_shards = n_shards.min(n_eval).max(usize::from(n_eval > 0));
let mut shards = Vec::new();
if n_eval > 0 && n_shards > 0 {
let base = n_eval / n_shards;
let rem = n_eval % n_shards;
let mut cursor = 0usize;
for s in 0..n_shards {
let len = base + usize::from(s < rem);
let rows: Vec<usize> = eval_rows[cursor..cursor + len].to_vec();
shards.push(RowBlockShard {
target: shared.clone(),
rows,
});
cursor += len;
}
}
EstimationEvalSplit {
estimation_rows,
shards,
}
}
pub struct StructureSearchResult {
pub term: SaeManifoldTerm,
pub rho: SaeManifoldRho,
pub rounds: Vec<SearchLedger>,
}
#[derive(Clone, Copy, Debug)]
pub struct RoundDriverConfig {
pub n_shards: usize,
pub budget: MoveBudget,
pub max_rounds: usize,
pub harvest_params: HarvestParams,
}
pub fn run_structure_search_rounds(
mut term: SaeManifoldTerm,
mut rho: SaeManifoldRho,
target: ArrayView2<'_, f64>,
config: RoundDriverConfig,
ledger: &mut StructureLedger,
mut candidate_fit: impl FnMut(
SaeManifoldTerm,
SaeManifoldRho,
&[usize],
) -> (SaeManifoldTerm, SaeManifoldRho),
) -> Result<StructureSearchResult, String> {
let RoundDriverConfig {
n_shards,
budget,
max_rounds,
harvest_params,
} = config;
let split = estimation_eval_split(target, n_shards);
let mut rounds: Vec<SearchLedger> = Vec::new();
for _ in 0..max_rounds {
let fitted = term.try_fitted()?;
let residuals = &target.to_owned() - &fitted;
let report = harvest_move_proposals(&term, &rho, residuals.view(), &harvest_params)?;
let birth_decoders = build_birth_decoders(&term, residuals.view(), &harvest_params)?;
if report.proposals.is_empty() || split.shards.is_empty() {
rounds.push(SearchLedger {
alpha: budget.alpha,
moves: Vec::new(),
collapse_events: term.collapse_events().to_vec(),
});
break;
}
type State = (SaeManifoldTerm, SaeManifoldRho);
let collapse_events = term.collapse_events().to_vec();
let decoders = birth_decoders;
let estimation_rows = split.estimation_rows.clone();
let outcome: SearchOutcome<State> = search(
(term, rho),
report.proposals,
&split.shards,
&budget,
ledger,
|state: &State, mv: &StructureMove| {
let (cand_term, cand_rho) =
apply_structure_move(&state.0, &state.1, mv, &decoders)?;
Ok(candidate_fit(cand_term, cand_rho, &estimation_rows))
},
|state: &State, shard: &RowBlockShard| eval_log_lik(&state.0, shard),
|state: &State, shard: &RowBlockShard| eval_log_lik(&state.0, shard),
|state: State, _: &RowBlockShard| state,
)?;
let (next_term, next_rho) = outcome.state;
let mut round_ledger = outcome.ledger;
round_ledger.collapse_events = collapse_events;
let applied = round_ledger.moves.iter().any(|m| {
matches!(
m.verdict,
crate::solver::structure_search::MoveVerdict::Accepted { .. }
| crate::solver::structure_search::MoveVerdict::Demoted { .. }
)
});
rounds.push(round_ledger);
term = next_term;
rho = next_rho;
if !applied {
break;
}
}
Ok(StructureSearchResult { term, rho, rounds })
}
fn build_birth_decoders(
term: &SaeManifoldTerm,
residuals: ArrayView2<'_, f64>,
params: &HarvestParams,
) -> Result<Vec<Array2<f64>>, String> {
let n = residuals.nrows();
let p = residuals.ncols();
if params.max_births == 0 || n == 0 || p == 0 {
return Ok(Vec::new());
}
let assignments = term.assignment.assignments();
let activity: Array1<f64> = (0..n).map(|r| assignments.row(r).sum()).collect();
let max_rank = params.max_births.min(p.saturating_sub(1));
let model = match StructuredResidualModel::fit(ResidualFactorInput {
residuals,
activity: activity.view(),
max_factor_rank: max_rank,
}) {
Ok(m) => m,
Err(_) => return Ok(Vec::new()),
};
let factor = model.factor();
let r = factor.ncols();
let m = term.atoms[0].basis_size();
let mut decoders = Vec::with_capacity(r);
for j in 0..r {
let mut decoder = Array2::<f64>::zeros((m, p));
for out in 0..p {
decoder[[0, out]] = factor[[out, j]];
}
decoders.push(decoder);
}
Ok(decoders)
}
fn eval_log_lik(term: &SaeManifoldTerm, shard: &RowBlockShard) -> f64 {
let fitted = match term.try_fitted() {
Ok(f) => f,
Err(_) => return f64::NEG_INFINITY,
};
let n_full = fitted.nrows();
let p = fitted.ncols();
if p != shard.target.ncols() || n_full != shard.target.nrows() {
return f64::NEG_INFINITY;
}
let mut sse = 0.0_f64;
let mut count = 0usize;
for &row in &shard.rows {
if row >= n_full {
continue;
}
for out in 0..p {
let d = fitted[[row, out]] - shard.target[[row, out]];
sse_accumulate(&mut sse, d);
}
count += p;
}
if count == 0 {
return f64::NEG_INFINITY;
}
-0.5 * sse
}
#[inline]
fn sse_accumulate(sse: &mut f64, d: f64) {
*sse += d * d;
}
#[derive(Clone, Copy, Debug)]
pub struct ProductionRefitParams {
pub inner_max_iter: usize,
pub learning_rate: f64,
pub ridge_ext_coord: f64,
pub ridge_beta: f64,
}
pub fn run_production_structure_search(
term: SaeManifoldTerm,
rho: SaeManifoldRho,
target: ArrayView2<'_, f64>,
config: RoundDriverConfig,
refit_params: ProductionRefitParams,
ledger: &mut StructureLedger,
) -> Result<StructureSearchResult, String> {
let full_target = target.to_owned();
let n = full_target.nrows();
run_structure_search_rounds(
term,
rho,
target,
config,
ledger,
move |mut cand_term, mut cand_rho, estimation_rows| {
const HELD_OUT_WEIGHT: f64 = 1e-12;
let mut weights = vec![HELD_OUT_WEIGHT; n];
for &r in estimation_rows {
if r < n {
weights[r] = 1.0;
}
}
if cand_term.set_row_loss_weights(weights).is_err() {
return (cand_term, cand_rho);
}
if cand_term
.run_joint_fit_arrow_schur(
full_target.view(),
&mut cand_rho,
None,
refit_params.inner_max_iter,
refit_params.learning_rate,
refit_params.ridge_ext_coord,
refit_params.ridge_beta,
)
.is_err()
{
return (cand_term, cand_rho);
}
(cand_term, cand_rho)
},
)
}
pub fn rounds_to_json(rounds: &[SearchLedger]) -> Result<String, String> {
serde_json::to_string(rounds)
.map_err(|e| format!("rounds_to_json: serialize search ledger: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::structure_search::{CollapseAction, CollapseEvent};
use crate::terms::latent_coord::LatentManifold;
use crate::terms::sae_manifold::{
AssignmentMode, PeriodicHarmonicEvaluator, SaeAssignment, SaeAtomBasisKind,
SaeBasisEvaluator, SaeManifoldAtom,
};
use ndarray::Array2;
use std::sync::Arc;
const ON: f64 = 6.0;
const OFF: f64 = -6.0;
fn planted_term(active: &[Vec<bool>]) -> (SaeManifoldTerm, SaeManifoldRho) {
let n = active.len();
let k = active[0].len();
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mut atoms = Vec::with_capacity(k);
let mut coord_blocks = Vec::with_capacity(k);
for atom_idx in 0..k {
let mut decoder = Array2::<f64>::zeros((3, p));
decoder[[1, atom_idx % p]] = 1.0;
decoder[[2, (atom_idx + 1) % p]] = 1.0;
let atom = SaeManifoldAtom::new(
format!("atom_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
atoms.push(atom);
coord_blocks.push(coords.clone());
}
let mut logits = Array2::<f64>::zeros((n, k));
for (row, atom_active) in active.iter().enumerate() {
for (atom, &on) in atom_active.iter().enumerate() {
logits[[row, atom]] = if on { ON } else { OFF };
}
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coord_blocks,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
(term, rho)
}
fn residuals_of(term: &SaeManifoldTerm) -> Array2<f64> {
let fitted = term.try_fitted().unwrap();
-&fitted
}
#[test]
fn planted_shatter_harvests_fusion_not_fission() {
let n = 30usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| {
let dup = row % 3 == 0;
vec![dup, dup, row % 2 == 0]
})
.collect();
let (term, rho) = planted_term(&active);
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 4,
max_fissions: 4,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let has_fusion_01 = report.proposals.iter().any(|p| {
matches!(p.mv, StructureMove::Fusion { a, b } if (a, b) == (0, 1) || (a, b) == (1, 0))
});
assert!(
has_fusion_01,
"shattered duplicate pair (0,1) must yield a fusion proposal; got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
let has_fission = report
.proposals
.iter()
.any(|p| matches!(p.mv, StructureMove::Fission { .. }));
assert!(
!has_fission,
"symmetric duplicate supports must not trigger an absorption fission audit"
);
}
#[test]
fn planted_absorption_harvests_fission_audit_with_loud_carve_skip() {
let n = 40usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| {
let child = row % 4 == 0;
let parent = row % 2 == 0 || row % 4 == 1;
vec![parent, child, row % 5 == 0]
})
.collect();
let (term, rho) = planted_term(&active);
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 4,
max_fissions: 4,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let fissioned_parent = report
.proposals
.iter()
.any(|p| matches!(p.mv, StructureMove::Fission { atom: 0 }));
assert!(
fissioned_parent,
"nested-support parent (atom 0) must be flagged for a fission audit; got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
assert!(
report.fission_carve_skipped,
"the #993 within-atom carve is unwired; the skip must be recorded, not silent"
);
}
#[test]
fn independent_atoms_harvest_no_fusion() {
let n = 60usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| vec![row % 2 == 0, row % 3 == 0, row % 5 == 0])
.collect();
let (term, rho) = planted_term(&active);
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 4,
max_fissions: 4,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let has_fusion = report
.proposals
.iter()
.any(|p| matches!(p.mv, StructureMove::Fusion { .. }));
assert!(
!has_fusion,
"independent atom supports must not produce fusion proposals; got {:?}",
report.proposals.iter().map(|p| &p.mv).collect::<Vec<_>>()
);
}
#[test]
fn diverged_ard_and_terminal_collapse_harvest_deaths() {
let n = 20usize;
let active: Vec<Vec<bool>> = (0..n).map(|row| vec![true, row % 2 == 0, false]).collect();
let (mut term, mut rho) = planted_term(&active);
rho.log_ard[2] = Array1::from_elem(1, ARD_DIVERGENCE_LOG_PRECISION + 5.0);
term.record_collapse_event(CollapseEvent {
iteration: 3,
atom: 1,
max_active_mass: 1e-6,
floor: 1e-3,
action: CollapseAction::Terminal,
});
let residuals = residuals_of(&term);
let params = HarvestParams {
max_fusions: 0,
max_fissions: 0,
max_births: 0,
};
let report = harvest_move_proposals(&term, &rho, residuals.view(), ¶ms).unwrap();
let death_atoms: Vec<usize> = report
.proposals
.iter()
.filter_map(|p| match p.mv {
StructureMove::Death { atom } => Some(atom),
_ => None,
})
.collect();
assert!(
death_atoms.contains(&2),
"diverged ARD on atom 2 must yield a death proposal; got {death_atoms:?}"
);
assert!(
death_atoms.contains(&1),
"terminal collapse on atom 1 must yield a death proposal; got {death_atoms:?}"
);
}
#[test]
fn apply_move_restructures_warm() {
let n = 12usize;
let active: Vec<Vec<bool>> = (0..n).map(|row| vec![true, row % 2 == 0]).collect();
let (term, rho) = planted_term(&active);
let k0 = term.k_atoms();
let (fissioned, fissioned_rho) =
apply_structure_move(&term, &rho, &StructureMove::Fission { atom: 0 }, &[]).unwrap();
assert_eq!(fissioned.k_atoms(), k0 + 1);
assert_eq!(fissioned_rho.log_ard.len(), k0 + 1);
let (fused, _) =
apply_structure_move(&term, &rho, &StructureMove::Fusion { a: 0, b: 1 }, &[]).unwrap();
assert_eq!(fused.k_atoms(), k0);
let fused_assign = fused.assignment.assignments();
assert!(
fused_assign.column(1).iter().all(|&m| m < 1e-6),
"fused-away atom 1 must route to ~0 mass"
);
let (dead, _) =
apply_structure_move(&term, &rho, &StructureMove::Death { atom: 1 }, &[]).unwrap();
assert_eq!(dead.k_atoms(), k0);
let dead_assign = dead.assignment.assignments();
assert!(dead_assign.column(1).iter().all(|&m| m < 1e-6));
let p = term.output_dim();
let m = term.atoms[0].basis_size();
let mut decoder = Array2::<f64>::zeros((m, p));
decoder[[0, 0]] = 0.7;
let (born, born_rho) = apply_structure_move(
&term,
&rho,
&StructureMove::Birth { candidate: 0 },
&[decoder],
)
.unwrap();
assert_eq!(born.k_atoms(), k0 + 1);
assert_eq!(born_rho.log_ard.len(), k0 + 1);
assert_eq!(born.atoms[k0].decoder_coefficients[[0, 0]], 0.7);
}
#[test]
fn round_driver_ledger_is_byte_deterministic() {
let n = 24usize;
let active: Vec<Vec<bool>> = (0..n)
.map(|row| {
let dup = row % 3 == 0;
vec![dup, dup, row % 2 == 0]
})
.collect();
let run = || {
let (term, rho) = planted_term(&active);
let target = Array2::<f64>::zeros((n, term.output_dim()));
let mut ledger = crate::inference::structure_evidence::StructureLedger::new();
let budget = MoveBudget {
max_moves: 4,
alpha: 0.05,
};
let params = HarvestParams {
max_fusions: 4,
max_fissions: 0,
max_births: 0,
};
let config = RoundDriverConfig {
n_shards: 3,
budget,
max_rounds: 2,
harvest_params: params,
};
run_structure_search_rounds(term, rho, target.view(), config, &mut ledger, |t, r, _| {
(t, r)
})
.unwrap()
};
let a = run();
let b = run();
let sa = serde_json::to_string(&a.rounds).unwrap();
let sb = serde_json::to_string(&b.rounds).unwrap();
assert_eq!(
sa, sb,
"identical inputs must produce a byte-identical ledger"
);
assert_eq!(a.term.k_atoms(), b.term.k_atoms());
}
#[test]
fn estimation_eval_split_is_disjoint() {
let target = Array2::<f64>::zeros((20, 3));
let split = estimation_eval_split(target.view(), 4);
assert!(!split.estimation_rows.is_empty());
assert!(!split.shards.is_empty());
let est: std::collections::HashSet<usize> = split.estimation_rows.iter().copied().collect();
for shard in &split.shards {
for &row in &shard.rows {
assert!(
!est.contains(&row),
"eval shard row {row} must not be in the estimation set"
);
}
}
}
}