use crate::r#trait::{RestorationOutcome, RestorationPhase};
use crate::resto_nlp::{BLOCK_N_C, BLOCK_N_D, BLOCK_P_C, BLOCK_P_D, BLOCK_X};
use pounce_algorithm::ipopt_cq::IpoptCqHandle;
use pounce_algorithm::ipopt_data::IpoptDataHandle;
use pounce_algorithm::ipopt_nlp::IpoptNlp;
use pounce_algorithm::iterates_vector::IteratesVector;
use pounce_algorithm::kkt::aug_system_solver::AugSystemSolver;
use pounce_common::types::Number;
use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
use pounce_linalg::{CompoundVector, CompoundVectorSpace, Vector};
use std::cell::RefCell;
use std::rc::Rc;
pub struct RestoRestorationPhase {
rho: Number,
orig_nlp: Option<Rc<RefCell<dyn IpoptNlp>>>,
}
impl RestoRestorationPhase {
pub fn new(rho: Number) -> Self {
Self {
rho,
orig_nlp: None,
}
}
pub fn with_orig_nlp(mut self, orig: Rc<RefCell<dyn IpoptNlp>>) -> Self {
self.orig_nlp = Some(orig);
self
}
pub fn set_orig_nlp(&mut self, orig: Rc<RefCell<dyn IpoptNlp>>) {
self.orig_nlp = Some(orig);
}
}
impl Default for RestoRestorationPhase {
fn default() -> Self {
Self::new(1000.0)
}
}
impl RestorationPhase for RestoRestorationPhase {
fn perform_restoration(
&mut self,
data: &IpoptDataHandle,
_cq: &IpoptCqHandle,
_nlp: &Rc<RefCell<dyn IpoptNlp>>,
_aug_solver: &mut dyn AugSystemSolver,
) -> RestorationOutcome {
let orig_nlp = match self.orig_nlp.as_ref() {
Some(o) => o.clone(),
None => return RestorationOutcome::Failed,
};
let mu = data.borrow().curr_mu;
let curr = match data.borrow().curr.clone() {
Some(c) => c,
None => return RestorationOutcome::Failed,
};
let curr_x_cv = match curr.x.as_any().downcast_ref::<CompoundVector>() {
Some(c) if c.n_comps() == 5 => c,
_ => return RestorationOutcome::Failed,
};
let x_orig = curr_x_cv.comp(BLOCK_X);
let n_orig = x_orig.dim();
let m_eq = curr_x_cv.comp(BLOCK_N_C).dim();
let m_ineq = curr_x_cv.comp(BLOCK_N_D).dim();
let mut c_buf = DenseVectorSpace::new(m_eq).make_new_dense();
if m_eq > 0 {
c_buf.values_mut().fill(0.0);
orig_nlp.borrow_mut().eval_c(x_orig, &mut c_buf);
}
let mut d_buf = DenseVectorSpace::new(m_ineq).make_new_dense();
if m_ineq > 0 {
d_buf.values_mut().fill(0.0);
orig_nlp.borrow_mut().eval_d(x_orig, &mut d_buf);
let s_vals = expanded_dense_values(&*curr.s, m_ineq);
for (i, v) in d_buf.values_mut().iter_mut().enumerate() {
*v -= s_vals[i];
}
}
let (n_c_vals, p_c_vals) = compute_n_p(&c_buf, mu, self.rho, m_eq);
let (n_d_vals, p_d_vals) = compute_n_p(&d_buf, mu, self.rho, m_ineq);
let new_x = build_new_x(
n_orig, m_eq, m_ineq, x_orig, &n_c_vals, &p_c_vals, &n_d_vals, &p_d_vals,
);
let trial = IteratesVector::new(
Rc::new(new_x),
curr.s.clone(),
curr.y_c.clone(),
curr.y_d.clone(),
curr.z_l.clone(),
curr.z_u.clone(),
curr.v_l.clone(),
curr.v_u.clone(),
);
data.borrow_mut().set_trial(trial);
RestorationOutcome::Recovered
}
}
fn compute_n_p(c: &DenseVector, mu: Number, rho: Number, m: i32) -> (Vec<f64>, Vec<f64>) {
let m = m as usize;
if m == 0 {
return (Vec::new(), Vec::new());
}
let half = mu / (2.0 * rho);
let cvals = c.expanded_values();
let mut n = vec![0.0; m];
let mut p = vec![0.0; m];
for i in 0..m {
let a = half - 0.5 * cvals[i];
let b = cvals[i] * half;
let radicand = (a * a + b).max(0.0);
let v = a + radicand.sqrt();
n[i] = v;
p[i] = cvals[i] + v;
}
(n, p)
}
fn build_new_x(
n_orig: i32,
m_eq: i32,
m_ineq: i32,
x_orig: &dyn Vector,
n_c: &[f64],
p_c: &[f64],
n_d: &[f64],
p_d: &[f64],
) -> CompoundVector {
let total = n_orig + 2 * m_eq + 2 * m_ineq;
let space = CompoundVectorSpace::new(5, total);
let s_n = DenseVectorSpace::new(n_orig);
space.set_comp(BLOCK_X, n_orig, {
let s = Rc::clone(&s_n);
move || Box::new(DenseVector::new(Rc::clone(&s)))
});
let s_eq = DenseVectorSpace::new(m_eq);
for i in [BLOCK_N_C, BLOCK_P_C] {
space.set_comp(i, m_eq, {
let s = Rc::clone(&s_eq);
move || Box::new(DenseVector::new(Rc::clone(&s)))
});
}
let s_ineq = DenseVectorSpace::new(m_ineq);
for i in [BLOCK_N_D, BLOCK_P_D] {
space.set_comp(i, m_ineq, {
let s = Rc::clone(&s_ineq);
move || Box::new(DenseVector::new(Rc::clone(&s)))
});
}
let mut cv = CompoundVector::new(space);
let x_orig_vals = expanded_dense_values(x_orig, n_orig);
set_block(&mut cv, BLOCK_X, &x_orig_vals);
set_block(&mut cv, BLOCK_N_C, n_c);
set_block(&mut cv, BLOCK_P_C, p_c);
set_block(&mut cv, BLOCK_N_D, n_d);
set_block(&mut cv, BLOCK_P_D, p_d);
cv
}
fn set_block(cv: &mut CompoundVector, idx: i32, vals: &[f64]) {
let comp = cv.comp_mut(idx);
let dense = comp
.as_any_mut()
.downcast_mut::<DenseVector>()
.expect("RestoRestorationPhase: compound block must be DenseVector");
if !vals.is_empty() {
dense.set_values(vals);
}
}
fn expanded_dense_values(v: &dyn Vector, fallback_dim: i32) -> Vec<f64> {
v.as_any()
.downcast_ref::<DenseVector>()
.map(|d| d.expanded_values())
.unwrap_or_else(|| vec![0.0; fallback_dim as usize])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn solve_quadratic_matches_upstream_formula() {
let s = DenseVectorSpace::new(2);
let mut c = s.make_new_dense();
c.set_values(&[0.0, 0.0]);
let (n, p) = compute_n_p(&c, 1.0, 1000.0, 2);
let half = 1.0 / 2000.0;
for i in 0..2 {
assert!((n[i] - 2.0 * half).abs() < 1e-15);
assert!((p[i] - 2.0 * half).abs() < 1e-15);
}
}
#[test]
fn solve_quadratic_satisfies_feasibility_identity() {
let s = DenseVectorSpace::new(4);
let mut c = s.make_new_dense();
c.set_values(&[1.0, -2.0, 0.5, -0.1]);
let (n, p) = compute_n_p(&c, 0.1, 1000.0, 4);
let cvals = c.expanded_values();
for i in 0..4 {
assert!((p[i] - n[i] - cvals[i]).abs() < 1e-12);
assert!(n[i] >= 0.0, "n must be non-negative");
assert!(p[i] >= 0.0, "p must be non-negative");
}
}
#[test]
fn quadratic_root_satisfies_v2_plus_2av_minus_b_zero() {
let s = DenseVectorSpace::new(3);
let mut c = s.make_new_dense();
c.set_values(&[3.0, -1.0, 0.7]);
let mu = 0.5;
let rho = 1000.0;
let (n, _p) = compute_n_p(&c, mu, rho, 3);
let half = mu / (2.0 * rho);
let cvals = c.expanded_values();
for i in 0..3 {
let a = half - 0.5 * cvals[i];
let b = cvals[i] * half;
let v = n[i];
let residual = v * v - 2.0 * a * v - b;
assert!(
residual.abs() < 1e-10 * (1.0 + a.abs() + b.abs()),
"residual={residual}, v={v}, a={a}, b={b}"
);
}
}
#[test]
fn perform_restoration_with_no_orig_nlp_returns_failed() {
let mut p = RestoRestorationPhase::new(1000.0);
assert!(p.orig_nlp.is_none());
let _ = &mut p; }
#[test]
fn default_uses_upstream_rho_default() {
let p = RestoRestorationPhase::default();
assert_eq!(p.rho, 1000.0);
}
}