use pounce_algorithm::application::{
default_backend_factory, feral_config_from_options, IpoptApplication,
};
use pounce_nlp::return_codes::ApplicationReturnStatus;
use pounce_nlp::tnlp::TNLP;
use pounce_restoration::resto_alg_builder::RestoAlgorithmBuilder;
use pounce_restoration::resto_inner_solver::{
make_default_restoration_factory_provider, InnerBackendFactoryFactory,
};
use pounce_sensitivity::Solver as RustSolver;
use std::cell::RefCell;
use std::ffi::c_void;
use std::rc::Rc;
use crate::{
Bool, CCallbackTnlp, Index, IpoptProblem, IpoptProblemInfo, LastSolve, Number, FALSE, TRUE,
};
pub struct IpoptSolverInfo {
session: Option<RustSolver>,
problem: IpoptProblemInfo,
m: Index,
}
pub type IpoptSolver = *mut IpoptSolverInfo;
#[no_mangle]
pub unsafe extern "C" fn IpoptCreateSolver(prob_handle: *mut IpoptProblem) -> IpoptSolver {
if prob_handle.is_null() {
return std::ptr::null_mut();
}
let prob = *prob_handle;
if prob.is_null() {
return std::ptr::null_mut();
}
let problem = *Box::from_raw(prob);
*prob_handle = std::ptr::null_mut();
let m = problem.m;
let info = Box::new(IpoptSolverInfo {
session: None,
problem,
m,
});
Box::into_raw(info)
}
#[no_mangle]
pub unsafe extern "C" fn IpoptFreeSolver(solver: IpoptSolver) {
if solver.is_null() {
return;
}
drop(Box::from_raw(solver));
}
#[no_mangle]
#[allow(clippy::too_many_arguments)]
pub unsafe extern "C" fn IpoptSolverSolve(
solver: IpoptSolver,
x: *mut Number,
g: *mut Number,
obj_val: *mut Number,
mult_g: *mut Number,
mult_x_L: *mut Number,
mult_x_U: *mut Number,
user_data: *mut c_void,
) -> Index {
if solver.is_null() {
return ApplicationReturnStatus::InternalError as Index;
}
{
let info = &mut *solver;
info.session = None;
info.problem.last_solve = None;
}
crate::ffi_guard(ApplicationReturnStatus::InternalError as Index, || unsafe {
let info = &mut *solver;
let n = info.problem.n;
let m = info.m;
if n < 0 || m < 0 {
return ApplicationReturnStatus::InvalidProblemDefinition as Index;
}
if n > 0 && x.is_null() {
return ApplicationReturnStatus::InvalidProblemDefinition as Index;
}
let n_us = n as usize;
let m_us = m as usize;
let initial_x = if n_us > 0 {
std::slice::from_raw_parts(x, n_us).to_vec()
} else {
Vec::new()
};
let bridge = Rc::new(RefCell::new(CCallbackTnlp {
n,
m,
nele_jac: info.problem.nele_jac,
nele_hess: info.problem.nele_hess,
index_style: info.problem.index_style,
x_l: info.problem.x_l.clone(),
x_u: info.problem.x_u.clone(),
g_l: info.problem.g_l.clone(),
g_u: info.problem.g_u.clone(),
initial_x,
eval_f: info.problem.eval_f,
eval_grad_f: info.problem.eval_grad_f,
eval_g: info.problem.eval_g,
eval_jac_g: info.problem.eval_jac_g,
eval_h: info.problem.eval_h,
user_data,
intermediate_cb: info.problem.intermediate_cb,
user_scaling: info.problem.user_scaling.clone(),
final_status: None,
final_x: vec![0.0; n_us],
final_z_l: vec![0.0; n_us],
final_z_u: vec![0.0; n_us],
final_g: vec![0.0; m_us],
final_lambda: vec![0.0; m_us],
final_obj: 0.0,
}));
let feral_cfg = feral_config_from_options(info.problem.app.options());
let bff_mint = move || -> InnerBackendFactoryFactory {
let feral_cfg = feral_cfg.clone();
Box::new(move || default_backend_factory(feral_cfg.clone()))
};
let resto_provider = make_default_restoration_factory_provider(
RestoAlgorithmBuilder::new(),
info.problem.app.algorithm_builder_from_options(),
bff_mint,
);
info.problem
.app
.set_restoration_factory_provider(resto_provider);
let saved_options = info.problem.app.options().clone();
let app = std::mem::replace(&mut info.problem.app, IpoptApplication::new());
*info.problem.app.options_mut() = saved_options;
let bridge_for_solver: Rc<RefCell<dyn TNLP>> = bridge.clone();
let mut rust_solver = RustSolver::new(app, bridge_for_solver);
let status = rust_solver.solve();
let bridge_ref = bridge.borrow();
info.problem.last_solve = Some(LastSolve {
stats: rust_solver.app().statistics(),
status,
linear_solver: rust_solver.app().linear_solver_summary(),
final_x: bridge_ref.final_x.clone(),
final_lambda: bridge_ref.final_lambda.clone(),
final_obj: bridge_ref.final_obj,
});
if !x.is_null() && n_us > 0 {
std::ptr::copy_nonoverlapping(bridge_ref.final_x.as_ptr(), x, n_us);
}
if !g.is_null() && m_us > 0 {
std::ptr::copy_nonoverlapping(bridge_ref.final_g.as_ptr(), g, m_us);
}
if !obj_val.is_null() {
*obj_val = bridge_ref.final_obj;
}
if !mult_g.is_null() && m_us > 0 {
std::ptr::copy_nonoverlapping(bridge_ref.final_lambda.as_ptr(), mult_g, m_us);
}
if !mult_x_L.is_null() && n_us > 0 {
std::ptr::copy_nonoverlapping(bridge_ref.final_z_l.as_ptr(), mult_x_L, n_us);
}
if !mult_x_U.is_null() && n_us > 0 {
std::ptr::copy_nonoverlapping(bridge_ref.final_z_u.as_ptr(), mult_x_U, n_us);
}
info.session = Some(rust_solver);
status as Index
})
}
#[no_mangle]
pub unsafe extern "C" fn IpoptSolverGetKktDim(solver: IpoptSolver) -> Index {
if solver.is_null() {
return -1;
}
let info = &*solver;
match info.session.as_ref().and_then(|s| s.kkt_dim()) {
Some(d) => d as Index,
None => -1,
}
}
#[no_mangle]
pub unsafe extern "C" fn IpoptSolverKktSolve(
solver: IpoptSolver,
rhs: *const Number,
lhs: *mut Number,
) -> Bool {
kkt_solve_impl(solver, rhs, lhs, false)
}
#[no_mangle]
pub unsafe extern "C" fn IpoptSolverKktSolveScaled(
solver: IpoptSolver,
rhs: *const Number,
lhs: *mut Number,
) -> Bool {
kkt_solve_impl(solver, rhs, lhs, true)
}
unsafe fn kkt_solve_impl(
solver: IpoptSolver,
rhs: *const Number,
lhs: *mut Number,
scaled: bool,
) -> Bool {
crate::ffi_guard(FALSE, || unsafe {
if solver.is_null() || rhs.is_null() || lhs.is_null() {
return FALSE;
}
let info = &*solver;
let Some(s) = info.session.as_ref() else {
return FALSE;
};
let Some(dim) = s.kkt_dim() else {
return FALSE;
};
let rhs_slice = std::slice::from_raw_parts(rhs, dim);
let mut lhs_vec = vec![0.0; dim];
let res = if scaled {
s.kkt_solve_scaled(rhs_slice, &mut lhs_vec)
} else {
s.kkt_solve(rhs_slice, &mut lhs_vec)
};
if res.is_err() {
return FALSE;
}
std::ptr::copy_nonoverlapping(lhs_vec.as_ptr(), lhs, dim);
TRUE
})
}
unsafe fn slice_or_empty<'a, T>(ptr: *const T, len: usize) -> &'a [T] {
if len == 0 {
&[]
} else {
std::slice::from_raw_parts(ptr, len)
}
}
#[no_mangle]
pub unsafe extern "C" fn IpoptSolverParametricStep(
solver: IpoptSolver,
n_pins: Index,
pin_indices: *const Index,
deltas: *const Number,
dx_out: *mut Number,
) -> Bool {
crate::ffi_guard(FALSE, || unsafe {
if solver.is_null() || n_pins < 0 {
return FALSE;
}
if n_pins > 0 && (pin_indices.is_null() || deltas.is_null()) {
return FALSE;
}
if dx_out.is_null() {
return FALSE;
}
let info = &*solver;
let Some(s) = info.session.as_ref() else {
return FALSE;
};
let m = info.m;
let pins_raw = slice_or_empty(pin_indices, n_pins as usize);
let mut pins = Vec::with_capacity(n_pins as usize);
for &i in pins_raw {
if i < 0 || i >= m {
return FALSE;
}
pins.push(i as pounce_common::types::Index);
}
let deltas_slice = slice_or_empty(deltas, n_pins as usize);
let Ok(dx) = s.parametric_step(&pins, deltas_slice) else {
return FALSE;
};
std::ptr::copy_nonoverlapping(dx.as_ptr(), dx_out, dx.len());
TRUE
})
}
#[no_mangle]
pub unsafe extern "C" fn IpoptSolverReducedHessian(
solver: IpoptSolver,
n_pins: Index,
pin_indices: *const Index,
obj_scal: Number,
hr_out: *mut Number,
) -> Bool {
crate::ffi_guard(FALSE, || unsafe {
if solver.is_null() || n_pins < 0 || hr_out.is_null() {
return FALSE;
}
if n_pins > 0 && pin_indices.is_null() {
return FALSE;
}
let info = &*solver;
let Some(s) = info.session.as_ref() else {
return FALSE;
};
let m = info.m;
let pins_raw = slice_or_empty(pin_indices, n_pins as usize);
let mut pins = Vec::with_capacity(n_pins as usize);
for &i in pins_raw {
if i < 0 || i >= m {
return FALSE;
}
pins.push(i as pounce_common::types::Index);
}
let Ok(hr) = s.compute_reduced_hessian(&pins, obj_scal) else {
return FALSE;
};
std::ptr::copy_nonoverlapping(hr.as_ptr(), hr_out, hr.len());
TRUE
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{AddIpoptIntOption, CreateIpoptProblem, FreeIpoptProblem};
use std::ffi::CString;
unsafe extern "C" fn quad_eval_f(
_n: Index,
x: *const Number,
_new_x: Bool,
obj_value: *mut Number,
_user_data: *mut c_void,
) -> Bool {
let v = *x.offset(0);
*obj_value = (v - 2.0) * (v - 2.0);
TRUE
}
unsafe extern "C" fn quad_eval_grad_f(
_n: Index,
x: *const Number,
_new_x: Bool,
grad: *mut Number,
_user_data: *mut c_void,
) -> Bool {
let v = *x.offset(0);
*grad.offset(0) = 2.0 * (v - 2.0);
TRUE
}
unsafe extern "C" fn quad_eval_h(
_n: Index,
_x: *const Number,
_new_x: Bool,
obj_factor: Number,
_m: Index,
_lambda: *const Number,
_new_lambda: Bool,
_nele_hess: Index,
irow: *mut Index,
jcol: *mut Index,
values: *mut Number,
_user_data: *mut c_void,
) -> Bool {
if !irow.is_null() && !jcol.is_null() && values.is_null() {
*irow.offset(0) = 0;
*jcol.offset(0) = 0;
} else if irow.is_null() && jcol.is_null() && !values.is_null() {
*values.offset(0) = 2.0 * obj_factor;
} else {
return FALSE;
}
TRUE
}
fn create_quad() -> IpoptProblem {
let xl = [-1.0e20];
let xu = [1.0e20];
unsafe {
CreateIpoptProblem(
1,
xl.as_ptr(),
xu.as_ptr(),
0,
std::ptr::null(),
std::ptr::null(),
0,
1,
0,
Some(quad_eval_f),
None,
Some(quad_eval_grad_f),
None,
Some(quad_eval_h),
)
}
}
#[test]
fn options_survive_repeated_session_solves() {
let mut prob = create_quad();
let key = CString::new("max_iter").unwrap();
assert_eq!(unsafe { AddIpoptIntOption(prob, key.as_ptr(), 7) }, TRUE);
let solver = unsafe { IpoptCreateSolver(&mut prob) };
assert!(!solver.is_null());
assert!(prob.is_null(), "create must null the caller's handle");
let read_max_iter = |solver: IpoptSolver| -> Option<i32> {
let info = unsafe { &*solver };
match info.problem.app.options().get_integer_value("max_iter", "") {
Ok((v, true)) => Some(v),
_ => None,
}
};
assert_eq!(read_max_iter(solver), Some(7), "option set pre-solve");
let mut x = [0.0_f64];
let mut obj = 0.0_f64;
let solve = |solver: IpoptSolver, x: &mut [f64], obj: &mut f64| unsafe {
IpoptSolverSolve(
solver,
x.as_mut_ptr(),
std::ptr::null_mut(),
obj as *mut f64,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
};
let _ = solve(solver, &mut x, &mut obj);
assert_eq!(
read_max_iter(solver),
Some(7),
"max_iter must survive the first session solve (H13)"
);
let _ = solve(solver, &mut x, &mut obj);
assert_eq!(
read_max_iter(solver),
Some(7),
"max_iter must survive a second session solve (H13)"
);
unsafe { IpoptFreeSolver(solver) };
unsafe { FreeIpoptProblem(prob) };
}
#[test]
fn zero_pins_with_null_pointers_is_not_ub() {
let mut prob = create_quad();
let solver = unsafe { IpoptCreateSolver(&mut prob) };
assert!(!solver.is_null());
let mut x = [0.0_f64];
let mut obj = 0.0_f64;
let status = unsafe {
IpoptSolverSolve(
solver,
x.as_mut_ptr(),
std::ptr::null_mut(),
&mut obj as *mut f64,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
};
assert_eq!(status, ApplicationReturnStatus::SolveSucceeded as Index);
let mut dx_out = [0.0_f64];
let mut hr_out = [0.0_f64];
let step = unsafe {
IpoptSolverParametricStep(
solver,
0,
std::ptr::null(),
std::ptr::null(),
dx_out.as_mut_ptr(),
)
};
assert_eq!(step, TRUE, "empty parametric step is a defined no-op");
let rh = unsafe {
IpoptSolverReducedHessian(solver, 0, std::ptr::null(), 1.0, hr_out.as_mut_ptr())
};
assert_eq!(rh, TRUE, "empty reduced Hessian is a defined no-op");
unsafe { IpoptFreeSolver(solver) };
unsafe { FreeIpoptProblem(prob) };
}
#[test]
fn stale_session_state_cleared_when_resolve_bails() {
let mut prob = create_quad();
let solver = unsafe { IpoptCreateSolver(&mut prob) };
assert!(!solver.is_null());
let mut x = [0.0_f64];
let mut obj = 0.0_f64;
let solve = |solver: IpoptSolver, x: &mut [f64], obj: &mut f64| unsafe {
IpoptSolverSolve(
solver,
x.as_mut_ptr(),
std::ptr::null_mut(),
obj as *mut f64,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
};
let rc = solve(solver, &mut x, &mut obj);
assert_eq!(rc, ApplicationReturnStatus::SolveSucceeded as Index);
{
let info = unsafe { &*solver };
assert!(
info.session.is_some(),
"converged solve should hold a session factor"
);
assert!(
info.problem.last_solve.is_some(),
"converged solve should record stats"
);
}
assert!(
unsafe { IpoptSolverGetKktDim(solver) } >= 0,
"a held factor reports a non-negative KKT dim"
);
unsafe { (*solver).m = -1 };
let mut x2 = [0.0_f64];
let mut obj2 = 0.0_f64;
let rc2 = solve(solver, &mut x2, &mut obj2);
assert_eq!(
rc2,
ApplicationReturnStatus::InvalidProblemDefinition as Index
);
{
let info = unsafe { &*solver };
assert!(
info.session.is_none(),
"bailed solve must drop the stale session factor (F5)"
);
assert!(
info.problem.last_solve.is_none(),
"bailed solve must clear stale stats (F5)"
);
}
assert_eq!(
unsafe { IpoptSolverGetKktDim(solver) },
-1,
"no factor is held after a bailed re-solve (F5)"
);
unsafe { IpoptFreeSolver(solver) };
unsafe { FreeIpoptProblem(prob) };
}
}