Skip to main content

pounce_cinterface/
solver.rs

1//! Session-style C ABI built on [`pounce_sensitivity::Solver`].
2//!
3//! Adds an opaque [`IpoptSolver`] handle that captures the converged
4//! KKT factor between calls, so C consumers can issue many cheap
5//! operations (KKT back-solves, parametric steps, reduced Hessians)
6//! against the same factorization without re-running the IPM.
7//!
8//! ```c
9//! IpoptProblem prob = CreateIpoptProblem(...);
10//! AddIpoptStrOption(prob, "linear_solver", "feral");
11//! IpoptSolver sol = IpoptCreateSolver(&prob);   // consumes prob
12//! IpoptSolverSolve(sol, x, NULL, NULL, NULL, NULL, NULL, user_data);
13//! IpoptSolverParametricStep(sol, 2, pin_indices, deltas, dx_out);
14//! IpoptSolverReducedHessian(sol, 2, pin_indices, 1.0, hr_out);
15//! IpoptFreeSolver(sol);
16//! ```
17//!
18//! Ownership: [`IpoptCreateSolver`] takes the IpoptProblem by **pointer
19//! to the handle** and nulls it out on success — the IpoptSolver
20//! becomes the sole owner. Calling [`crate::FreeIpoptProblem`] on the
21//! now-null handle is safe (it null-checks).
22
23use pounce_algorithm::application::{
24    default_backend_factory, feral_config_from_options, IpoptApplication,
25};
26use pounce_nlp::return_codes::ApplicationReturnStatus;
27use pounce_nlp::tnlp::TNLP;
28use pounce_restoration::resto_alg_builder::RestoAlgorithmBuilder;
29use pounce_restoration::resto_inner_solver::{
30    make_default_restoration_factory_provider, InnerBackendFactoryFactory,
31};
32use pounce_sensitivity::Solver as RustSolver;
33use std::cell::RefCell;
34use std::ffi::c_void;
35use std::rc::Rc;
36
37use crate::{
38    Bool, CCallbackTnlp, Index, IpoptProblem, IpoptProblemInfo, LastSolve, Number, FALSE, TRUE,
39};
40
41/// Internal owned state for the session-style C handle.
42pub struct IpoptSolverInfo {
43    /// The session. `None` before the first solve or after a solve
44    /// that didn't converge.
45    session: Option<RustSolver>,
46    /// All the problem state: callbacks, dims, bounds, options. The
47    /// IpoptApplication inside is moved out into the `session` on each
48    /// successful solve, then restored on the next solve via the
49    /// `app_template` field below.
50    ///
51    /// (Stored as `Option` so `IpoptSolverSolve` can `.take()` the app
52    /// to move into the session, then put it back on next call.)
53    problem: IpoptProblemInfo,
54    /// Number of constraints — cached for cheap shape checks.
55    m: Index,
56}
57
58/// Opaque session-style handle. Construction via
59/// [`IpoptCreateSolver`]; release via [`IpoptFreeSolver`].
60pub type IpoptSolver = *mut IpoptSolverInfo;
61
62/// Build an [`IpoptSolver`] session from a configured
63/// [`IpoptProblem`]. **Consumes the IpoptProblem** on success: the
64/// pointer at `*prob_handle` is set to NULL and ownership transfers
65/// to the returned IpoptSolver. The user should not use the original
66/// handle again, though calling [`crate::FreeIpoptProblem`] on the
67/// now-null pointer is harmless (it null-checks).
68///
69/// Returns NULL if `prob_handle` is NULL, `*prob_handle` is NULL, or
70/// the IpoptProblem hasn't been fully initialized.
71///
72/// # Safety
73///
74/// `prob_handle` must be a valid pointer to an [`IpoptProblem`]
75/// previously returned by [`crate::CreateIpoptProblem`] (or NULL).
76#[no_mangle]
77pub unsafe extern "C" fn IpoptCreateSolver(prob_handle: *mut IpoptProblem) -> IpoptSolver {
78    if prob_handle.is_null() {
79        return std::ptr::null_mut();
80    }
81    let prob = *prob_handle;
82    if prob.is_null() {
83        return std::ptr::null_mut();
84    }
85    // Take ownership of the Box and null out the caller's handle.
86    let problem = *Box::from_raw(prob);
87    *prob_handle = std::ptr::null_mut();
88    let m = problem.m;
89    let info = Box::new(IpoptSolverInfo {
90        session: None,
91        problem,
92        m,
93    });
94    Box::into_raw(info)
95}
96
97/// Release an [`IpoptSolver`] and all owned resources, including the
98/// IpoptProblem state that was consumed by [`IpoptCreateSolver`].
99///
100/// # Safety
101///
102/// `solver` must be a pointer returned by [`IpoptCreateSolver`] and
103/// not yet freed, or NULL.
104#[no_mangle]
105pub unsafe extern "C" fn IpoptFreeSolver(solver: IpoptSolver) {
106    if solver.is_null() {
107        return;
108    }
109    drop(Box::from_raw(solver));
110}
111
112/// Run the IPM. Same output buffer contract as [`crate::IpoptSolve`]:
113/// `x` is in/out (initial guess in, solution out); `g`, `obj_val`,
114/// `mult_g`, `mult_x_L`, `mult_x_U` are out-only and may be NULL.
115/// `user_data` is threaded into the C callbacks unchanged.
116///
117/// Returns the same `Index`-cast [`ApplicationReturnStatus`] code as
118/// [`crate::IpoptSolve`]. On a converged status the session retains
119/// the KKT factor for subsequent [`IpoptSolverKktSolve`],
120/// [`IpoptSolverParametricStep`], and [`IpoptSolverReducedHessian`]
121/// calls.
122///
123/// # Safety
124///
125/// All non-NULL output pointers must be valid for the appropriate
126/// length; the C callbacks stored on the underlying IpoptProblem must
127/// remain valid through the solve.
128#[no_mangle]
129#[allow(clippy::too_many_arguments)]
130pub unsafe extern "C" fn IpoptSolverSolve(
131    solver: IpoptSolver,
132    x: *mut Number,
133    g: *mut Number,
134    obj_val: *mut Number,
135    mult_g: *mut Number,
136    mult_x_L: *mut Number,
137    mult_x_U: *mut Number,
138    user_data: *mut c_void,
139) -> Index {
140    if solver.is_null() {
141        return ApplicationReturnStatus::InternalError as Index;
142    }
143    let info = &mut *solver;
144    let n = info.problem.n;
145    let m = info.m;
146    if n < 0 || m < 0 {
147        return ApplicationReturnStatus::InvalidProblemDefinition as Index;
148    }
149    if n > 0 && x.is_null() {
150        return ApplicationReturnStatus::InvalidProblemDefinition as Index;
151    }
152    let n_us = n as usize;
153    let m_us = m as usize;
154    let initial_x = if n_us > 0 {
155        std::slice::from_raw_parts(x, n_us).to_vec()
156    } else {
157        Vec::new()
158    };
159
160    let bridge = Rc::new(RefCell::new(CCallbackTnlp {
161        n,
162        m,
163        nele_jac: info.problem.nele_jac,
164        nele_hess: info.problem.nele_hess,
165        index_style: info.problem.index_style,
166        x_l: info.problem.x_l.clone(),
167        x_u: info.problem.x_u.clone(),
168        g_l: info.problem.g_l.clone(),
169        g_u: info.problem.g_u.clone(),
170        initial_x,
171        eval_f: info.problem.eval_f,
172        eval_grad_f: info.problem.eval_grad_f,
173        eval_g: info.problem.eval_g,
174        eval_jac_g: info.problem.eval_jac_g,
175        eval_h: info.problem.eval_h,
176        user_data,
177        intermediate_cb: info.problem.intermediate_cb,
178        user_scaling: info.problem.user_scaling.clone(),
179        final_status: None,
180        final_x: vec![0.0; n_us],
181        final_z_l: vec![0.0; n_us],
182        final_z_u: vec![0.0; n_us],
183        final_g: vec![0.0; m_us],
184        final_lambda: vec![0.0; m_us],
185        final_obj: 0.0,
186    }));
187
188    // Re-wire restoration fresh for this solve (same pattern as
189    // IpoptSolve). Multi-pass provider so the ℓ₁ wrapper / auto-fallback
190    // don't panic on the second inner solve (pounce#10 / pounce#24).
191    let feral_cfg = feral_config_from_options(info.problem.app.options());
192    let bff_mint = move || -> InnerBackendFactoryFactory {
193        Box::new(move || default_backend_factory(feral_cfg))
194    };
195    let resto_provider = make_default_restoration_factory_provider(
196        RestoAlgorithmBuilder::new(),
197        info.problem.app.algorithm_builder_from_options(),
198        bff_mint,
199    );
200    info.problem
201        .app
202        .set_restoration_factory_provider(resto_provider);
203
204    // Move the app out of the problem and into a fresh RustSolver.
205    let app = std::mem::replace(&mut info.problem.app, IpoptApplication::new());
206    let bridge_for_solver: Rc<RefCell<dyn TNLP>> = bridge.clone();
207    let mut rust_solver = RustSolver::new(app, bridge_for_solver);
208    let status = rust_solver.solve();
209    let bridge_ref = bridge.borrow();
210    info.problem.last_solve = Some(LastSolve {
211        stats: rust_solver.app().statistics(),
212        status,
213        linear_solver: rust_solver.app().linear_solver_summary(),
214        final_x: bridge_ref.final_x.clone(),
215        final_lambda: bridge_ref.final_lambda.clone(),
216        final_obj: bridge_ref.final_obj,
217    });
218    if !x.is_null() && n_us > 0 {
219        std::ptr::copy_nonoverlapping(bridge_ref.final_x.as_ptr(), x, n_us);
220    }
221    if !g.is_null() && m_us > 0 {
222        std::ptr::copy_nonoverlapping(bridge_ref.final_g.as_ptr(), g, m_us);
223    }
224    if !obj_val.is_null() {
225        *obj_val = bridge_ref.final_obj;
226    }
227    if !mult_g.is_null() && m_us > 0 {
228        std::ptr::copy_nonoverlapping(bridge_ref.final_lambda.as_ptr(), mult_g, m_us);
229    }
230    if !mult_x_L.is_null() && n_us > 0 {
231        std::ptr::copy_nonoverlapping(bridge_ref.final_z_l.as_ptr(), mult_x_L, n_us);
232    }
233    if !mult_x_U.is_null() && n_us > 0 {
234        std::ptr::copy_nonoverlapping(bridge_ref.final_z_u.as_ptr(), mult_x_U, n_us);
235    }
236
237    info.session = Some(rust_solver);
238    status as Index
239}
240
241/// Total compound-KKT vector dimension. Returns -1 if no converged
242/// factor is held.
243///
244/// # Safety
245///
246/// `solver` must be a valid [`IpoptSolver`] or NULL.
247#[no_mangle]
248pub unsafe extern "C" fn IpoptSolverGetKktDim(solver: IpoptSolver) -> Index {
249    if solver.is_null() {
250        return -1;
251    }
252    let info = &*solver;
253    match info.session.as_ref().and_then(|s| s.kkt_dim()) {
254        Some(d) => d as Index,
255        None => -1,
256    }
257}
258
259/// Solve `K · lhs = rhs` against the converged KKT factor. Both
260/// `rhs` and `lhs` are flat buffers of length [`IpoptSolverGetKktDim`]
261/// in the `x || s || y_c || y_d || z_l || z_u || v_l || v_u` packing.
262///
263/// Returns `TRUE` on success, `FALSE` if no factor is held or the
264/// back-solve fails.
265///
266/// # Safety
267///
268/// `rhs` and `lhs` must point to buffers at least
269/// [`IpoptSolverGetKktDim`] doubles long.
270#[no_mangle]
271pub unsafe extern "C" fn IpoptSolverKktSolve(
272    solver: IpoptSolver,
273    rhs: *const Number,
274    lhs: *mut Number,
275) -> Bool {
276    if solver.is_null() || rhs.is_null() || lhs.is_null() {
277        return FALSE;
278    }
279    let info = &*solver;
280    let Some(s) = info.session.as_ref() else {
281        return FALSE;
282    };
283    let Some(dim) = s.kkt_dim() else {
284        return FALSE;
285    };
286    let rhs_slice = std::slice::from_raw_parts(rhs, dim);
287    let mut lhs_vec = vec![0.0; dim];
288    if s.kkt_solve(rhs_slice, &mut lhs_vec).is_err() {
289        return FALSE;
290    }
291    std::ptr::copy_nonoverlapping(lhs_vec.as_ptr(), lhs, dim);
292    TRUE
293}
294
295/// First-order parametric step `Δx ≈ ∂x*/∂p · Δp`. `pin_indices` is
296/// `n_pins` `Index` values (0-based indices into `g(x)`); `deltas` is
297/// the parameter perturbation `Δp` of the same length; `dx_out` is the
298/// `n`-long primal step output (length matches the problem's `n`).
299///
300/// Returns `TRUE` on success, `FALSE` if no converged factor, invalid
301/// indices, or the sensitivity computation fails.
302///
303/// # Safety
304///
305/// `pin_indices` and `deltas` must point to `n_pins` valid elements;
306/// `dx_out` must point to at least `n` `Number` slots (`n` from the
307/// underlying IpoptProblem).
308#[no_mangle]
309pub unsafe extern "C" fn IpoptSolverParametricStep(
310    solver: IpoptSolver,
311    n_pins: Index,
312    pin_indices: *const Index,
313    deltas: *const Number,
314    dx_out: *mut Number,
315) -> Bool {
316    if solver.is_null() || n_pins < 0 {
317        return FALSE;
318    }
319    if n_pins > 0 && (pin_indices.is_null() || deltas.is_null()) {
320        return FALSE;
321    }
322    if dx_out.is_null() {
323        return FALSE;
324    }
325    let info = &*solver;
326    let Some(s) = info.session.as_ref() else {
327        return FALSE;
328    };
329    let m = info.m;
330    let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
331    let mut pins = Vec::with_capacity(n_pins as usize);
332    for &i in pins_raw {
333        if i < 0 || i >= m {
334            return FALSE;
335        }
336        pins.push(i as pounce_common::types::Index);
337    }
338    let deltas_slice = std::slice::from_raw_parts(deltas, n_pins as usize);
339    let Ok(dx) = s.parametric_step(&pins, deltas_slice) else {
340        return FALSE;
341    };
342    std::ptr::copy_nonoverlapping(dx.as_ptr(), dx_out, dx.len());
343    TRUE
344}
345
346/// Reduced Hessian `H_R = obj_scal · B K⁻¹ Bᵀ` over the pinned rows.
347/// `hr_out` receives an `n_pins²`-long column-major dense matrix.
348///
349/// Returns `TRUE` on success, `FALSE` otherwise.
350///
351/// # Safety
352///
353/// `pin_indices` must point to `n_pins` valid elements; `hr_out` must
354/// point to at least `n_pins²` `Number` slots.
355#[no_mangle]
356pub unsafe extern "C" fn IpoptSolverReducedHessian(
357    solver: IpoptSolver,
358    n_pins: Index,
359    pin_indices: *const Index,
360    obj_scal: Number,
361    hr_out: *mut Number,
362) -> Bool {
363    if solver.is_null() || n_pins < 0 || hr_out.is_null() {
364        return FALSE;
365    }
366    if n_pins > 0 && pin_indices.is_null() {
367        return FALSE;
368    }
369    let info = &*solver;
370    let Some(s) = info.session.as_ref() else {
371        return FALSE;
372    };
373    let m = info.m;
374    let pins_raw = std::slice::from_raw_parts(pin_indices, n_pins as usize);
375    let mut pins = Vec::with_capacity(n_pins as usize);
376    for &i in pins_raw {
377        if i < 0 || i >= m {
378            return FALSE;
379        }
380        pins.push(i as pounce_common::types::Index);
381    }
382    let Ok(hr) = s.compute_reduced_hessian(&pins, obj_scal) else {
383        return FALSE;
384    };
385    std::ptr::copy_nonoverlapping(hr.as_ptr(), hr_out, hr.len());
386    TRUE
387}