use std::cell::RefCell;
use std::rc::Rc;
use pounce_algorithm::ipopt_cq::IpoptCqHandle;
use pounce_algorithm::ipopt_data::IpoptDataHandle;
use pounce_algorithm::iterates_vector::{IteratesVector, IteratesVectorMut};
use pounce_algorithm::kkt::pd_full_space_solver::PdFullSpaceSolver;
use pounce_common::types::Number;
use pounce_linalg::dense_vector::DenseVector;
use pounce_nlp::ipopt_nlp::IpoptNlp;
use crate::backsolver::SensBacksolver;
#[derive(Clone)]
pub struct PdSensBacksolver {
pd: Rc<RefCell<PdFullSpaceSolver>>,
data: IpoptDataHandle,
cq: IpoptCqHandle,
nlp: Rc<RefCell<dyn IpoptNlp>>,
dims: [usize; 8],
template: IteratesVector,
}
impl PdSensBacksolver {
pub fn new(
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
pd: Rc<RefCell<PdFullSpaceSolver>>,
) -> Result<Self, ()> {
let curr = data.borrow().curr.clone().ok_or(())?;
let dims = [
curr.x.dim() as usize,
curr.s.dim() as usize,
curr.y_c.dim() as usize,
curr.y_d.dim() as usize,
curr.z_l.dim() as usize,
curr.z_u.dim() as usize,
curr.v_l.dim() as usize,
curr.v_u.dim() as usize,
];
Ok(Self {
pd,
data: Rc::clone(data),
cq: Rc::clone(cq),
nlp: Rc::clone(nlp),
dims,
template: curr,
})
}
pub fn block_dims(&self) -> [usize; 8] {
self.dims
}
fn offsets(&self) -> [usize; 9] {
let mut o = [0usize; 9];
for i in 0..8 {
o[i + 1] = o[i] + self.dims[i];
}
o
}
fn pack(&self, flat: &[Number]) -> Result<IteratesVectorMut, ()> {
let mut out = self.template.make_new_zeroed();
let off = self.offsets();
let blocks: [&mut Box<dyn pounce_linalg::vector::Vector>; 8] = [
&mut out.x,
&mut out.s,
&mut out.y_c,
&mut out.y_d,
&mut out.z_l,
&mut out.z_u,
&mut out.v_l,
&mut out.v_u,
];
for (i, blk) in blocks.into_iter().enumerate() {
let slice = &flat[off[i]..off[i + 1]];
let dv = blk.as_any_mut().downcast_mut::<DenseVector>().ok_or(())?;
dv.set_values(slice);
}
Ok(out)
}
fn unpack(&self, iv: &IteratesVectorMut, out: &mut [Number]) -> Result<(), ()> {
let off = self.offsets();
let blocks: [&Box<dyn pounce_linalg::vector::Vector>; 8] = [
&iv.x, &iv.s, &iv.y_c, &iv.y_d, &iv.z_l, &iv.z_u, &iv.v_l, &iv.v_u,
];
for (i, blk) in blocks.into_iter().enumerate() {
let dst = &mut out[off[i]..off[i + 1]];
if dst.is_empty() {
continue;
}
let dv = (**blk).as_any().downcast_ref::<DenseVector>().ok_or(())?;
let ev = dv.expanded_values();
dst.copy_from_slice(&ev);
}
Ok(())
}
}
impl SensBacksolver for PdSensBacksolver {
fn dim(&self) -> usize {
self.dims.iter().sum()
}
fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
let total = self.dim();
if rhs.len() != total || lhs.len() != total {
return false;
}
let rhs_mut = match self.pack(rhs) {
Ok(v) => v,
Err(()) => return false,
};
let rhs_iv = rhs_mut.freeze();
let mut res_iv = self.template.make_new_zeroed();
let ok = {
let mut pd_ref = self.pd.borrow_mut();
pd_ref.solve(
&self.data,
&self.cq,
&self.nlp,
1.0,
0.0,
&rhs_iv,
&mut res_iv,
false,
false,
)
};
if !ok {
return false;
}
self.unpack(&res_iv, lhs).is_ok()
}
}