1use pounce_algorithm::alg_builder::AlgorithmBuilder;
24use pounce_algorithm::application::{
25 default_backend_factory, feral_config_from_options, IpoptApplication,
26};
27use pounce_nlp::return_codes::ApplicationReturnStatus;
28use pounce_nlp::tnlp::TNLP;
29use pounce_restoration::resto_alg_builder::RestoAlgorithmBuilder;
30use pounce_restoration::resto_inner_solver::{
31 make_default_restoration_factory, InnerBackendFactoryFactory,
32};
33use pounce_sensitivity::Solver as RustSolver;
34use std::cell::RefCell;
35use std::ffi::c_void;
36use std::rc::Rc;
37
38use crate::{
39 Bool, CCallbackTnlp, Index, IpoptProblem, IpoptProblemInfo, LastSolve, Number, FALSE, TRUE,
40};
41
42pub struct IpoptSolverInfo {
44 session: Option<RustSolver>,
47 problem: IpoptProblemInfo,
55 m: Index,
57}
58
59pub type IpoptSolver = *mut IpoptSolverInfo;
62
63#[no_mangle]
78pub unsafe extern "C" fn IpoptCreateSolver(prob_handle: *mut IpoptProblem) -> IpoptSolver {
79 if prob_handle.is_null() {
80 return std::ptr::null_mut();
81 }
82 let prob = *prob_handle;
83 if prob.is_null() {
84 return std::ptr::null_mut();
85 }
86 let problem = *Box::from_raw(prob);
88 *prob_handle = std::ptr::null_mut();
89 let m = problem.m;
90 let info = Box::new(IpoptSolverInfo {
91 session: None,
92 problem,
93 m,
94 });
95 Box::into_raw(info)
96}
97
98#[no_mangle]
106pub unsafe extern "C" fn IpoptFreeSolver(solver: IpoptSolver) {
107 if solver.is_null() {
108 return;
109 }
110 drop(Box::from_raw(solver));
111}
112
113#[no_mangle]
130#[allow(clippy::too_many_arguments)]
131pub unsafe extern "C" fn IpoptSolverSolve(
132 solver: IpoptSolver,
133 x: *mut Number,
134 g: *mut Number,
135 obj_val: *mut Number,
136 mult_g: *mut Number,
137 mult_x_L: *mut Number,
138 mult_x_U: *mut Number,
139 user_data: *mut c_void,
140) -> Index {
141 if solver.is_null() {
142 return ApplicationReturnStatus::InternalError as Index;
143 }
144 let info = &mut *solver;
145 let n = info.problem.n;
146 let m = info.m;
147 if n < 0 || m < 0 {
148 return ApplicationReturnStatus::InvalidProblemDefinition as Index;
149 }
150 if n > 0 && x.is_null() {
151 return ApplicationReturnStatus::InvalidProblemDefinition as Index;
152 }
153 let n_us = n as usize;
154 let m_us = m as usize;
155 let initial_x = if n_us > 0 {
156 std::slice::from_raw_parts(x, n_us).to_vec()
157 } else {
158 Vec::new()
159 };
160
161 let bridge = Rc::new(RefCell::new(CCallbackTnlp {
162 n,
163 m,
164 nele_jac: info.problem.nele_jac,
165 nele_hess: info.problem.nele_hess,
166 index_style: info.problem.index_style,
167 x_l: info.problem.x_l.clone(),
168 x_u: info.problem.x_u.clone(),
169 g_l: info.problem.g_l.clone(),
170 g_u: info.problem.g_u.clone(),
171 initial_x,
172 eval_f: info.problem.eval_f,
173 eval_grad_f: info.problem.eval_grad_f,
174 eval_g: info.problem.eval_g,
175 eval_jac_g: info.problem.eval_jac_g,
176 eval_h: info.problem.eval_h,
177 user_data,
178 intermediate_cb: info.problem.intermediate_cb,
179 user_scaling: info.problem.user_scaling.clone(),
180 final_status: None,
181 final_x: vec![0.0; n_us],
182 final_z_l: vec![0.0; n_us],
183 final_z_u: vec![0.0; n_us],
184 final_g: vec![0.0; m_us],
185 final_lambda: vec![0.0; m_us],
186 final_obj: 0.0,
187 }));
188
189 let feral_cfg = feral_config_from_options(info.problem.app.options());
192 let bff: InnerBackendFactoryFactory = Box::new(move || default_backend_factory(feral_cfg));
193 let resto_factory = make_default_restoration_factory(
194 RestoAlgorithmBuilder::new(),
195 AlgorithmBuilder::new(),
196 bff,
197 );
198 info.problem.app.set_restoration_factory(resto_factory);
199
200 let app = std::mem::replace(&mut info.problem.app, IpoptApplication::new());
202 let bridge_for_solver: Rc<RefCell<dyn TNLP>> = bridge.clone();
203 let mut rust_solver = RustSolver::new(app, bridge_for_solver);
204 let status = rust_solver.solve();
205 info.problem.last_solve = Some(LastSolve {
206 stats: rust_solver.app().statistics(),
207 });
208
209 let bridge_ref = bridge.borrow();
210 if !x.is_null() && n_us > 0 {
211 std::ptr::copy_nonoverlapping(bridge_ref.final_x.as_ptr(), x, n_us);
212 }
213 if !g.is_null() && m_us > 0 {
214 std::ptr::copy_nonoverlapping(bridge_ref.final_g.as_ptr(), g, m_us);
215 }
216 if !obj_val.is_null() {
217 *obj_val = bridge_ref.final_obj;
218 }
219 if !mult_g.is_null() && m_us > 0 {
220 std::ptr::copy_nonoverlapping(bridge_ref.final_lambda.as_ptr(), mult_g, m_us);
221 }
222 if !mult_x_L.is_null() && n_us > 0 {
223 std::ptr::copy_nonoverlapping(bridge_ref.final_z_l.as_ptr(), mult_x_L, n_us);
224 }
225 if !mult_x_U.is_null() && n_us > 0 {
226 std::ptr::copy_nonoverlapping(bridge_ref.final_z_u.as_ptr(), mult_x_U, n_us);
227 }
228
229 info.session = Some(rust_solver);
230 status as Index
231}
232
233#[no_mangle]
240pub unsafe extern "C" fn IpoptSolverGetKktDim(solver: IpoptSolver) -> Index {
241 if solver.is_null() {
242 return -1;
243 }
244 let info = &*solver;
245 match info.session.as_ref().and_then(|s| s.kkt_dim()) {
246 Some(d) => d as Index,
247 None => -1,
248 }
249}
250
251#[no_mangle]
263pub unsafe extern "C" fn IpoptSolverKktSolve(
264 solver: IpoptSolver,
265 rhs: *const Number,
266 lhs: *mut Number,
267) -> Bool {
268 if solver.is_null() || rhs.is_null() || lhs.is_null() {
269 return FALSE;
270 }
271 let info = &*solver;
272 let Some(s) = info.session.as_ref() else {
273 return FALSE;
274 };
275 let Some(dim) = s.kkt_dim() else {
276 return FALSE;
277 };
278 let rhs_slice = std::slice::from_raw_parts(rhs, dim);
279 let mut lhs_vec = vec![0.0; dim];
280 if s.kkt_solve(rhs_slice, &mut lhs_vec).is_err() {
281 return FALSE;
282 }
283 std::ptr::copy_nonoverlapping(lhs_vec.as_ptr(), lhs, dim);
284 TRUE
285}
286
287#[no_mangle]
301pub unsafe extern "C" fn IpoptSolverParametricStep(
302 solver: IpoptSolver,
303 n_pins: Index,
304 pin_indices: *const Index,
305 deltas: *const Number,
306 dx_out: *mut Number,
307) -> Bool {
308 if solver.is_null() || n_pins < 0 {
309 return FALSE;
310 }
311 if n_pins > 0 && (pin_indices.is_null() || deltas.is_null()) {
312 return FALSE;
313 }
314 if dx_out.is_null() {
315 return FALSE;
316 }
317 let info = &*solver;
318 let Some(s) = info.session.as_ref() else {
319 return FALSE;
320 };
321 let m = info.m;
322 let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
323 let mut pins = Vec::with_capacity(n_pins as usize);
324 for &i in pins_raw {
325 if i < 0 || i >= m {
326 return FALSE;
327 }
328 pins.push(i as pounce_common::types::Index);
329 }
330 let deltas_slice = std::slice::from_raw_parts(deltas, n_pins as usize);
331 let Ok(dx) = s.parametric_step(&pins, deltas_slice) else {
332 return FALSE;
333 };
334 std::ptr::copy_nonoverlapping(dx.as_ptr(), dx_out, dx.len());
335 TRUE
336}
337
338#[no_mangle]
348pub unsafe extern "C" fn IpoptSolverReducedHessian(
349 solver: IpoptSolver,
350 n_pins: Index,
351 pin_indices: *const Index,
352 obj_scal: Number,
353 hr_out: *mut Number,
354) -> Bool {
355 if solver.is_null() || n_pins < 0 || hr_out.is_null() {
356 return FALSE;
357 }
358 if n_pins > 0 && pin_indices.is_null() {
359 return FALSE;
360 }
361 let info = &*solver;
362 let Some(s) = info.session.as_ref() else {
363 return FALSE;
364 };
365 let m = info.m;
366 let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
367 let mut pins = Vec::with_capacity(n_pins as usize);
368 for &i in pins_raw {
369 if i < 0 || i >= m {
370 return FALSE;
371 }
372 pins.push(i as pounce_common::types::Index);
373 }
374 let Ok(hr) = s.compute_reduced_hessian(&pins, obj_scal) else {
375 return FALSE;
376 };
377 std::ptr::copy_nonoverlapping(hr.as_ptr(), hr_out, hr.len());
378 TRUE
379}