use crate::ipopt_cq::IpoptCqHandle;
use crate::ipopt_data::IpoptDataHandle;
use crate::ipopt_nlp::IpoptNlp;
use crate::iterates_vector::{IteratesVector, IteratesVectorMut};
use crate::kkt::pd_full_space_solver::PdFullSpaceSolver;
use crate::kkt::search_dir_calc::SearchDirCalculator;
use pounce_common::types::Number;
use std::cell::{RefCell, RefMut};
use std::rc::Rc;
pub struct PdSearchDirCalc {
pd_solver: Rc<RefCell<PdFullSpaceSolver>>,
pub fast_step_computation: bool,
pub mehrotra_algorithm: bool,
}
impl PdSearchDirCalc {
pub fn new(pd_solver: PdFullSpaceSolver) -> Self {
Self {
pd_solver: Rc::new(RefCell::new(pd_solver)),
fast_step_computation: false,
mehrotra_algorithm: false,
}
}
pub fn pd_solver_rc(&self) -> Rc<RefCell<PdFullSpaceSolver>> {
Rc::clone(&self.pd_solver)
}
pub fn pd_solver_mut(&self) -> RefMut<'_, PdFullSpaceSolver> {
self.pd_solver.borrow_mut()
}
pub fn compute_search_direction(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
) -> bool {
let improve_solution = data.borrow().delta.is_some();
if improve_solution && self.fast_step_computation {
return true;
}
let curr = {
let d = data.borrow();
d.curr
.clone()
.unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
};
let mut rhs = curr.make_new_zeroed();
{
let cq_ref = cq.borrow();
rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
rhs.y_c.copy(&*cq_ref.curr_c());
rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
}
let nbounds = {
let n = nlp.borrow();
n.x_l().dim() + n.x_u().dim() + n.d_l().dim() + n.d_u().dim()
};
if nbounds > 0 && self.mehrotra_algorithm {
let delta_aff = {
let d = data.borrow();
d.delta_aff
.clone()
.unwrap_or_else(|| panic!("PdSearchDirCalc: delta_aff missing for Mehrotra"))
};
self.fill_mehrotra_z_blocks(&delta_aff, cq, nlp, &mut rhs);
} else {
let cq_ref = cq.borrow();
rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
}
let frozen_rhs = rhs.freeze();
let mut delta = frozen_rhs.make_new_zeroed();
if improve_solution {
let prev = {
let d = data.borrow();
let Some(p) = d.delta.clone() else {
unreachable!("PdSearchDirCalc: delta cleared between is_some() and clone()")
};
p
};
delta.add_one_vector(-1.0, &prev, 0.0);
}
let allow_inexact = self.fast_step_computation;
let ok = self.pd_solver.borrow_mut().solve(
data,
cq,
nlp,
-1.0,
0.0,
&frozen_rhs,
&mut delta,
allow_inexact,
improve_solution,
);
if ok {
data.borrow_mut().set_delta(delta.freeze());
}
ok
}
pub fn compute_affine_step(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
) -> bool {
let curr = {
let d = data.borrow();
d.curr
.clone()
.unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
};
let mut rhs = curr.make_new_zeroed();
{
let cq_ref = cq.borrow();
rhs.x.copy(&*cq_ref.curr_grad_lag_x());
rhs.s.copy(&*cq_ref.curr_grad_lag_s());
rhs.y_c.copy(&*cq_ref.curr_c());
rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
rhs.z_l.copy(&*cq_ref.curr_compl_x_l());
rhs.z_u.copy(&*cq_ref.curr_compl_x_u());
rhs.v_l.copy(&*cq_ref.curr_compl_s_l());
rhs.v_u.copy(&*cq_ref.curr_compl_s_u());
}
let frozen_rhs = rhs.freeze();
let mut delta_aff = frozen_rhs.make_new_zeroed();
let ok = self.pd_solver.borrow_mut().solve(
data,
cq,
nlp,
-1.0,
0.0,
&frozen_rhs,
&mut delta_aff,
true,
false,
);
if ok {
data.borrow_mut().set_delta_aff(delta_aff.freeze());
}
ok
}
pub fn compute_centering_step(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
) -> bool {
let curr = {
let d = data.borrow();
d.curr
.clone()
.unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
};
let avrg_compl = cq.borrow().curr_avrg_compl();
let mut rhs = curr.make_new_zeroed();
{
let cq_ref = cq.borrow();
rhs.x
.add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_x(), 0.0);
rhs.s
.add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_s(), 0.0);
}
rhs.y_c.set(0.0);
rhs.y_d.set(0.0);
rhs.z_l.set(avrg_compl);
rhs.z_u.set(avrg_compl);
rhs.v_l.set(avrg_compl);
rhs.v_u.set(avrg_compl);
let frozen_rhs = rhs.freeze();
let mut delta_cen = frozen_rhs.make_new_zeroed();
let ok = self.pd_solver.borrow_mut().solve(
data,
cq,
nlp,
1.0,
0.0,
&frozen_rhs,
&mut delta_cen,
true,
false,
);
if ok {
data.borrow_mut().set_delta_cen(delta_cen.freeze());
}
ok
}
pub fn compute_soc_step(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
c_soc: &dyn pounce_linalg::Vector,
dms_soc: &dyn pounce_linalg::Vector,
alpha_primal_soc: Number,
soc_method: i32,
) -> Option<IteratesVector> {
let curr = {
let d = data.borrow();
d.curr
.clone()
.unwrap_or_else(|| panic!("PdSearchDirCalc::compute_soc_step: curr is unset"))
};
let mut rhs = curr.make_new_zeroed();
{
let cq_ref = cq.borrow();
rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
if soc_method == 1 {
rhs.x.scal(alpha_primal_soc);
rhs.s.scal(alpha_primal_soc);
}
rhs.y_c.copy(c_soc);
rhs.y_d.copy(dms_soc);
rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
}
let frozen_rhs = rhs.freeze();
let mut delta_soc = frozen_rhs.make_new_zeroed();
let ok = self.pd_solver.borrow_mut().solve(
data,
cq,
nlp,
-1.0,
0.0,
&frozen_rhs,
&mut delta_soc,
false,
false,
);
if ok {
Some(delta_soc.freeze())
} else {
None
}
}
fn fill_mehrotra_z_blocks(
&self,
delta_aff: &IteratesVector,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
rhs: &mut IteratesVectorMut,
) {
let n = nlp.borrow();
let cq_ref = cq.borrow();
n.px_l()
.trans_mult_vector(1.0, &*delta_aff.x, 0.0, &mut *rhs.z_l);
rhs.z_l.element_wise_multiply(&*delta_aff.z_l);
rhs.z_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_l());
n.px_u()
.trans_mult_vector(-1.0, &*delta_aff.x, 0.0, &mut *rhs.z_u);
rhs.z_u.element_wise_multiply(&*delta_aff.z_u);
rhs.z_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_u());
n.pd_l()
.trans_mult_vector(1.0, &*delta_aff.s, 0.0, &mut *rhs.v_l);
rhs.v_l.element_wise_multiply(&*delta_aff.v_l);
rhs.v_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_l());
n.pd_u()
.trans_mult_vector(-1.0, &*delta_aff.s, 0.0, &mut *rhs.v_u);
rhs.v_u.element_wise_multiply(&*delta_aff.v_u);
rhs.v_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_u());
}
}
impl SearchDirCalculator for PdSearchDirCalc {}
pub fn mehrotra_corrector_lower(
delta_aff_x_lo: Number,
delta_aff_z: Number,
relaxed_compl: Number,
) -> Number {
delta_aff_x_lo * delta_aff_z + relaxed_compl
}
pub fn mehrotra_corrector_upper(
delta_aff_x_up: Number,
delta_aff_z: Number,
relaxed_compl: Number,
) -> Number {
-delta_aff_x_up * delta_aff_z + relaxed_compl
}
pub fn relaxed_complementarity(x: Number, z: Number, mu: Number) -> Number {
x * z - mu
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn relaxed_compl_at_central_path_is_zero() {
assert_eq!(relaxed_complementarity(2.0, 0.5, 1.0), 0.0);
}
#[test]
fn mehrotra_lower_combines_linearly() {
assert_eq!(mehrotra_corrector_lower(1.0, 2.0, 0.5), 2.5);
}
#[test]
fn mehrotra_upper_negates_dx() {
assert_eq!(mehrotra_corrector_upper(1.0, 2.0, 0.5), -1.5);
}
}