use ndarray::Array1;
use crate::families::custom_family::KktRefusalDiagnosis;
use crate::families::inner_status::{InnerFailure, classify_inner_error};
use crate::solver::estimate::EstimationError;
use crate::solver::outer_strategy::{OuterEval, OuterEvalOrder, OuterObjective};
pub(crate) const PATH_BUDGET: usize = 64;
pub(crate) const ALPHA_FLOOR: f64 = 1.0 / 1024.0;
pub(crate) const ALPHA_INIT: f64 = 0.5;
pub(crate) const ALPHA_EXPAND: f64 = 1.5;
pub(crate) const ALPHA_SHRINK: f64 = 0.5;
pub(crate) const OVERSMOOTH_OFFSET_INIT: f64 = 3.4657359027997265;
pub(crate) const OVERSMOOTH_RETRY_MAX: usize = 3;
pub(crate) const RHO_EQUAL_TOL: f64 = 0.5;
#[derive(Debug, Clone)]
pub(crate) enum ContinuationFailure {
PathBudgetExhausted {
last: InnerFailure,
steps_taken: usize,
final_rho: Array1<f64>,
},
PathStuck {
last: InnerFailure,
rho_zero_offset: f64,
final_rho: Array1<f64>,
},
StructuralPropagate(InnerFailure),
DomainAtOversmoothedStart(InnerFailure),
}
impl ContinuationFailure {
pub(crate) fn message(&self) -> String {
match self {
Self::PathBudgetExhausted {
last,
steps_taken,
final_rho,
} => format!(
"{} (continuation budget exhausted after {} step(s), final rho dim={})",
last.message(),
steps_taken,
final_rho.len()
),
Self::PathStuck {
last,
rho_zero_offset,
final_rho,
} => format!(
"{} (continuation stuck at oversmooth offset {:.6e}, final rho dim={})",
last.message(),
rho_zero_offset,
final_rho.len()
),
Self::StructuralPropagate(last) | Self::DomainAtOversmoothedStart(last) => {
last.message().to_string()
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ContinuationState {
pub last_rho: Array1<f64>,
pub last_eval: OuterEval,
pub last_beta: Array1<f64>,
pub steps_accepted: usize,
}
#[derive(Debug, Clone, Copy)]
enum FailureAction {
ShrinkStep,
ShrinkOrExpand,
Propagate,
ExpandRhoZero,
}
fn classify_action(failure: &InnerFailure) -> FailureAction {
match failure {
InnerFailure::CertRefused { diagnosis, .. } => match diagnosis {
KktRefusalDiagnosis::RankDeficientHPen => FailureAction::ExpandRhoZero,
KktRefusalDiagnosis::ActiveSetIncomplete => FailureAction::Propagate,
KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => FailureAction::ShrinkStep,
KktRefusalDiagnosis::AliasingDetectedAtFit => FailureAction::Propagate,
},
InnerFailure::BudgetExhausted { .. } => FailureAction::ShrinkStep,
InnerFailure::TrustRegionFloor { .. } => FailureAction::ShrinkOrExpand,
InnerFailure::LikelihoodFailure(_) => FailureAction::ShrinkStep,
InnerFailure::IdentifiabilityFailure { .. } => FailureAction::Propagate,
InnerFailure::Other(_) => FailureAction::Propagate,
}
}
fn build_rho_zero(target: &Array1<f64>, upper: &Array1<f64>, offset: f64) -> Array1<f64> {
assert_eq!(target.len(), upper.len());
let mut rho0 = target.clone();
for i in 0..rho0.len() {
let candidate = target[i] + offset;
rho0[i] = candidate.min(upper[i]);
}
rho0
}
fn rho_zero_is_target(rho0: &Array1<f64>, target: &Array1<f64>) -> bool {
assert_eq!(rho0.len(), target.len());
rho0.iter()
.zip(target.iter())
.all(|(a, b)| (a - b).abs() <= RHO_EQUAL_TOL)
}
fn step_toward(rho_k: &Array1<f64>, target: &Array1<f64>, alpha: f64) -> Array1<f64> {
assert_eq!(rho_k.len(), target.len());
let mut out = Array1::<f64>::zeros(rho_k.len());
for i in 0..rho_k.len() {
out[i] = rho_k[i] + alpha * (target[i] - rho_k[i]);
}
out
}
fn reached_target(rho: &Array1<f64>, target: &Array1<f64>) -> bool {
let tol = RHO_EQUAL_TOL / 8.0;
rho.iter()
.zip(target.iter())
.all(|(a, b)| (a - b).abs() <= tol)
}
fn inner_failure_from(err: EstimationError) -> InnerFailure {
match err {
EstimationError::RemlOptimizationFailed(msg) => classify_inner_error(msg),
other => InnerFailure::Other(other.to_string()),
}
}
fn eval_step(
obj: &mut dyn OuterObjective,
rho: &Array1<f64>,
beta_seed: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, InnerFailure> {
if let Err(e) = obj.seed_inner_state(beta_seed) {
return Err(inner_failure_from(e));
}
obj.eval_with_order(rho, order).map_err(inner_failure_from)
}
pub(crate) type ContinuationResult = Result<ContinuationState, ContinuationFailure>;
#[derive(Debug, Clone, Copy)]
pub(crate) struct PrimingSummary {
pub collapsed: bool,
pub steps_accepted: usize,
}
pub(crate) fn prime_outer_seed(
obj: &mut dyn OuterObjective,
seed: &Array1<f64>,
bounds_upper: &Array1<f64>,
) -> Result<PrimingSummary, ContinuationFailure> {
let rho_zero = build_rho_zero(seed, bounds_upper, OVERSMOOTH_OFFSET_INIT);
if rho_zero_is_target(&rho_zero, seed) {
return Ok(PrimingSummary {
collapsed: true,
steps_accepted: 0,
});
}
let empty_beta: Array1<f64> = Array1::zeros(0);
match fit_with_continuation(
obj,
seed,
bounds_upper,
&empty_beta,
OuterEvalOrder::ValueAndGradient,
) {
Ok(state) => Ok(PrimingSummary {
collapsed: false,
steps_accepted: state.steps_accepted,
}),
Err(failure) => Err(failure),
}
}
fn fit_with_continuation(
obj: &mut dyn OuterObjective,
target: &Array1<f64>,
bounds_upper: &Array1<f64>,
initial_beta: &Array1<f64>,
order: OuterEvalOrder,
) -> ContinuationResult {
if target.len() != bounds_upper.len() {
return Err(ContinuationFailure::StructuralPropagate(
InnerFailure::Other(format!(
"continuation: target len {} != bounds_upper len {}",
target.len(),
bounds_upper.len()
)),
));
}
let mut offset = OVERSMOOTH_OFFSET_INIT;
for retry in 0..=OVERSMOOTH_RETRY_MAX {
match run_path(obj, target, bounds_upper, initial_beta, order, offset) {
Ok(state) => return Ok(state),
Err(PathOutcome::ExpandRhoZero(last)) | Err(PathOutcome::Stuck(last)) => {
if retry == OVERSMOOTH_RETRY_MAX {
let final_rho = build_rho_zero(target, bounds_upper, offset);
return Err(ContinuationFailure::PathStuck {
last,
rho_zero_offset: offset,
final_rho,
});
}
offset *= 2.0;
}
Err(PathOutcome::PathBudgetExhausted {
last,
steps_taken,
final_rho,
}) => {
return Err(ContinuationFailure::PathBudgetExhausted {
last,
steps_taken,
final_rho,
});
}
Err(PathOutcome::Propagate(last)) => {
return Err(ContinuationFailure::StructuralPropagate(last));
}
Err(PathOutcome::DomainAtStart(last)) => {
if retry == OVERSMOOTH_RETRY_MAX {
return Err(ContinuationFailure::DomainAtOversmoothedStart(last));
}
offset *= 2.0;
}
}
}
Err(ContinuationFailure::PathStuck {
last: InnerFailure::Other("continuation: retry loop ended unexpectedly".into()),
rho_zero_offset: offset,
final_rho: build_rho_zero(target, bounds_upper, offset),
})
}
enum PathOutcome {
ExpandRhoZero(InnerFailure),
Stuck(InnerFailure),
DomainAtStart(InnerFailure),
Propagate(InnerFailure),
PathBudgetExhausted {
last: InnerFailure,
steps_taken: usize,
final_rho: Array1<f64>,
},
}
fn run_path(
obj: &mut dyn OuterObjective,
target: &Array1<f64>,
bounds_upper: &Array1<f64>,
initial_beta: &Array1<f64>,
order: OuterEvalOrder,
offset: f64,
) -> Result<ContinuationState, PathOutcome> {
let rho0 = build_rho_zero(target, bounds_upper, offset);
let collapsed = rho_zero_is_target(&rho0, target);
let rho_first = if collapsed { target.clone() } else { rho0 };
let mut beta_seed = initial_beta.clone();
let eval0 = match eval_step(obj, &rho_first, &beta_seed, order) {
Ok(eval) => eval,
Err(failure) => {
return Err(match failure {
InnerFailure::LikelihoodFailure(_) => PathOutcome::DomainAtStart(failure),
InnerFailure::CertRefused {
diagnosis: KktRefusalDiagnosis::ActiveSetIncomplete,
..
} => PathOutcome::Propagate(failure),
InnerFailure::CertRefused {
diagnosis: KktRefusalDiagnosis::AliasingDetectedAtFit,
..
} => PathOutcome::Propagate(failure),
InnerFailure::IdentifiabilityFailure { .. } => PathOutcome::Propagate(failure),
_ => PathOutcome::ExpandRhoZero(failure),
});
}
};
let mut state = ContinuationState {
last_rho: rho_first,
last_eval: eval0,
last_beta: beta_seed.clone(),
steps_accepted: 1,
};
if collapsed || reached_target(&state.last_rho, target) {
return Ok(state);
}
let mut alpha = ALPHA_INIT;
let mut steps_taken: usize = 1;
let mut last_failure: Option<InnerFailure> = None;
let mut consecutive_trust_floor: usize = 0;
while steps_taken < PATH_BUDGET {
if reached_target(&state.last_rho, target) {
return Ok(state);
}
let rho_next = step_toward(&state.last_rho, target, alpha);
beta_seed = state
.last_eval
.inner_beta_hint
.clone()
.unwrap_or_else(|| state.last_beta.clone());
match eval_step(obj, &rho_next, &beta_seed, order) {
Ok(eval) => {
state.last_rho = rho_next;
state.last_eval = eval;
state.last_beta = beta_seed;
state.steps_accepted += 1;
steps_taken += 1;
last_failure = None;
consecutive_trust_floor = 0;
alpha = (alpha * ALPHA_EXPAND).min(1.0);
}
Err(failure) => {
last_failure = Some(failure.clone());
match classify_action(&failure) {
FailureAction::Propagate => {
return Err(PathOutcome::Propagate(failure));
}
FailureAction::ExpandRhoZero => {
return Err(PathOutcome::ExpandRhoZero(failure));
}
FailureAction::ShrinkStep => {
alpha *= ALPHA_SHRINK;
if alpha < ALPHA_FLOOR {
return Err(PathOutcome::Stuck(failure));
}
steps_taken += 1;
}
FailureAction::ShrinkOrExpand => {
consecutive_trust_floor += 1;
if consecutive_trust_floor >= 2 {
return Err(PathOutcome::ExpandRhoZero(failure));
}
alpha *= ALPHA_SHRINK;
if alpha < ALPHA_FLOOR {
return Err(PathOutcome::Stuck(failure));
}
steps_taken += 1;
}
}
}
}
}
Err(PathOutcome::PathBudgetExhausted {
last: last_failure.unwrap_or_else(|| {
InnerFailure::Other("continuation: budget hit without recorded failure".into())
}),
steps_taken,
final_rho: state.last_rho.clone(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rho_zero_collapses_when_target_at_upper_bound() {
let target = Array1::from_vec(vec![5.0, 5.0]);
let upper = Array1::from_vec(vec![5.0, 5.0]);
let rho0 = build_rho_zero(&target, &upper, OVERSMOOTH_OFFSET_INIT);
assert_eq!(rho0, target);
assert!(rho_zero_is_target(&rho0, &target));
}
#[test]
fn rho_zero_offsets_above_target_when_room() {
let target = Array1::from_vec(vec![0.0, -2.0]);
let upper = Array1::from_vec(vec![10.0, 10.0]);
let rho0 = build_rho_zero(&target, &upper, OVERSMOOTH_OFFSET_INIT);
assert!((rho0[0] - OVERSMOOTH_OFFSET_INIT).abs() < 1e-12);
assert!((rho0[1] - (-2.0 + OVERSMOOTH_OFFSET_INIT)).abs() < 1e-12);
assert!(!rho_zero_is_target(&rho0, &target));
}
#[test]
fn step_toward_is_convex_combination() {
let a = Array1::from_vec(vec![0.0, 0.0]);
let b = Array1::from_vec(vec![4.0, -8.0]);
let mid = step_toward(&a, &b, 0.5);
assert!((mid[0] - 2.0).abs() < 1e-12);
assert!((mid[1] - (-4.0)).abs() < 1e-12);
let full = step_toward(&a, &b, 1.0);
assert!((full[0] - 4.0).abs() < 1e-12);
assert!((full[1] - (-8.0)).abs() < 1e-12);
}
#[test]
fn classify_action_routes_diagnoses_correctly() {
let rank_def = InnerFailure::CertRefused {
diagnosis: KktRefusalDiagnosis::RankDeficientHPen,
carrying_block: None,
message: "".into(),
};
assert!(matches!(
classify_action(&rank_def),
FailureAction::ExpandRhoZero
));
let active_incomp = InnerFailure::CertRefused {
diagnosis: KktRefusalDiagnosis::ActiveSetIncomplete,
carrying_block: None,
message: "".into(),
};
assert!(matches!(
classify_action(&active_incomp),
FailureAction::Propagate
));
let phantom = InnerFailure::CertRefused {
diagnosis: KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH,
carrying_block: None,
message: "".into(),
};
assert!(matches!(
classify_action(&phantom),
FailureAction::ShrinkStep
));
assert!(matches!(
classify_action(&InnerFailure::BudgetExhausted { message: "".into() }),
FailureAction::ShrinkStep
));
assert!(matches!(
classify_action(&InnerFailure::TrustRegionFloor { message: "".into() }),
FailureAction::ShrinkOrExpand
));
assert!(matches!(
classify_action(&InnerFailure::LikelihoodFailure("".into())),
FailureAction::ShrinkStep
));
assert!(matches!(
classify_action(&InnerFailure::Other("".into())),
FailureAction::Propagate
));
}
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, HessianResult, OuterCapability,
};
#[derive(Clone)]
enum ScriptedResponse {
Ok,
Fail(&'static str),
}
struct ScriptedObjective {
n_params: usize,
queue: Vec<ScriptedResponse>,
idx: usize,
rho_history: Vec<Array1<f64>>,
seed_calls: usize,
last_seeded_beta_len: Option<usize>,
}
impl ScriptedObjective {
fn new(n_params: usize, queue: Vec<ScriptedResponse>) -> Self {
Self {
n_params,
queue,
idx: 0,
rho_history: Vec::new(),
seed_calls: 0,
last_seeded_beta_len: None,
}
}
fn next_response(&mut self) -> ScriptedResponse {
let r = self
.queue
.get(self.idx)
.cloned()
.unwrap_or(ScriptedResponse::Ok);
self.idx += 1;
r
}
}
impl OuterObjective for ScriptedObjective {
fn capability(&self) -> OuterCapability {
OuterCapability {
gradient: Derivative::Analytic,
hessian: DeclaredHessianForm::Unavailable,
n_params: self.n_params,
psi_dim: 0,
fixed_point_available: false,
barrier_config: None,
prefer_gradient_only: false,
disable_fixed_point: false,
}
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.rho_history.push(rho.clone());
match self.next_response() {
ScriptedResponse::Ok => Ok(rho.dot(rho)),
ScriptedResponse::Fail(msg) => {
Err(EstimationError::RemlOptimizationFailed(msg.to_string()))
}
}
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
let cost = self.eval_cost(rho)?;
Ok(OuterEval {
cost,
gradient: Array1::zeros(self.n_params),
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
}
fn reset(&mut self) {
self.idx = 0;
self.rho_history.clear();
self.seed_calls = 0;
self.last_seeded_beta_len = None;
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<(), EstimationError> {
assert_eq!(beta.len(), self.n_params);
self.seed_calls += 1;
self.last_seeded_beta_len = Some(beta.len());
Ok(())
}
}
fn rho(values: &[f64]) -> Array1<f64> {
Array1::from_vec(values.to_vec())
}
#[test]
fn degenerates_to_cold_start_on_easy_fits() {
let target = rho(&[5.0, 5.0]);
let upper = rho(&[5.0, 5.0]);
let mut obj = ScriptedObjective::new(2, Vec::new());
let summary = prime_outer_seed(&mut obj, &target, &upper).expect("collapse path");
assert!(summary.collapsed, "must report collapsed=true on easy fits");
assert_eq!(summary.steps_accepted, 0);
assert_eq!(obj.rho_history.len(), 0, "no inner calls on collapse");
assert_eq!(obj.seed_calls, 0);
}
#[test]
fn budget_exhausted_warmstart_completes_path() {
let target = rho(&[0.0]);
let upper = rho(&[10.0]);
let mut obj = ScriptedObjective::new(
1,
vec![
ScriptedResponse::Ok, ScriptedResponse::Fail("inner_max_cycles reached"), ScriptedResponse::Fail("inner_max_cycles reached"), ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
],
);
prime_outer_seed(&mut obj, &target, &upper).expect("path completes via shrink-on-budget");
assert!(obj.rho_history.len() >= 3, "must have walked a path");
let rho0 = &obj.rho_history[0];
assert!(
(rho0[0] - (target[0] + OVERSMOOTH_OFFSET_INIT)).abs() < 1e-9,
"first call is at ρ₀ = ρ*+offset",
);
}
#[test]
fn trust_region_floor_alpha_shrink_then_recovers() {
let target = rho(&[0.0]);
let upper = rho(&[10.0]);
let mut obj = ScriptedObjective::new(
1,
vec![
ScriptedResponse::Ok, ScriptedResponse::Fail("trust-region floor reached"), ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
],
);
prime_outer_seed(&mut obj, &target, &upper)
.expect("path completes after single TR-floor shrink");
assert!(obj.rho_history.len() >= 3);
}
#[test]
fn likelihood_failure_alpha_shrink_then_recovers() {
let target = rho(&[0.0]);
let upper = rho(&[10.0]);
let mut obj = ScriptedObjective::new(
1,
vec![
ScriptedResponse::Ok, ScriptedResponse::Fail("likelihood evaluation failed: NaN"),
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
ScriptedResponse::Ok,
],
);
let outcome = prime_outer_seed(&mut obj, &target, &upper);
assert!(
outcome.is_ok(),
"path completes after likelihood shrink, got {:?}",
outcome.err(),
);
assert!(obj.rho_history.len() >= 3);
}
#[test]
fn active_set_incomplete_propagates_structurally() {
let target = rho(&[0.0]);
let upper = rho(&[10.0]);
let mut obj = ScriptedObjective::new(
1,
vec![
ScriptedResponse::Ok, ScriptedResponse::Fail(
"cycle=3 cert REFUSED: residual=1.0e+02 > tol=1.0e+00; \
carrying-block: time_surface (idx=0); \
diagnosis: active_set_incomplete",
),
],
);
let err = prime_outer_seed(&mut obj, &target, &upper)
.expect_err("structural failure must propagate");
assert!(
matches!(err, ContinuationFailure::StructuralPropagate(_)),
"got {err:?}",
);
match err {
ContinuationFailure::StructuralPropagate(InnerFailure::CertRefused {
diagnosis,
..
}) => assert_eq!(diagnosis, KktRefusalDiagnosis::ActiveSetIncomplete),
other => panic!("expected CertRefused, got {other:?}"),
}
}
#[test]
fn path_budget_exhausted_surfaces_last_inner_failure() {
let target = rho(&[0.0]);
let upper = rho(&[10.0]);
let mut responses: Vec<ScriptedResponse> = Vec::new();
for _ in 0..32 {
responses.push(ScriptedResponse::Ok);
for _ in 0..20 {
responses.push(ScriptedResponse::Fail(
"coupled exact-joint inner solve exited the joint Newton path \
before convergence — block 'time_surface' carries the dominant \
unresolved KKT gradient (|g_block|∞ = 5.000e+05); \
|∇L − Sβ|∞ = 5.000e+05",
));
}
}
let mut obj = ScriptedObjective::new(1, responses);
let err = prime_outer_seed(&mut obj, &target, &upper).expect_err("schedule must fail");
match err {
ContinuationFailure::PathStuck { last, .. } => match last {
InnerFailure::CertRefused { diagnosis, .. } => assert_eq!(
diagnosis,
KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH
),
other => panic!("expected CertRefused, got {other:?}"),
},
other => panic!("expected PathStuck, got {other:?}"),
}
}
#[test]
fn pre_warm_failure_carries_underlying_message_for_seed_rejection() {
let target = rho(&[0.0]);
let upper = rho(&[10.0]);
let mut obj = ScriptedObjective::new(
1,
vec![ScriptedResponse::Fail(
"cycle=3 cert REFUSED: residual=1.0e+02 > tol=1.0e+00; \
diagnosis: active_set_incomplete",
)],
);
let err = prime_outer_seed(&mut obj, &target, &upper).expect_err("propagation expected");
let msg = err.message();
assert!(msg.contains("active_set_incomplete"), "msg='{msg}'");
}
}