Skip to main content

pounce_cinterface/
solver.rs

1//! Session-style C ABI built on [`pounce_sensitivity::Solver`].
2//!
3//! Adds an opaque [`IpoptSolver`] handle that captures the converged
4//! KKT factor between calls, so C consumers can issue many cheap
5//! operations (KKT back-solves, parametric steps, reduced Hessians)
6//! against the same factorization without re-running the IPM.
7//!
8//! ```c
9//! IpoptProblem prob = CreateIpoptProblem(...);
10//! AddIpoptStrOption(prob, "linear_solver", "feral");
11//! IpoptSolver sol = IpoptCreateSolver(&prob);   // consumes prob
12//! IpoptSolverSolve(sol, x, NULL, NULL, NULL, NULL, NULL, user_data);
13//! IpoptSolverParametricStep(sol, 2, pin_indices, deltas, dx_out);
14//! IpoptSolverReducedHessian(sol, 2, pin_indices, 1.0, hr_out);
15//! IpoptFreeSolver(sol);
16//! ```
17//!
18//! Ownership: [`IpoptCreateSolver`] takes the IpoptProblem by **pointer
19//! to the handle** and nulls it out on success — the IpoptSolver
20//! becomes the sole owner. Calling [`crate::FreeIpoptProblem`] on the
21//! now-null handle is safe (it null-checks).
22
23use pounce_algorithm::application::{
24    default_backend_factory, feral_config_from_options, IpoptApplication,
25};
26use pounce_nlp::return_codes::ApplicationReturnStatus;
27use pounce_nlp::tnlp::TNLP;
28use pounce_restoration::resto_alg_builder::RestoAlgorithmBuilder;
29use pounce_restoration::resto_inner_solver::{
30    make_default_restoration_factory_provider, InnerBackendFactoryFactory,
31};
32use pounce_sensitivity::Solver as RustSolver;
33use std::cell::RefCell;
34use std::ffi::c_void;
35use std::rc::Rc;
36
37use crate::{
38    Bool, CCallbackTnlp, Index, IpoptProblem, IpoptProblemInfo, LastSolve, Number, FALSE, TRUE,
39};
40
41/// Internal owned state for the session-style C handle.
42pub struct IpoptSolverInfo {
43    /// The session. `None` before the first solve or after a solve
44    /// that didn't converge.
45    session: Option<RustSolver>,
46    /// All the problem state: callbacks, dims, bounds, options. The
47    /// IpoptApplication inside is moved out into the `session` on each
48    /// successful solve, then restored on the next solve via the
49    /// `app_template` field below.
50    ///
51    /// (Stored as `Option` so `IpoptSolverSolve` can `.take()` the app
52    /// to move into the session, then put it back on next call.)
53    problem: IpoptProblemInfo,
54    /// Number of constraints — cached for cheap shape checks.
55    m: Index,
56}
57
58/// Opaque session-style handle. Construction via
59/// [`IpoptCreateSolver`]; release via [`IpoptFreeSolver`].
60pub type IpoptSolver = *mut IpoptSolverInfo;
61
62/// Build an [`IpoptSolver`] session from a configured
63/// [`IpoptProblem`]. **Consumes the IpoptProblem** on success: the
64/// pointer at `*prob_handle` is set to NULL and ownership transfers
65/// to the returned IpoptSolver. The user should not use the original
66/// handle again, though calling [`crate::FreeIpoptProblem`] on the
67/// now-null pointer is harmless (it null-checks).
68///
69/// Returns NULL if `prob_handle` is NULL, `*prob_handle` is NULL, or
70/// the IpoptProblem hasn't been fully initialized.
71///
72/// # Safety
73///
74/// `prob_handle` must be a valid pointer to an [`IpoptProblem`]
75/// previously returned by [`crate::CreateIpoptProblem`] (or NULL).
76#[no_mangle]
77pub unsafe extern "C" fn IpoptCreateSolver(prob_handle: *mut IpoptProblem) -> IpoptSolver {
78    if prob_handle.is_null() {
79        return std::ptr::null_mut();
80    }
81    let prob = *prob_handle;
82    if prob.is_null() {
83        return std::ptr::null_mut();
84    }
85    // Take ownership of the Box and null out the caller's handle.
86    let problem = *Box::from_raw(prob);
87    *prob_handle = std::ptr::null_mut();
88    let m = problem.m;
89    let info = Box::new(IpoptSolverInfo {
90        session: None,
91        problem,
92        m,
93    });
94    Box::into_raw(info)
95}
96
97/// Release an [`IpoptSolver`] and all owned resources, including the
98/// IpoptProblem state that was consumed by [`IpoptCreateSolver`].
99///
100/// # Safety
101///
102/// `solver` must be a pointer returned by [`IpoptCreateSolver`] and
103/// not yet freed, or NULL.
104#[no_mangle]
105pub unsafe extern "C" fn IpoptFreeSolver(solver: IpoptSolver) {
106    if solver.is_null() {
107        return;
108    }
109    drop(Box::from_raw(solver));
110}
111
112/// Run the IPM. Same output buffer contract as [`crate::IpoptSolve`]:
113/// `x` is in/out (initial guess in, solution out); `g`, `obj_val`,
114/// `mult_g`, `mult_x_L`, `mult_x_U` are out-only and may be NULL.
115/// `user_data` is threaded into the C callbacks unchanged.
116///
117/// Returns the same `Index`-cast [`ApplicationReturnStatus`] code as
118/// [`crate::IpoptSolve`]. On a converged status the session retains
119/// the KKT factor for subsequent [`IpoptSolverKktSolve`],
120/// [`IpoptSolverParametricStep`], and [`IpoptSolverReducedHessian`]
121/// calls.
122///
123/// # Safety
124///
125/// All non-NULL output pointers must be valid for the appropriate
126/// length; the C callbacks stored on the underlying IpoptProblem must
127/// remain valid through the solve.
128#[no_mangle]
129#[allow(clippy::too_many_arguments)]
130pub unsafe extern "C" fn IpoptSolverSolve(
131    solver: IpoptSolver,
132    x: *mut Number,
133    g: *mut Number,
134    obj_val: *mut Number,
135    mult_g: *mut Number,
136    mult_x_L: *mut Number,
137    mult_x_U: *mut Number,
138    user_data: *mut c_void,
139) -> Index {
140    if solver.is_null() {
141        return ApplicationReturnStatus::InternalError as Index;
142    }
143    let info = &mut *solver;
144    let n = info.problem.n;
145    let m = info.m;
146    if n < 0 || m < 0 {
147        return ApplicationReturnStatus::InvalidProblemDefinition as Index;
148    }
149    if n > 0 && x.is_null() {
150        return ApplicationReturnStatus::InvalidProblemDefinition as Index;
151    }
152    let n_us = n as usize;
153    let m_us = m as usize;
154    let initial_x = if n_us > 0 {
155        std::slice::from_raw_parts(x, n_us).to_vec()
156    } else {
157        Vec::new()
158    };
159
160    let bridge = Rc::new(RefCell::new(CCallbackTnlp {
161        n,
162        m,
163        nele_jac: info.problem.nele_jac,
164        nele_hess: info.problem.nele_hess,
165        index_style: info.problem.index_style,
166        x_l: info.problem.x_l.clone(),
167        x_u: info.problem.x_u.clone(),
168        g_l: info.problem.g_l.clone(),
169        g_u: info.problem.g_u.clone(),
170        initial_x,
171        eval_f: info.problem.eval_f,
172        eval_grad_f: info.problem.eval_grad_f,
173        eval_g: info.problem.eval_g,
174        eval_jac_g: info.problem.eval_jac_g,
175        eval_h: info.problem.eval_h,
176        user_data,
177        intermediate_cb: info.problem.intermediate_cb,
178        user_scaling: info.problem.user_scaling.clone(),
179        final_status: None,
180        final_x: vec![0.0; n_us],
181        final_z_l: vec![0.0; n_us],
182        final_z_u: vec![0.0; n_us],
183        final_g: vec![0.0; m_us],
184        final_lambda: vec![0.0; m_us],
185        final_obj: 0.0,
186    }));
187
188    // Re-wire restoration fresh for this solve (same pattern as
189    // IpoptSolve). Multi-pass provider so the ℓ₁ wrapper / auto-fallback
190    // don't panic on the second inner solve (pounce#10 / pounce#24).
191    let feral_cfg = feral_config_from_options(info.problem.app.options());
192    let bff_mint = move || -> InnerBackendFactoryFactory {
193        let feral_cfg = feral_cfg.clone();
194        Box::new(move || default_backend_factory(feral_cfg.clone()))
195    };
196    let resto_provider = make_default_restoration_factory_provider(
197        RestoAlgorithmBuilder::new(),
198        info.problem.app.algorithm_builder_from_options(),
199        bff_mint,
200    );
201    info.problem
202        .app
203        .set_restoration_factory_provider(resto_provider);
204
205    // Move the app out of the problem and into a fresh RustSolver.
206    let app = std::mem::replace(&mut info.problem.app, IpoptApplication::new());
207    let bridge_for_solver: Rc<RefCell<dyn TNLP>> = bridge.clone();
208    let mut rust_solver = RustSolver::new(app, bridge_for_solver);
209    let status = rust_solver.solve();
210    let bridge_ref = bridge.borrow();
211    info.problem.last_solve = Some(LastSolve {
212        stats: rust_solver.app().statistics(),
213        status,
214        linear_solver: rust_solver.app().linear_solver_summary(),
215        final_x: bridge_ref.final_x.clone(),
216        final_lambda: bridge_ref.final_lambda.clone(),
217        final_obj: bridge_ref.final_obj,
218    });
219    if !x.is_null() && n_us > 0 {
220        std::ptr::copy_nonoverlapping(bridge_ref.final_x.as_ptr(), x, n_us);
221    }
222    if !g.is_null() && m_us > 0 {
223        std::ptr::copy_nonoverlapping(bridge_ref.final_g.as_ptr(), g, m_us);
224    }
225    if !obj_val.is_null() {
226        *obj_val = bridge_ref.final_obj;
227    }
228    if !mult_g.is_null() && m_us > 0 {
229        std::ptr::copy_nonoverlapping(bridge_ref.final_lambda.as_ptr(), mult_g, m_us);
230    }
231    if !mult_x_L.is_null() && n_us > 0 {
232        std::ptr::copy_nonoverlapping(bridge_ref.final_z_l.as_ptr(), mult_x_L, n_us);
233    }
234    if !mult_x_U.is_null() && n_us > 0 {
235        std::ptr::copy_nonoverlapping(bridge_ref.final_z_u.as_ptr(), mult_x_U, n_us);
236    }
237
238    info.session = Some(rust_solver);
239    status as Index
240}
241
242/// Total compound-KKT vector dimension. Returns -1 if no converged
243/// factor is held.
244///
245/// # Safety
246///
247/// `solver` must be a valid [`IpoptSolver`] or NULL.
248#[no_mangle]
249pub unsafe extern "C" fn IpoptSolverGetKktDim(solver: IpoptSolver) -> Index {
250    if solver.is_null() {
251        return -1;
252    }
253    let info = &*solver;
254    match info.session.as_ref().and_then(|s| s.kkt_dim()) {
255        Some(d) => d as Index,
256        None => -1,
257    }
258}
259
260/// Solve `K · lhs = rhs` against the converged KKT factor. Both
261/// `rhs` and `lhs` are flat buffers of length [`IpoptSolverGetKktDim`]
262/// in the `x || s || y_c || y_d || z_l || z_u || v_l || v_u` packing.
263///
264/// Returns `TRUE` on success, `FALSE` if no factor is held or the
265/// back-solve fails.
266///
267/// # Safety
268///
269/// `rhs` and `lhs` must point to buffers at least
270/// [`IpoptSolverGetKktDim`] doubles long.
271#[no_mangle]
272pub unsafe extern "C" fn IpoptSolverKktSolve(
273    solver: IpoptSolver,
274    rhs: *const Number,
275    lhs: *mut Number,
276) -> Bool {
277    if solver.is_null() || rhs.is_null() || lhs.is_null() {
278        return FALSE;
279    }
280    let info = &*solver;
281    let Some(s) = info.session.as_ref() else {
282        return FALSE;
283    };
284    let Some(dim) = s.kkt_dim() else {
285        return FALSE;
286    };
287    let rhs_slice = std::slice::from_raw_parts(rhs, dim);
288    let mut lhs_vec = vec![0.0; dim];
289    if s.kkt_solve(rhs_slice, &mut lhs_vec).is_err() {
290        return FALSE;
291    }
292    std::ptr::copy_nonoverlapping(lhs_vec.as_ptr(), lhs, dim);
293    TRUE
294}
295
296/// First-order parametric step `Δx ≈ ∂x*/∂p · Δp`. `pin_indices` is
297/// `n_pins` `Index` values (0-based indices into `g(x)`); `deltas` is
298/// the parameter perturbation `Δp` of the same length; `dx_out` is the
299/// `n`-long primal step output (length matches the problem's `n`).
300///
301/// Returns `TRUE` on success, `FALSE` if no converged factor, invalid
302/// indices, or the sensitivity computation fails.
303///
304/// # Safety
305///
306/// `pin_indices` and `deltas` must point to `n_pins` valid elements;
307/// `dx_out` must point to at least `n` `Number` slots (`n` from the
308/// underlying IpoptProblem).
309#[no_mangle]
310pub unsafe extern "C" fn IpoptSolverParametricStep(
311    solver: IpoptSolver,
312    n_pins: Index,
313    pin_indices: *const Index,
314    deltas: *const Number,
315    dx_out: *mut Number,
316) -> Bool {
317    if solver.is_null() || n_pins < 0 {
318        return FALSE;
319    }
320    if n_pins > 0 && (pin_indices.is_null() || deltas.is_null()) {
321        return FALSE;
322    }
323    if dx_out.is_null() {
324        return FALSE;
325    }
326    let info = &*solver;
327    let Some(s) = info.session.as_ref() else {
328        return FALSE;
329    };
330    let m = info.m;
331    let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
332    let mut pins = Vec::with_capacity(n_pins as usize);
333    for &i in pins_raw {
334        if i < 0 || i >= m {
335            return FALSE;
336        }
337        pins.push(i as pounce_common::types::Index);
338    }
339    let deltas_slice = std::slice::from_raw_parts(deltas, n_pins as usize);
340    let Ok(dx) = s.parametric_step(&pins, deltas_slice) else {
341        return FALSE;
342    };
343    std::ptr::copy_nonoverlapping(dx.as_ptr(), dx_out, dx.len());
344    TRUE
345}
346
347/// Reduced Hessian `H_R = obj_scal · B K⁻¹ Bᵀ` over the pinned rows.
348/// `hr_out` receives an `n_pins²`-long column-major dense matrix.
349///
350/// Returns `TRUE` on success, `FALSE` otherwise.
351///
352/// # Safety
353///
354/// `pin_indices` must point to `n_pins` valid elements; `hr_out` must
355/// point to at least `n_pins²` `Number` slots.
356#[no_mangle]
357pub unsafe extern "C" fn IpoptSolverReducedHessian(
358    solver: IpoptSolver,
359    n_pins: Index,
360    pin_indices: *const Index,
361    obj_scal: Number,
362    hr_out: *mut Number,
363) -> Bool {
364    if solver.is_null() || n_pins < 0 || hr_out.is_null() {
365        return FALSE;
366    }
367    if n_pins > 0 && pin_indices.is_null() {
368        return FALSE;
369    }
370    let info = &*solver;
371    let Some(s) = info.session.as_ref() else {
372        return FALSE;
373    };
374    let m = info.m;
375    let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
376    let mut pins = Vec::with_capacity(n_pins as usize);
377    for &i in pins_raw {
378        if i < 0 || i >= m {
379            return FALSE;
380        }
381        pins.push(i as pounce_common::types::Index);
382    }
383    let Ok(hr) = s.compute_reduced_hessian(&pins, obj_scal) else {
384        return FALSE;
385    };
386    std::ptr::copy_nonoverlapping(hr.as_ptr(), hr_out, hr.len());
387    TRUE
388}