use std::cell::RefCell;
use std::rc::Rc;
use pounce_algorithm::ipopt_cq::IpoptCqHandle;
use pounce_algorithm::ipopt_data::IpoptDataHandle;
use pounce_algorithm::iterates_vector::{IteratesVector, IteratesVectorMut};
use pounce_algorithm::kkt::pd_full_space_solver::PdFullSpaceSolver;
use pounce_common::types::{Index, Number};
use pounce_linalg::dense_vector::DenseVector;
use pounce_nlp::ipopt_nlp::IpoptNlp;
use crate::backsolver::SensBacksolver;
#[derive(Clone)]
pub struct PdSensBacksolver {
pd: Rc<RefCell<PdFullSpaceSolver>>,
data: IpoptDataHandle,
cq: IpoptCqHandle,
nlp: Rc<RefCell<dyn IpoptNlp>>,
dims: [usize; 8],
template: IteratesVector,
conj: Option<Rc<ConjPair>>,
}
struct ConjPair {
e: Vec<Number>,
f: Vec<Number>,
}
impl PdSensBacksolver {
pub fn new(
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
pd: Rc<RefCell<PdFullSpaceSolver>>,
) -> Result<Self, String> {
let curr = data
.borrow()
.curr
.clone()
.ok_or_else(|| "no current iterate at convergence".to_string())?;
let dims = [
curr.x.dim() as usize,
curr.s.dim() as usize,
curr.y_c.dim() as usize,
curr.y_d.dim() as usize,
curr.z_l.dim() as usize,
curr.z_u.dim() as usize,
curr.v_l.dim() as usize,
curr.v_u.dim() as usize,
];
let conj = Self::natural_units_conj(nlp, &dims)?;
Ok(Self {
pd,
data: Rc::clone(data),
cq: Rc::clone(cq),
nlp: Rc::clone(nlp),
dims,
template: curr,
conj,
})
}
fn natural_units_conj(
nlp: &Rc<RefCell<dyn IpoptNlp>>,
dims: &[usize; 8],
) -> Result<Option<Rc<ConjPair>>, String> {
let nlp_ref = nlp.borrow();
let df = nlp_ref.obj_scaling_factor();
let dc = nlp_ref.c_scale_vec();
let dd = nlp_ref.d_scale_vec();
if df == 1.0 && dc.is_none() && dd.is_none() {
return Ok(None);
}
if !df.is_finite() || df == 0.0 {
return Err(format!("invalid obj_scaling_factor {df}"));
}
if let Some(v) = &dc {
if v.len() != dims[2] {
return Err(format!("c_scale length {} != y_c dim {}", v.len(), dims[2]));
}
}
if let Some(v) = &dd {
if v.len() != dims[3] || dims[1] != dims[3] {
return Err(format!(
"d_scale length {} != y_d dim {} (s dim {})",
v.len(),
dims[3],
dims[1]
));
}
}
let v_row_scale = |pm: Rc<dyn pounce_linalg::matrix::Matrix>,
n_v: usize,
which: &str|
-> Result<Vec<Number>, String> {
let Some(dd) = &dd else {
return Ok(vec![1.0; n_v]);
};
if n_v == 0 {
return Ok(Vec::new());
}
let Some(em) = pm
.as_any()
.downcast_ref::<pounce_linalg::expansion_matrix::ExpansionMatrix>()
else {
return Err(format!("{which} is not an ExpansionMatrix"));
};
let pos = em.expanded_pos_indices();
if pos.len() != n_v {
return Err(format!(
"{which} expansion length {} != {} block dim {}",
pos.len(),
which,
n_v
));
}
pos.iter()
.map(|&r| {
dd.get(r as usize).copied().ok_or_else(|| {
format!(
"{which} expansion row {r} out of d_scale range {}",
dd.len()
)
})
})
.collect()
};
let vl_dd = v_row_scale(nlp_ref.pd_l(), dims[6], "pd_l")?;
let vu_dd = v_row_scale(nlp_ref.pd_u(), dims[7], "pd_u")?;
drop(nlp_ref);
let total: usize = dims.iter().sum();
let mut e = Vec::with_capacity(total);
let mut f = Vec::with_capacity(total);
e.extend(std::iter::repeat_n(df, dims[0]));
f.extend(std::iter::repeat_n(1.0, dims[0]));
match &dd {
Some(v) => {
e.extend(v.iter().map(|&ddi| df / ddi));
f.extend(v.iter().map(|&ddi| 1.0 / ddi));
}
None => {
e.extend(std::iter::repeat_n(df, dims[1]));
f.extend(std::iter::repeat_n(1.0, dims[1]));
}
}
match &dc {
Some(v) => {
e.extend(v.iter().copied());
f.extend(v.iter().map(|&dci| dci / df));
}
None => {
e.extend(std::iter::repeat_n(1.0, dims[2]));
f.extend(std::iter::repeat_n(1.0 / df, dims[2]));
}
}
match &dd {
Some(v) => {
e.extend(v.iter().copied());
f.extend(v.iter().map(|&ddi| ddi / df));
}
None => {
e.extend(std::iter::repeat_n(1.0, dims[3]));
f.extend(std::iter::repeat_n(1.0 / df, dims[3]));
}
}
e.extend(std::iter::repeat_n(df, dims[4] + dims[5]));
f.extend(std::iter::repeat_n(1.0 / df, dims[4] + dims[5]));
e.extend(std::iter::repeat_n(df, dims[6] + dims[7]));
f.extend(vl_dd.iter().map(|&ddr| ddr / df));
f.extend(vu_dd.iter().map(|&ddr| ddr / df));
Ok(Some(Rc::new(ConjPair { e, f })))
}
pub fn obj_scaling_factor(&self) -> Number {
self.nlp.borrow().obj_scaling_factor()
}
pub fn nlp_scaling(&self) -> (Number, Option<Vec<Number>>, Option<Vec<Number>>) {
let n = self.nlp.borrow();
(n.obj_scaling_factor(), n.c_scale_vec(), n.d_scale_vec())
}
pub fn kkt_perturbations(&self) -> [Number; 4] {
let p = self.data.borrow().perturbations;
[p.delta_x, p.delta_s, p.delta_c, p.delta_d]
}
pub fn pin_rows_and_c_scales(
&self,
pin_g_indices: &[Index],
) -> Result<(Vec<Index>, Vec<Number>), String> {
let y_c_offset = (self.dims[0] + self.dims[1]) as Index;
let nlp = self.nlp.borrow();
let dc = nlp.c_scale_vec();
let n_full_g = nlp.n_full_g();
let mut rows = Vec::with_capacity(pin_g_indices.len());
let mut scales = Vec::with_capacity(pin_g_indices.len());
for &gi in pin_g_indices {
if gi < 0 || (n_full_g > 0 && gi >= n_full_g) {
return Err(format!(
"pin constraint index {gi} out of range [0, m={n_full_g})"
));
}
let Some(ci) = nlp.full_g_to_c_block(gi) else {
return Err(format!(
"pin constraint index {gi} is an inequality (not an equality row); \
parameter pins must be exact equalities"
));
};
rows.push(y_c_offset + ci);
scales.push(dc.as_ref().map(|v| v[ci as usize]).unwrap_or(1.0));
}
Ok((rows, scales))
}
pub fn map_pin_g_to_kkt_rows(&self, pin_g_indices: &[Index]) -> Result<Vec<Index>, String> {
Ok(self.pin_rows_and_c_scales(pin_g_indices)?.0)
}
pub fn pin_c_scales(&self, pin_g_indices: &[Index]) -> Result<Vec<Number>, String> {
Ok(self.pin_rows_and_c_scales(pin_g_indices)?.1)
}
pub fn block_dims(&self) -> [usize; 8] {
self.dims
}
pub fn full_g_to_c_block(&self, full_idx: Index) -> Option<Index> {
self.nlp.borrow().full_g_to_c_block(full_idx)
}
fn offsets(&self) -> [usize; 9] {
let mut o = [0usize; 9];
for i in 0..8 {
o[i + 1] = o[i] + self.dims[i];
}
o
}
fn pack(&self, flat: &[Number]) -> Result<IteratesVectorMut, ()> {
let mut out = self.template.make_new_zeroed();
let off = self.offsets();
let blocks: [&mut Box<dyn pounce_linalg::vector::Vector>; 8] = [
&mut out.x,
&mut out.s,
&mut out.y_c,
&mut out.y_d,
&mut out.z_l,
&mut out.z_u,
&mut out.v_l,
&mut out.v_u,
];
for (i, blk) in blocks.into_iter().enumerate() {
let slice = &flat[off[i]..off[i + 1]];
let dv = blk.as_any_mut().downcast_mut::<DenseVector>().ok_or(())?;
dv.set_values(slice);
}
Ok(out)
}
fn unpack(&self, iv: &IteratesVectorMut, out: &mut [Number]) -> Result<(), ()> {
let off = self.offsets();
let blocks: [&Box<dyn pounce_linalg::vector::Vector>; 8] = [
&iv.x, &iv.s, &iv.y_c, &iv.y_d, &iv.z_l, &iv.z_u, &iv.v_l, &iv.v_u,
];
for (i, blk) in blocks.into_iter().enumerate() {
let dst = &mut out[off[i]..off[i + 1]];
if dst.is_empty() {
continue;
}
let dv = (**blk).as_any().downcast_ref::<DenseVector>().ok_or(())?;
let ev = dv.expanded_values();
dst.copy_from_slice(&ev);
}
Ok(())
}
}
impl PdSensBacksolver {
pub fn solve_many(&self, rhs_flat: &[Number], lhs_flat: &mut [Number], n_rhs: usize) -> bool {
match &self.conj {
None => self.solve_many_scaled_space(rhs_flat, lhs_flat, n_rhs),
Some(c) => {
let total = self.dim();
if rhs_flat.len() != n_rhs * total || lhs_flat.len() != n_rhs * total {
return false;
}
let mut rhs_scaled = rhs_flat.to_vec();
for row in rhs_scaled.chunks_mut(total) {
for (r, &ei) in row.iter_mut().zip(c.e.iter()) {
*r *= ei;
}
}
if !self.solve_many_scaled_space(&rhs_scaled, lhs_flat, n_rhs) {
return false;
}
for row in lhs_flat.chunks_mut(total) {
for (l, &fi) in row.iter_mut().zip(c.f.iter()) {
*l *= fi;
}
}
true
}
}
}
pub fn solve_many_scaled_space(
&self,
rhs_flat: &[Number],
lhs_flat: &mut [Number],
n_rhs: usize,
) -> bool {
let total = self.dim();
if rhs_flat.len() != n_rhs * total || lhs_flat.len() != n_rhs * total {
return false;
}
if n_rhs == 0 {
return true;
}
let off = self.offsets();
{
let mut pd_ref = self.pd.borrow_mut();
let fast_flat = pd_ref.solve_many_cached_flat(
&self.data, &self.cq, &self.nlp, n_rhs, rhs_flat, lhs_flat, self.dims,
);
match fast_flat {
Some(true) => return true,
Some(false) => return false,
None => { }
}
}
{
let mut pd_ref = self.pd.borrow_mut();
let fast = pd_ref.solve_many_cached(
&self.data,
&self.cq,
&self.nlp,
n_rhs,
|k, iv| {
let row = &rhs_flat[k * total..(k + 1) * total];
let _ = write_rhs_box(&mut iv.x, &row[off[0]..off[1]])
&& write_rhs_box(&mut iv.s, &row[off[1]..off[2]])
&& write_rhs_box(&mut iv.y_c, &row[off[2]..off[3]])
&& write_rhs_box(&mut iv.y_d, &row[off[3]..off[4]])
&& write_rhs_box(&mut iv.z_l, &row[off[4]..off[5]])
&& write_rhs_box(&mut iv.z_u, &row[off[5]..off[6]])
&& write_rhs_box(&mut iv.v_l, &row[off[6]..off[7]])
&& write_rhs_box(&mut iv.v_u, &row[off[7]..off[8]]);
},
|k, iv| {
let row = &mut lhs_flat[k * total..(k + 1) * total];
let _ = read_res_block(&*iv.x, &mut row[off[0]..off[1]])
&& read_res_block(&*iv.s, &mut row[off[1]..off[2]])
&& read_res_block(&*iv.y_c, &mut row[off[2]..off[3]])
&& read_res_block(&*iv.y_d, &mut row[off[3]..off[4]])
&& read_res_block(&*iv.z_l, &mut row[off[4]..off[5]])
&& read_res_block(&*iv.z_u, &mut row[off[5]..off[6]])
&& read_res_block(&*iv.v_l, &mut row[off[6]..off[7]])
&& read_res_block(&*iv.v_u, &mut row[off[7]..off[8]]);
},
);
match fast {
Some(true) => return true,
Some(false) => return false,
None => { }
}
}
let rhs_mut0 = self.template.make_new_zeroed();
let mut rhs_iv = rhs_mut0.freeze();
let mut res_iv = self.template.make_new_zeroed();
let mut pd_ref = self.pd.borrow_mut();
for k in 0..n_rhs {
let rhs_row = &rhs_flat[k * total..(k + 1) * total];
let lhs_row = &mut lhs_flat[k * total..(k + 1) * total];
if !write_rhs_block(&mut rhs_iv.x, &rhs_row[off[0]..off[1]])
|| !write_rhs_block(&mut rhs_iv.s, &rhs_row[off[1]..off[2]])
|| !write_rhs_block(&mut rhs_iv.y_c, &rhs_row[off[2]..off[3]])
|| !write_rhs_block(&mut rhs_iv.y_d, &rhs_row[off[3]..off[4]])
|| !write_rhs_block(&mut rhs_iv.z_l, &rhs_row[off[4]..off[5]])
|| !write_rhs_block(&mut rhs_iv.z_u, &rhs_row[off[5]..off[6]])
|| !write_rhs_block(&mut rhs_iv.v_l, &rhs_row[off[6]..off[7]])
|| !write_rhs_block(&mut rhs_iv.v_u, &rhs_row[off[7]..off[8]])
{
return false;
}
let ok = pd_ref.solve(
&self.data,
&self.cq,
&self.nlp,
1.0,
0.0,
&rhs_iv,
&mut res_iv,
true,
false,
);
if !ok {
return false;
}
if !read_res_block(&*res_iv.x, &mut lhs_row[off[0]..off[1]])
|| !read_res_block(&*res_iv.s, &mut lhs_row[off[1]..off[2]])
|| !read_res_block(&*res_iv.y_c, &mut lhs_row[off[2]..off[3]])
|| !read_res_block(&*res_iv.y_d, &mut lhs_row[off[3]..off[4]])
|| !read_res_block(&*res_iv.z_l, &mut lhs_row[off[4]..off[5]])
|| !read_res_block(&*res_iv.z_u, &mut lhs_row[off[5]..off[6]])
|| !read_res_block(&*res_iv.v_l, &mut lhs_row[off[6]..off[7]])
|| !read_res_block(&*res_iv.v_u, &mut lhs_row[off[7]..off[8]])
{
return false;
}
}
true
}
}
fn write_rhs_box(b: &mut Box<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
if slice.is_empty() {
return true;
}
let Some(dv) = b.as_any_mut().downcast_mut::<DenseVector>() else {
return false;
};
dv.set_values(slice);
true
}
fn write_rhs_block(rc: &mut Rc<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
if slice.is_empty() {
return true;
}
let Some(v) = Rc::get_mut(rc) else {
return false;
};
let Some(dv) = v.as_any_mut().downcast_mut::<DenseVector>() else {
return false;
};
dv.set_values(slice);
true
}
fn read_res_block(blk: &dyn pounce_linalg::vector::Vector, dst: &mut [Number]) -> bool {
if dst.is_empty() {
return true;
}
let Some(dv) = blk.as_any().downcast_ref::<DenseVector>() else {
return false;
};
if dv.is_homogeneous() {
let s = dv.scalar();
for x in dst.iter_mut() {
*x = s;
}
} else {
dst.copy_from_slice(dv.values());
}
true
}
impl PdSensBacksolver {
pub fn solve_scaled_space(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
let total = self.dim();
if rhs.len() != total || lhs.len() != total {
return false;
}
let rhs_mut = match self.pack(rhs) {
Ok(v) => v,
Err(()) => return false,
};
let rhs_iv = rhs_mut.freeze();
let mut res_iv = self.template.make_new_zeroed();
let ok = {
let mut pd_ref = self.pd.borrow_mut();
pd_ref.solve(
&self.data,
&self.cq,
&self.nlp,
1.0,
0.0,
&rhs_iv,
&mut res_iv,
true,
false,
)
};
if !ok {
return false;
}
self.unpack(&res_iv, lhs).is_ok()
}
}
impl SensBacksolver for PdSensBacksolver {
fn dim(&self) -> usize {
self.dims.iter().sum()
}
fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
match &self.conj {
None => self.solve_scaled_space(rhs, lhs),
Some(c) => {
let total = self.dim();
if rhs.len() != total || lhs.len() != total {
return false;
}
let rhs_scaled: Vec<Number> =
rhs.iter().zip(c.e.iter()).map(|(&r, &ei)| r * ei).collect();
if !self.solve_scaled_space(&rhs_scaled, lhs) {
return false;
}
for (l, &fi) in lhs.iter_mut().zip(c.f.iter()) {
*l *= fi;
}
true
}
}
}
}