Skip to main content

pounce_cinterface/
solver.rs

1//! Session-style C ABI built on [`pounce_sensitivity::Solver`].
2//!
3//! Adds an opaque [`IpoptSolver`] handle that captures the converged
4//! KKT factor between calls, so C consumers can issue many cheap
5//! operations (KKT back-solves, parametric steps, reduced Hessians)
6//! against the same factorization without re-running the IPM.
7//!
8//! ```c
9//! IpoptProblem prob = CreateIpoptProblem(...);
10//! AddIpoptStrOption(prob, "linear_solver", "feral");
11//! IpoptSolver sol = IpoptCreateSolver(&prob);   // consumes prob
12//! IpoptSolverSolve(sol, x, NULL, NULL, NULL, NULL, NULL, user_data);
13//! IpoptSolverParametricStep(sol, 2, pin_indices, deltas, dx_out);
14//! IpoptSolverReducedHessian(sol, 2, pin_indices, 1.0, hr_out);
15//! IpoptFreeSolver(sol);
16//! ```
17//!
18//! Ownership: [`IpoptCreateSolver`] takes the IpoptProblem by **pointer
19//! to the handle** and nulls it out on success — the IpoptSolver
20//! becomes the sole owner. Calling [`crate::FreeIpoptProblem`] on the
21//! now-null handle is safe (it null-checks).
22
23use pounce_algorithm::alg_builder::AlgorithmBuilder;
24use pounce_algorithm::application::{
25    default_backend_factory, feral_config_from_options, IpoptApplication,
26};
27use pounce_nlp::return_codes::ApplicationReturnStatus;
28use pounce_nlp::tnlp::TNLP;
29use pounce_restoration::resto_alg_builder::RestoAlgorithmBuilder;
30use pounce_restoration::resto_inner_solver::{
31    make_default_restoration_factory, InnerBackendFactoryFactory,
32};
33use pounce_sensitivity::Solver as RustSolver;
34use std::cell::RefCell;
35use std::ffi::c_void;
36use std::rc::Rc;
37
38use crate::{
39    Bool, CCallbackTnlp, Index, IpoptProblem, IpoptProblemInfo, LastSolve, Number, FALSE, TRUE,
40};
41
42/// Internal owned state for the session-style C handle.
43pub struct IpoptSolverInfo {
44    /// The session. `None` before the first solve or after a solve
45    /// that didn't converge.
46    session: Option<RustSolver>,
47    /// All the problem state: callbacks, dims, bounds, options. The
48    /// IpoptApplication inside is moved out into the `session` on each
49    /// successful solve, then restored on the next solve via the
50    /// `app_template` field below.
51    ///
52    /// (Stored as `Option` so `IpoptSolverSolve` can `.take()` the app
53    /// to move into the session, then put it back on next call.)
54    problem: IpoptProblemInfo,
55    /// Number of constraints — cached for cheap shape checks.
56    m: Index,
57}
58
59/// Opaque session-style handle. Construction via
60/// [`IpoptCreateSolver`]; release via [`IpoptFreeSolver`].
61pub type IpoptSolver = *mut IpoptSolverInfo;
62
63/// Build an [`IpoptSolver`] session from a configured
64/// [`IpoptProblem`]. **Consumes the IpoptProblem** on success: the
65/// pointer at `*prob_handle` is set to NULL and ownership transfers
66/// to the returned IpoptSolver. The user should not use the original
67/// handle again, though calling [`crate::FreeIpoptProblem`] on the
68/// now-null pointer is harmless (it null-checks).
69///
70/// Returns NULL if `prob_handle` is NULL, `*prob_handle` is NULL, or
71/// the IpoptProblem hasn't been fully initialized.
72///
73/// # Safety
74///
75/// `prob_handle` must be a valid pointer to an [`IpoptProblem`]
76/// previously returned by [`crate::CreateIpoptProblem`] (or NULL).
77#[no_mangle]
78pub unsafe extern "C" fn IpoptCreateSolver(prob_handle: *mut IpoptProblem) -> IpoptSolver {
79    if prob_handle.is_null() {
80        return std::ptr::null_mut();
81    }
82    let prob = *prob_handle;
83    if prob.is_null() {
84        return std::ptr::null_mut();
85    }
86    // Take ownership of the Box and null out the caller's handle.
87    let problem = *Box::from_raw(prob);
88    *prob_handle = std::ptr::null_mut();
89    let m = problem.m;
90    let info = Box::new(IpoptSolverInfo {
91        session: None,
92        problem,
93        m,
94    });
95    Box::into_raw(info)
96}
97
98/// Release an [`IpoptSolver`] and all owned resources, including the
99/// IpoptProblem state that was consumed by [`IpoptCreateSolver`].
100///
101/// # Safety
102///
103/// `solver` must be a pointer returned by [`IpoptCreateSolver`] and
104/// not yet freed, or NULL.
105#[no_mangle]
106pub unsafe extern "C" fn IpoptFreeSolver(solver: IpoptSolver) {
107    if solver.is_null() {
108        return;
109    }
110    drop(Box::from_raw(solver));
111}
112
113/// Run the IPM. Same output buffer contract as [`crate::IpoptSolve`]:
114/// `x` is in/out (initial guess in, solution out); `g`, `obj_val`,
115/// `mult_g`, `mult_x_L`, `mult_x_U` are out-only and may be NULL.
116/// `user_data` is threaded into the C callbacks unchanged.
117///
118/// Returns the same `Index`-cast [`ApplicationReturnStatus`] code as
119/// [`crate::IpoptSolve`]. On a converged status the session retains
120/// the KKT factor for subsequent [`IpoptSolverKktSolve`],
121/// [`IpoptSolverParametricStep`], and [`IpoptSolverReducedHessian`]
122/// calls.
123///
124/// # Safety
125///
126/// All non-NULL output pointers must be valid for the appropriate
127/// length; the C callbacks stored on the underlying IpoptProblem must
128/// remain valid through the solve.
129#[no_mangle]
130#[allow(clippy::too_many_arguments)]
131pub unsafe extern "C" fn IpoptSolverSolve(
132    solver: IpoptSolver,
133    x: *mut Number,
134    g: *mut Number,
135    obj_val: *mut Number,
136    mult_g: *mut Number,
137    mult_x_L: *mut Number,
138    mult_x_U: *mut Number,
139    user_data: *mut c_void,
140) -> Index {
141    if solver.is_null() {
142        return ApplicationReturnStatus::InternalError as Index;
143    }
144    let info = &mut *solver;
145    let n = info.problem.n;
146    let m = info.m;
147    if n < 0 || m < 0 {
148        return ApplicationReturnStatus::InvalidProblemDefinition as Index;
149    }
150    if n > 0 && x.is_null() {
151        return ApplicationReturnStatus::InvalidProblemDefinition as Index;
152    }
153    let n_us = n as usize;
154    let m_us = m as usize;
155    let initial_x = if n_us > 0 {
156        std::slice::from_raw_parts(x, n_us).to_vec()
157    } else {
158        Vec::new()
159    };
160
161    let bridge = Rc::new(RefCell::new(CCallbackTnlp {
162        n,
163        m,
164        nele_jac: info.problem.nele_jac,
165        nele_hess: info.problem.nele_hess,
166        index_style: info.problem.index_style,
167        x_l: info.problem.x_l.clone(),
168        x_u: info.problem.x_u.clone(),
169        g_l: info.problem.g_l.clone(),
170        g_u: info.problem.g_u.clone(),
171        initial_x,
172        eval_f: info.problem.eval_f,
173        eval_grad_f: info.problem.eval_grad_f,
174        eval_g: info.problem.eval_g,
175        eval_jac_g: info.problem.eval_jac_g,
176        eval_h: info.problem.eval_h,
177        user_data,
178        intermediate_cb: info.problem.intermediate_cb,
179        user_scaling: info.problem.user_scaling.clone(),
180        final_status: None,
181        final_x: vec![0.0; n_us],
182        final_z_l: vec![0.0; n_us],
183        final_z_u: vec![0.0; n_us],
184        final_g: vec![0.0; m_us],
185        final_lambda: vec![0.0; m_us],
186        final_obj: 0.0,
187    }));
188
189    // Re-wire restoration fresh for this solve (same pattern as
190    // IpoptSolve).
191    let feral_cfg = feral_config_from_options(info.problem.app.options());
192    let bff: InnerBackendFactoryFactory = Box::new(move || default_backend_factory(feral_cfg));
193    let resto_factory = make_default_restoration_factory(
194        RestoAlgorithmBuilder::new(),
195        AlgorithmBuilder::new(),
196        bff,
197    );
198    info.problem.app.set_restoration_factory(resto_factory);
199
200    // Move the app out of the problem and into a fresh RustSolver.
201    let app = std::mem::replace(&mut info.problem.app, IpoptApplication::new());
202    let bridge_for_solver: Rc<RefCell<dyn TNLP>> = bridge.clone();
203    let mut rust_solver = RustSolver::new(app, bridge_for_solver);
204    let status = rust_solver.solve();
205    info.problem.last_solve = Some(LastSolve {
206        stats: rust_solver.app().statistics(),
207    });
208
209    let bridge_ref = bridge.borrow();
210    if !x.is_null() && n_us > 0 {
211        std::ptr::copy_nonoverlapping(bridge_ref.final_x.as_ptr(), x, n_us);
212    }
213    if !g.is_null() && m_us > 0 {
214        std::ptr::copy_nonoverlapping(bridge_ref.final_g.as_ptr(), g, m_us);
215    }
216    if !obj_val.is_null() {
217        *obj_val = bridge_ref.final_obj;
218    }
219    if !mult_g.is_null() && m_us > 0 {
220        std::ptr::copy_nonoverlapping(bridge_ref.final_lambda.as_ptr(), mult_g, m_us);
221    }
222    if !mult_x_L.is_null() && n_us > 0 {
223        std::ptr::copy_nonoverlapping(bridge_ref.final_z_l.as_ptr(), mult_x_L, n_us);
224    }
225    if !mult_x_U.is_null() && n_us > 0 {
226        std::ptr::copy_nonoverlapping(bridge_ref.final_z_u.as_ptr(), mult_x_U, n_us);
227    }
228
229    info.session = Some(rust_solver);
230    status as Index
231}
232
233/// Total compound-KKT vector dimension. Returns -1 if no converged
234/// factor is held.
235///
236/// # Safety
237///
238/// `solver` must be a valid [`IpoptSolver`] or NULL.
239#[no_mangle]
240pub unsafe extern "C" fn IpoptSolverGetKktDim(solver: IpoptSolver) -> Index {
241    if solver.is_null() {
242        return -1;
243    }
244    let info = &*solver;
245    match info.session.as_ref().and_then(|s| s.kkt_dim()) {
246        Some(d) => d as Index,
247        None => -1,
248    }
249}
250
251/// Solve `K · lhs = rhs` against the converged KKT factor. Both
252/// `rhs` and `lhs` are flat buffers of length [`IpoptSolverGetKktDim`]
253/// in the `x || s || y_c || y_d || z_l || z_u || v_l || v_u` packing.
254///
255/// Returns `TRUE` on success, `FALSE` if no factor is held or the
256/// back-solve fails.
257///
258/// # Safety
259///
260/// `rhs` and `lhs` must point to buffers at least
261/// [`IpoptSolverGetKktDim`] doubles long.
262#[no_mangle]
263pub unsafe extern "C" fn IpoptSolverKktSolve(
264    solver: IpoptSolver,
265    rhs: *const Number,
266    lhs: *mut Number,
267) -> Bool {
268    if solver.is_null() || rhs.is_null() || lhs.is_null() {
269        return FALSE;
270    }
271    let info = &*solver;
272    let Some(s) = info.session.as_ref() else {
273        return FALSE;
274    };
275    let Some(dim) = s.kkt_dim() else {
276        return FALSE;
277    };
278    let rhs_slice = std::slice::from_raw_parts(rhs, dim);
279    let mut lhs_vec = vec![0.0; dim];
280    if s.kkt_solve(rhs_slice, &mut lhs_vec).is_err() {
281        return FALSE;
282    }
283    std::ptr::copy_nonoverlapping(lhs_vec.as_ptr(), lhs, dim);
284    TRUE
285}
286
287/// First-order parametric step `Δx ≈ ∂x*/∂p · Δp`. `pin_indices` is
288/// `n_pins` `Index` values (0-based indices into `g(x)`); `deltas` is
289/// the parameter perturbation `Δp` of the same length; `dx_out` is the
290/// `n`-long primal step output (length matches the problem's `n`).
291///
292/// Returns `TRUE` on success, `FALSE` if no converged factor, invalid
293/// indices, or the sensitivity computation fails.
294///
295/// # Safety
296///
297/// `pin_indices` and `deltas` must point to `n_pins` valid elements;
298/// `dx_out` must point to at least `n` `Number` slots (`n` from the
299/// underlying IpoptProblem).
300#[no_mangle]
301pub unsafe extern "C" fn IpoptSolverParametricStep(
302    solver: IpoptSolver,
303    n_pins: Index,
304    pin_indices: *const Index,
305    deltas: *const Number,
306    dx_out: *mut Number,
307) -> Bool {
308    if solver.is_null() || n_pins < 0 {
309        return FALSE;
310    }
311    if n_pins > 0 && (pin_indices.is_null() || deltas.is_null()) {
312        return FALSE;
313    }
314    if dx_out.is_null() {
315        return FALSE;
316    }
317    let info = &*solver;
318    let Some(s) = info.session.as_ref() else {
319        return FALSE;
320    };
321    let m = info.m;
322    let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
323    let mut pins = Vec::with_capacity(n_pins as usize);
324    for &i in pins_raw {
325        if i < 0 || i >= m {
326            return FALSE;
327        }
328        pins.push(i as pounce_common::types::Index);
329    }
330    let deltas_slice = std::slice::from_raw_parts(deltas, n_pins as usize);
331    let Ok(dx) = s.parametric_step(&pins, deltas_slice) else {
332        return FALSE;
333    };
334    std::ptr::copy_nonoverlapping(dx.as_ptr(), dx_out, dx.len());
335    TRUE
336}
337
338/// Reduced Hessian `H_R = obj_scal · B K⁻¹ Bᵀ` over the pinned rows.
339/// `hr_out` receives an `n_pins²`-long column-major dense matrix.
340///
341/// Returns `TRUE` on success, `FALSE` otherwise.
342///
343/// # Safety
344///
345/// `pin_indices` must point to `n_pins` valid elements; `hr_out` must
346/// point to at least `n_pins²` `Number` slots.
347#[no_mangle]
348pub unsafe extern "C" fn IpoptSolverReducedHessian(
349    solver: IpoptSolver,
350    n_pins: Index,
351    pin_indices: *const Index,
352    obj_scal: Number,
353    hr_out: *mut Number,
354) -> Bool {
355    if solver.is_null() || n_pins < 0 || hr_out.is_null() {
356        return FALSE;
357    }
358    if n_pins > 0 && pin_indices.is_null() {
359        return FALSE;
360    }
361    let info = &*solver;
362    let Some(s) = info.session.as_ref() else {
363        return FALSE;
364    };
365    let m = info.m;
366    let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
367    let mut pins = Vec::with_capacity(n_pins as usize);
368    for &i in pins_raw {
369        if i < 0 || i >= m {
370            return FALSE;
371        }
372        pins.push(i as pounce_common::types::Index);
373    }
374    let Ok(hr) = s.compute_reduced_hessian(&pins, obj_scal) else {
375        return FALSE;
376    };
377    std::ptr::copy_nonoverlapping(hr.as_ptr(), hr_out, hr.len());
378    TRUE
379}