fn reconstruction_explained_variance(
target: ArrayView2<'_, f64>,
fitted: ArrayView2<'_, f64>,
) -> Option<f64> {
if target.dim() != fitted.dim() {
return None;
}
let (n, p) = target.dim();
if n == 0 || p == 0 {
return None;
}
let mut means = vec![0.0_f64; p];
for col in 0..p {
let mut acc = 0.0;
for row in 0..n {
acc += target[[row, col]];
}
means[col] = acc / n as f64;
}
let mut ssr = 0.0_f64;
let mut sst = 0.0_f64;
for row in 0..n {
for col in 0..p {
let residual = target[[row, col]] - fitted[[row, col]];
ssr += residual * residual;
let centered = target[[row, col]] - means[col];
sst += centered * centered;
}
}
if ssr.is_finite() && sst.is_finite() && sst > f64::MIN_POSITIVE {
Some(1.0 - ssr / sst)
} else {
None
}
}
pub struct SaeManifoldOuterObjective {
term: SaeManifoldTerm,
baseline_term: SaeManifoldTerm,
target: Array2<f64>,
registry: Option<AnalyticPenaltyRegistry>,
current_rho: SaeManifoldRho,
baseline_rho: SaeManifoldRho,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
last_loss: Option<SaeManifoldLoss>,
seeded_beta: Option<Array1<f64>>,
}
impl SaeManifoldOuterObjective {
pub fn new(
mut term: SaeManifoldTerm,
target: Array2<f64>,
registry: Option<AnalyticPenaltyRegistry>,
init_rho: SaeManifoldRho,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Self {
term.expected_evidence_gauge_deflated_directions = None;
let baseline_term = term.clone();
let baseline_rho = init_rho.clone();
Self {
term,
baseline_term,
target,
registry,
current_rho: init_rho,
baseline_rho,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
last_loss: None,
seeded_beta: None,
}
}
pub fn into_fitted(self) -> (SaeManifoldTerm, SaeManifoldRho, SaeManifoldLoss) {
let Self {
term,
mut baseline_term,
target,
registry,
current_rho,
baseline_rho,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
last_loss,
..
} = self;
let pristine_seed_term = baseline_term.clone();
let pristine_seed_rho = baseline_rho.clone();
let mut fitted_rho = current_rho;
let loss = last_loss.unwrap_or_else(|| SaeManifoldLoss {
data_fit: 0.0,
assignment_sparsity: 0.0,
smoothness: 0.0,
ard: 0.0,
evidence_gauge_deflated_directions: 0,
});
let settled_objective =
term.penalized_objective_total(target.view(), &fitted_rho, registry.as_ref(), 1.0);
let mut rho_seed = fitted_rho.clone();
let seed_solve = match baseline_term.streaming_plan().admitted_or_error(
baseline_term.n_obs(),
baseline_term.output_dim(),
baseline_term.k_atoms(),
) {
Ok(plan)
if plan.streaming
&& plan.estimated_full_batch_bytes > plan.in_core_budget_bytes
&& plan.estimated_dense_schur_bytes <= plan.in_core_budget_bytes =>
{
baseline_term.fit_streaming_in_memory(
target.view(),
&mut rho_seed,
registry.as_ref(),
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)
}
Ok(_) => baseline_term.run_joint_fit_arrow_schur(
target.view(),
&mut rho_seed,
registry.as_ref(),
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
),
Err(err) => Err(err),
};
let mut seed_won = false;
if let (Ok(settled_total), Ok(_)) = (&settled_objective, &seed_solve) {
let seed_total = baseline_term.penalized_objective_total(
target.view(),
&fitted_rho,
registry.as_ref(),
1.0,
);
if let Ok(seed_total) = seed_total {
seed_won = seed_total.is_finite() && seed_total < *settled_total;
}
}
let (mut fitted, mut fitted_loss) = if seed_won {
let seed_loss = seed_solve.expect("seed_won implies seed_solve is Ok");
(baseline_term, seed_loss)
} else {
(term, loss)
};
if let (Ok(seed_fit), Ok(returned_fit)) = (
pristine_seed_term.try_fitted_for_rho(&pristine_seed_rho),
fitted.try_fitted_for_rho(&fitted_rho),
) && let (Some(seed_ev), Some(returned_ev)) = (
reconstruction_explained_variance(target.view(), seed_fit.view()),
reconstruction_explained_variance(target.view(), returned_fit.view()),
) && seed_ev >= SAE_PRISTINE_SEED_EV_RETAIN_FLOOR
&& returned_ev < SAE_PRISTINE_SEED_EV_RETAIN_FLOOR
&& returned_ev + SAE_FINAL_EV_DEGRADATION_TOL < seed_ev
&& let Ok(seed_loss) = pristine_seed_term.loss(target.view(), &pristine_seed_rho)
{
fitted = pristine_seed_term;
fitted_rho = pristine_seed_rho;
fitted_loss = seed_loss;
}
if let Err(err) =
fitted.canonicalize_charts_post_fit(target.view(), &fitted_rho, registry.as_ref())
{
log::debug!("into_fitted: chart canonicalization refused: {err}");
}
if fitted.assignment.persist_resolved_ibp_alpha(&fitted_rho) {
fitted_rho.log_lambda_sparse = 0.0;
}
(fitted, fitted_rho, fitted_loss)
}
pub fn optimality_certificate(&mut self) -> Result<CriterionCertificate, String> {
let rho_hat_flat = self.current_rho.to_flat();
let dir = deterministic_probe_direction(rho_hat_flat.view());
let h = probe_step(rho_hat_flat.view());
let rho_hat = self.current_rho.clone();
let (_v_hat, loss_hat, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
&rho_hat,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
let solver = self.term.outer_gradient_arrow_solver(&cache)?;
let components = self.term.analytic_outer_rho_gradient_components(
self.target.view(),
&rho_hat,
&loss_hat,
&cache,
&solver,
)?;
let grad = components.gradient_with_available_correction();
let grad_norm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
let analytic_directional: f64 = grad.iter().zip(dir.iter()).map(|(g, d)| g * d).sum();
let mut probe_term = self.baseline_term.clone();
let value_at = |term: &mut SaeManifoldTerm, mult: f64| -> Result<f64, String> {
let flat: Array1<f64> =
Array1::from_shape_fn(rho_hat_flat.len(), |i| rho_hat_flat[i] + mult * h * dir[i]);
let rho = self.baseline_rho.from_flat(flat.view());
let (cost, _loss) = term.reml_criterion(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
Ok(cost)
};
let plus_h = value_at(&mut probe_term, 1.0)?;
let minus_h = value_at(&mut probe_term, -1.0)?;
let plus_2h = value_at(&mut probe_term, 2.0)?;
let minus_2h = value_at(&mut probe_term, -2.0)?;
let well_posed = plus_h.is_finite()
&& minus_h.is_finite()
&& plus_2h.is_finite()
&& minus_2h.is_finite();
let samples = DirectionalSamples {
plus_h,
minus_h,
plus_2h,
minus_2h,
step: h,
grad_norm,
analytic_directional,
well_posed,
};
Ok(certificate_from_samples(&samples))
}
pub fn curvature_walk_report(&self) -> Option<&CurvatureWalkReport> {
self.term.curvature_walk_report()
}
pub fn decoder_shape_uncertainty(&mut self) -> Result<SaeShapeUncertainty, String> {
let rho = self.current_rho.clone();
let plan = self.term.streaming_plan().admitted_or_error(
self.term.n_obs(),
self.term.output_dim(),
self.term.k_atoms(),
)?;
if !plan.direct_logdet_admitted() {
let loss = self.term.loss(self.target.view(), &rho)?;
let n_scalar = (self.term.n_obs().saturating_mul(self.term.output_dim())).max(1) as f64;
let dispersion = (2.0 * loss.data_fit / n_scalar).max(f64::MIN_POSITIVE);
return Ok(self
.term
.shape_uncertainty_without_decoder_covariance(dispersion));
}
let (_cost, loss, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
let dispersion = self.term.reconstruction_dispersion(&loss, &cache, &rho)?;
self.term.assemble_shape_uncertainty(&cache, dispersion)
}
pub fn run_curvature_homotopy_entry(&mut self) -> Result<bool, String> {
let rho = self.baseline_rho.clone();
self.run_curvature_homotopy_entry_at_rho(&rho)
}
pub fn run_curvature_homotopy_entry_at_rho(
&mut self,
rho: &SaeManifoldRho,
) -> Result<bool, String> {
let rho = rho.clone();
self.current_rho = rho.clone();
let isometry_targets = self
.registry
.as_ref()
.map(AnalyticPenaltyRegistry::isometry_scalar_weights)
.unwrap_or_default();
self.set_isometry_homotopy_weight(0.0, &isometry_targets);
let anchor = match linear_span_anchor(&self.term, self.target.view()) {
Ok(anchor) => anchor,
Err(err) => {
log::info!(
"[#1007] curvature anchor degenerate ({err}); deferring to seed cascade"
);
self.set_isometry_homotopy_weight(1.0, &isometry_targets);
return Ok(false);
}
};
let anchor_residual_norm_sq = anchor.residual_norm_sq;
let (_loss0, mut last_cache) = match self.solve_at_eta(&rho, 0.0, &isometry_targets) {
Ok(pair) => pair,
Err(err) => {
log::info!(
"[#1007] curvature anchor solve failed at η=0 ({err}); deferring to cascade"
);
self.term.set_homotopy_eta(1.0).ok();
self.set_isometry_homotopy_weight(1.0, &isometry_targets);
return Ok(false);
}
};
let mut eta = 0.0_f64;
let mut eta_step = CURVATURE_WALK_INITIAL_ETA_STEP;
let mut eta_steps = 0usize;
let mut step_halvings = 0usize;
let mut total_correctors = 0usize;
let mut bifurcation: Option<CurvatureBifurcation> = None;
if isometry_targets.iter().all(|&target| target == 0.0)
&& self.term.curvature_homotopy_eta_is_inert()?
{
self.term.set_homotopy_eta(1.0)?;
eta = 1.0;
}
'walk: while eta < 1.0 {
let eta_next = (eta + eta_step).min(1.0);
let d_eta = eta_next - eta;
if let Ok(dg_beta) = self
.term
.curvature_beta_gradient_eta_derivative(self.target.view(), &rho)
&& dg_beta.len() == last_cache.k
{
let w_t = Array1::<f64>::zeros(last_cache.delta_t_len());
if let Ok((_u_t, u_beta)) =
last_cache.full_inverse_apply(w_t.view(), dg_beta.view())
{
let mut beta = self.term.flatten_beta();
if beta.len() == u_beta.len() {
for (b, u) in beta.iter_mut().zip(u_beta.iter()) {
*b -= u * d_eta;
}
if beta.iter().all(|v| v.is_finite()) {
self.term.set_flat_beta(beta.view()).ok();
}
}
}
}
let cache = match self.solve_at_eta(&rho, eta_next, &isometry_targets) {
Ok((_loss, cache)) => cache,
Err(err) => {
if eta_step <= CURVATURE_WALK_MIN_ETA_STEP {
log::info!(
"[#1007] curvature corrector failed at η={eta_next:.4} at the minimum \
η-step ({err}); recording branch bifurcation"
);
bifurcation = Some(CurvatureBifurcation {
eta: eta_next,
min_pivot: 0.0,
});
break 'walk;
}
eta_step *= 0.5;
step_halvings += 1;
self.term.set_homotopy_eta(eta).ok();
self.set_isometry_homotopy_weight(eta, &isometry_targets);
continue 'walk;
}
};
total_correctors += 1;
let pivot = arrow_factor_min_pivot(&cache).min_pivot.unwrap_or(0.0);
let diag_scale = arrow_factor_max_pivot(&cache).unwrap_or(1.0);
let floor = f64::EPSILON * diag_scale;
let pivot_deficit_is_gauge = !(pivot.is_finite() && pivot >= floor)
&& self.term.outer_gradient_arrow_solver(&cache).is_ok();
if !(pivot.is_finite() && pivot >= floor) && !pivot_deficit_is_gauge {
if eta_step > CURVATURE_WALK_MIN_ETA_STEP {
eta_step *= 0.5;
step_halvings += 1;
self.term.set_homotopy_eta(eta).ok();
self.set_isometry_homotopy_weight(eta, &isometry_targets);
continue 'walk;
}
log::info!(
"[#1007] curvature branch bifurcation at η={eta_next:.4}: min pivot \
{pivot:.3e} < floor {floor:.3e}; deferring to seed cascade"
);
bifurcation = Some(CurvatureBifurcation {
eta: eta_next,
min_pivot: pivot,
});
break 'walk;
}
eta = eta_next;
last_cache = cache;
eta_steps += 1;
eta_step = (eta_step * 2.0).min(CURVATURE_WALK_INITIAL_ETA_STEP);
if total_correctors >= CURVATURE_WALK_MAX_CORRECTORS && eta < 1.0 {
log::info!(
"[#1007] curvature walk hit its corrector budget at η={eta:.4}; deferring to \
seed cascade"
);
bifurcation = Some(CurvatureBifurcation {
eta,
min_pivot: pivot,
});
break 'walk;
}
}
let arrived = bifurcation.is_none() && eta >= 1.0;
if !arrived {
self.term.set_homotopy_eta(1.0).ok();
}
self.set_isometry_homotopy_weight(1.0, &isometry_targets);
if arrived
&& let Ok(before_fit) = self.term.try_fitted_for_rho(&rho)
&& let Some(before_ev) =
reconstruction_explained_variance(self.target.view(), before_fit.view())
&& before_ev < 0.9
{
let snapshot = self.term.snapshot_mutable_state();
let accepted_polish = self
.term
.refit_decoder_least_squares_at_current_state(self.target.view(), Some(&rho))
.and_then(|()| {
self.term
.seed_coords_by_decoder_projection(self.target.view(), 256)
})
.and_then(|()| {
self.term.refit_decoder_least_squares_at_current_state(
self.target.view(),
Some(&rho),
)
})
.and_then(|()| {
let after_fit = self.term.try_fitted_for_rho(&rho)?;
let Some(after_ev) =
reconstruction_explained_variance(self.target.view(), after_fit.view())
else {
return Err(
"curvature-homotopy decoder LSQ polish produced no EV".to_string()
);
};
if after_ev > before_ev {
self.term.loss(self.target.view(), &rho)
} else {
Err(format!(
"curvature-homotopy decoder LSQ polish refused: EV {after_ev:.6} \
did not improve from {before_ev:.6}"
))
}
});
match accepted_polish {
Ok(loss) => self.last_loss = Some(loss),
Err(_) => self.term.restore_mutable_state(&snapshot),
}
}
let collapse_events = self.term.collapse_events().len();
self.term.set_curvature_walk_report(CurvatureWalkReport {
arrived,
anchor_residual_norm_sq,
bifurcation,
eta_steps,
step_halvings,
collapse_events,
reseeds: 0,
});
Ok(arrived)
}
fn solve_at_eta(
&mut self,
rho: &SaeManifoldRho,
eta: f64,
isometry_targets: &[f64],
) -> Result<(SaeManifoldLoss, ArrowFactorCache), String> {
self.term.set_homotopy_eta(eta)?;
self.set_isometry_homotopy_weight(eta, isometry_targets);
let (_cost, loss, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
self.last_loss = Some(loss.clone());
Ok((loss, cache))
}
fn set_isometry_homotopy_weight(&mut self, eta: f64, targets: &[f64]) {
if targets.is_empty() {
return;
}
if let Some(registry) = self.registry.as_mut() {
let eta = eta.clamp(0.0, 1.0);
let weights: Vec<f64> = targets.iter().map(|target| eta * target).collect();
registry.set_isometry_scalar_weights(&weights);
}
}
fn add_fit_data_collapse_penalty(
&mut self,
cost: f64,
rho: &SaeManifoldRho,
) -> Result<f64, String> {
let fitted = self.term.try_fitted_for_rho(rho)?;
let assignments = self.term.assignment.assignments_for_rho(rho)?;
let collapsed = self.term.record_fit_data_collapse_if_needed(
self.target.view(),
fitted.view(),
assignments.view(),
self.inner_max_iter,
)?;
if collapsed {
Ok(cost + SAE_FIT_DATA_COLLAPSE_COST)
} else {
Ok(cost)
}
}
fn is_recoverable_value_probe_refusal(err: &str) -> bool {
err.contains("inner solve did not converge at fixed ρ")
|| err.contains(
"undamped evidence factorization hit a non-PD per-row H_tt block before KKT",
)
}
fn evaluate_with_refine_policy(
&mut self,
rho_flat: ArrayView1<'_, f64>,
refine_progress_extension: bool,
) -> Result<(f64, Array1<f64>), String> {
let rho = self.baseline_rho.from_flat(rho_flat);
if let Some(beta) = self.seeded_beta.take() {
if beta.len() == self.term.beta_dim() {
self.term.set_flat_beta(beta.view())?;
}
}
let (cost, loss) = self.term.reml_criterion_with_refine_policy(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
refine_progress_extension,
)?;
let beta_hat = self.term.flatten_beta();
let cost = self.add_fit_data_collapse_penalty(cost, &rho)?;
self.current_rho = rho;
self.last_loss = Some(loss);
Ok((cost, beta_hat))
}
fn efs_step(&mut self, rho_flat: ArrayView1<'_, f64>) -> Result<EfsEval, String> {
let rho = self.baseline_rho.from_flat(rho_flat);
if let Some(beta) = self.seeded_beta.take()
&& beta.len() == self.term.beta_dim()
{
self.term.set_flat_beta(beta.view())?;
}
let (cost, loss, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
self.current_rho = rho.clone();
let dispersion = self
.term
.reconstruction_dispersion(&loss, &cache, &rho)
.map_err(|e| format!("SaeManifoldOuterObjective::efs_step: dispersion: {e}"))?;
self.last_loss = Some(loss);
let n_obs = self.term.n_obs() as f64;
let sumsq = self.term.ard_coord_sumsq();
let traces = self
.term
.ard_inverse_traces(&cache)
.map_err(|e| format!("SaeManifoldOuterObjective::efs_step: ARD traces: {e}"))?;
let n_params = rho.to_flat().len();
let mut steps = vec![0.0_f64; n_params];
steps[0] = 0.0;
let lambda_smooth = rho.lambda_smooth();
let p_out = self.term.output_dim() as f64;
let mut smooth_rank_total = 0usize;
for atom in &self.term.atoms {
smooth_rank_total += SaeManifoldTerm::symmetric_rank(&atom.smooth_penalty)?;
}
let rank_total = p_out * (smooth_rank_total as f64);
let quad = self.term.decoder_smoothness_quadratic_form();
let eff_dof = self
.term
.decoder_smoothness_effective_dof(&cache, lambda_smooth)
.map_err(|e| format!("SaeManifoldOuterObjective::efs_step: smooth dof: {e}"))?;
if quad > 0.0 && rank_total - eff_dof > 0.0 && lambda_smooth > 0.0 {
let lambda_new = dispersion * (rank_total - eff_dof) / quad;
if lambda_new.is_finite() && lambda_new > 0.0 {
steps[1] = lambda_new.ln() - rho.log_lambda_smooth;
}
}
let mut cursor = 2usize;
for (k, axis_logard) in rho.log_ard.iter().enumerate() {
let d = axis_logard.len();
for j in 0..d {
let denom = sumsq[k][j] + traces[k][j];
if denom > 0.0 {
let alpha_new = dispersion * n_obs / denom;
if alpha_new.is_finite() && alpha_new > 0.0 {
steps[cursor + j] = alpha_new.ln() - axis_logard[j];
}
}
}
cursor += d;
}
let beta_hat = self.term.flatten_beta();
let cost = self.add_fit_data_collapse_penalty(cost, &rho)?;
Ok(EfsEval {
cost,
steps,
beta: Some(beta_hat),
psi_gradient: None,
psi_indices: None,
inner_hessian_scale: None,
logdet_enclosure_gap: None,
})
}
}
impl OuterObjective for SaeManifoldOuterObjective {
fn capability(&self) -> OuterCapability {
let plan = self.term.streaming_plan();
let gradient = if plan.direct_admitted {
Derivative::Analytic
} else {
Derivative::Unavailable
};
OuterCapability {
gradient,
hessian: DeclaredHessianForm::Unavailable,
n_params: self.baseline_rho.to_flat().len(),
psi_dim: 0,
fixed_point_available: false,
barrier_config: None,
prefer_gradient_only: false,
disable_fixed_point: true,
}
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
match self.evaluate_with_refine_policy(rho.view(), false) {
Ok((cost, _beta)) => Ok(cost),
Err(err) if Self::is_recoverable_value_probe_refusal(&err) => Ok(f64::INFINITY),
Err(err) => Err(EstimationError::RemlOptimizationFailed(err)),
}
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
let rho_state = self.baseline_rho.from_flat(rho.view());
if let Some(beta) = self.seeded_beta.take()
&& beta.len() == self.term.beta_dim()
{
self.term
.set_flat_beta(beta.view())
.map_err(EstimationError::RemlOptimizationFailed)?;
}
let (cost, loss, cache) = self
.term
.reml_criterion_with_cache(
self.target.view(),
&rho_state,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)
.map_err(EstimationError::RemlOptimizationFailed)?;
let solver = self
.term
.outer_gradient_arrow_solver(&cache)
.map_err(EstimationError::RemlOptimizationFailed)?;
let components = self
.term
.analytic_outer_rho_gradient_components(
self.target.view(),
&rho_state,
&loss,
&cache,
&solver,
)
.map_err(EstimationError::RemlOptimizationFailed)?;
let gradient = components.gradient_with_available_correction();
let beta_hat = self.term.flatten_beta();
let cost = self
.add_fit_data_collapse_penalty(cost, &rho_state)
.map_err(EstimationError::RemlOptimizationFailed)?;
self.current_rho = rho_state;
self.last_loss = Some(loss);
Ok(OuterEval {
cost,
gradient,
hessian: HessianResult::Unavailable,
inner_beta_hint: Some(beta_hat),
})
}
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
match order {
OuterEvalOrder::Value => {
let (cost, _beta_hat) = match self.evaluate_with_refine_policy(rho.view(), false) {
Ok(evaluated) => evaluated,
Err(err) if Self::is_recoverable_value_probe_refusal(&err) => {
return Ok(OuterEval::infeasible(rho.len()));
}
Err(err) => return Err(EstimationError::RemlOptimizationFailed(err)),
};
Ok(OuterEval {
cost,
gradient: Array1::zeros(rho.len()),
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
}
OuterEvalOrder::ValueAndGradient | OuterEvalOrder::ValueGradientHessian => {
self.eval(rho)
}
}
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
self.efs_step(rho.view())
.map_err(EstimationError::RemlOptimizationFailed)
}
fn reset(&mut self) {
self.term = self.baseline_term.clone();
self.current_rho = self.baseline_rho.clone();
self.last_loss = None;
self.seeded_beta = None;
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
if beta.is_empty() {
return Ok(SeedOutcome::NoSlot);
}
if beta.len() != self.term.beta_dim() {
return Err(EstimationError::RemlOptimizationFailed(format!(
"SaeManifoldOuterObjective::seed_inner_state: β length {} != decoder dim {}",
beta.len(),
self.term.beta_dim()
)));
}
self.seeded_beta = Some(beta.clone());
Ok(SeedOutcome::Installed)
}
fn requires_continuation_path_entry(&self) -> bool {
if self.term.k_atoms() >= 2 {
return true;
}
self.term.atoms.iter().any(|atom| {
matches!(
atom.basis_kind,
SaeAtomBasisKind::Duchon
| SaeAtomBasisKind::EuclideanPatch
| SaeAtomBasisKind::Poincare
)
})
}
fn curvature_homotopy_entry(
&mut self,
rho: &Array1<f64>,
) -> Option<Result<bool, EstimationError>> {
let rho_state = self.baseline_rho.from_flat(rho.view());
Some(
self.run_curvature_homotopy_entry_at_rho(&rho_state)
.map_err(EstimationError::RemlOptimizationFailed),
)
}
}
fn sae_manifold_newton_directional_decrease(
sys: &ArrowSchurSystem,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_ext_coord.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut gradient_dot_step = 0.0;
for (row_idx, row) in sys.rows.iter().enumerate() {
let row_base = sys.row_offsets[row_idx];
let di = sys.row_dims[row_idx];
for axis in 0..di {
gradient_dot_step += row.gt[axis] * delta_ext_coord[row_base + axis];
}
}
for idx in 0..sys.k {
gradient_dot_step += sys.gb[idx] * delta_beta[idx];
}
-gradient_dot_step
}
fn batched_smooth_sb(
sb_inputs: &[(ArrayView2<'_, f64>, ArrayView2<'_, f64>)],
symmetrize: bool,
) -> Vec<Array2<f64>> {
let n_atoms = sb_inputs.len();
let s_mats: Vec<Array2<f64>> = sb_inputs
.iter()
.map(|(s, _)| {
if symmetrize {
let m = s.nrows();
let mut sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
sym[[i, j]] = 0.5 * (s[[i, j]] + s[[j, i]]);
}
}
sym
} else {
s.to_owned()
}
})
.collect();
let cpu_one = |idx: usize| -> Array2<f64> { s_mats[idx].dot(&sb_inputs[idx].1) };
let rt = match crate::gpu::runtime::GpuRuntime::global() {
Some(rt) => rt,
None => return (0..n_atoms).map(cpu_one).collect(),
};
let mut groups: std::collections::BTreeMap<(usize, usize), Vec<usize>> =
std::collections::BTreeMap::new();
for (idx, (_, b)) in sb_inputs.iter().enumerate() {
let m = s_mats[idx].nrows();
let p = b.ncols();
groups.entry((m, p)).or_default().push(idx);
}
let mut out: Vec<Option<Array2<f64>>> = (0..n_atoms).map(|_| None).collect();
for ((m, p), members) in groups {
if members.len() < 2 || m == 0 || p == 0 {
for &idx in &members {
out[idx] = Some(cpu_one(idx));
}
continue;
}
let mut items: Vec<usize> = members.clone();
let s_ref = &s_mats;
let tile_results: std::sync::Mutex<Vec<(usize, Array2<f64>)>> =
std::sync::Mutex::new(Vec::with_capacity(members.len()));
let ok = crate::gpu::pool::scatter_batched(rt, &mut items, |_ordinal, slice| {
if slice.is_empty() {
return Some(());
}
let batch = slice.len();
let mut a = Array3::<f64>::zeros((batch, m, m));
let mut bt = Array3::<f64>::zeros((batch, p, m));
for (t, &idx) in slice.iter().enumerate() {
let s = &s_ref[idx];
let b = &sb_inputs[idx].1;
for i in 0..m {
for j in 0..m {
a[[t, i, j]] = s[[i, j]];
}
}
for i in 0..p {
for j in 0..m {
bt[[t, i, j]] = b[[j, i]];
}
}
}
let prod = crate::gpu::try_fast_abt_strided_batched(a.view(), bt.view())?;
let mut sink = tile_results.lock().expect("tile_results mutex poisoned");
for (t, &idx) in slice.iter().enumerate() {
sink.push((idx, prod.slice(s![t, .., ..]).to_owned()));
}
Some(())
});
match ok {
Some(()) => {
let sink = tile_results
.into_inner()
.expect("tile_results mutex poisoned");
for (idx, mat) in sink {
out[idx] = Some(mat);
}
for &idx in &members {
if out[idx].is_none() {
out[idx] = Some(cpu_one(idx));
}
}
}
None => {
for &idx in &members {
out[idx] = Some(cpu_one(idx));
}
}
}
}
out.into_iter()
.enumerate()
.map(|(idx, slot)| slot.unwrap_or_else(|| cpu_one(idx)))
.collect()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CurvatureBifurcation {
pub eta: f64,
pub min_pivot: f64,
}
#[derive(Debug, Clone)]
pub struct CurvatureWalkReport {
pub arrived: bool,
pub anchor_residual_norm_sq: f64,
pub bifurcation: Option<CurvatureBifurcation>,
pub eta_steps: usize,
pub step_halvings: usize,
pub collapse_events: usize,
pub reseeds: usize,
}
#[derive(Debug, Clone)]
pub struct LinearSpanAtomAnchor {
pub gate_weight: f64,
pub frame: GrassmannFrame,
pub decoder_coordinates: Array2<f64>,
pub singular_values: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct LinearSpanAnchor {
pub atoms: Vec<LinearSpanAtomAnchor>,
pub reconstruction: Array2<f64>,
pub residual_norm_sq: f64,
}
pub fn linear_span_anchor(
term: &SaeManifoldTerm,
targets: ArrayView2<'_, f64>,
) -> Result<LinearSpanAnchor, String> {
let n = term.n_obs();
let p = term.output_dim();
if targets.dim() != (n, p) {
return Err(format!(
"linear_span_anchor: targets shape {:?} != ({n}, {p})",
targets.dim()
));
}
if term.k_atoms() == 0 {
return Err("linear_span_anchor: term must contain at least one atom".into());
}
if !targets.iter().all(|v| v.is_finite()) {
return Err("linear_span_anchor: targets must be finite".into());
}
let gates = neutral_gate_weights(term.assignment.mode, term.k_atoms());
let mut residual = targets.to_owned();
let mut reconstruction = Array2::<f64>::zeros((n, p));
let mut atoms = Vec::with_capacity(term.k_atoms());
for (atom_idx, atom) in term.atoms.iter().enumerate() {
let gate = gates[atom_idx];
if !(gate.is_finite() && gate > 0.0) {
return Err(format!(
"linear_span_anchor: neutral gate for atom {atom_idx} must be positive finite; got {gate}"
));
}
let requested_rank = atom.basis_size().min(n).min(p);
if requested_rank == 0 {
return Err(format!(
"linear_span_anchor: atom {atom_idx} has no recoverable linear span rank"
));
}
let weighted = residual.mapv(|v| gate * v);
let (_u_opt, singular_values_full, vt_opt) = weighted
.svd(false, true)
.map_err(|err| format!("linear_span_anchor: SVD failed for atom {atom_idx}: {err}"))?;
let vt = vt_opt.ok_or_else(|| {
format!("linear_span_anchor: SVD returned no right factor for atom {atom_idx}")
})?;
let rank = requested_rank
.min(vt.nrows())
.min(singular_values_full.len());
if rank == 0 {
return Err(format!(
"linear_span_anchor: atom {atom_idx} SVD returned rank zero"
));
}
let mut frame = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
let singular_values = singular_values_full.slice(s![..rank]).to_owned();
let frame = GrassmannFrame::from_oriented(frame, singular_values.clone());
let frame_matrix = frame.frame().to_owned();
let mut coordinates = residual.dot(&frame_matrix);
coordinates.mapv_inplace(|v| v / gate);
let contribution = fast_abt(&coordinates, &frame_matrix).mapv(|v| gate * v);
reconstruction += &contribution;
residual -= &contribution;
atoms.push(LinearSpanAtomAnchor {
gate_weight: gate,
frame,
decoder_coordinates: coordinates,
singular_values,
});
}
let residual_norm_sq = residual.iter().map(|v| v * v).sum();
Ok(LinearSpanAnchor {
atoms,
reconstruction,
residual_norm_sq,
})
}
fn sae_cholesky_solve_neg_gradient(
h: ArrayView2<'_, f64>,
g: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, String> {
let n = h.nrows();
if h.ncols() != n || g.len() != n {
return Err(format!(
"sae_cholesky_solve_neg_gradient: shape mismatch H={:?}, g={}",
h.dim(),
g.len()
));
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = h[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if !(sum.is_finite() && sum > 0.0) {
return Err(format!("non-positive Cholesky pivot at {i}: {sum}"));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = -g[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for ii in 0..n {
let i = n - 1 - ii;
let mut sum = y[i];
for k in i + 1..n {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
if !x.iter().all(|v| v.is_finite()) {
return Err("sae_cholesky_solve_neg_gradient: non-finite solution".into());
}
Ok(x)
}
fn solve_basis_transport(
new_phi: ArrayView2<'_, f64>,
old_phi: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
solve_design_least_squares(new_phi, old_phi)
}
fn transport_smooth_penalty_for_decoder(
decoder_transport: ArrayView2<'_, f64>,
old_smooth_penalty: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let m = decoder_transport.nrows();
if decoder_transport.ncols() != m {
return Err(format!(
"transport_smooth_penalty_for_decoder: decoder transport must be square; got {:?}",
decoder_transport.dim()
));
}
if old_smooth_penalty.dim() != (m, m) {
return Err(format!(
"transport_smooth_penalty_for_decoder: smooth penalty shape {:?} != ({m}, {m})",
old_smooth_penalty.dim()
));
}
let transport_inverse =
solve_design_least_squares(decoder_transport, Array2::<f64>::eye(m).view())?;
Ok(fast_atb(
&transport_inverse,
&fast_ab(&old_smooth_penalty.to_owned(), &transport_inverse),
))
}
pub(crate) fn solve_design_least_squares(
design: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
if design.nrows() != rhs.nrows() {
return Err(format!(
"solve_design_least_squares: row mismatch design={} rhs={}",
design.nrows(),
rhs.nrows()
));
}
let (u_opt, sigma, vt_opt) = design
.to_owned()
.svd(true, true)
.map_err(|err| format!("solve_design_least_squares: SVD failed: {err}"))?;
let u = u_opt.ok_or_else(|| "solve_design_least_squares: SVD omitted U".to_string())?;
let vt = vt_opt.ok_or_else(|| "solve_design_least_squares: SVD omitted Vt".to_string())?;
let smax = sigma.iter().fold(0.0_f64, |acc, &v| acc.max(v));
if !(smax.is_finite() && smax > 0.0) {
return Err("solve_design_least_squares: design has zero numerical rank".to_string());
}
let cutoff = smax * f64::EPSILON * (design.nrows().max(design.ncols()) as f64);
let coeffs = u.t().dot(&rhs);
let mut scaled = Array2::<f64>::zeros(coeffs.dim());
for row in 0..sigma.len() {
if sigma[row] > cutoff {
let inv = 1.0 / sigma[row];
for col in 0..rhs.ncols() {
scaled[[row, col]] = inv * coeffs[[row, col]];
}
}
}
Ok(vt.t().dot(&scaled))
}
fn sae_coord_penalty_offset(
row_layout: Option<&SaeRowLayout>,
dense_off: usize,
row: usize,
atom_idx: usize,
) -> Option<usize> {
match row_layout {
Some(layout) => {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
active
.iter()
.zip(starts.iter())
.find_map(|(&active_atom, &coord_start)| {
if active_atom == atom_idx {
Some(coord_start)
} else {
None
}
})
}
None => Some(dense_off),
}
}
fn sae_penalty_is_row_block_supported(penalty: &AnalyticPenaltyKind) -> bool {
matches!(
penalty,
AnalyticPenaltyKind::Ard(_)
| AnalyticPenaltyKind::TopKActivation(_)
| AnalyticPenaltyKind::JumpReLU(_)
| AnalyticPenaltyKind::Sparsity(_)
| AnalyticPenaltyKind::RowPrecisionPrior(_)
| AnalyticPenaltyKind::ParametricRowPrecisionPrior(_)
| AnalyticPenaltyKind::ScadMcp(_)
| AnalyticPenaltyKind::BlockOrthogonality(_)
| AnalyticPenaltyKind::Isometry(_)
)
}
fn sae_coord_penalty_is_origin_anchored_magnitude(penalty: &AnalyticPenaltyKind) -> bool {
matches!(penalty, AnalyticPenaltyKind::ScadMcp(_))
}
fn sae_coord_penalty_euclidean_restriction(
coord: &LatentCoordValues,
) -> Option<(Vec<usize>, Array1<f64>)> {
let periods = coord.effective_axis_periods();
let d = periods.len();
let euclidean_axes: Vec<usize> = (0..d).filter(|&axis| periods[axis].is_none()).collect();
if euclidean_axes.len() == d {
return None;
}
let n = coord.n_obs();
let de = euclidean_axes.len();
let flat = coord.as_flat();
let mut compacted = Array1::<f64>::zeros(n * de);
for row in 0..n {
for (j, &axis) in euclidean_axes.iter().enumerate() {
compacted[row * de + j] = flat[row * d + axis];
}
}
Some((euclidean_axes, compacted))
}
pub fn sae_row_block_penalty_kinds() -> &'static [&'static str] {
&[
"ard",
"top_k_activation",
"jumprelu",
"sparsity",
"row_precision_prior",
"parametric_row_precision_prior",
"scad_mcp",
"block_orthogonality",
"isometry",
]
}
#[must_use = "build error must be handled"]
pub fn term_from_padded_blocks_with_mode(
n_obs: usize,
p_out: usize,
basis_kinds: &[SaeAtomBasisKind],
basis_values: ArrayView3<'_, f64>,
basis_jacobian: ArrayView4<'_, f64>,
basis_sizes: &[usize],
latent_dims: &[usize],
decoder_coefficients: ArrayView3<'_, f64>,
smooth_penalties: ArrayView3<'_, f64>,
logits: ArrayView2<'_, f64>,
coords: &[Array2<f64>],
mode: AssignmentMode,
evaluators: &[Option<Arc<dyn SaeBasisEvaluator>>],
) -> Result<SaeManifoldTerm, String> {
let k_atoms = basis_sizes.len();
if latent_dims.len() != k_atoms || basis_kinds.len() != k_atoms || coords.len() != k_atoms {
return Err("term_from_padded_blocks: K-length metadata mismatch".into());
}
if !evaluators.is_empty() && evaluators.len() != k_atoms {
return Err(format!(
"term_from_padded_blocks: evaluators length {} must equal K={k_atoms} or be empty",
evaluators.len()
));
}
if logits.dim() != (n_obs, k_atoms) {
return Err(format!(
"term_from_padded_blocks: logits must be ({n_obs}, {k_atoms}); got {:?}",
logits.dim()
));
}
let mut atoms = Vec::with_capacity(k_atoms);
for k in 0..k_atoms {
let m = basis_sizes[k];
let d = latent_dims[k];
let phi = basis_values.slice(s![k, 0..n_obs, 0..m]).to_owned();
let jet = basis_jacobian.slice(s![k, 0..n_obs, 0..m, 0..d]).to_owned();
let b = decoder_coefficients.slice(s![k, 0..m, 0..p_out]).to_owned();
let s = smooth_penalties.slice(s![k, 0..m, 0..m]).to_owned();
let atom = SaeManifoldAtom::new(
format!("atom_{k}"),
basis_kinds[k].clone(),
d,
phi,
jet,
b,
s,
)?;
let atom = match evaluators.get(k).and_then(|slot| slot.clone()) {
Some(evaluator) => atom.with_basis_evaluator(evaluator),
None => atom,
};
atoms.push(atom);
}
let manifolds = basis_kinds
.iter()
.zip(latent_dims.iter().copied())
.map(|(kind, d)| kind.latent_manifold(d))
.collect();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.to_owned(),
coords.to_vec(),
manifolds,
mode,
)?;
SaeManifoldTerm::new(atoms, assignment)
}
pub fn refresh_isometry_caches_from_atom(
penalty: &IsometryPenalty,
atom: &SaeManifoldAtom,
coords: ArrayView2<'_, f64>,
) -> Result<bool, String> {
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"refresh_isometry_caches_from_atom: atom {} has no basis evaluator",
atom.name
)
})?;
let (_phi, jet) = evaluator.evaluate(coords)?;
let n_obs = coords.nrows();
let d = atom.latent_dim;
let m = atom.basis_size();
let p = atom.decoder_coefficients.ncols();
if penalty.p_out != p {
return Err(format!(
"refresh_isometry_caches_from_atom: penalty.p_out={} but atom.decoder.cols={p}",
penalty.p_out
));
}
if jet.dim() != (n_obs, m, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator first jet has shape {:?}, expected ({n_obs}, {m}, {d})",
jet.dim()
));
}
let b = &atom.decoder_coefficients;
let mut jac = Array2::<f64>::zeros((n_obs, p * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += jet[[n, mm, a]] * b[[mm, i]];
}
jac[[n, i * d + a]] = acc;
}
}
}
let jac2_opt = if let Some(second_eval) = atom.basis_second_jet.as_ref() {
let hess = second_eval.second_jet(coords)?;
if hess.dim() != (n_obs, m, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator second jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d})",
hess.dim()
));
}
let mut jac2 = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += hess[[n, mm, a, c]] * b[[mm, i]];
}
jac2[[n, (i * d + a) * d + c]] = acc;
}
}
}
}
Some(Arc::new(jac2))
} else {
None
};
let jac3_opt = if penalty.duchon_radial_source.is_none() {
match evaluator.third_jet_dyn(coords) {
Some(third) => {
let t3 = third?;
if t3.dim() != (n_obs, m, d, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator third jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d}, {d})",
t3.dim()
));
}
let mut jac3 = Array3::<f64>::zeros((n_obs, p, d * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
for e in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += t3[[n, mm, a, c, e]] * b[[mm, i]];
}
jac3[[n, i, ((a * d) + c) * d + e]] = acc;
}
}
}
}
}
Some(Arc::new(jac3))
}
None => None,
}
} else {
None
};
let installed = jac2_opt.is_some();
penalty.refresh_caches(Some(Arc::new(jac)), jac2_opt);
penalty.set_third_decoder_derivative(jac3_opt);
Ok(installed)
}
pub fn refresh_isometry_caches_from_term(
registry: &AnalyticPenaltyRegistry,
term: &SaeManifoldTerm,
coords_per_atom: &[Array2<f64>],
) -> Result<usize, String> {
if coords_per_atom.len() != term.atoms.len() {
return Err(format!(
"refresh_isometry_caches_from_term: coords_per_atom length {} != number of atoms {}",
coords_per_atom.len(),
term.atoms.len()
));
}
let mut refreshed_with_second = 0usize;
let mut consumed_per_signature: std::collections::HashMap<(usize, usize), usize> =
std::collections::HashMap::new();
for entry in registry.penalties.iter() {
let AnalyticPenaltyKind::Isometry(p) = entry else {
continue;
};
let Some(p_latent_dim) = p.target.latent_dim else {
continue;
};
let signature = (p_latent_dim, p.p_out);
let already_consumed = consumed_per_signature.entry(signature).or_insert(0);
let mut seen = 0usize;
let mut paired: Option<usize> = None;
for (atom_idx, atom) in term.atoms.iter().enumerate() {
let matches = atom.latent_dim == p_latent_dim
&& atom.decoder_coefficients.ncols() == p.p_out
&& atom.basis_evaluator.is_some();
if !matches {
continue;
}
if seen == *already_consumed {
paired = Some(atom_idx);
break;
}
seen += 1;
}
let Some(atom_idx) = paired else {
continue;
};
*already_consumed += 1;
let atom = &term.atoms[atom_idx];
let coords = coords_per_atom[atom_idx].view();
if refresh_isometry_caches_from_atom(p, atom, coords)? {
refreshed_with_second += 1;
}
}
Ok(refreshed_with_second)
}