use std::cell::{Ref, RefCell};
use std::rc::Rc;
use pounce_algorithm::application::IpoptApplication;
use pounce_common::types::{Index, Number};
use pounce_nlp::return_codes::ApplicationReturnStatus;
use pounce_nlp::TNLP;
use crate::backsolver::SensBacksolver;
use crate::schur_data::IndexSchurData;
use crate::sens_app::{SensApplication, SensOptions};
use crate::PdSensBacksolver;
#[derive(Debug, Clone)]
pub enum SolverError {
NotConverged,
BadShape {
what: &'static str,
got: usize,
expected: usize,
},
BacksolveFailed,
SensComputationFailed(String),
}
pub struct ConvergedState {
pub status: ApplicationReturnStatus,
pub x: Vec<Number>,
pub obj_val: Number,
backsolver: PdSensBacksolver,
}
impl ConvergedState {
pub fn block_dims(&self) -> [usize; 8] {
self.backsolver.block_dims()
}
pub fn kkt_dim(&self) -> usize {
self.backsolver.dim()
}
}
pub struct Solver {
app: IpoptApplication,
tnlp: Rc<RefCell<dyn TNLP>>,
state: Rc<RefCell<Option<ConvergedState>>>,
}
impl Solver {
pub fn new(app: IpoptApplication, tnlp: Rc<RefCell<dyn TNLP>>) -> Self {
Self {
app,
tnlp,
state: Rc::new(RefCell::new(None)),
}
}
pub fn app(&self) -> &IpoptApplication {
&self.app
}
pub fn app_mut(&mut self) -> &mut IpoptApplication {
&mut self.app
}
pub fn solve(&mut self) -> ApplicationReturnStatus {
self.state.borrow_mut().take();
let state_cb = Rc::clone(&self.state);
self.app
.set_on_converged(Box::new(move |data, cq, nlp, pd| {
let curr = match data.borrow().curr.clone() {
Some(c) => c,
None => return,
};
let backsolver = match PdSensBacksolver::new(data, cq, nlp, Rc::clone(&pd)) {
Ok(b) => b,
Err(_) => return,
};
let x = dense_to_vec(&*curr.x);
let obj_val = cq.borrow_mut().curr_f();
*state_cb.borrow_mut() = Some(ConvergedState {
status: ApplicationReturnStatus::InternalError,
x,
obj_val,
backsolver,
});
}));
let status = self.app.optimize_tnlp(Rc::clone(&self.tnlp));
if let Some(s) = self.state.borrow_mut().as_mut() {
s.status = status;
}
status
}
pub fn converged(&self) -> Option<Ref<'_, ConvergedState>> {
let r = self.state.borrow();
r.as_ref()?;
Some(Ref::map(r, |o| {
o.as_ref()
.unwrap_or_else(|| unreachable!("checked is_some above"))
}))
}
pub fn kkt_dim(&self) -> Option<usize> {
self.converged().map(|c| c.kkt_dim())
}
pub fn block_dims(&self) -> Option<[usize; 8]> {
self.converged().map(|c| c.block_dims())
}
pub fn kkt_solve(&self, rhs: &[Number], lhs: &mut [Number]) -> Result<(), SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let total = state.backsolver.dim();
if rhs.len() != total {
return Err(SolverError::BadShape {
what: "rhs",
got: rhs.len(),
expected: total,
});
}
if lhs.len() != total {
return Err(SolverError::BadShape {
what: "lhs",
got: lhs.len(),
expected: total,
});
}
if state.backsolver.solve(rhs, lhs) {
Ok(())
} else {
Err(SolverError::BacksolveFailed)
}
}
pub fn parametric_step(
&self,
pin_constraint_indices: &[Index],
deltas: &[Number],
) -> Result<Vec<Number>, SolverError> {
if pin_constraint_indices.len() != deltas.len() {
return Err(SolverError::BadShape {
what: "deltas",
got: deltas.len(),
expected: pin_constraint_indices.len(),
});
}
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let dims = state.backsolver.block_dims();
let n_x = dims[0];
let n_s = dims[1];
let y_c_offset = (n_x + n_s) as Index;
let param_rows: Vec<Index> = pin_constraint_indices
.iter()
.map(|&i| y_c_offset + i)
.collect();
let signs = vec![1; pin_constraint_indices.len()];
let a_data = IndexSchurData::from_parts(param_rows, signs)
.map_err(|e| SolverError::SensComputationFailed(format!("{e:?}")))?;
let opts = SensOptions {
run_sens: true,
..SensOptions::default()
};
let sens_app = SensApplication::new(a_data, state.backsolver.clone(), opts);
let n_full = state.backsolver.dim();
let mut dx_full = vec![0.0; n_full];
if !sens_app.parametric_step(deltas, &mut dx_full) {
return Err(SolverError::SensComputationFailed(
"SensApplication::parametric_step failed".into(),
));
}
dx_full.truncate(n_x);
Ok(dx_full)
}
pub fn compute_reduced_hessian(
&self,
pin_constraint_indices: &[Index],
obj_scal: Number,
) -> Result<Vec<Number>, SolverError> {
let state = self.state.borrow();
let state = state.as_ref().ok_or(SolverError::NotConverged)?;
let n = pin_constraint_indices.len();
let dims = state.backsolver.block_dims();
let y_c_offset = (dims[0] + dims[1]) as Index;
let param_rows: Vec<Index> = pin_constraint_indices
.iter()
.map(|&i| y_c_offset + i)
.collect();
let signs = vec![1; n];
let a_data = IndexSchurData::from_parts(param_rows, signs)
.map_err(|e| SolverError::SensComputationFailed(format!("{e:?}")))?;
let opts = SensOptions {
compute_red_hessian: true,
obj_scal,
..SensOptions::default()
};
let mut sens_app = SensApplication::new(a_data, state.backsolver.clone(), opts);
let mut hr = vec![0.0; n * n];
if !sens_app.compute_reduced_hessian(&mut hr) {
return Err(SolverError::SensComputationFailed(
"SensApplication::compute_reduced_hessian failed".into(),
));
}
Ok(hr)
}
}
fn dense_to_vec(v: &dyn pounce_linalg::Vector) -> Vec<Number> {
match v
.as_any()
.downcast_ref::<pounce_linalg::dense_vector::DenseVector>()
{
Some(d) => d.values().to_vec(),
None => vec![0.0; v.dim() as usize],
}
}