1use pounce_algorithm::application::{
24 default_backend_factory, feral_config_from_options, IpoptApplication,
25};
26use pounce_nlp::return_codes::ApplicationReturnStatus;
27use pounce_nlp::tnlp::TNLP;
28use pounce_restoration::resto_alg_builder::RestoAlgorithmBuilder;
29use pounce_restoration::resto_inner_solver::{
30 make_default_restoration_factory_provider, InnerBackendFactoryFactory,
31};
32use pounce_sensitivity::Solver as RustSolver;
33use std::cell::RefCell;
34use std::ffi::c_void;
35use std::rc::Rc;
36
37use crate::{
38 Bool, CCallbackTnlp, Index, IpoptProblem, IpoptProblemInfo, LastSolve, Number, FALSE, TRUE,
39};
40
41pub struct IpoptSolverInfo {
43 session: Option<RustSolver>,
46 problem: IpoptProblemInfo,
54 m: Index,
56}
57
58pub type IpoptSolver = *mut IpoptSolverInfo;
61
62#[no_mangle]
77pub unsafe extern "C" fn IpoptCreateSolver(prob_handle: *mut IpoptProblem) -> IpoptSolver {
78 if prob_handle.is_null() {
79 return std::ptr::null_mut();
80 }
81 let prob = *prob_handle;
82 if prob.is_null() {
83 return std::ptr::null_mut();
84 }
85 let problem = *Box::from_raw(prob);
87 *prob_handle = std::ptr::null_mut();
88 let m = problem.m;
89 let info = Box::new(IpoptSolverInfo {
90 session: None,
91 problem,
92 m,
93 });
94 Box::into_raw(info)
95}
96
97#[no_mangle]
105pub unsafe extern "C" fn IpoptFreeSolver(solver: IpoptSolver) {
106 if solver.is_null() {
107 return;
108 }
109 drop(Box::from_raw(solver));
110}
111
112#[no_mangle]
129#[allow(clippy::too_many_arguments)]
130pub unsafe extern "C" fn IpoptSolverSolve(
131 solver: IpoptSolver,
132 x: *mut Number,
133 g: *mut Number,
134 obj_val: *mut Number,
135 mult_g: *mut Number,
136 mult_x_L: *mut Number,
137 mult_x_U: *mut Number,
138 user_data: *mut c_void,
139) -> Index {
140 if solver.is_null() {
141 return ApplicationReturnStatus::InternalError as Index;
142 }
143 let info = &mut *solver;
144 let n = info.problem.n;
145 let m = info.m;
146 if n < 0 || m < 0 {
147 return ApplicationReturnStatus::InvalidProblemDefinition as Index;
148 }
149 if n > 0 && x.is_null() {
150 return ApplicationReturnStatus::InvalidProblemDefinition as Index;
151 }
152 let n_us = n as usize;
153 let m_us = m as usize;
154 let initial_x = if n_us > 0 {
155 std::slice::from_raw_parts(x, n_us).to_vec()
156 } else {
157 Vec::new()
158 };
159
160 let bridge = Rc::new(RefCell::new(CCallbackTnlp {
161 n,
162 m,
163 nele_jac: info.problem.nele_jac,
164 nele_hess: info.problem.nele_hess,
165 index_style: info.problem.index_style,
166 x_l: info.problem.x_l.clone(),
167 x_u: info.problem.x_u.clone(),
168 g_l: info.problem.g_l.clone(),
169 g_u: info.problem.g_u.clone(),
170 initial_x,
171 eval_f: info.problem.eval_f,
172 eval_grad_f: info.problem.eval_grad_f,
173 eval_g: info.problem.eval_g,
174 eval_jac_g: info.problem.eval_jac_g,
175 eval_h: info.problem.eval_h,
176 user_data,
177 intermediate_cb: info.problem.intermediate_cb,
178 user_scaling: info.problem.user_scaling.clone(),
179 final_status: None,
180 final_x: vec![0.0; n_us],
181 final_z_l: vec![0.0; n_us],
182 final_z_u: vec![0.0; n_us],
183 final_g: vec![0.0; m_us],
184 final_lambda: vec![0.0; m_us],
185 final_obj: 0.0,
186 }));
187
188 let feral_cfg = feral_config_from_options(info.problem.app.options());
192 let bff_mint = move || -> InnerBackendFactoryFactory {
193 Box::new(move || default_backend_factory(feral_cfg))
194 };
195 let resto_provider = make_default_restoration_factory_provider(
196 RestoAlgorithmBuilder::new(),
197 info.problem.app.algorithm_builder_from_options(),
198 bff_mint,
199 );
200 info.problem
201 .app
202 .set_restoration_factory_provider(resto_provider);
203
204 let app = std::mem::replace(&mut info.problem.app, IpoptApplication::new());
206 let bridge_for_solver: Rc<RefCell<dyn TNLP>> = bridge.clone();
207 let mut rust_solver = RustSolver::new(app, bridge_for_solver);
208 let status = rust_solver.solve();
209 let bridge_ref = bridge.borrow();
210 info.problem.last_solve = Some(LastSolve {
211 stats: rust_solver.app().statistics(),
212 status,
213 linear_solver: rust_solver.app().linear_solver_summary(),
214 final_x: bridge_ref.final_x.clone(),
215 final_lambda: bridge_ref.final_lambda.clone(),
216 final_obj: bridge_ref.final_obj,
217 });
218 if !x.is_null() && n_us > 0 {
219 std::ptr::copy_nonoverlapping(bridge_ref.final_x.as_ptr(), x, n_us);
220 }
221 if !g.is_null() && m_us > 0 {
222 std::ptr::copy_nonoverlapping(bridge_ref.final_g.as_ptr(), g, m_us);
223 }
224 if !obj_val.is_null() {
225 *obj_val = bridge_ref.final_obj;
226 }
227 if !mult_g.is_null() && m_us > 0 {
228 std::ptr::copy_nonoverlapping(bridge_ref.final_lambda.as_ptr(), mult_g, m_us);
229 }
230 if !mult_x_L.is_null() && n_us > 0 {
231 std::ptr::copy_nonoverlapping(bridge_ref.final_z_l.as_ptr(), mult_x_L, n_us);
232 }
233 if !mult_x_U.is_null() && n_us > 0 {
234 std::ptr::copy_nonoverlapping(bridge_ref.final_z_u.as_ptr(), mult_x_U, n_us);
235 }
236
237 info.session = Some(rust_solver);
238 status as Index
239}
240
241#[no_mangle]
248pub unsafe extern "C" fn IpoptSolverGetKktDim(solver: IpoptSolver) -> Index {
249 if solver.is_null() {
250 return -1;
251 }
252 let info = &*solver;
253 match info.session.as_ref().and_then(|s| s.kkt_dim()) {
254 Some(d) => d as Index,
255 None => -1,
256 }
257}
258
259#[no_mangle]
271pub unsafe extern "C" fn IpoptSolverKktSolve(
272 solver: IpoptSolver,
273 rhs: *const Number,
274 lhs: *mut Number,
275) -> Bool {
276 if solver.is_null() || rhs.is_null() || lhs.is_null() {
277 return FALSE;
278 }
279 let info = &*solver;
280 let Some(s) = info.session.as_ref() else {
281 return FALSE;
282 };
283 let Some(dim) = s.kkt_dim() else {
284 return FALSE;
285 };
286 let rhs_slice = std::slice::from_raw_parts(rhs, dim);
287 let mut lhs_vec = vec![0.0; dim];
288 if s.kkt_solve(rhs_slice, &mut lhs_vec).is_err() {
289 return FALSE;
290 }
291 std::ptr::copy_nonoverlapping(lhs_vec.as_ptr(), lhs, dim);
292 TRUE
293}
294
295#[no_mangle]
309pub unsafe extern "C" fn IpoptSolverParametricStep(
310 solver: IpoptSolver,
311 n_pins: Index,
312 pin_indices: *const Index,
313 deltas: *const Number,
314 dx_out: *mut Number,
315) -> Bool {
316 if solver.is_null() || n_pins < 0 {
317 return FALSE;
318 }
319 if n_pins > 0 && (pin_indices.is_null() || deltas.is_null()) {
320 return FALSE;
321 }
322 if dx_out.is_null() {
323 return FALSE;
324 }
325 let info = &*solver;
326 let Some(s) = info.session.as_ref() else {
327 return FALSE;
328 };
329 let m = info.m;
330 let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
331 let mut pins = Vec::with_capacity(n_pins as usize);
332 for &i in pins_raw {
333 if i < 0 || i >= m {
334 return FALSE;
335 }
336 pins.push(i as pounce_common::types::Index);
337 }
338 let deltas_slice = std::slice::from_raw_parts(deltas, n_pins as usize);
339 let Ok(dx) = s.parametric_step(&pins, deltas_slice) else {
340 return FALSE;
341 };
342 std::ptr::copy_nonoverlapping(dx.as_ptr(), dx_out, dx.len());
343 TRUE
344}
345
346#[no_mangle]
356pub unsafe extern "C" fn IpoptSolverReducedHessian(
357 solver: IpoptSolver,
358 n_pins: Index,
359 pin_indices: *const Index,
360 obj_scal: Number,
361 hr_out: *mut Number,
362) -> Bool {
363 if solver.is_null() || n_pins < 0 || hr_out.is_null() {
364 return FALSE;
365 }
366 if n_pins > 0 && pin_indices.is_null() {
367 return FALSE;
368 }
369 let info = &*solver;
370 let Some(s) = info.session.as_ref() else {
371 return FALSE;
372 };
373 let m = info.m;
374 let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
375 let mut pins = Vec::with_capacity(n_pins as usize);
376 for &i in pins_raw {
377 if i < 0 || i >= m {
378 return FALSE;
379 }
380 pins.push(i as pounce_common::types::Index);
381 }
382 let Ok(hr) = s.compute_reduced_hessian(&pins, obj_scal) else {
383 return FALSE;
384 };
385 std::ptr::copy_nonoverlapping(hr.as_ptr(), hr_out, hr.len());
386 TRUE
387}