use std::cell::RefCell;
use std::rc::Rc;
use pounce_algorithm::application::IpoptApplication;
use pounce_common::types::{Index, Number};
use pounce_nlp::return_codes::ApplicationReturnStatus;
use pounce_nlp::tnlp::{
BoundsInfo, IndexStyle, IpoptCq, IpoptData, NlpInfo, Solution, SparsityRequest, StartingPoint,
TNLP,
};
use pounce_sensitivity::Solver;
struct LeadingInequalityTNLP {
p1: Number,
p2: Number,
}
impl TNLP for LeadingInequalityTNLP {
fn get_nlp_info(&mut self) -> Option<NlpInfo> {
Some(NlpInfo {
n: 3,
m: 4,
nnz_jac_g: 8,
nnz_h_lag: 1,
index_style: IndexStyle::C,
})
}
fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
for k in 0..3 {
b.x_l[k] = -1.0e19;
b.x_u[k] = 1.0e19;
}
b.g_l[0] = -1.0e19;
b.g_u[0] = 1000.0;
b.g_l[1] = 0.0;
b.g_u[1] = 0.0;
b.g_l[2] = self.p1;
b.g_u[2] = self.p1;
b.g_l[3] = self.p2;
b.g_u[3] = self.p2;
true
}
fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
sp.x[0] = self.p1 + self.p2;
sp.x[1] = self.p1;
sp.x[2] = self.p2;
true
}
fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
Some(x[0] * x[0])
}
fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
g[0] = 2.0 * x[0];
g[1] = 0.0;
g[2] = 0.0;
true
}
fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
g[0] = x[0] + x[1] + x[2];
g[1] = x[0] - x[1] - x[2];
g[2] = x[1];
g[3] = x[2];
true
}
fn eval_jac_g(
&mut self,
_x: Option<&[Number]>,
_new_x: bool,
mode: SparsityRequest<'_>,
) -> bool {
match mode {
SparsityRequest::Structure { irow, jcol } => {
let rs: [Index; 8] = [0, 0, 0, 1, 1, 1, 2, 3];
let cs: [Index; 8] = [0, 1, 2, 0, 1, 2, 1, 2];
irow.copy_from_slice(&rs);
jcol.copy_from_slice(&cs);
}
SparsityRequest::Values { values } => {
values.copy_from_slice(&[1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0]);
}
}
true
}
fn eval_h(
&mut self,
_x: Option<&[Number]>,
_new_x: bool,
obj_factor: Number,
_lambda: Option<&[Number]>,
_new_lambda: bool,
mode: SparsityRequest<'_>,
) -> bool {
match mode {
SparsityRequest::Structure { irow, jcol } => {
irow[0] = 0;
jcol[0] = 0;
}
SparsityRequest::Values { values } => {
values[0] = 2.0 * obj_factor;
}
}
true
}
fn finalize_solution(&mut self, _sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {}
}
fn make_app() -> IpoptApplication {
let mut app = IpoptApplication::new();
app.options_mut()
.set_integer_value("print_level", 0, true, false)
.unwrap();
app.options_mut()
.set_string_value("sb", "yes", true, false)
.unwrap();
app.initialize().unwrap();
app
}
#[test]
fn parametric_step_translates_pin_index_through_cd_split() {
let tnlp: Rc<RefCell<dyn TNLP>> =
Rc::new(RefCell::new(LeadingInequalityTNLP { p1: 1.0, p2: 1.0 }));
let mut solver = Solver::new(make_app(), tnlp);
let status = solver.solve();
assert!(
matches!(
status,
ApplicationReturnStatus::SolveSucceeded
| ApplicationReturnStatus::SolvedToAcceptableLevel
),
"solve failed: {status:?}"
);
let delta = 0.1;
let dx = solver
.parametric_step(&[2], &[delta])
.expect("parametric_step ok");
assert_eq!(dx.len(), 3);
assert!(
(dx[0] - delta).abs() < 1e-7,
"dx[0] = {} expected ≈ {delta}",
dx[0]
);
assert!(
(dx[1] - delta).abs() < 1e-7,
"dx[1] = {} expected ≈ {delta} (pinning x1's constraint must move x1; \
pre-fix bug pins x2's constraint and leaves this 0)",
dx[1]
);
assert!(
dx[2].abs() < 1e-7,
"dx[2] = {} expected ≈ 0 (x2's constraint was NOT pinned; \
pre-fix bug pins it and moves this by {delta})",
dx[2]
);
}
#[test]
fn parametric_step_errors_on_pinned_inequality() {
let tnlp: Rc<RefCell<dyn TNLP>> =
Rc::new(RefCell::new(LeadingInequalityTNLP { p1: 1.0, p2: 1.0 }));
let mut solver = Solver::new(make_app(), tnlp);
solver.solve();
let res = solver.parametric_step(&[0], &[0.1]);
assert!(
res.is_err(),
"pinning an inequality constraint must error, got {res:?}"
);
}
#[test]
fn reduced_hessian_errors_on_pinned_inequality() {
let tnlp: Rc<RefCell<dyn TNLP>> =
Rc::new(RefCell::new(LeadingInequalityTNLP { p1: 1.0, p2: 1.0 }));
let mut solver = Solver::new(make_app(), tnlp);
solver.solve();
let res = solver.compute_reduced_hessian(&[0], 1.0);
assert!(
res.is_err(),
"reduced Hessian over a pinned inequality must error, got {res:?}"
);
}