#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RestoConvergenceStatus {
Continue,
Converged,
MaxIterExceeded,
UserStop,
}
pub struct RestoConvCheck {
pub kappa_resto: f64,
pub maximum_iters: i32,
pub maximum_resto_iters: i32,
pub orig_constr_viol_tol: f64,
first_resto_iter: bool,
successive_resto_iter: i32,
}
impl Default for RestoConvCheck {
fn default() -> Self {
Self {
kappa_resto: 0.9,
maximum_iters: 3000,
maximum_resto_iters: 3000,
orig_constr_viol_tol: 1e-4,
first_resto_iter: true,
successive_resto_iter: 0,
}
}
}
impl RestoConvCheck {
pub fn new() -> Self {
Self::default()
}
pub fn reset(&mut self) {
self.first_resto_iter = true;
self.successive_resto_iter = 0;
}
pub fn check_convergence(
&mut self,
iter_count: i32,
is_square_problem: bool,
orig_curr_inf_pr: f64,
orig_trial_inf_pr: f64,
orig_tol: f64,
acceptable_to_outer: impl FnOnce() -> bool,
) -> RestoConvergenceStatus {
if iter_count > self.maximum_iters {
return RestoConvergenceStatus::MaxIterExceeded;
}
if self.successive_resto_iter > self.maximum_resto_iters {
return RestoConvergenceStatus::MaxIterExceeded;
}
self.successive_resto_iter += 1;
if self.first_resto_iter {
self.first_resto_iter = false;
return RestoConvergenceStatus::Continue;
}
if is_square_problem {
let target = orig_tol.min(self.orig_constr_viol_tol);
if orig_trial_inf_pr <= target {
return RestoConvergenceStatus::Converged;
}
}
if self.kappa_resto > 0.0 && orig_trial_inf_pr > self.kappa_resto * orig_curr_inf_pr {
return RestoConvergenceStatus::Continue;
}
if acceptable_to_outer() {
RestoConvergenceStatus::Converged
} else {
RestoConvergenceStatus::Continue
}
}
}
pub struct RestoFilterConvCheck {
pub base: RestoConvCheck,
pub obj_max_inc: f64,
}
impl RestoFilterConvCheck {
pub fn new() -> Self {
Self {
base: RestoConvCheck::new(),
obj_max_inc: 5.0,
}
}
pub fn test_orig_progress(
&self,
outer: &pounce_algorithm::line_search::filter_acceptor::FilterLsAcceptor,
orig_trial_barr: f64,
orig_trial_theta: f64,
reference_barr: f64,
reference_theta: f64,
) -> RestoConvergenceStatus {
if !outer.is_acceptable_to_current_filter(orig_trial_barr, orig_trial_theta) {
return RestoConvergenceStatus::Continue;
}
if !outer.is_acceptable_to_current_iterate(
orig_trial_barr,
orig_trial_theta,
reference_barr,
reference_theta,
self.obj_max_inc,
true, ) {
return RestoConvergenceStatus::Continue;
}
RestoConvergenceStatus::Converged
}
}
impl Default for RestoFilterConvCheck {
fn default() -> Self {
Self::new()
}
}
pub struct RestoPenaltyConvCheck {
pub base: RestoConvCheck,
}
pub struct RestoConvCheckAdapter {
inner: pounce_algorithm::conv_check::opt_error::OptErrorConvCheck,
maximum_iters: i32,
maximum_resto_iters: i32,
successive_resto_iter: i32,
orig_nlp: Option<std::rc::Rc<std::cell::RefCell<dyn pounce_nlp::ipopt_nlp::IpoptNlp>>>,
orig_curr_inf_pr: f64,
kappa_resto: f64,
orig_progress_callback: Option<pounce_algorithm::restoration::OrigProgressCallback>,
}
impl RestoConvCheckAdapter {
pub fn new(
tol: f64,
acceptable_tol: f64,
acceptable_iter: i32,
max_iter: i32,
maximum_resto_iters: i32,
) -> Self {
let mut inner = pounce_algorithm::conv_check::opt_error::OptErrorConvCheck::new();
inner.tol = tol;
inner.acceptable_tol = acceptable_tol;
inner.acceptable_iter = acceptable_iter;
inner.max_iter = max_iter;
Self {
inner,
maximum_iters: max_iter,
maximum_resto_iters,
successive_resto_iter: 0,
orig_nlp: None,
orig_curr_inf_pr: f64::INFINITY,
kappa_resto: 0.9,
orig_progress_callback: None,
}
}
pub fn with_orig_progress_callback(
mut self,
cb: pounce_algorithm::restoration::OrigProgressCallback,
) -> Self {
self.orig_progress_callback = Some(cb);
self
}
pub fn with_orig_progress_guard(
mut self,
orig: std::rc::Rc<std::cell::RefCell<dyn pounce_nlp::ipopt_nlp::IpoptNlp>>,
orig_curr_inf_pr: f64,
kappa_resto: f64,
) -> Self {
self.orig_nlp = Some(orig);
self.orig_curr_inf_pr = orig_curr_inf_pr;
self.kappa_resto = kappa_resto;
self
}
pub fn from_base(base: &RestoConvCheck, tol: f64, acceptable_tol: f64) -> Self {
Self::new(
tol,
acceptable_tol,
15, base.maximum_iters,
base.maximum_resto_iters,
)
}
}
impl pounce_algorithm::conv_check::r#trait::ConvCheck for RestoConvCheckAdapter {
fn check_convergence(
&mut self,
nlp_err: pounce_common::types::Number,
iter_count: pounce_common::types::Index,
) -> pounce_algorithm::conv_check::r#trait::ConvergenceStatus {
use pounce_algorithm::conv_check::r#trait::ConvergenceStatus;
if iter_count >= self.maximum_iters
|| self.successive_resto_iter >= self.maximum_resto_iters
{
return ConvergenceStatus::MaxIterExceeded;
}
self.successive_resto_iter += 1;
self.inner.check_convergence(nlp_err, iter_count)
}
fn check_convergence_with_state(
&mut self,
nlp_err: pounce_common::types::Number,
iter_count: pounce_common::types::Index,
data: &pounce_algorithm::ipopt_data::IpoptDataHandle,
_cq: &pounce_algorithm::ipopt_cq::IpoptCqHandle,
) -> pounce_algorithm::conv_check::r#trait::ConvergenceStatus {
use pounce_algorithm::conv_check::r#trait::ConvergenceStatus;
if iter_count >= self.maximum_iters
|| self.successive_resto_iter >= self.maximum_resto_iters
{
return ConvergenceStatus::MaxIterExceeded;
}
self.successive_resto_iter += 1;
if iter_count > 0 && self.kappa_resto > 0.0 {
if let Some(orig_rc) = self.orig_nlp.clone() {
if let Some((orig_trial_inf_pr, orig_trial_f)) =
eval_orig_inf_pr_and_f(data, &orig_rc)
{
if std::env::var_os("POUNCE_DBG_RESTO_KAPPA").is_some() {
tracing::debug!(target: "pounce::restoration",
"[PN_RESTO_KAPPA] iter={} orig_trial_inf_pr={:.6e} orig_curr_inf_pr={:.6e} kappa_resto={:.3e} threshold={:.6e} guard_passes={}",
iter_count,
orig_trial_inf_pr,
self.orig_curr_inf_pr,
self.kappa_resto,
self.kappa_resto * self.orig_curr_inf_pr,
orig_trial_inf_pr <= self.kappa_resto * self.orig_curr_inf_pr
);
}
if orig_trial_inf_pr <= self.kappa_resto * self.orig_curr_inf_pr {
let outer_accept = match &self.orig_progress_callback {
Some(cb) => cb(orig_trial_f, orig_trial_inf_pr),
None => true,
};
if outer_accept {
return ConvergenceStatus::Converged;
}
}
}
}
}
self.inner.check_convergence(nlp_err, iter_count)
}
}
fn eval_orig_inf_pr_and_f(
data: &pounce_algorithm::ipopt_data::IpoptDataHandle,
orig_rc: &std::rc::Rc<std::cell::RefCell<dyn pounce_nlp::ipopt_nlp::IpoptNlp>>,
) -> Option<(f64, f64)> {
use pounce_linalg::dense_vector::DenseVectorSpace;
use pounce_linalg::{CompoundVector, Vector};
let curr = data.borrow().curr.clone()?;
let xc = curr.x.as_any().downcast_ref::<CompoundVector>()?;
let x_orig = xc.comp(crate::resto_nlp::BLOCK_X);
let s_inner = &*curr.s;
let mut orig = orig_rc.borrow_mut();
let m_eq = orig.m_eq();
let m_ineq = orig.m_ineq();
let c_amax = if m_eq > 0 {
let mut c_buf = DenseVectorSpace::new(m_eq).make_new_dense();
orig.eval_c(x_orig, &mut c_buf);
c_buf.amax()
} else {
0.0
};
let d_minus_s_amax = if m_ineq > 0 {
let mut d_buf = DenseVectorSpace::new(m_ineq).make_new_dense();
orig.eval_d(x_orig, &mut d_buf);
d_buf.axpy(-1.0, s_inner);
d_buf.amax()
} else {
0.0
};
let f = orig.eval_f(x_orig);
Some((c_amax.max(d_minus_s_amax), f))
}
impl RestoPenaltyConvCheck {
pub fn new() -> Self {
Self {
base: RestoConvCheck::new(),
}
}
}
impl Default for RestoPenaltyConvCheck {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn first_iteration_always_continues() {
let mut cc = RestoConvCheck::new();
let s = cc.check_convergence(0, false, 1.0, 0.5, 1e-8, || true);
assert_eq!(s, RestoConvergenceStatus::Continue);
}
#[test]
fn outer_iter_cap_triggers_max() {
let mut cc = RestoConvCheck::new();
cc.maximum_iters = 5;
let s = cc.check_convergence(6, false, 1.0, 0.5, 1e-8, || true);
assert_eq!(s, RestoConvergenceStatus::MaxIterExceeded);
}
#[test]
fn successive_resto_cap_triggers_max() {
let mut cc = RestoConvCheck::new();
cc.maximum_resto_iters = 2;
cc.check_convergence(0, false, 1.0, 0.9, 1e-8, || false);
cc.check_convergence(1, false, 1.0, 0.8, 1e-8, || false);
cc.check_convergence(2, false, 1.0, 0.7, 1e-8, || false);
let s = cc.check_convergence(3, false, 1.0, 0.6, 1e-8, || false);
assert_eq!(s, RestoConvergenceStatus::MaxIterExceeded);
}
#[test]
fn square_problem_fast_path_converges() {
let mut cc = RestoConvCheck::new();
cc.orig_constr_viol_tol = 1e-4;
cc.check_convergence(0, true, 1.0, 0.5, 1e-8, || false);
let s = cc.check_convergence(1, true, 0.5, 1e-10, 1e-8, || false);
assert_eq!(s, RestoConvergenceStatus::Converged);
}
#[test]
fn insufficient_reduction_keeps_going() {
let mut cc = RestoConvCheck::new();
cc.kappa_resto = 0.9;
cc.check_convergence(0, false, 1.0, 0.95, 1e-8, || true);
let s = cc.check_convergence(1, false, 1.0, 0.95, 1e-8, || true);
assert_eq!(s, RestoConvergenceStatus::Continue);
}
#[test]
fn sufficient_reduction_plus_filter_accept_converges() {
let mut cc = RestoConvCheck::new();
cc.kappa_resto = 0.9;
cc.check_convergence(0, false, 1.0, 0.5, 1e-8, || true);
let s = cc.check_convergence(1, false, 1.0, 0.5, 1e-8, || true);
assert_eq!(s, RestoConvergenceStatus::Converged);
}
#[test]
fn sufficient_reduction_but_filter_rejects_continues() {
let mut cc = RestoConvCheck::new();
cc.kappa_resto = 0.9;
cc.check_convergence(0, false, 1.0, 0.5, 1e-8, || false);
let s = cc.check_convergence(1, false, 1.0, 0.5, 1e-8, || false);
assert_eq!(s, RestoConvergenceStatus::Continue);
}
#[test]
fn kappa_zero_disables_reduction_guard() {
let mut cc = RestoConvCheck::new();
cc.kappa_resto = 0.0;
cc.check_convergence(0, false, 1.0, 1.5, 1e-8, || true);
let s = cc.check_convergence(1, false, 1.0, 1.5, 1e-8, || true);
assert_eq!(s, RestoConvergenceStatus::Converged);
}
#[test]
fn reset_clears_state() {
let mut cc = RestoConvCheck::new();
cc.check_convergence(0, false, 1.0, 0.5, 1e-8, || true);
cc.check_convergence(1, false, 1.0, 0.5, 1e-8, || true);
cc.reset();
assert!(cc.first_resto_iter);
assert_eq!(cc.successive_resto_iter, 0);
}
#[test]
fn test_orig_progress_converges_when_filter_and_iterate_accept() {
use pounce_algorithm::line_search::filter_acceptor::FilterLsAcceptor;
let outer = FilterLsAcceptor::new();
let cc = RestoFilterConvCheck::new();
let s = cc.test_orig_progress(&outer, 0.5, 0.1, 1.0, 1.0);
assert_eq!(s, RestoConvergenceStatus::Converged);
}
#[test]
fn test_orig_progress_continues_when_filter_dominates() {
use pounce_algorithm::line_search::filter_acceptor::FilterLsAcceptor;
let mut outer = FilterLsAcceptor::new();
outer.filter.add(0.05, 0.4, 0);
let cc = RestoFilterConvCheck::new();
let s = cc.test_orig_progress(&outer, 0.5, 0.1, 1.0, 1.0);
assert_eq!(s, RestoConvergenceStatus::Continue);
}
#[test]
fn test_orig_progress_continues_when_iterate_rejects() {
use pounce_algorithm::line_search::filter_acceptor::FilterLsAcceptor;
let outer = FilterLsAcceptor::new();
let cc = RestoFilterConvCheck::new();
let s = cc.test_orig_progress(&outer, 2.0, 1.0, 1.0, 1.0);
assert_eq!(s, RestoConvergenceStatus::Continue);
}
#[test]
fn adapter_converges_at_inner_stationarity_tol() {
use pounce_algorithm::conv_check::r#trait::{ConvCheck, ConvergenceStatus};
let mut a = RestoConvCheckAdapter::new(1e-8, 1e-6, 15, 3000, 3000);
assert_eq!(a.check_convergence(1e-12, 0), ConvergenceStatus::Converged);
}
#[test]
fn adapter_caps_at_maximum_resto_iters() {
use pounce_algorithm::conv_check::r#trait::{ConvCheck, ConvergenceStatus};
let mut a = RestoConvCheckAdapter::new(1e-8, 1e-6, 15, 1000, 2);
assert_eq!(a.check_convergence(1.0, 0), ConvergenceStatus::Continue);
assert_eq!(a.check_convergence(1.0, 1), ConvergenceStatus::Continue);
assert_eq!(
a.check_convergence(1.0, 2),
ConvergenceStatus::MaxIterExceeded
);
}
#[test]
fn adapter_caps_at_outer_max_iter() {
use pounce_algorithm::conv_check::r#trait::{ConvCheck, ConvergenceStatus};
let mut a = RestoConvCheckAdapter::new(1e-8, 1e-6, 15, 5, 3000);
assert_eq!(
a.check_convergence(1.0, 5),
ConvergenceStatus::MaxIterExceeded
);
}
#[test]
fn with_orig_progress_callback_records_callback() {
let cb: pounce_algorithm::restoration::OrigProgressCallback =
Box::new(|_barr: f64, _theta: f64| true);
let a =
RestoConvCheckAdapter::new(1e-8, 1e-6, 15, 3000, 3000).with_orig_progress_callback(cb);
assert!(a.orig_progress_callback.is_some());
}
#[test]
fn with_orig_progress_guard_stores_reference_and_kappa() {
let a = RestoConvCheckAdapter::new(1e-8, 1e-6, 15, 3000, 3000);
assert!(a.orig_nlp.is_none());
assert!(a.orig_curr_inf_pr.is_infinite());
assert_eq!(a.kappa_resto, 0.9);
}
}