use std::cell::RefCell;
use std::rc::Rc;
use pounce_algorithm::ipopt_cq::IpoptCqHandle;
use pounce_algorithm::ipopt_data::IpoptDataHandle;
use pounce_algorithm::iterates_vector::{IteratesVector, IteratesVectorMut};
use pounce_algorithm::kkt::pd_full_space_solver::PdFullSpaceSolver;
use pounce_common::types::Number;
use pounce_linalg::dense_vector::DenseVector;
use pounce_nlp::ipopt_nlp::IpoptNlp;
use crate::backsolver::SensBacksolver;
#[derive(Clone)]
pub struct PdSensBacksolver {
pd: Rc<RefCell<PdFullSpaceSolver>>,
data: IpoptDataHandle,
cq: IpoptCqHandle,
nlp: Rc<RefCell<dyn IpoptNlp>>,
dims: [usize; 8],
template: IteratesVector,
}
impl PdSensBacksolver {
pub fn new(
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
pd: Rc<RefCell<PdFullSpaceSolver>>,
) -> Result<Self, ()> {
let curr = data.borrow().curr.clone().ok_or(())?;
let dims = [
curr.x.dim() as usize,
curr.s.dim() as usize,
curr.y_c.dim() as usize,
curr.y_d.dim() as usize,
curr.z_l.dim() as usize,
curr.z_u.dim() as usize,
curr.v_l.dim() as usize,
curr.v_u.dim() as usize,
];
Ok(Self {
pd,
data: Rc::clone(data),
cq: Rc::clone(cq),
nlp: Rc::clone(nlp),
dims,
template: curr,
})
}
pub fn block_dims(&self) -> [usize; 8] {
self.dims
}
fn offsets(&self) -> [usize; 9] {
let mut o = [0usize; 9];
for i in 0..8 {
o[i + 1] = o[i] + self.dims[i];
}
o
}
fn pack(&self, flat: &[Number]) -> Result<IteratesVectorMut, ()> {
let mut out = self.template.make_new_zeroed();
let off = self.offsets();
let blocks: [&mut Box<dyn pounce_linalg::vector::Vector>; 8] = [
&mut out.x,
&mut out.s,
&mut out.y_c,
&mut out.y_d,
&mut out.z_l,
&mut out.z_u,
&mut out.v_l,
&mut out.v_u,
];
for (i, blk) in blocks.into_iter().enumerate() {
let slice = &flat[off[i]..off[i + 1]];
let dv = blk.as_any_mut().downcast_mut::<DenseVector>().ok_or(())?;
dv.set_values(slice);
}
Ok(out)
}
fn unpack(&self, iv: &IteratesVectorMut, out: &mut [Number]) -> Result<(), ()> {
let off = self.offsets();
let blocks: [&Box<dyn pounce_linalg::vector::Vector>; 8] = [
&iv.x, &iv.s, &iv.y_c, &iv.y_d, &iv.z_l, &iv.z_u, &iv.v_l, &iv.v_u,
];
for (i, blk) in blocks.into_iter().enumerate() {
let dst = &mut out[off[i]..off[i + 1]];
if dst.is_empty() {
continue;
}
let dv = (**blk).as_any().downcast_ref::<DenseVector>().ok_or(())?;
let ev = dv.expanded_values();
dst.copy_from_slice(&ev);
}
Ok(())
}
}
impl PdSensBacksolver {
pub fn solve_many(&self, rhs_flat: &[Number], lhs_flat: &mut [Number], n_rhs: usize) -> bool {
let total = self.dim();
if rhs_flat.len() != n_rhs * total || lhs_flat.len() != n_rhs * total {
return false;
}
if n_rhs == 0 {
return true;
}
let off = self.offsets();
{
let mut pd_ref = self.pd.borrow_mut();
let fast_flat = pd_ref.solve_many_cached_flat(
&self.data, &self.cq, &self.nlp, n_rhs, rhs_flat, lhs_flat, self.dims,
);
match fast_flat {
Some(true) => return true,
Some(false) => return false,
None => { }
}
}
{
let mut pd_ref = self.pd.borrow_mut();
let fast = pd_ref.solve_many_cached(
&self.data,
&self.cq,
&self.nlp,
n_rhs,
|k, iv| {
let row = &rhs_flat[k * total..(k + 1) * total];
let _ = write_rhs_box(&mut iv.x, &row[off[0]..off[1]])
&& write_rhs_box(&mut iv.s, &row[off[1]..off[2]])
&& write_rhs_box(&mut iv.y_c, &row[off[2]..off[3]])
&& write_rhs_box(&mut iv.y_d, &row[off[3]..off[4]])
&& write_rhs_box(&mut iv.z_l, &row[off[4]..off[5]])
&& write_rhs_box(&mut iv.z_u, &row[off[5]..off[6]])
&& write_rhs_box(&mut iv.v_l, &row[off[6]..off[7]])
&& write_rhs_box(&mut iv.v_u, &row[off[7]..off[8]]);
},
|k, iv| {
let row = &mut lhs_flat[k * total..(k + 1) * total];
let _ = read_res_block(&*iv.x, &mut row[off[0]..off[1]])
&& read_res_block(&*iv.s, &mut row[off[1]..off[2]])
&& read_res_block(&*iv.y_c, &mut row[off[2]..off[3]])
&& read_res_block(&*iv.y_d, &mut row[off[3]..off[4]])
&& read_res_block(&*iv.z_l, &mut row[off[4]..off[5]])
&& read_res_block(&*iv.z_u, &mut row[off[5]..off[6]])
&& read_res_block(&*iv.v_l, &mut row[off[6]..off[7]])
&& read_res_block(&*iv.v_u, &mut row[off[7]..off[8]]);
},
);
match fast {
Some(true) => return true,
Some(false) => return false,
None => { }
}
}
let rhs_mut0 = self.template.make_new_zeroed();
let mut rhs_iv = rhs_mut0.freeze();
let mut res_iv = self.template.make_new_zeroed();
let mut pd_ref = self.pd.borrow_mut();
for k in 0..n_rhs {
let rhs_row = &rhs_flat[k * total..(k + 1) * total];
let lhs_row = &mut lhs_flat[k * total..(k + 1) * total];
if !write_rhs_block(&mut rhs_iv.x, &rhs_row[off[0]..off[1]])
|| !write_rhs_block(&mut rhs_iv.s, &rhs_row[off[1]..off[2]])
|| !write_rhs_block(&mut rhs_iv.y_c, &rhs_row[off[2]..off[3]])
|| !write_rhs_block(&mut rhs_iv.y_d, &rhs_row[off[3]..off[4]])
|| !write_rhs_block(&mut rhs_iv.z_l, &rhs_row[off[4]..off[5]])
|| !write_rhs_block(&mut rhs_iv.z_u, &rhs_row[off[5]..off[6]])
|| !write_rhs_block(&mut rhs_iv.v_l, &rhs_row[off[6]..off[7]])
|| !write_rhs_block(&mut rhs_iv.v_u, &rhs_row[off[7]..off[8]])
{
return false;
}
let ok = pd_ref.solve(
&self.data,
&self.cq,
&self.nlp,
1.0,
0.0,
&rhs_iv,
&mut res_iv,
true,
false,
);
if !ok {
return false;
}
if !read_res_block(&*res_iv.x, &mut lhs_row[off[0]..off[1]])
|| !read_res_block(&*res_iv.s, &mut lhs_row[off[1]..off[2]])
|| !read_res_block(&*res_iv.y_c, &mut lhs_row[off[2]..off[3]])
|| !read_res_block(&*res_iv.y_d, &mut lhs_row[off[3]..off[4]])
|| !read_res_block(&*res_iv.z_l, &mut lhs_row[off[4]..off[5]])
|| !read_res_block(&*res_iv.z_u, &mut lhs_row[off[5]..off[6]])
|| !read_res_block(&*res_iv.v_l, &mut lhs_row[off[6]..off[7]])
|| !read_res_block(&*res_iv.v_u, &mut lhs_row[off[7]..off[8]])
{
return false;
}
}
true
}
}
fn write_rhs_box(b: &mut Box<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
if slice.is_empty() {
return true;
}
let Some(dv) = b.as_any_mut().downcast_mut::<DenseVector>() else {
return false;
};
dv.set_values(slice);
true
}
fn write_rhs_block(rc: &mut Rc<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
if slice.is_empty() {
return true;
}
let Some(v) = Rc::get_mut(rc) else {
return false;
};
let Some(dv) = v.as_any_mut().downcast_mut::<DenseVector>() else {
return false;
};
dv.set_values(slice);
true
}
fn read_res_block(blk: &dyn pounce_linalg::vector::Vector, dst: &mut [Number]) -> bool {
if dst.is_empty() {
return true;
}
let Some(dv) = blk.as_any().downcast_ref::<DenseVector>() else {
return false;
};
if dv.is_homogeneous() {
let s = dv.scalar();
for x in dst.iter_mut() {
*x = s;
}
} else {
dst.copy_from_slice(dv.values());
}
true
}
impl SensBacksolver for PdSensBacksolver {
fn dim(&self) -> usize {
self.dims.iter().sum()
}
fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
let total = self.dim();
if rhs.len() != total || lhs.len() != total {
return false;
}
let rhs_mut = match self.pack(rhs) {
Ok(v) => v,
Err(()) => return false,
};
let rhs_iv = rhs_mut.freeze();
let mut res_iv = self.template.make_new_zeroed();
let ok = {
let mut pd_ref = self.pd.borrow_mut();
pd_ref.solve(
&self.data,
&self.cq,
&self.nlp,
1.0,
0.0,
&rhs_iv,
&mut res_iv,
true,
false,
)
};
if !ok {
return false;
}
self.unpack(&res_iv, lhs).is_ok()
}
}