use crate::ipopt_cq::IpoptCqHandle;
use crate::ipopt_data::IpoptDataHandle;
use crate::iterates_vector::IteratesVector;
use crate::line_search::filter_acceptor::AcceptDecision;
use crate::line_search::ls_acceptor::BacktrackingLsAcceptor;
use pounce_common::types::Number;
use pounce_common::utils::compare_le;
use pounce_linalg::Vector;
use std::rc::Rc;
pub struct PenaltyLsAcceptor {
pub rho: Number,
pub nu_inc: Number,
pub nu_init: Number,
pub nu_max: Number,
pub eta_penalty: Number,
nu: Number,
last_nu: Number,
cache: Option<RefCache>,
}
struct RefCache {
theta_ref: Number,
barr_ref: Number,
grad_barr_t_delta: Number,
dwd: Number,
c_ref: Rc<dyn Vector>,
d_minus_s_ref: Rc<dyn Vector>,
jac_c_delta: Rc<dyn Vector>,
jac_d_delta_minus_ds: Rc<dyn Vector>,
}
impl Default for PenaltyLsAcceptor {
fn default() -> Self {
Self {
rho: 0.1,
nu_inc: 1e-4,
nu_init: 1e-6,
nu_max: 1e40,
eta_penalty: 1e-8,
nu: 1e-6,
last_nu: 1e-6,
cache: None,
}
}
}
impl PenaltyLsAcceptor {
pub fn new() -> Self {
Self::default()
}
pub fn nu(&self) -> Number {
self.nu
}
pub fn last_nu(&self) -> Number {
self.last_nu
}
pub fn reset(&mut self) {
self.nu = self.nu_init;
self.last_nu = self.nu_init;
self.cache = None;
}
pub fn update_nu(
&mut self,
grad_barr_t_delta: Number,
delta_w_delta: Number,
reference_theta: Number,
) {
self.last_nu = self.nu;
if reference_theta > 0.0 {
let nu_plus =
(grad_barr_t_delta + 0.5 * delta_w_delta) / ((1.0 - self.rho) * reference_theta);
if self.nu < nu_plus {
self.nu = nu_plus + self.nu_inc;
}
}
}
fn calc_pred(&self, alpha: Number) -> Number {
let cache = self
.cache
.as_ref()
.expect("calc_pred called before init_this_line_search");
let mut tmp_c = cache.c_ref.make_new();
tmp_c.set(0.0);
tmp_c.add_two_vectors(1.0, &*cache.c_ref, alpha, &*cache.jac_c_delta, 0.0);
let mut tmp_d = cache.d_minus_s_ref.make_new();
tmp_d.set(0.0);
tmp_d.add_two_vectors(
1.0,
&*cache.d_minus_s_ref,
alpha,
&*cache.jac_d_delta_minus_ds,
0.0,
);
let theta_2 = tmp_c.asum() + tmp_d.asum();
let pred = -alpha * cache.grad_barr_t_delta - 0.5 * alpha * alpha * cache.dwd
+ self.nu * (cache.theta_ref - theta_2);
if pred < 0.0 {
0.0
} else {
pred
}
}
}
impl BacktrackingLsAcceptor for PenaltyLsAcceptor {
fn reset(&mut self) {
PenaltyLsAcceptor::reset(self);
}
fn init_this_line_search(
&mut self,
_data: &IpoptDataHandle,
cq: &IpoptCqHandle,
delta: &IteratesVector,
) {
let cqr = cq.borrow();
let theta_ref = cqr.curr_constraint_violation();
let barr_ref = cqr.curr_barrier_obj();
let grad_barr_t_delta = cqr.curr_grad_barr_t_delta(&*delta.x, &*delta.s);
let dwd = cqr.curr_dwd(&*delta.x, &*delta.s);
let c_ref = cqr.curr_c();
let d_minus_s_ref = cqr.curr_d_minus_s();
let jac_c_delta = cqr.curr_jac_c_times_vec(&*delta.x);
let jac_d_delta = cqr.curr_jac_d_times_vec(&*delta.x);
let mut tmp = jac_d_delta.make_new();
tmp.set(0.0);
tmp.add_two_vectors(1.0, &*jac_d_delta, -1.0, &*delta.s, 0.0);
let jac_d_delta_minus_ds: Rc<dyn Vector> = Rc::from(tmp);
drop(cqr);
self.cache = Some(RefCache {
theta_ref,
barr_ref,
grad_barr_t_delta,
dwd,
c_ref,
d_minus_s_ref,
jac_c_delta,
jac_d_delta_minus_ds,
});
self.update_nu(grad_barr_t_delta, dwd, theta_ref);
}
fn check_trial_point(
&mut self,
alpha_primal: Number,
_theta: Number,
_phi: Number,
_d_phi: Number,
theta_trial: Number,
phi_trial: Number,
) -> AcceptDecision {
let cache = match &self.cache {
Some(c) => c,
None => return AcceptDecision::Accept,
};
let pred = self.calc_pred(alpha_primal);
let ref_merit = cache.barr_ref + self.nu * cache.theta_ref;
let ared = ref_merit - (phi_trial + self.nu * theta_trial);
if compare_le(self.eta_penalty * pred, ared, ref_merit.abs()) {
AcceptDecision::Accept
} else {
AcceptDecision::Reject
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_bump_when_theta_zero() {
let mut a = PenaltyLsAcceptor::new();
let nu0 = a.nu();
a.update_nu(10.0, 5.0, 0.0);
assert_eq!(a.nu(), nu0);
assert_eq!(a.last_nu(), nu0);
}
#[test]
fn bump_when_nu_plus_exceeds_current() {
let mut a = PenaltyLsAcceptor {
rho: 0.1,
nu_inc: 1e-4,
nu: 0.0,
last_nu: 0.0,
..Default::default()
};
a.update_nu(1.0, 0.0, 1.0);
assert!(a.last_nu() == 0.0);
let expected = 1.0 / 0.9 + 1e-4;
assert!((a.nu() - expected).abs() < 1e-12);
}
#[test]
fn no_bump_when_already_above_nu_plus() {
let mut a = PenaltyLsAcceptor {
rho: 0.1,
nu_inc: 1e-4,
nu: 1e6,
last_nu: 1e6,
..Default::default()
};
a.update_nu(1.0, 0.0, 1.0);
assert_eq!(a.nu(), 1e6);
}
#[test]
fn reset_restores_init() {
let mut a = PenaltyLsAcceptor::new();
a.update_nu(10.0, 0.0, 1.0); let bumped = a.nu();
assert!(bumped > a.nu_init);
PenaltyLsAcceptor::reset(&mut a);
assert_eq!(a.nu(), a.nu_init);
}
#[test]
fn check_trial_point_without_cache_accepts() {
let mut a = PenaltyLsAcceptor::new();
assert_eq!(
a.check_trial_point(1.0, 1.0, 10.0, -1.0, 0.5, 8.0),
AcceptDecision::Accept
);
}
fn cache_for_test(
theta_ref: Number,
barr_ref: Number,
grad_barr_t_delta: Number,
dwd: Number,
c_ref: Vec<Number>,
d_minus_s_ref: Vec<Number>,
jac_c_delta: Vec<Number>,
jac_d_delta_minus_ds: Vec<Number>,
) -> RefCache {
use pounce_linalg::dense_vector::DenseVectorSpace;
use pounce_linalg::Vector;
let mkr = |v: Vec<Number>| -> Rc<dyn Vector> {
let mut x = DenseVectorSpace::new(v.len() as i32).make_new_dense();
x.values_mut().copy_from_slice(&v);
Rc::new(x)
};
RefCache {
theta_ref,
barr_ref,
grad_barr_t_delta,
dwd,
c_ref: mkr(c_ref),
d_minus_s_ref: mkr(d_minus_s_ref),
jac_c_delta: mkr(jac_c_delta),
jac_d_delta_minus_ds: mkr(jac_d_delta_minus_ds),
}
}
#[test]
fn calc_pred_matches_closed_form() {
let mut a = PenaltyLsAcceptor::new();
a.nu = 0.5;
a.cache = Some(cache_for_test(
3.0,
0.0,
2.0,
4.0,
vec![1.0, 2.0],
vec![4.0],
vec![-1.0, -1.0],
vec![-2.0],
));
assert!((a.calc_pred(0.5) - 0.0).abs() < 1e-12);
}
#[test]
fn calc_pred_positive_when_directions_align() {
let mut a = PenaltyLsAcceptor::new();
a.nu = 1.0;
a.cache = Some(cache_for_test(
3.0,
0.0,
-2.0,
0.0,
vec![1.0, 2.0],
vec![0.0],
vec![-1.0, -2.0],
vec![0.0],
));
assert!((a.calc_pred(1.0) - 5.0).abs() < 1e-12);
}
#[test]
fn check_trial_point_accepts_when_ared_meets_pred() {
let mut a = PenaltyLsAcceptor::new();
a.nu = 1.0;
a.eta_penalty = 0.5;
a.cache = Some(cache_for_test(
3.0,
0.0,
-2.0,
0.0,
vec![1.0, 2.0],
vec![0.0],
vec![-1.0, -2.0],
vec![0.0],
));
assert_eq!(
a.check_trial_point(1.0, 3.0, 0.0, -2.0, 0.0, -3.0),
AcceptDecision::Accept
);
}
#[test]
fn check_trial_point_rejects_insufficient_decrease() {
let mut a = PenaltyLsAcceptor::new();
a.nu = 1.0;
a.eta_penalty = 0.5;
a.cache = Some(cache_for_test(
3.0,
0.0,
-2.0,
0.0,
vec![1.0, 2.0],
vec![0.0],
vec![-1.0, -2.0],
vec![0.0],
));
assert_eq!(
a.check_trial_point(1.0, 3.0, 0.0, -2.0, 2.999, 0.0),
AcceptDecision::Reject
);
}
}