use crate::r#trait::{RestorationOutcome, RestorationPhase};
use pounce_algorithm::eq_mult::least_square::LeastSquareMults;
use pounce_algorithm::eq_mult::r#trait::EqMultCalculator;
use pounce_algorithm::ipopt_cq::IpoptCqHandle;
use pounce_algorithm::ipopt_data::IpoptDataHandle;
use pounce_algorithm::ipopt_nlp::IpoptNlp;
use pounce_algorithm::iterates_vector::IteratesVector;
use pounce_algorithm::kkt::aug_system_solver::AugSystemSolver;
use pounce_common::types::{Index, Number};
use pounce_linalg::dense_vector::DenseVectorSpace;
use pounce_linalg::Vector;
use std::cell::RefCell;
use std::rc::Rc;
pub struct RestoSolveResult {
pub trial_x: Rc<dyn Vector>,
pub trial_s: Rc<dyn Vector>,
pub iter_count: Index,
pub iters_since_header: Index,
pub last_output: Number,
pub locally_infeasible: bool,
}
pub type RestoInnerSolver = Box<
dyn FnMut(
&IpoptDataHandle,
&IpoptCqHandle,
&Rc<RefCell<dyn IpoptNlp>>,
Option<pounce_algorithm::restoration::OrigProgressCallback>,
// Suppress the nested IPM's `r`-suffixed per-iteration table when
// false. Outer driver forwards `print_level == 0` via
// `RestorationPhase::set_print_iter_output`.
bool,
// Shared interactive debugger, forwarded onto the inner IPM so the
// same debugger can step the restoration sub-solve.
Option<Rc<RefCell<dyn pounce_algorithm::debug::DebugHook>>>,
) -> Option<RestoSolveResult>,
>;
pub struct MinC1NormRestoration {
pub bound_mult_reset_threshold: Number,
pub constr_mult_reset_threshold: Number,
pub expect_infeasible_problem: bool,
pub start_with_resto: bool,
pub eq_mult: Box<dyn EqMultCalculator>,
pub inner_solver: RestoInnerSolver,
pub(crate) orig_progress: Option<pounce_algorithm::restoration::OrigProgressCallback>,
pub(crate) last_inner_iter_count: Index,
pub(crate) print_iter_output: bool,
pub(crate) debug_hook: Option<Rc<RefCell<dyn pounce_algorithm::debug::DebugHook>>>,
}
impl Default for MinC1NormRestoration {
fn default() -> Self {
Self {
bound_mult_reset_threshold: 1e3,
constr_mult_reset_threshold: 0.0,
expect_infeasible_problem: false,
start_with_resto: false,
eq_mult: Box::new(LeastSquareMults::new()),
inner_solver: Box::new(|_, _, _, _, _, _| None),
orig_progress: None,
last_inner_iter_count: 0,
print_iter_output: true,
debug_hook: None,
}
}
}
impl MinC1NormRestoration {
pub fn new() -> Self {
Self::default()
}
pub fn with_inner_solver(mut self, hook: RestoInnerSolver) -> Self {
self.inner_solver = hook;
self
}
pub fn should_reset_bound_mults(&self, bound_mult_max: Number) -> bool {
bound_mult_max > self.bound_mult_reset_threshold
}
}
impl RestorationPhase for MinC1NormRestoration {
fn set_orig_progress_check(
&mut self,
cb: Option<pounce_algorithm::restoration::OrigProgressCallback>,
) {
self.orig_progress = cb;
}
fn last_inner_iter_count(&self) -> Index {
self.last_inner_iter_count
}
fn set_print_iter_output(&mut self, enabled: bool) {
self.print_iter_output = enabled;
}
fn set_debug_hook(
&mut self,
hook: Option<Rc<RefCell<dyn pounce_algorithm::debug::DebugHook>>>,
) {
self.debug_hook = hook;
}
fn perform_restoration(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
aug_solver: &mut dyn AugSystemSolver,
) -> RestorationOutcome {
let _resto_span = tracing::info_span!("restoration").entered();
let cb = self.orig_progress.take();
self.last_inner_iter_count = 0;
let saved_mu = data.borrow().curr_mu;
let saved_tau = data.borrow().curr_tau;
let Some(result) = (self.inner_solver)(
data,
cq,
nlp,
cb,
self.print_iter_output,
self.debug_hook.clone(),
) else {
return RestorationOutcome::Failed;
};
self.last_inner_iter_count = result.iter_count;
{
let mut d = data.borrow_mut();
d.curr_mu = saved_mu;
d.curr_tau = saved_tau;
}
if result.locally_infeasible {
return RestorationOutcome::LocallyInfeasible;
}
let Some(curr) = data.borrow().curr.clone() else {
return RestorationOutcome::Failed;
};
let mu = data.borrow().curr_mu;
let tau = data.borrow().curr_tau;
let new_trial = IteratesVector::new(
result.trial_x.clone(),
result.trial_s.clone(),
curr.y_c.clone(),
curr.y_d.clone(),
curr.z_l.clone(),
curr.z_u.clone(),
curr.v_l.clone(),
curr.v_u.clone(),
);
data.borrow_mut().set_trial(new_trial);
let cq_ref = cq.borrow();
let curr_slack_x_l = cq_ref.curr_slack_x_l();
let curr_slack_x_u = cq_ref.curr_slack_x_u();
let curr_slack_s_l = cq_ref.curr_slack_s_l();
let curr_slack_s_u = cq_ref.curr_slack_s_u();
let trial_slack_x_l = cq_ref.trial_slack_x_l();
let trial_slack_x_u = cq_ref.trial_slack_x_u();
let trial_slack_s_l = cq_ref.trial_slack_s_l();
let trial_slack_s_u = cq_ref.trial_slack_s_u();
drop(cq_ref);
let mut delta_z_l = make_zeroed_like(&*curr.z_l);
let mut delta_z_u = make_zeroed_like(&*curr.z_u);
let mut delta_v_l = make_zeroed_like(&*curr.v_l);
let mut delta_v_u = make_zeroed_like(&*curr.v_u);
compute_bound_multiplier_step(
&mut *delta_z_l,
&*curr.z_l,
&*curr_slack_x_l,
&*trial_slack_x_l,
mu,
);
compute_bound_multiplier_step(
&mut *delta_z_u,
&*curr.z_u,
&*curr_slack_x_u,
&*trial_slack_x_u,
mu,
);
compute_bound_multiplier_step(
&mut *delta_v_l,
&*curr.v_l,
&*curr_slack_s_l,
&*trial_slack_s_l,
mu,
);
compute_bound_multiplier_step(
&mut *delta_v_u,
&*curr.v_u,
&*curr_slack_s_u,
&*trial_slack_s_u,
mu,
);
let alpha_dual = curr
.z_l
.frac_to_bound(&*delta_z_l, tau)
.min(curr.z_u.frac_to_bound(&*delta_z_u, tau))
.min(curr.v_l.frac_to_bound(&*delta_v_l, tau))
.min(curr.v_u.frac_to_bound(&*delta_v_u, tau));
let mut new_z_l = clone_to_owned(&*curr.z_l);
let mut new_z_u = clone_to_owned(&*curr.z_u);
let mut new_v_l = clone_to_owned(&*curr.v_l);
let mut new_v_u = clone_to_owned(&*curr.v_u);
new_z_l.axpy(alpha_dual, &*delta_z_l);
new_z_u.axpy(alpha_dual, &*delta_z_u);
new_v_l.axpy(alpha_dual, &*delta_v_l);
new_v_u.axpy(alpha_dual, &*delta_v_u);
let bound_max = bound_mult_amax(&*new_z_l, &*new_z_u, &*new_v_l, &*new_v_u);
if self.should_reset_bound_mults(bound_max) {
reset_bound_multipliers_to_one(
&mut *new_z_l,
&mut *new_z_u,
&mut *new_v_l,
&mut *new_v_u,
);
}
let trial_with_bound_mults = IteratesVector::new(
result.trial_x.clone(),
result.trial_s.clone(),
curr.y_c.clone(),
curr.y_d.clone(),
Rc::from(new_z_l),
Rc::from(new_z_u),
Rc::from(new_v_l),
Rc::from(new_v_u),
);
data.borrow_mut().set_trial(trial_with_bound_mults);
let mut new_y_c = make_zeroed_like(&*curr.y_c);
let mut new_y_d = make_zeroed_like(&*curr.y_d);
let total_eq_dim = new_y_c.dim() + new_y_d.dim();
let square = new_y_c.dim() == result.trial_x.dim();
if !square && self.constr_mult_reset_threshold > 0.0 && total_eq_dim > 0 {
let recovered = data
.borrow()
.trial
.as_ref()
.expect("just set above")
.clone();
data.borrow_mut().curr = Some(recovered);
let lsm_ok = self.eq_mult.calculate_y_eq(
data,
cq,
nlp,
aug_solver,
&mut *new_y_c,
&mut *new_y_d,
);
if lsm_ok {
let yinitnrm = new_y_c.amax().max(new_y_d.amax());
if yinitnrm > self.constr_mult_reset_threshold {
new_y_c.set(0.0);
new_y_d.set(0.0);
}
} else {
new_y_c.set(0.0);
new_y_d.set(0.0);
}
}
let staged = data
.borrow()
.trial
.as_ref()
.expect("just set above")
.clone();
let final_trial = IteratesVector::new(
staged.x.clone(),
staged.s.clone(),
Rc::from(new_y_c),
Rc::from(new_y_d),
staged.z_l.clone(),
staged.z_u.clone(),
staged.v_l.clone(),
staged.v_u.clone(),
);
data.borrow_mut().set_trial(final_trial);
{
let mut d = data.borrow_mut();
d.iter_count = result.iter_count.saturating_sub(1).max(0);
d.info_skip_output = true;
d.info_iters_since_header = result.iters_since_header;
d.info_last_output = result.last_output;
d.info_alpha_primal_char = 'R';
}
RestorationOutcome::Recovered
}
}
fn make_zeroed_like(template: &dyn Vector) -> Box<dyn Vector> {
let n = template.dim();
let mut v = DenseVectorSpace::new(n).make_new_dense();
v.set(0.0);
Box::new(v)
}
fn clone_to_owned(template: &dyn Vector) -> Box<dyn Vector> {
let n = template.dim();
let mut v = DenseVectorSpace::new(n).make_new_dense();
v.copy(template);
Box::new(v)
}
pub fn compute_bound_multiplier_step_elem(
curr_z: f64,
curr_slack: f64,
trial_slack: f64,
mu: f64,
) -> f64 {
let num = curr_z * (curr_slack - trial_slack) + mu;
num / curr_slack - curr_z
}
pub fn compute_bound_multiplier_step(
delta_z: &mut dyn Vector,
curr_z: &dyn Vector,
curr_slack: &dyn Vector,
trial_slack: &dyn Vector,
mu: Number,
) {
debug_assert_eq!(delta_z.dim(), curr_z.dim());
debug_assert_eq!(delta_z.dim(), curr_slack.dim());
debug_assert_eq!(delta_z.dim(), trial_slack.dim());
delta_z.copy(curr_slack);
delta_z.axpy(-1.0, trial_slack);
delta_z.element_wise_multiply(curr_z);
delta_z.add_scalar(mu);
delta_z.element_wise_divide(curr_slack);
delta_z.axpy(-1.0, curr_z);
}
pub fn bound_mult_amax(
z_l: &dyn Vector,
z_u: &dyn Vector,
v_l: &dyn Vector,
v_u: &dyn Vector,
) -> Number {
let a = z_l.amax();
let b = z_u.amax();
let c = v_l.amax();
let d = v_u.amax();
a.max(b).max(c).max(d)
}
pub fn reset_bound_multipliers_to_one(
z_l: &mut dyn Vector,
z_u: &mut dyn Vector,
v_l: &mut dyn Vector,
v_u: &mut dyn Vector,
) {
z_l.set(1.0);
z_u.set(1.0);
v_l.set(1.0);
v_u.set(1.0);
}
#[cfg(test)]
mod tests {
use super::*;
use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
fn dv(values: &[f64]) -> DenseVector {
let mut v = DenseVectorSpace::new(values.len() as i32).make_new_dense();
v.values_mut().copy_from_slice(values);
v
}
#[test]
fn no_reset_when_below_threshold() {
let r = MinC1NormRestoration::new();
assert!(!r.should_reset_bound_mults(999.0));
assert!(r.should_reset_bound_mults(1001.0));
}
#[test]
fn bound_step_zero_when_slacks_unchanged_and_complementary() {
let mu = 0.1;
let s = 0.5;
let z = mu / s; let dz = compute_bound_multiplier_step_elem(z, s, s, mu);
assert!(dz.abs() < 1e-15, "dz = {}", dz);
}
#[test]
fn bound_step_handles_slack_decrease() {
let z = 2.0;
let s_curr = 1.0;
let s_trial = 0.5;
let mu = 0.1;
let dz = compute_bound_multiplier_step_elem(z, s_curr, s_trial, mu);
let expected = (z * (s_curr - s_trial) + mu) / s_curr - z;
assert!((dz - expected).abs() < 1e-15);
}
#[test]
fn vector_bound_step_matches_scalar_per_element() {
let curr_z = dv(&[2.0, 0.5, 4.0]);
let curr_s = dv(&[1.0, 0.8, 2.0]);
let trial_s = dv(&[0.5, 0.9, 1.5]);
let mu = 0.1;
let mut delta = dv(&[0.0; 3]);
compute_bound_multiplier_step(&mut delta, &curr_z, &curr_s, &trial_s, mu);
for i in 0..3 {
let expected = compute_bound_multiplier_step_elem(
curr_z.values()[i],
curr_s.values()[i],
trial_s.values()[i],
mu,
);
assert!(
(delta.values()[i] - expected).abs() < 1e-14,
"i={i}: got {} vs {}",
delta.values()[i],
expected
);
}
}
#[test]
fn vector_bound_step_zero_when_slacks_unchanged_and_complementary() {
let mu = 0.2;
let s = [0.5, 0.4, 1.0];
let z: Vec<f64> = s.iter().map(|&si| mu / si).collect();
let curr_z = dv(&z);
let curr_s = dv(&s);
let trial_s = dv(&s);
let mut delta = dv(&[0.0; 3]);
compute_bound_multiplier_step(&mut delta, &curr_z, &curr_s, &trial_s, mu);
for v in delta.values() {
assert!(v.abs() < 1e-14, "expected ~0, got {v}");
}
}
#[test]
fn bound_mult_amax_takes_global_max_across_four_vectors() {
let z_l = dv(&[1.0, -2.5]);
let z_u = dv(&[3.0, 0.0]);
let v_l = dv(&[-7.0]);
let v_u = dv(&[4.5, 4.4]);
assert_eq!(bound_mult_amax(&z_l, &z_u, &v_l, &v_u), 7.0);
}
#[test]
fn reset_bound_multipliers_writes_one_to_all_four() {
let mut z_l = dv(&[5.0, 6.0, 7.0]);
let mut z_u = dv(&[8.0]);
let mut v_l = dv(&[9.0, 10.0]);
let mut v_u = dv(&[11.0]);
reset_bound_multipliers_to_one(&mut z_l, &mut z_u, &mut v_l, &mut v_u);
assert_eq!(z_l.expanded_values(), vec![1.0, 1.0, 1.0]);
assert_eq!(z_u.expanded_values(), vec![1.0]);
assert_eq!(v_l.expanded_values(), vec![1.0, 1.0]);
assert_eq!(v_u.expanded_values(), vec![1.0]);
}
#[test]
fn driver_constructs_with_default_inner_solver_that_short_circuits_to_failed() {
let driver = MinC1NormRestoration::new();
assert_eq!(driver.bound_mult_reset_threshold, 1e3);
assert_eq!(driver.constr_mult_reset_threshold, 0.0);
}
}