use std::cell::{Ref, RefCell};
use std::rc::Rc;
use pounce_algorithm::application::IpoptApplication;
use pounce_common::types::{Index, Number};
use pounce_nlp::return_codes::ApplicationReturnStatus;
use pounce_nlp::TNLP;
use crate::backsolver::SensBacksolver;
use crate::schur_data::IndexSchurData;
use crate::sens_app::{SensApplication, SensOptions};
use crate::vec_util::dense_to_vec;
use crate::PdSensBacksolver;
#[derive(Debug, Clone)]
pub enum SolverError {
NotConverged,
BadShape {
what: &'static str,
got: usize,
expected: usize,
},
BacksolveFailed,
SensComputationFailed(String),
}
pub struct ConvergedState {
pub status: ApplicationReturnStatus,
pub x: Vec<Number>,
pub obj_val: Number,
backsolver: PdSensBacksolver,
}
impl ConvergedState {
pub fn block_dims(&self) -> [usize; 8] {
self.backsolver.block_dims()
}
pub fn kkt_dim(&self) -> usize {
self.backsolver.dim()
}
}
pub struct Solver {
app: IpoptApplication,
tnlp: Rc<RefCell<dyn TNLP>>,
state: Rc<RefCell<Option<ConvergedState>>>,
}
impl Solver {
pub fn new(app: IpoptApplication, tnlp: Rc<RefCell<dyn TNLP>>) -> Self {
Self {
app,
tnlp,
state: Rc::new(RefCell::new(None)),
}
}
pub fn app(&self) -> &IpoptApplication {
&self.app
}
pub fn app_mut(&mut self) -> &mut IpoptApplication {
&mut self.app
}
pub fn solve(&mut self) -> ApplicationReturnStatus {
self.state.borrow_mut().take();
let state_cb = Rc::clone(&self.state);
self.app
.set_on_converged(Box::new(move |data, cq, nlp, pd| {
let curr = match data.borrow().curr.clone() {
Some(c) => c,
None => return,
};
let backsolver = match PdSensBacksolver::new(data, cq, nlp, Rc::clone(&pd)) {
Ok(b) => b,
Err(e) => {
eprintln!("pounce: Solver could not capture the KKT factor: {e}");
return;
}
};
let x = dense_to_vec(&*curr.x);
let obj_val = cq.borrow_mut().curr_f();
*state_cb.borrow_mut() = Some(ConvergedState {
status: ApplicationReturnStatus::InternalError,
x,
obj_val,
backsolver,
});
}));
let status = self.app.optimize_tnlp(Rc::clone(&self.tnlp));
if let Some(s) = self.state.borrow_mut().as_mut() {
s.status = status;
}
status
}
pub fn converged(&self) -> Option<Ref<'_, ConvergedState>> {
let r = self.state.borrow();
r.as_ref()?;
Some(Ref::map(r, |o| {
o.as_ref()
.unwrap_or_else(|| unreachable!("checked is_some above"))
}))
}
pub fn kkt_dim(&self) -> Option<usize> {
self.converged().map(|c| c.kkt_dim())
}
pub fn block_dims(&self) -> Option<[usize; 8]> {
self.converged().map(|c| c.block_dims())
}
pub fn kkt_solve(&self, rhs: &[Number], lhs: &mut [Number]) -> Result<(), SolverError> {
self.kkt_solve_impl(rhs, lhs, false)
}
pub fn kkt_solve_scaled(&self, rhs: &[Number], lhs: &mut [Number]) -> Result<(), SolverError> {
self.kkt_solve_impl(rhs, lhs, true)
}
fn kkt_solve_impl(
&self,
rhs: &[Number],
lhs: &mut [Number],
scaled: bool,
) -> Result<(), SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let total = state.backsolver.dim();
if rhs.len() != total {
return Err(SolverError::BadShape {
what: "rhs",
got: rhs.len(),
expected: total,
});
}
if lhs.len() != total {
return Err(SolverError::BadShape {
what: "lhs",
got: lhs.len(),
expected: total,
});
}
let ok = if scaled {
state.backsolver.solve_scaled_space(rhs, lhs)
} else {
state.backsolver.solve(rhs, lhs)
};
if ok {
Ok(())
} else {
Err(SolverError::BacksolveFailed)
}
}
pub fn kkt_solve_many(
&self,
rhs_flat: &[Number],
lhs_flat: &mut [Number],
n_rhs: usize,
) -> Result<(), SolverError> {
self.kkt_solve_many_impl(rhs_flat, lhs_flat, n_rhs, false)
}
pub fn kkt_solve_many_scaled(
&self,
rhs_flat: &[Number],
lhs_flat: &mut [Number],
n_rhs: usize,
) -> Result<(), SolverError> {
self.kkt_solve_many_impl(rhs_flat, lhs_flat, n_rhs, true)
}
fn kkt_solve_many_impl(
&self,
rhs_flat: &[Number],
lhs_flat: &mut [Number],
n_rhs: usize,
scaled: bool,
) -> Result<(), SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let total = state.backsolver.dim();
let expected = n_rhs * total;
if rhs_flat.len() != expected {
return Err(SolverError::BadShape {
what: "rhs",
got: rhs_flat.len(),
expected,
});
}
if lhs_flat.len() != expected {
return Err(SolverError::BadShape {
what: "lhs",
got: lhs_flat.len(),
expected,
});
}
let ok = if scaled {
state
.backsolver
.solve_many_scaled_space(rhs_flat, lhs_flat, n_rhs)
} else {
state.backsolver.solve_many(rhs_flat, lhs_flat, n_rhs)
};
if ok {
Ok(())
} else {
Err(SolverError::BacksolveFailed)
}
}
pub fn parametric_step(
&self,
pin_constraint_indices: &[Index],
deltas: &[Number],
) -> Result<Vec<Number>, SolverError> {
if pin_constraint_indices.len() != deltas.len() {
return Err(SolverError::BadShape {
what: "deltas",
got: deltas.len(),
expected: pin_constraint_indices.len(),
});
}
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let dims = state.backsolver.block_dims();
let n_x = dims[0];
let param_rows = state
.backsolver
.map_pin_g_to_kkt_rows(pin_constraint_indices)
.map_err(SolverError::SensComputationFailed)?;
let signs = vec![1; pin_constraint_indices.len()];
let a_data = IndexSchurData::from_parts(param_rows, signs)
.map_err(|e| SolverError::SensComputationFailed(format!("{e:?}")))?;
let opts = SensOptions {
run_sens: true,
..SensOptions::default()
};
let sens_app = SensApplication::new(a_data, state.backsolver.clone(), opts);
let n_full = state.backsolver.dim();
let mut dx_full = vec![0.0; n_full];
if !sens_app.parametric_step(deltas, &mut dx_full) {
return Err(SolverError::SensComputationFailed(
"SensApplication::parametric_step failed".into(),
));
}
dx_full.truncate(n_x);
Ok(dx_full)
}
pub fn compute_reduced_hessian(
&self,
pin_constraint_indices: &[Index],
obj_scal: Number,
) -> Result<Vec<Number>, SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let n = pin_constraint_indices.len();
let param_rows = state
.backsolver
.map_pin_g_to_kkt_rows(pin_constraint_indices)
.map_err(SolverError::SensComputationFailed)?;
let signs = vec![1; n];
let a_data = IndexSchurData::from_parts(param_rows, signs)
.map_err(|e| SolverError::SensComputationFailed(format!("{e:?}")))?;
let opts = SensOptions {
compute_red_hessian: true,
obj_scal,
..SensOptions::default()
};
let mut sens_app = SensApplication::new(a_data, state.backsolver.clone(), opts);
let mut hr = vec![0.0; n * n];
if !sens_app.compute_reduced_hessian(&mut hr) {
return Err(SolverError::SensComputationFailed(
"SensApplication::compute_reduced_hessian failed".into(),
));
}
Ok(hr)
}
pub fn compute_reduced_hessian_scaled(
&self,
pin_constraint_indices: &[Index],
obj_scal: Number,
) -> Result<Vec<Number>, SolverError> {
let mut hr = self.compute_reduced_hessian(pin_constraint_indices, obj_scal)?;
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let df = state.backsolver.obj_scaling_factor();
let dc = state
.backsolver
.pin_c_scales(pin_constraint_indices)
.map_err(SolverError::SensComputationFailed)?;
crate::reduced_hessian::scale_to_solver_space(&mut hr, df, &dc);
Ok(hr)
}
pub fn nlp_scaling(
&self,
) -> Result<(Number, Option<Vec<Number>>, Option<Vec<Number>>), SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
Ok(state.backsolver.nlp_scaling())
}
pub fn kkt_perturbations(&self) -> Result<[Number; 4], SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
Ok(state.backsolver.kkt_perturbations())
}
pub fn pin_g_scaling(
&self,
pin_constraint_indices: &[Index],
) -> Result<Vec<Number>, SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
state
.backsolver
.pin_c_scales(pin_constraint_indices)
.map_err(SolverError::SensComputationFailed)
}
}