use ndarray::{Array1, ArrayView2};
use crate::solver::estimate::reml::continuation::{
ContinuationFailure, ContinuationState, PATH_BUDGET, continue_path_from, fit_with_continuation,
};
use crate::solver::rho_optimizer::{OuterEvalOrder, OuterObjective};
use crate::terms::analytic_penalties::ScalarWeightSchedule;
use crate::terms::sae::manifold::{GumbelTemperatureSchedule, ScheduleKind};
pub const CONTINUATION_WAYPOINTS: usize = 8;
pub const REENTRY_BACKOFF: f64 = 1.0 / CONTINUATION_WAYPOINTS as f64;
pub const CONTINUATION_WALK_BUDGET: usize = 2 * CONTINUATION_WAYPOINTS;
pub const WARM_LEG_EVAL_BUDGET: usize = 8;
pub const WALK_EVAL_CEILING: usize =
3 * (PATH_BUDGET + CONTINUATION_WAYPOINTS * WARM_LEG_EVAL_BUDGET);
pub const S_STEP_FLOOR: f64 = 1.0 / 256.0;
#[derive(Debug, Clone, Copy)]
pub struct LegEndpoints {
pub at_entry: f64,
pub at_target: f64,
}
impl LegEndpoints {
#[must_use]
pub fn new(at_entry: f64, at_target: f64) -> Self {
Self {
at_entry,
at_target,
}
}
#[must_use]
pub fn at(&self, s: f64) -> f64 {
let s = s.clamp(0.0, 1.0);
self.at_target + s * (self.at_entry - self.at_target)
}
}
#[derive(Debug, Clone)]
pub struct CoupledSchedules {
pub rho_entry: Array1<f64>,
pub rho_target: Array1<f64>,
pub rho_bounds_upper: Array1<f64>,
pub temperature: GumbelTemperatureSchedule,
pub isometry: ScalarWeightSchedule,
}
impl CoupledSchedules {
#[must_use]
pub fn temperature_endpoints(&self) -> LegEndpoints {
LegEndpoints::new(self.temperature.tau_start, self.temperature.tau_min)
}
#[must_use]
pub fn isometry_endpoints(&self) -> LegEndpoints {
LegEndpoints::new(self.isometry.w_start, self.isometry.w_end)
}
#[must_use]
pub fn scalar_targets_at(&self, s: f64) -> ScalarLegTargets {
ScalarLegTargets {
tau: self.temperature_endpoints().at(s),
isometry_weight: self.isometry_endpoints().at(s),
}
}
#[must_use]
pub fn rho_target_at(&self, s: f64) -> Array1<f64> {
assert_eq!(
self.rho_entry.len(),
self.rho_target.len(),
"ContinuationPath: ρ entry/target dimension mismatch"
);
let s = s.clamp(0.0, 1.0);
let mut out = self.rho_target.clone();
for i in 0..out.len() {
out[i] = self.rho_target[i] + s * (self.rho_entry[i] - self.rho_target[i]);
}
out
}
}
#[derive(Debug, Clone, Copy)]
pub struct ScalarLegTargets {
pub tau: f64,
pub isometry_weight: f64,
}
#[derive(Debug, Clone, Copy)]
pub struct LogitTrustRegion {
pub radius: f64,
}
#[derive(Debug, Clone, Copy)]
pub enum LogitStepCap {
Within,
Scaled { scale: f64 },
}
impl LogitTrustRegion {
#[must_use]
pub fn for_tau(tau: f64) -> Self {
const LOGIT_TR_TAU_GAIN: f64 = 4.0;
const LOGIT_TR_MIN: f64 = 1.0e-2;
const LOGIT_TR_MAX: f64 = 8.0;
let radius = (tau * LOGIT_TR_TAU_GAIN).clamp(LOGIT_TR_MIN, LOGIT_TR_MAX);
Self { radius }
}
#[must_use]
pub fn cap_step(&self, step_inf_norm: f64) -> LogitStepCap {
if !step_inf_norm.is_finite() || step_inf_norm <= self.radius || step_inf_norm == 0.0 {
LogitStepCap::Within
} else {
LogitStepCap::Scaled {
scale: self.radius / step_inf_norm,
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ActiveMassFloor {
pub floor: f64,
}
impl ActiveMassFloor {
pub const DEFAULT_FLOOR: f64 = 0.1;
#[must_use]
pub fn default_floor() -> Self {
Self {
floor: Self::DEFAULT_FLOOR,
}
}
#[must_use]
pub fn check(&self, mean_active_mass: f64) -> Option<MassFloorBreach> {
if mean_active_mass.is_finite() && mean_active_mass >= self.floor {
None
} else {
Some(MassFloorBreach {
observed_mean_mass: mean_active_mass,
floor: self.floor,
})
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MassFloorBreach {
pub observed_mean_mass: f64,
pub floor: f64,
}
#[derive(Debug, Clone, Default)]
pub struct ReseedLedger {
entries: Vec<ReseedEvent>,
}
#[derive(Debug, Clone, Copy)]
pub struct ReseedEvent {
pub s: f64,
pub breach: MassFloorBreach,
}
impl ReseedLedger {
#[must_use]
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn record(&mut self, s: f64, breach: MassFloorBreach) {
self.entries.push(ReseedEvent { s, breach });
}
#[must_use]
pub fn reseed_count(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn events(&self) -> &[ReseedEvent] {
&self.entries
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathRegime {
Target,
Annealing,
Heavy,
}
impl PathRegime {
#[must_use]
fn from_s(s: f64) -> Self {
let s = s.clamp(0.0, 1.0);
if s > 0.75 {
PathRegime::Heavy
} else if s > 0.25 {
PathRegime::Annealing
} else {
PathRegime::Target
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathDemotionReason {
UniformStructural,
PrewarmStructural,
}
#[derive(Debug, Clone)]
pub(crate) enum ContinuationStep {
Descended { s: f64, state: ContinuationState },
Arrived { state: ContinuationState },
Reentered { s: f64, reason: ReentryReason },
}
#[derive(Debug, Clone)]
pub(crate) enum ReentryReason {
SpineStruggled(ContinuationFailure),
MassFloorBreached(MassFloorBreach),
StepUnderflow,
}
#[derive(Debug, Clone)]
pub struct ContinuationPath {
schedules: CoupledSchedules,
s: f64,
s_step: f64,
logit_tr: LogitTrustRegion,
mass_floor: ActiveMassFloor,
reseed_ledger: ReseedLedger,
warm: Option<ContinuationState>,
evals_budgeted: usize,
}
impl ContinuationPath {
#[must_use]
pub fn enter(schedules: CoupledSchedules) -> Self {
let entry_targets = schedules.scalar_targets_at(1.0);
let logit_tr = LogitTrustRegion::for_tau(entry_targets.tau);
Self {
schedules,
s: 1.0,
s_step: 1.0 / CONTINUATION_WAYPOINTS as f64,
logit_tr,
mass_floor: ActiveMassFloor::default_floor(),
reseed_ledger: ReseedLedger::new(),
warm: None,
evals_budgeted: 0,
}
}
#[must_use]
pub fn heavy_entry() -> Self {
Self::enter(default_coupled_schedules())
}
#[must_use]
pub fn heavy_entry_for_rho(rho_target: Array1<f64>, bounds_upper: Array1<f64>) -> Self {
assert_eq!(
rho_target.len(),
bounds_upper.len(),
"ContinuationPath::heavy_entry_for_rho: ρ target/bounds dim mismatch"
);
let schedules = couple_schedules(
rho_target.clone(),
rho_target,
bounds_upper,
default_temperature_schedule(),
default_isometry_schedule(),
);
Self::enter(schedules)
}
#[must_use]
pub fn enter_regime(&self) -> PathRegime {
PathRegime::from_s(self.s)
}
pub fn demote_with_reason(&mut self, reason: PathDemotionReason) -> PathRegime {
match reason {
PathDemotionReason::UniformStructural | PathDemotionReason::PrewarmStructural => {
self.reenter_heavier();
}
}
self.enter_regime()
}
#[must_use]
pub fn logit_step_radius(&self) -> f64 {
self.logit_tr.radius
}
pub fn note_active_mass_breach(&mut self) -> PathRegime {
let breach = MassFloorBreach {
observed_mean_mass: self.mass_floor.floor,
floor: self.mass_floor.floor,
};
let mut owned = std::mem::take(&mut self.reseed_ledger);
let step = self.note_mass_breach(breach, &mut owned);
self.reseed_ledger = owned;
match step {
ContinuationStep::Reentered { .. }
| ContinuationStep::Descended { .. }
| ContinuationStep::Arrived { .. } => self.enter_regime(),
}
}
#[must_use]
pub fn reseed_count(&self) -> usize {
self.reseed_ledger.reseed_count()
}
#[must_use]
pub fn s(&self) -> f64 {
self.s
}
#[must_use]
pub fn current_scalar_targets(&self) -> ScalarLegTargets {
self.schedules.scalar_targets_at(self.s)
}
#[must_use]
pub fn current_rho_target(&self) -> Array1<f64> {
self.schedules.rho_target_at(self.s)
}
#[must_use]
pub fn logit_trust_region(&self) -> LogitTrustRegion {
self.logit_tr
}
#[must_use]
pub fn active_mass_floor(&self) -> ActiveMassFloor {
self.mass_floor
}
pub(crate) fn note_mass_breach(
&mut self,
breach: MassFloorBreach,
ledger: &mut ReseedLedger,
) -> ContinuationStep {
ledger.record(self.s, breach);
self.reenter_heavier();
ContinuationStep::Reentered {
s: self.s,
reason: ReentryReason::MassFloorBreached(breach),
}
}
fn reenter_heavier(&mut self) {
self.s = (self.s + REENTRY_BACKOFF).min(1.0);
self.s_step = (self.s_step * 0.5).max(S_STEP_FLOOR);
self.logit_tr = LogitTrustRegion::for_tau(self.schedules.scalar_targets_at(self.s).tau);
}
#[must_use]
pub fn arrived(&self) -> bool {
self.s <= 0.0
}
pub(crate) fn step(
&mut self,
obj: &mut dyn OuterObjective,
initial_beta: &Array1<f64>,
) -> ContinuationStep {
if self.evals_budgeted >= WALK_EVAL_CEILING {
if let Some(state) = self.warm.clone() {
log::warn!(
"[PATH] walk eval ceiling {WALK_EVAL_CEILING} reached at s={:.4}; arriving \
with the best converged waypoint state (scalar legs advanced to target)",
self.s
);
self.advance_scalar_legs_to(0.0);
self.s = 0.0;
return ContinuationStep::Arrived { state };
}
}
if self.s_step < S_STEP_FLOOR {
self.reenter_heavier();
return ContinuationStep::Reentered {
s: self.s,
reason: ReentryReason::StepUnderflow,
};
}
let s_next = (self.s - self.s_step).max(0.0);
self.advance_scalar_legs_to(s_next);
let rho_target = self.schedules.rho_target_at(s_next);
let spine = match self.warm.clone() {
Some(start) => {
self.evals_budgeted += WARM_LEG_EVAL_BUDGET;
continue_path_from(
obj,
start,
&rho_target,
OuterEvalOrder::ValueAndGradient,
WARM_LEG_EVAL_BUDGET,
)
}
None => {
self.evals_budgeted += PATH_BUDGET;
fit_with_continuation(
obj,
&rho_target,
&self.schedules.rho_bounds_upper,
initial_beta,
OuterEvalOrder::ValueAndGradient,
)
}
};
match spine {
Ok(state) => {
self.warm = Some(state.clone());
self.s = s_next;
self.s_step = (1.0 / CONTINUATION_WAYPOINTS as f64).min(self.s.max(S_STEP_FLOOR));
self.logit_tr =
LogitTrustRegion::for_tau(self.schedules.scalar_targets_at(self.s).tau);
if self.s <= 0.0 {
ContinuationStep::Arrived { state }
} else {
ContinuationStep::Descended { s: self.s, state }
}
}
Err(failure) => {
self.reenter_heavier();
ContinuationStep::Reentered {
s: self.s,
reason: ReentryReason::SpineStruggled(failure),
}
}
}
}
fn advance_scalar_legs_to(&mut self, s_next: f64) {
let targets = self.schedules.scalar_targets_at(s_next);
Self::advance_temperature_to(&mut self.schedules.temperature, targets.tau);
Self::advance_isometry_to(&mut self.schedules.isometry, targets.isometry_weight);
self.logit_tr = LogitTrustRegion::for_tau(targets.tau);
}
fn advance_temperature_to(schedule: &mut GumbelTemperatureSchedule, target_tau: f64) {
let max_scan = temperature_scan_budget(schedule);
let mut scanned = 0;
while scanned < max_scan && schedule.current_tau(schedule.iter_count) > target_tau {
schedule.iter_count += 1;
scanned += 1;
}
}
fn advance_isometry_to(schedule: &mut ScalarWeightSchedule, target_weight: f64) {
let max_scan = isometry_scan_budget(schedule);
let mut scanned = 0;
while scanned < max_scan && schedule.current_weight(schedule.iter_count) < target_weight {
schedule.iter_count += 1;
scanned += 1;
}
}
}
fn temperature_scan_budget(schedule: &GumbelTemperatureSchedule) -> usize {
const GEOMETRIC_SCAN_CAP: usize = 4096;
match &schedule.decay {
ScheduleKind::Linear { steps } => *steps + 1,
ScheduleKind::Geometric { .. } | ScheduleKind::ReciprocalIter => GEOMETRIC_SCAN_CAP,
}
}
fn isometry_scan_budget(schedule: &ScalarWeightSchedule) -> usize {
const GEOMETRIC_SCAN_CAP: usize = 4096;
match &schedule.kind {
ScheduleKind::Linear { steps } => *steps + 1,
ScheduleKind::Geometric { .. } | ScheduleKind::ReciprocalIter => GEOMETRIC_SCAN_CAP,
}
}
#[must_use]
pub fn couple_schedules(
rho_entry: Array1<f64>,
rho_target: Array1<f64>,
rho_bounds_upper: Array1<f64>,
temperature: GumbelTemperatureSchedule,
isometry: ScalarWeightSchedule,
) -> CoupledSchedules {
CoupledSchedules {
rho_entry,
rho_target,
rho_bounds_upper,
temperature,
isometry,
}
}
#[must_use]
fn default_coupled_schedules() -> CoupledSchedules {
const DEFAULT_ENTRY_RHO: f64 = 5.0;
const DEFAULT_RHO_UPPER: f64 = 10.0;
couple_schedules(
Array1::from_elem(1, DEFAULT_ENTRY_RHO),
Array1::zeros(1),
Array1::from_elem(1, DEFAULT_RHO_UPPER),
default_temperature_schedule(),
default_isometry_schedule(),
)
}
#[must_use]
fn default_temperature_schedule() -> GumbelTemperatureSchedule {
const DEFAULT_ENTRY_TAU: f64 = 4.0;
const DEFAULT_TARGET_TAU: f64 = 0.5;
GumbelTemperatureSchedule::new(
DEFAULT_ENTRY_TAU,
DEFAULT_TARGET_TAU,
ScheduleKind::Linear {
steps: CONTINUATION_WAYPOINTS,
},
)
.expect("default continuation temperature schedule must be valid")
}
#[must_use]
fn default_isometry_schedule() -> ScalarWeightSchedule {
const DEFAULT_ENTRY_ISOMETRY: f64 = 0.0;
const DEFAULT_TARGET_ISOMETRY: f64 = 1.0;
ScalarWeightSchedule::new(
DEFAULT_ENTRY_ISOMETRY,
DEFAULT_TARGET_ISOMETRY,
ScheduleKind::Linear {
steps: CONTINUATION_WAYPOINTS,
},
)
.expect("default continuation isometry schedule must be valid")
}
#[must_use]
pub fn mean_active_mass(assignments: ArrayView2<'_, f64>) -> f64 {
let n = assignments.nrows();
if n == 0 {
return 0.0;
}
let mut acc = 0.0;
for row in assignments.rows() {
let row_max = row.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if row_max.is_finite() {
acc += row_max;
}
}
acc / n as f64
}
#[cfg(test)]
mod tests {
use super::*;
fn lin_temp() -> GumbelTemperatureSchedule {
GumbelTemperatureSchedule::new(2.0, 0.1, ScheduleKind::Linear { steps: 8 })
.expect("valid temperature schedule")
}
fn lin_iso() -> ScalarWeightSchedule {
ScalarWeightSchedule::new(0.01, 1.0, ScheduleKind::Linear { steps: 8 })
.expect("valid isometry schedule")
}
fn schedules() -> CoupledSchedules {
couple_schedules(
Array1::from_vec(vec![5.0, 5.0]),
Array1::from_vec(vec![0.0, 0.0]),
Array1::from_vec(vec![10.0, 10.0]),
lin_temp(),
lin_iso(),
)
}
#[test]
fn entry_is_the_heavy_smoothing_regime() {
let path = ContinuationPath::enter(schedules());
assert_eq!(
path.s(),
1.0,
"entry must be s = 1 (heavy-smoothing regime)"
);
let targets = path.current_scalar_targets();
assert!((targets.tau - 2.0).abs() < 1e-12, "entry τ = tau_start");
assert!(
(targets.isometry_weight - 0.01).abs() < 1e-12,
"entry isometry = w_start"
);
let rho = path.current_rho_target();
assert!((rho[0] - 5.0).abs() < 1e-12 && (rho[1] - 5.0).abs() < 1e-12);
}
#[test]
fn target_endpoint_is_the_real_objective() {
let sch = schedules();
let targets0 = sch.scalar_targets_at(0.0);
assert!(
(targets0.tau - 0.1).abs() < 1e-12,
"s=0 τ = tau_min (sharp)"
);
assert!(
(targets0.isometry_weight - 1.0).abs() < 1e-12,
"s=0 isometry = w_end (tight)"
);
let rho0 = sch.rho_target_at(0.0);
assert!(
(rho0[0]).abs() < 1e-12 && (rho0[1]).abs() < 1e-12,
"s=0 ρ = ρ*"
);
}
#[test]
fn legs_move_in_lockstep_along_s() {
let sch = schedules();
let mid = sch.scalar_targets_at(0.5);
assert!((mid.tau - (0.1 + 0.5 * (2.0 - 0.1))).abs() < 1e-12);
assert!((mid.isometry_weight - (0.01 + 0.5 * (1.0 - 0.01))).abs() < 1e-12);
let rho_mid = sch.rho_target_at(0.5);
assert!((rho_mid[0] - 2.5).abs() < 1e-12);
}
#[test]
fn logit_trust_region_tightens_as_tau_cools() {
let hot = LogitTrustRegion::for_tau(2.0);
let cold = LogitTrustRegion::for_tau(0.05);
assert!(
cold.radius < hot.radius,
"colder τ must give a tighter logit trust region"
);
assert!(matches!(
cold.cap_step(cold.radius * 0.5),
LogitStepCap::Within
));
match cold.cap_step(cold.radius * 4.0) {
LogitStepCap::Scaled { scale } => {
assert!(scale > 0.0 && scale < 1.0);
assert!((scale - 0.25).abs() < 1e-12);
}
LogitStepCap::Within => panic!("expected the over-radius step to be scaled"),
}
}
#[test]
fn active_mass_floor_breach_is_recorded_never_fatal() {
let floor = ActiveMassFloor::default_floor();
assert!(floor.check(0.9).is_none(), "healthy routing → no breach");
let breach = floor.check(0.05).expect("collapsed routing → breach");
let mut ledger = ReseedLedger::new();
ledger.record(0.3, breach);
assert_eq!(ledger.reseed_count(), 1);
assert!((ledger.events()[0].s - 0.3).abs() < 1e-12);
}
#[test]
fn note_mass_breach_reenters_heavier_and_logs() {
let mut path = ContinuationPath::enter(schedules());
path.s = 0.5;
let mut ledger = ReseedLedger::new();
let breach = MassFloorBreach {
observed_mean_mass: 0.05,
floor: ActiveMassFloor::DEFAULT_FLOOR,
};
let step = path.note_mass_breach(breach, &mut ledger);
assert!(matches!(
step,
ContinuationStep::Reentered {
reason: ReentryReason::MassFloorBreached(_),
..
}
));
assert!(
path.s() > 0.5,
"re-entry must raise s toward the entry regime"
);
assert_eq!(ledger.reseed_count(), 1);
}
#[test]
fn continuation_step_has_no_reject_arm() {
fn is_progress(step: &ContinuationStep) -> bool {
match step {
ContinuationStep::Descended { .. }
| ContinuationStep::Arrived { .. }
| ContinuationStep::Reentered { .. } => true,
}
}
let breach = MassFloorBreach {
observed_mean_mass: 0.0,
floor: 0.2,
};
assert!(is_progress(&ContinuationStep::Reentered {
s: 1.0,
reason: ReentryReason::MassFloorBreached(breach),
}));
assert!(is_progress(&ContinuationStep::Reentered {
s: 1.0,
reason: ReentryReason::StepUnderflow,
}));
}
#[test]
fn mean_active_mass_distinguishes_routed_from_saddle() {
use ndarray::array;
let routed = array![[0.95, 0.05], [0.9, 0.1]];
let saddle = array![[0.5, 0.5], [0.5, 0.5]];
assert!(mean_active_mass(routed.view()) > 0.85);
assert!((mean_active_mass(saddle.view()) - 0.5).abs() < 1e-12);
assert!(
ActiveMassFloor::default_floor()
.check(mean_active_mass(saddle.view()))
.is_none(),
"uniform 0.5 is above the floor — saddle detection is about \
collapse below the failure boundary (0.5× the planted healthy \
mass), not the healthy operating point"
);
}
#[test]
fn heavy_entry_starts_in_the_heavy_regime() {
let path = ContinuationPath::heavy_entry();
assert_eq!(path.s(), 1.0, "heavy_entry must enter at s = 1");
assert_eq!(
path.enter_regime(),
PathRegime::Heavy,
"a fresh heavy_entry is in the heavy-smoothing regime"
);
assert!(
path.logit_step_radius().is_finite() && path.logit_step_radius() > 0.0,
"logit step radius must be finite and positive at entry"
);
}
#[test]
fn demote_with_reason_reenters_heavier_never_rejects() {
let mut path = ContinuationPath::heavy_entry();
path.s = 0.3;
path.s_step = 0.1;
let before = path.s;
let regime = path.demote_with_reason(PathDemotionReason::UniformStructural);
assert!(
path.s > before,
"demotion must raise s toward the entry regime"
);
assert_eq!(regime, path.enter_regime());
let regime2 = path.demote_with_reason(PathDemotionReason::PrewarmStructural);
assert_eq!(regime2, path.enter_regime());
assert!(path.s >= before, "repeated demotions never lower s");
}
#[test]
fn bare_active_mass_breach_records_and_reenters() {
let mut path = ContinuationPath::heavy_entry();
path.s = 0.4;
assert_eq!(path.reseed_count(), 0);
let before = path.s;
let regime = path.note_active_mass_breach();
assert_eq!(
path.reseed_count(),
1,
"breach must be recorded in the path ledger"
);
assert!(path.s > before, "breach must re-enter a heavier regime");
assert_eq!(regime, path.enter_regime());
}
#[test]
fn path_regime_bands_are_monotone_in_s() {
assert_eq!(PathRegime::from_s(0.0), PathRegime::Target);
assert_eq!(PathRegime::from_s(0.2), PathRegime::Target);
assert_eq!(PathRegime::from_s(0.5), PathRegime::Annealing);
assert_eq!(PathRegime::from_s(0.9), PathRegime::Heavy);
assert_eq!(PathRegime::from_s(1.0), PathRegime::Heavy);
}
#[test]
fn reentry_floors_step_but_never_exits() {
let mut path = ContinuationPath::enter(schedules());
path.s = 0.5;
for _ in 0..50 {
path.reenter_heavier();
assert!(path.s_step >= S_STEP_FLOOR);
assert!((0.0..=1.0).contains(&path.s));
}
}
}