use ndarray::{Array1, Array2, ArrayView2};
use faer::Side;
use gam_linalg::faer_ndarray::FaerEigh;
use gam_solve::inference::residual_factor::{ResidualFactorInput, StructuredResidualModel};
use gam_solve::structure_search::StructureMove;
use crate::structure_harvest::apply_structure_move;
use super::*;
#[derive(Clone, Copy, Debug)]
pub struct StagewiseConfig {
pub inner_max_iter: usize,
pub learning_rate: f64,
pub ridge_ext_coord: f64,
pub ridge_beta: f64,
pub max_births: usize,
pub max_backfit_sweeps: usize,
pub min_effect_ev: f64,
pub max_factor_rank: usize,
pub structured_whitening: bool,
}
impl Default for StagewiseConfig {
fn default() -> Self {
Self {
inner_max_iter: 64,
learning_rate: 1.0,
ridge_ext_coord: 1e-6,
ridge_beta: 1e-6,
max_births: 32,
max_backfit_sweeps: 4,
min_effect_ev: 0.0,
max_factor_rank: 4,
structured_whitening: true,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BirthKind {
NewAtom,
ChartExtension,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum StagewiseStop {
TwoConsecutiveRejections,
MaxBirths,
NoResidualStructure,
}
#[derive(Clone, Copy, Debug)]
pub struct BirthRecord {
pub kind: BirthKind,
pub delta_ev: f64,
pub factor_energy: f64,
pub joint_reml_before: f64,
pub joint_reml_after: f64,
pub accepted: bool,
}
#[derive(Clone, Debug)]
pub struct StagewiseReport {
pub births_accepted: usize,
pub births_rejected: usize,
pub birth_records: Vec<BirthRecord>,
pub ev_trace: Vec<f64>,
pub backfit_ev_trace: Vec<f64>,
pub stopped_reason: StagewiseStop,
pub terminal_joint_reml: f64,
pub terminal_joint_loss: SaeManifoldLoss,
}
#[derive(Clone, Debug)]
pub struct StagewiseResult {
pub term: SaeManifoldTerm,
pub rho: SaeManifoldRho,
pub report: StagewiseReport,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum StagewiseEventKind {
SeedReady,
BirthRoundStarted,
ResidualModelStarted,
ResidualModelFitted,
CurrentEvidenceStarted,
CurrentEvidenceFinished,
CandidateStarted,
CandidateFinished,
BirthAccepted,
BirthRejected,
BackfitSweepStarted,
BackfitSweepAccepted,
BackfitSweepRejected,
TerminalEvidenceCompleted,
}
pub struct StagewiseProgress<'a> {
pub event: StagewiseEventKind,
pub birth_round: usize,
pub backfit_sweep: usize,
pub candidate: Option<BirthKind>,
pub accepted: Option<bool>,
pub checkpoint: bool,
pub k_atoms: usize,
pub births_accepted: usize,
pub births_rejected: usize,
pub ev: Option<f64>,
pub factor_energy: Option<f64>,
pub joint_reml_before: Option<f64>,
pub joint_reml_after: Option<f64>,
pub terminal_joint_reml: Option<f64>,
pub term: &'a SaeManifoldTerm,
pub rho: &'a SaeManifoldRho,
}
pub type StagewiseProgressCallback<'cb> =
dyn for<'event> FnMut(StagewiseProgress<'event>) -> Result<(), String> + 'cb;
fn emit_stagewise_progress(
progress: &mut Option<&mut StagewiseProgressCallback<'_>>,
event: StagewiseProgress<'_>,
) -> Result<(), String> {
if let Some(callback) = progress.as_deref_mut() {
callback(event)?;
}
Ok(())
}
fn current_residual(
term: &SaeManifoldTerm,
target: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let fitted = term.try_fitted()?;
Ok(&target.to_owned() - &fitted)
}
pub fn frozen_joint_evidence(
term: &mut SaeManifoldTerm,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
config: &StagewiseConfig,
) -> Result<(f64, SaeManifoldLoss), String> {
term.reml_criterion(
target,
rho,
registry,
0,
config.learning_rate,
config.ridge_ext_coord,
config.ridge_beta,
)
}
fn ev_of(term: &SaeManifoldTerm, target: ArrayView2<'_, f64>) -> f64 {
match term.try_fitted() {
Ok(fitted) => reconstruction_explained_variance(target, fitted.view()).unwrap_or(f64::NAN),
Err(_) => f64::NAN,
}
}
fn activity_of(term: &SaeManifoldTerm) -> Array1<f64> {
let assignments = term.assignment.assignments();
let n = assignments.nrows();
(0..n).map(|r| assignments.row(r).sum()).collect()
}
fn fit_residual_covariance(
term: &SaeManifoldTerm,
target: ArrayView2<'_, f64>,
config: &StagewiseConfig,
) -> Result<Option<(Array2<f64>, StructuredResidualModel)>, String> {
let residual = current_residual(term, target)?;
let (n, p) = residual.dim();
if n == 0 || p < 2 {
return Ok(None);
}
let activity = activity_of(term);
let max_rank = config.max_factor_rank.min(p.saturating_sub(1)).max(1);
match StructuredResidualModel::fit(ResidualFactorInput {
residuals: residual.view(),
activity: activity.view(),
max_factor_rank: max_rank,
}) {
Ok(model) => Ok(Some((residual, model))),
Err(_) => Ok(None),
}
}
fn fit_single_atom_response_in_place(
term: &mut SaeManifoldTerm,
rho: &mut SaeManifoldRho,
atom_idx: usize,
response: ArrayView2<'_, f64>,
registry: Option<&AnalyticPenaltyRegistry>,
config: &StagewiseConfig,
) -> Result<(), String> {
let n = term.n_obs();
let k = term.k_atoms();
if atom_idx >= k {
return Err(format!(
"fit_single_atom_response_in_place: atom {atom_idx} out of range (K={k})"
));
}
let sub_atom = term.atoms[atom_idx].clone();
let coord_block = term.assignment.coords[atom_idx].clone();
let mut sub_logits = Array2::<f64>::zeros((n, 1));
for row in 0..n {
sub_logits[[row, 0]] = term.assignment.logits[[row, atom_idx]];
}
let sub_assignment =
SaeAssignment::with_mode(sub_logits, vec![coord_block], term.assignment.mode)?;
let mut sub_term = SaeManifoldTerm::new(vec![sub_atom], sub_assignment)?;
sub_term.set_guards_enabled(false);
if let Some(w) = term.row_loss_weights().map(|w| w.to_vec()) {
sub_term.set_row_loss_weights(w)?;
}
if let Some(metric) = term.row_metric().cloned() {
sub_term.set_row_metric(metric)?;
}
let mut sub_rho = SaeManifoldRho::with_per_atom_smooth(
rho.log_lambda_sparse,
vec![*rho.log_lambda_smooth.get(atom_idx).unwrap_or(&0.0)],
vec![
rho.log_ard
.get(atom_idx)
.cloned()
.unwrap_or_else(|| Array1::zeros(0)),
],
);
sub_term.run_joint_fit_arrow_schur(
response,
&mut sub_rho,
registry,
config.inner_max_iter,
config.learning_rate,
config.ridge_ext_coord,
config.ridge_beta,
)?;
term.atoms[atom_idx] = sub_term.atoms[0].clone();
term.assignment.coords[atom_idx] = sub_term.assignment.coords[0].clone();
for row in 0..n {
term.assignment.logits[[row, atom_idx]] = sub_term.assignment.logits[[row, 0]];
}
if atom_idx < rho.log_lambda_smooth.len() {
rho.log_lambda_smooth[atom_idx] = sub_rho.log_lambda_smooth[0];
}
if atom_idx < rho.log_ard.len() {
rho.log_ard[atom_idx] = sub_rho.log_ard[0].clone();
}
term.assignment.frozen_logits = None;
term.last_row_layout = None;
term.last_frames_active = false;
term.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
Ok(())
}
fn birth_anchor_weights(term: &SaeManifoldTerm) -> Array1<f64> {
let activity = activity_of(term);
let m_max = activity.iter().copied().fold(0.0_f64, f64::max);
if m_max > 0.0 {
activity.mapv(|m| (m_max - m).max(0.0))
} else {
Array1::ones(activity.len())
}
}
fn top_factor_birth_decoder(
term: &SaeManifoldTerm,
model: &StructuredResidualModel,
residual: ArrayView2<'_, f64>,
) -> Option<(Array2<f64>, f64)> {
let r = model.factor_rank();
if r == 0 {
return None;
}
let factor = model.factor(); let p = factor.nrows();
let (n, p_res) = residual.dim();
if p_res != p || n == 0 {
return None;
}
let anchor_w = birth_anchor_weights(term);
let anchor_total: f64 = anchor_w.iter().sum();
let use_anchor = anchor_total > 0.0;
let mut best_j = 0usize;
let mut best_score = f64::NEG_INFINITY;
if use_anchor {
for j in 0..r {
let col = factor.column(j);
let energy: f64 = col.iter().map(|v| v * v).sum();
if !(energy > 0.0) {
continue;
}
let inv_norm = 1.0 / energy.sqrt();
let mut num = 0.0_f64; let mut den = 0.0_f64; for i in 0..n {
let mut proj = 0.0_f64;
for out in 0..p {
proj += residual[[i, out]] * col[out];
}
let s = (proj * inv_norm) * (proj * inv_norm);
num += anchor_w[i] * s;
den += s;
}
if den <= 0.0 {
continue;
}
let score = num / den;
if score > best_score {
best_score = score;
best_j = j;
}
}
}
let chosen = if use_anchor { best_j } else { 0 };
let energy: f64 = factor.column(chosen).iter().map(|v| v * v).sum();
if !(energy > 0.0) {
return None;
}
let m = term.atoms[0].basis_size();
let mut decoder = Array2::<f64>::zeros((m, p));
for out in 0..p {
decoder[[0, out]] = factor[[out, chosen]];
}
Some((decoder, energy))
}
fn residual_principal_birth_candidate(
term: &SaeManifoldTerm,
residual: ArrayView2<'_, f64>,
) -> Option<(Array2<f64>, f64)> {
let (n, p) = residual.dim();
if n < 2 || p == 0 || term.atoms.is_empty() {
return None;
}
let mut mean = Array1::<f64>::zeros(p);
for row in 0..n {
for j in 0..p {
mean[j] += residual[[row, j]];
}
}
mean.mapv_inplace(|v| v / n as f64);
let mut s = Array2::<f64>::zeros((p, p));
for row in 0..n {
for a in 0..p {
let ra = residual[[row, a]] - mean[a];
for b in 0..p {
s[[a, b]] += ra * (residual[[row, b]] - mean[b]);
}
}
}
s.mapv_inplace(|v| v / n as f64);
let (evals, evecs) = s.eigh(Side::Lower).ok()?; if evals.is_empty() {
return None;
}
let mut ascending: Vec<f64> = evals.iter().copied().collect();
ascending.sort_by(|a, b| a.total_cmp(b));
let mid = ascending.len() / 2;
let sigma2 = if ascending.len() % 2 == 1 {
ascending[mid]
} else {
0.5 * (ascending[mid - 1] + ascending[mid])
}
.max(f64::MIN_POSITIVE);
let gamma = p as f64 / n as f64;
let mp_edge = sigma2 * (1.0 + gamma.sqrt()).powi(2);
let mut above: Vec<usize> = (0..evals.len()).filter(|&k| evals[k] > mp_edge).collect();
above.sort_by(|&a, &b| evals[b].total_cmp(&evals[a]));
if above.is_empty() {
return None; }
let anchor_w = birth_anchor_weights(term);
let mut best = above[0];
if anchor_w.iter().sum::<f64>() > 0.0 {
let mut best_score = f64::NEG_INFINITY;
for &k in &above {
let col = evecs.column(k); let mut num = 0.0_f64;
let mut den = 0.0_f64;
for i in 0..n {
let mut proj = 0.0_f64;
for j in 0..p {
proj += residual[[i, j]] * col[j];
}
let si = proj * proj;
num += anchor_w[i] * si;
den += si;
}
if den > 0.0 {
let score = num / den;
if score > best_score {
best_score = score;
best = k;
}
}
}
}
let energy = evals[best].max(0.0);
if !(energy > 0.0) {
return None;
}
let amp = energy.sqrt();
let m = term.atoms[0].basis_size();
let mut decoder = Array2::<f64>::zeros((m, p));
for j in 0..p {
decoder[[0, j]] = amp * evecs[[j, best]];
}
Some((decoder, energy))
}
fn refit_single_atom_in_place(
term: &mut SaeManifoldTerm,
rho: &SaeManifoldRho,
atom_idx: usize,
target: ArrayView2<'_, f64>,
registry: Option<&AnalyticPenaltyRegistry>,
config: &StagewiseConfig,
) -> Result<(), String> {
let n = term.n_obs();
let p = term.output_dim();
let k = term.k_atoms();
if atom_idx >= k {
return Err(format!(
"refit_single_atom_in_place: atom {atom_idx} out of range (K={k})"
));
}
let full = term.try_fitted_for_rho(rho)?;
let mut e_k = &target.to_owned() - &full;
let mut g_buf = vec![0.0_f64; p];
for row in 0..n {
let weights = term.assignment.try_assignments_row_for_rho(row, rho)?;
let a_k = weights[atom_idx];
if a_k == 0.0 {
continue;
}
term.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
let mut e_row = e_k.row_mut(row);
for out in 0..p {
e_row[out] += a_k * g_buf[out];
}
}
let mut rho_scratch = rho.clone();
fit_single_atom_response_in_place(
term,
&mut rho_scratch,
atom_idx,
e_k.view(),
registry,
config,
)
}
fn backfit_sweep(
term: &mut SaeManifoldTerm,
rho: &mut SaeManifoldRho,
target: ArrayView2<'_, f64>,
registry: Option<&AnalyticPenaltyRegistry>,
config: &StagewiseConfig,
) -> Result<(), String> {
term.set_guards_enabled(false);
term.run_fixed_decoder_arrow_schur(
target,
rho,
registry,
1,
config.learning_rate,
config.ridge_ext_coord,
)
.ok();
term.run_joint_fit_arrow_schur(
target,
rho,
registry,
config.inner_max_iter,
config.learning_rate,
config.ridge_ext_coord,
config.ridge_beta,
)?;
Ok(())
}
pub fn fit_stagewise(
seed: SaeManifoldTerm,
mut rho: SaeManifoldRho,
target: ArrayView2<'_, f64>,
registry: Option<&AnalyticPenaltyRegistry>,
sample_weights: Option<&[f64]>,
config: &StagewiseConfig,
mut progress: Option<&mut StagewiseProgressCallback<'_>>,
) -> Result<StagewiseResult, String> {
let n = target.nrows();
if seed.k_atoms() != 1 {
return Err(format!(
"fit_stagewise: seed must be a single-atom (K=1) term; got K={}",
seed.k_atoms()
));
}
if seed.n_obs() != n {
return Err(format!(
"fit_stagewise: seed n_obs {} != target rows {n}",
seed.n_obs()
));
}
let mut term = seed;
term.set_guards_enabled(false);
if let Some(w) = sample_weights {
if w.len() != n {
return Err(format!(
"fit_stagewise: sample_weights length {} != target rows {n}",
w.len()
));
}
term.set_row_loss_weights(w.to_vec())?;
}
let mut ev_trace = vec![ev_of(&term, target)];
let mut birth_records: Vec<BirthRecord> = Vec::new();
let mut births_accepted = 0usize;
let mut births_rejected = 0usize;
let mut consecutive_rejections = 0usize;
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::SeedReady,
birth_round: 0,
backfit_sweep: 0,
candidate: None,
accepted: Some(true),
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: ev_trace.last().copied(),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
let mut birth_round = 0usize;
let stopped_reason = loop {
if births_accepted >= config.max_births {
break StagewiseStop::MaxBirths;
}
if consecutive_rejections >= 2 {
break StagewiseStop::TwoConsecutiveRejections;
}
let round = birth_round;
birth_round += 1;
let entry_ev = ev_of(&term, target);
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::BirthRoundStarted,
birth_round: round,
backfit_sweep: 0,
candidate: None,
accepted: None,
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(entry_ev),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::ResidualModelStarted,
birth_round: round,
backfit_sweep: 0,
candidate: None,
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(entry_ev),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
let Some((residual, model)) = fit_residual_covariance(&term, target, config)? else {
break StagewiseStop::NoResidualStructure;
};
let Some((birth_decoder, factor_energy)) =
top_factor_birth_decoder(&term, &model, residual.view())
.or_else(|| residual_principal_birth_candidate(&term, residual.view()))
else {
break StagewiseStop::NoResidualStructure;
};
if config.structured_whitening {
term.set_row_metric(model.row_metric(n)?)?;
}
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::ResidualModelFitted,
birth_round: round,
backfit_sweep: 0,
candidate: None,
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(entry_ev),
factor_energy: Some(factor_energy),
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CurrentEvidenceStarted,
birth_round: round,
backfit_sweep: 0,
candidate: None,
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(entry_ev),
factor_energy: Some(factor_energy),
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
let (cur_reml, _) = frozen_joint_evidence(&mut term, target, &rho, registry, config)?;
let cur_ev = ev_of(&term, target);
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CurrentEvidenceFinished,
birth_round: round,
backfit_sweep: 0,
candidate: None,
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(cur_ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CandidateStarted,
birth_round: round,
backfit_sweep: 0,
candidate: Some(BirthKind::NewAtom),
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(cur_ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
let mut cand_a = apply_structure_move(
&term,
&rho,
&StructureMove::Birth { candidate: 0 },
std::slice::from_ref(&birth_decoder),
)
.and_then(|(mut cand_term, mut cand_rho)| {
cand_term.set_guards_enabled(false);
let born = cand_term.k_atoms() - 1;
fit_single_atom_response_in_place(
&mut cand_term,
&mut cand_rho,
born,
residual.view(),
registry,
config,
)?;
let (reml, _) =
frozen_joint_evidence(&mut cand_term, target, &cand_rho, registry, config)?;
let ev = ev_of(&cand_term, target);
Ok((cand_term, cand_rho, reml, ev))
})
.ok();
if let Some((cand_term, cand_rho, reml, ev)) = cand_a.as_ref() {
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CandidateFinished,
birth_round: round,
backfit_sweep: 0,
candidate: Some(BirthKind::NewAtom),
accepted: None,
checkpoint: false,
k_atoms: cand_term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(*ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: Some(*reml),
terminal_joint_reml: None,
term: cand_term,
rho: cand_rho,
},
)?;
} else {
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CandidateFinished,
birth_round: round,
backfit_sweep: 0,
candidate: Some(BirthKind::NewAtom),
accepted: Some(false),
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(cur_ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
}
let mut cand_b = if term.k_atoms() > 1 {
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CandidateStarted,
birth_round: round,
backfit_sweep: 0,
candidate: Some(BirthKind::ChartExtension),
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(cur_ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
let last = term.k_atoms() - 1;
let mut cand_term = term.clone();
let mut cand_rho = rho.clone();
let built = (|| -> Result<(SaeManifoldTerm, SaeManifoldRho, f64, f64), String> {
refit_single_atom_in_place(
&mut cand_term,
&cand_rho,
last,
target,
registry,
config,
)?;
cand_term.set_guards_enabled(false);
cand_term.run_joint_fit_arrow_schur(
target,
&mut cand_rho,
registry,
config.inner_max_iter,
config.learning_rate,
config.ridge_ext_coord,
config.ridge_beta,
)?;
let (reml, _) =
frozen_joint_evidence(&mut cand_term, target, &cand_rho, registry, config)?;
let ev = ev_of(&cand_term, target);
Ok((cand_term, cand_rho, reml, ev))
})();
let out = built.ok();
if let Some((cand_term, cand_rho, reml, ev)) = out.as_ref() {
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CandidateFinished,
birth_round: round,
backfit_sweep: 0,
candidate: Some(BirthKind::ChartExtension),
accepted: None,
checkpoint: false,
k_atoms: cand_term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(*ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: Some(*reml),
terminal_joint_reml: None,
term: cand_term,
rho: cand_rho,
},
)?;
} else {
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::CandidateFinished,
birth_round: round,
backfit_sweep: 0,
candidate: Some(BirthKind::ChartExtension),
accepted: Some(false),
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(cur_ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
}
out
} else {
None
};
let passes = |reml: f64, ev: f64| -> bool {
reml.is_finite()
&& reml < cur_reml
&& ev.is_finite()
&& (ev - cur_ev) >= config.min_effect_ev
};
let a_ok = cand_a
.as_ref()
.map(|&(_, _, r, e)| passes(r, e))
.unwrap_or(false);
let b_ok = cand_b
.as_ref()
.map(|&(_, _, r, e)| passes(r, e))
.unwrap_or(false);
let choose_a = match (a_ok, b_ok) {
(true, true) => {
let ar = cand_a.as_ref().unwrap().2;
let br = cand_b.as_ref().unwrap().2;
ar <= br
}
(true, false) => true,
(false, true) => false,
(false, false) => {
births_rejected += 1;
consecutive_rejections += 1;
birth_records.push(BirthRecord {
kind: BirthKind::NewAtom,
delta_ev: 0.0,
factor_energy,
joint_reml_before: cur_reml,
joint_reml_after: cur_reml,
accepted: false,
});
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::BirthRejected,
birth_round: round,
backfit_sweep: 0,
candidate: None,
accepted: Some(false),
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(cur_ev),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: Some(cur_reml),
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
continue;
}
};
let (kind, (cand_term, cand_rho, reml_after, ev_after)) = if choose_a {
(BirthKind::NewAtom, cand_a.take().unwrap())
} else {
(BirthKind::ChartExtension, cand_b.take().unwrap())
};
term = cand_term;
rho = cand_rho;
births_accepted += 1;
consecutive_rejections = 0;
birth_records.push(BirthRecord {
kind,
delta_ev: ev_after - cur_ev,
factor_energy,
joint_reml_before: cur_reml,
joint_reml_after: reml_after,
accepted: true,
});
ev_trace.push(ev_after);
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::BirthAccepted,
birth_round: round,
backfit_sweep: 0,
candidate: Some(kind),
accepted: Some(true),
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(ev_after),
factor_energy: Some(factor_energy),
joint_reml_before: Some(cur_reml),
joint_reml_after: Some(reml_after),
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
};
let mut backfit_ev_trace: Vec<f64> = Vec::new();
let mut prev_ev = *ev_trace.last().unwrap_or(&f64::NEG_INFINITY);
for sweep in 0..config.max_backfit_sweeps {
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::BackfitSweepStarted,
birth_round,
backfit_sweep: sweep,
candidate: None,
accepted: None,
checkpoint: false,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(prev_ev),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
let term_snapshot = term.clone();
let rho_snapshot = rho.clone();
backfit_sweep(&mut term, &mut rho, target, registry, config)?;
let ev = ev_of(&term, target);
if ev > prev_ev {
backfit_ev_trace.push(ev);
prev_ev = ev;
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::BackfitSweepAccepted,
birth_round,
backfit_sweep: sweep,
candidate: None,
accepted: Some(true),
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(ev),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
} else {
term = term_snapshot;
rho = rho_snapshot;
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::BackfitSweepRejected,
birth_round,
backfit_sweep: sweep,
candidate: None,
accepted: Some(false),
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(prev_ev),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: None,
terminal_joint_reml: None,
term: &term,
rho: &rho,
},
)?;
break;
}
}
let (terminal_joint_reml, terminal_joint_loss) =
frozen_joint_evidence(&mut term, target, &rho, registry, config)?;
term.set_guards_enabled(true);
emit_stagewise_progress(
&mut progress,
StagewiseProgress {
event: StagewiseEventKind::TerminalEvidenceCompleted,
birth_round,
backfit_sweep: backfit_ev_trace.len(),
candidate: None,
accepted: Some(true),
checkpoint: true,
k_atoms: term.k_atoms(),
births_accepted,
births_rejected,
ev: Some(prev_ev),
factor_energy: None,
joint_reml_before: None,
joint_reml_after: Some(terminal_joint_reml),
terminal_joint_reml: Some(terminal_joint_reml),
term: &term,
rho: &rho,
},
)?;
Ok(StagewiseResult {
term,
rho,
report: StagewiseReport {
births_accepted,
births_rejected,
birth_records,
ev_trace,
backfit_ev_trace,
stopped_reason,
terminal_joint_reml,
terminal_joint_loss,
},
})
}
pub fn terminal_joint_assembly(
primary: SaeManifoldTerm,
primary_rho: &SaeManifoldRho,
secondary: SaeManifoldTerm,
secondary_rho: &SaeManifoldRho,
target: ArrayView2<'_, f64>,
registry: Option<&AnalyticPenaltyRegistry>,
config: &StagewiseConfig,
) -> Result<(SaeManifoldTerm, SaeManifoldRho, f64, SaeManifoldLoss), String> {
let (mut merged, merged_rho) =
SaeManifoldTerm::merge_tiers(primary, primary_rho, secondary, secondary_rho)?;
merged.set_guards_enabled(false);
let (reml, loss) = frozen_joint_evidence(&mut merged, target, &merged_rho, registry, config)?;
merged.set_guards_enabled(true);
Ok((merged, merged_rho, reml, loss))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::manifold::{
AssignmentMode, PeriodicHarmonicEvaluator, SaeAssignment, SaeAtomBasisKind,
SaeBasisEvaluator, SaeManifoldAtom,
};
use gam_terms::latent::LatentManifold;
use ndarray::Array2;
use std::sync::Arc;
const ON: f64 = 6.0;
const OFF: f64 = -6.0;
fn test_config() -> StagewiseConfig {
StagewiseConfig {
inner_max_iter: 24,
learning_rate: 1.0,
ridge_ext_coord: 1e-6,
ridge_beta: 1e-6,
max_births: 3,
max_backfit_sweeps: 2,
min_effect_ev: 0.0,
max_factor_rank: 3,
structured_whitening: false,
}
}
fn circle_atom(
name: &str,
evaluator: &Arc<PeriodicHarmonicEvaluator>,
coords: &Array2<f64>,
dir_a: usize,
dir_b: usize,
p: usize,
) -> (SaeManifoldAtom, Array2<f64>) {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mut decoder = Array2::<f64>::zeros((3, p));
decoder[[1, dir_a % p]] = 1.0;
decoder[[2, dir_b % p]] = 1.0;
let atom = SaeManifoldAtom::new(
name.to_string(),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
(atom, coords.clone())
}
fn build_term(
atoms: Vec<SaeManifoldAtom>,
coord_blocks: Vec<Array2<f64>>,
active: &[Vec<bool>],
) -> (SaeManifoldTerm, SaeManifoldRho) {
let n = active.len();
let k = atoms.len();
let mut logits = Array2::<f64>::zeros((n, k));
for (row, atom_active) in active.iter().enumerate() {
for (atom, &on) in atom_active.iter().enumerate() {
logits[[row, atom]] = if on { ON } else { OFF };
}
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coord_blocks,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
(term, rho)
}
fn fitted_seed(
mut seed: SaeManifoldTerm,
mut rho: SaeManifoldRho,
target: ArrayView2<'_, f64>,
config: &StagewiseConfig,
) -> (SaeManifoldTerm, SaeManifoldRho) {
seed.set_guards_enabled(false);
seed.run_joint_fit_arrow_schur(
target,
&mut rho,
None,
config.inner_max_iter,
config.learning_rate,
config.ridge_ext_coord,
config.ridge_beta,
)
.expect("test seed K=1 fit must complete before stagewise entry");
(seed, rho)
}
fn is_non_decreasing(xs: &[f64]) -> bool {
xs.windows(2).all(|w| {
let tol = 1e-9 * (1.0 + w[0].abs());
w[1] >= w[0] - tol
})
}
#[test]
fn stagewise_recovers_planted_two_circles_ev_monotone() {
let n = 48usize;
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (atom0, cb0) = circle_atom("t0", &evaluator, &coords, 0, 1, p);
let (atom1, cb1) = circle_atom("t1", &evaluator, &coords, 2, 3, p);
let active_truth: Vec<Vec<bool>> = (0..n).map(|r| vec![r < n / 2, r >= n / 2]).collect();
let (truth, _truth_rho) = build_term(
vec![atom0.clone(), atom1.clone()],
vec![cb0.clone(), cb1.clone()],
&active_truth,
);
let target = truth.fitted();
let config = test_config();
let (seed, rho) = build_term(vec![atom0], vec![cb0], &vec![vec![true]; n]);
let (seed, rho) = fitted_seed(seed, rho, target.view(), &config);
let result = fit_stagewise(seed, rho, target.view(), None, None, &config, None)
.expect("fit_stagewise must complete on planted two-circles");
assert!(
is_non_decreasing(&result.report.ev_trace),
"EV must be monotone non-decreasing in births by construction; got {:?}",
result.report.ev_trace
);
assert!(
result.report.terminal_joint_reml.is_finite(),
"terminal frozen joint REML must be finite"
);
let seed_ev = result.report.ev_trace[0];
let final_ev = *result.report.ev_trace.last().unwrap();
assert!(
final_ev >= seed_ev - 1e-9,
"final EV {final_ev} must not fall below the seed EV {seed_ev}"
);
assert_eq!(
result.term.k_atoms(),
1 + result.report.births_accepted,
"K must equal the seed atom plus the accepted new-atom births"
);
}
#[test]
fn duplicate_atom_birth_is_rejected() {
let n = 40usize;
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (atom0, cb0) = circle_atom("t0", &evaluator, &coords, 0, 1, p);
let (truth, _rho) =
build_term(vec![atom0.clone()], vec![cb0.clone()], &vec![vec![true]; n]);
let target = truth.fitted();
let config = StagewiseConfig {
min_effect_ev: 0.01,
..test_config()
};
let (seed, rho) = build_term(vec![atom0], vec![cb0], &vec![vec![true]; n]);
let (seed, rho) = fitted_seed(seed, rho, target.view(), &config);
let result = fit_stagewise(seed, rho, target.view(), None, None, &config, None)
.expect("fit_stagewise must complete on a fully-explained target");
assert_eq!(
result.report.births_accepted, 0,
"a duplicate/empty residual must yield no accepted births"
);
assert_eq!(
result.term.k_atoms(),
1,
"K must stay at the single seed atom"
);
assert!(
is_non_decreasing(&result.report.ev_trace),
"EV trace must remain monotone"
);
}
#[test]
fn anchor_scored_birth_prefers_uncontested_factor_2080() {
use gam_solve::inference::residual_factor::{ResidualFactorInput, StructuredResidualModel};
let n = 120usize;
let p = 6usize;
let h = n / 2; let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
let d_a = [inv_sqrt2, inv_sqrt2, 0.0, 0.0, 0.0, 0.0];
let d_b = [0.0, 0.0, inv_sqrt2, inv_sqrt2, 0.0, 0.0];
let mut residual = Array2::<f64>::zeros((n, p));
for i in 0..n {
let s = (std::f64::consts::TAU * i as f64 / 11.0).cos(); let (dir, amp) = if i < h { (&d_a, 3.0) } else { (&d_b, 2.0) };
for j in 0..p {
residual[[i, j]] = amp * s * dir[j];
residual[[i, j]] += 0.04 * ((i * 7 + j * 13) as f64).sin();
}
}
let uniform_act = Array1::<f64>::ones(n);
let model = StructuredResidualModel::fit(ResidualFactorInput {
residuals: residual.view(),
activity: uniform_act.view(),
max_factor_rank: 2,
})
.unwrap();
assert!(model.factor_rank() >= 2, "need both planted factors");
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (atom0, cb0) = circle_atom("t0", &evaluator, &coords, 0, 1, p);
let build_ibp = |logit: &dyn Fn(usize) -> f64| -> SaeManifoldTerm {
let mut logits = Array2::<f64>::zeros((n, 1));
for row in 0..n {
logits[[row, 0]] = logit(row);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![cb0.clone()],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(1.0, 1.0, false),
)
.unwrap();
SaeManifoldTerm::new(vec![atom0.clone()], assignment).unwrap()
};
let contrast_term = build_ibp(&|row| if row < h { 3.0 } else { -3.0 });
let act = activity_of(&contrast_term);
assert!(
act[0] > act[n - 1] + 1e-6,
"IBP activity must be higher on contested rows (got {} vs {})",
act[0],
act[n - 1]
);
let (decoder, _energy) =
top_factor_birth_decoder(&contrast_term, &model, residual.view()).unwrap();
let pick_strong = decoder[[0, 0]].hypot(decoder[[0, 1]]); let pick_anchor = decoder[[0, 2]].hypot(decoder[[0, 3]]); assert!(
pick_anchor > pick_strong,
"anchor-scored birth must pick the UNCONTESTED (dB, channels 2,3) factor, not the \
dominant-variance (dA, channels 0,1) one: |dB|={pick_anchor:.4} |dA|={pick_strong:.4}"
);
let uniform_term = build_ibp(&|_| 0.5);
let (decoder_u, _e) =
top_factor_birth_decoder(&uniform_term, &model, residual.view()).unwrap();
let u_strong = decoder_u[[0, 0]].hypot(decoder_u[[0, 1]]);
let u_anchor = decoder_u[[0, 2]].hypot(decoder_u[[0, 3]]);
assert!(
u_strong > u_anchor,
"uniform routing must fall back to the dominant-energy factor (dA, channels 0,1): \
|dA|={u_strong:.4} |dB|={u_anchor:.4}"
);
}
#[test]
fn residual_principal_fallback_fires_on_disjoint_not_noise_2080() {
let n = 400usize;
let p = 8usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (atom0, cb0) = circle_atom("t0", &evaluator, &coords, 0, 1, p);
let (term, _rho) = build_term(vec![atom0], vec![cb0], &vec![vec![true]; n]);
let mut state = 0xC0FFEE_1234_5678_u64;
let mut rng = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
};
let mut residual = Array2::<f64>::zeros((n, p));
for i in 0..n {
let a = rng();
let b = rng();
residual[[i, 0]] = 2.0 * a;
residual[[i, 1]] = 2.0 * a; residual[[i, 2]] = 1.5 * b;
residual[[i, 3]] = 1.5 * b; for j in 0..p {
residual[[i, j]] += 0.03 * rng();
}
}
let got = residual_principal_birth_candidate(&term, residual.view());
let (decoder, energy) = got.expect(
"disjoint block-diagonal residual must yield a fallback candidate \
(structure above the derived MP noise floor)",
);
assert!(energy > 0.0 && energy.is_finite());
let sig_mass: f64 = (0..4).map(|j| decoder[[0, j]].powi(2)).sum();
let noise_mass: f64 = (4..p).map(|j| decoder[[0, j]].powi(2)).sum();
assert!(
sig_mass > noise_mass,
"fallback birth direction must land on the signal block (0-3), not noise: \
sig={sig_mass:.3e} noise={noise_mass:.3e}"
);
let mut noise = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
noise[[i, j]] = rng();
}
}
assert!(
residual_principal_birth_candidate(&term, noise.view()).is_none(),
"pure-noise residual must be below the derived MP floor ⇒ no candidate (stop)"
);
}
#[test]
fn progress_callback_emits_pre_birth_checkpoints() {
let n = 32usize;
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (atom0, cb0) = circle_atom("t0", &evaluator, &coords, 0, 1, p);
let (atom1, cb1) = circle_atom("t1", &evaluator, &coords, 2, 3, p);
let active_truth: Vec<Vec<bool>> = (0..n).map(|r| vec![r < n / 2, r >= n / 2]).collect();
let (truth, _rho) = build_term(
vec![atom0.clone(), atom1],
vec![cb0.clone(), cb1],
&active_truth,
);
let target = truth.fitted();
let config = StagewiseConfig {
max_births: 1,
max_backfit_sweeps: 0,
..test_config()
};
let (seed, rho) = build_term(vec![atom0], vec![cb0], &vec![vec![true]; n]);
let (seed, rho) = fitted_seed(seed, rho, target.view(), &config);
let mut events: Vec<(StagewiseEventKind, bool, usize, Option<BirthKind>)> = Vec::new();
let mut progress = |event: StagewiseProgress<'_>| -> Result<(), String> {
events.push((
event.event,
event.checkpoint,
event.k_atoms,
event.candidate,
));
Ok(())
};
fit_stagewise(
seed,
rho,
target.view(),
None,
None,
&config,
Some(&mut progress),
)
.expect("fit_stagewise must complete while emitting progress");
assert_eq!(
events.first().map(|event| event.0),
Some(StagewiseEventKind::SeedReady),
"the first callback must expose the fitted K=1 seed"
);
assert_eq!(
events.get(1).map(|event| event.0),
Some(StagewiseEventKind::BirthRoundStarted),
"the second callback must expose a durable birth-round checkpoint"
);
assert_eq!(events[0].1, true, "seed_ready must be checkpointable");
assert_eq!(
events[1].1, true,
"birth_round_started must be checkpointable before residual work"
);
assert_eq!(events[0].2, 1, "seed checkpoint must be K=1");
let pos = |kind: StagewiseEventKind| -> usize {
events
.iter()
.position(|event| event.0 == kind)
.expect("expected progress event")
};
assert!(
pos(StagewiseEventKind::ResidualModelStarted)
< pos(StagewiseEventKind::CurrentEvidenceStarted),
"residual-fit progress must precede current-evidence progress"
);
assert!(
pos(StagewiseEventKind::CurrentEvidenceStarted)
< pos(StagewiseEventKind::CandidateStarted),
"current-evidence progress must precede candidate fitting"
);
assert!(
events
.iter()
.any(|event| event.0 == StagewiseEventKind::CandidateStarted
&& event.3 == Some(BirthKind::NewAtom)),
"first birth must report the new-atom candidate"
);
}
#[test]
fn backfitting_ev_is_monotone() {
let n = 48usize;
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (atom0, cb0) = circle_atom("t0", &evaluator, &coords, 0, 1, p);
let (atom1, cb1) = circle_atom("t1", &evaluator, &coords, 2, 3, p);
let active_truth: Vec<Vec<bool>> = (0..n).map(|r| vec![r < n / 2, r >= n / 2]).collect();
let (truth, _rho) = build_term(
vec![atom0.clone(), atom1.clone()],
vec![cb0.clone(), cb1.clone()],
&active_truth,
);
let target = truth.fitted();
let config = test_config();
let (seed, rho) = build_term(vec![atom0], vec![cb0], &vec![vec![true]; n]);
let (seed, rho) = fitted_seed(seed, rho, target.view(), &config);
let result = fit_stagewise(seed, rho, target.view(), None, None, &config, None)
.expect("fit_stagewise must complete");
assert!(
is_non_decreasing(&result.report.backfit_ev_trace),
"backfitting EV must be monotone non-decreasing; got {:?}",
result.report.backfit_ev_trace
);
}
}