Skip to main content

pounce_algorithm/kkt/
pd_search_dir_calc.rs

1//! PD search-direction calculator — port of
2//! `Algorithm/IpPDSearchDirCalc.{hpp,cpp}`.
3//!
4//! Builds the right-hand side from the current iterate's KKT
5//! residuals (gradient of Lagrangian, constraint values, relaxed
6//! complementarities), optionally adds a Mehrotra corrector, then
7//! calls `PdFullSpaceSolver::solve` to produce the search direction
8//! `delta`.
9//!
10//! Two RHS modes:
11//! * standard: z-blocks are the relaxed complementarities
12//!   `s_L · z_L − μ`, …
13//! * Mehrotra: z-blocks include the second-order term
14//!   `(P_L^T Δx_aff) · Δz_aff_L + (s_L · z_L − μ)`.
15
16use crate::ipopt_cq::IpoptCqHandle;
17use crate::ipopt_data::IpoptDataHandle;
18use crate::ipopt_nlp::IpoptNlp;
19use crate::iterates_vector::{IteratesVector, IteratesVectorMut};
20use crate::kkt::pd_full_space_solver::PdFullSpaceSolver;
21use crate::kkt::search_dir_calc::SearchDirCalculator;
22use pounce_common::types::Number;
23use std::cell::{RefCell, RefMut};
24use std::rc::Rc;
25
26pub struct PdSearchDirCalc {
27    /// Owned via `Rc<RefCell<…>>` so external callers (e.g. the
28    /// post-converged sensitivity callback) can retain a cloned handle
29    /// past the IPM call. During the IPM loop refcount is 1 and every
30    /// internal call goes through `borrow_mut`; the runtime borrow
31    /// check costs are negligible relative to the linear solve.
32    pd_solver: Rc<RefCell<PdFullSpaceSolver>>,
33    /// Skip the residual check on the search direction. Mirrors
34    /// `fast_step_computation` (default false).
35    pub fast_step_computation: bool,
36    /// Mehrotra-style predictor-corrector step. Mirrors
37    /// `mehrotra_algorithm` (default false in v1.0; flipped on by
38    /// the adaptive-mu wiring in Phase 10).
39    pub mehrotra_algorithm: bool,
40}
41
42impl PdSearchDirCalc {
43    pub fn new(pd_solver: PdFullSpaceSolver) -> Self {
44        Self {
45            pd_solver: Rc::new(RefCell::new(pd_solver)),
46            fast_step_computation: false,
47            mehrotra_algorithm: false,
48        }
49    }
50
51    /// Clone the shared handle to the PD solver. Used by the
52    /// post-converged sensitivity callback to retain a factor handle
53    /// past the IPM call.
54    pub fn pd_solver_rc(&self) -> Rc<RefCell<PdFullSpaceSolver>> {
55        Rc::clone(&self.pd_solver)
56    }
57
58    /// Borrow the PD solver mutably. Caller is responsible for not
59    /// holding two mutable borrows at once (single-thread, single-
60    /// borrow access pattern — matches every existing call site).
61    pub fn pd_solver_mut(&self) -> RefMut<'_, PdFullSpaceSolver> {
62        self.pd_solver.borrow_mut()
63    }
64
65    /// Compute the search direction and write it back into
66    /// `data.delta`. Returns `false` if the underlying linear solve
67    /// fails. Mirrors `PDSearchDirCalculator::ComputeSearchDirection`.
68    pub fn compute_search_direction(
69        &mut self,
70        data: &IpoptDataHandle,
71        cq: &IpoptCqHandle,
72        nlp: &Rc<RefCell<dyn IpoptNlp>>,
73    ) -> bool {
74        let improve_solution = data.borrow().delta.is_some();
75
76        if improve_solution && self.fast_step_computation {
77            return true;
78        }
79
80        let curr = {
81            let d = data.borrow();
82            d.curr
83                .clone()
84                .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
85        };
86
87        // Build RHS.
88        let mut rhs = curr.make_new_zeroed();
89        {
90            let cq_ref = cq.borrow();
91            rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
92            rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
93            rhs.y_c.copy(&*cq_ref.curr_c());
94            rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
95        }
96
97        let nbounds = {
98            let n = nlp.borrow();
99            n.x_l().dim() + n.x_u().dim() + n.d_l().dim() + n.d_u().dim()
100        };
101
102        if nbounds > 0 && self.mehrotra_algorithm {
103            let delta_aff = {
104                let d = data.borrow();
105                d.delta_aff
106                    .clone()
107                    .unwrap_or_else(|| panic!("PdSearchDirCalc: delta_aff missing for Mehrotra"))
108            };
109            self.fill_mehrotra_z_blocks(&delta_aff, cq, nlp, &mut rhs);
110        } else {
111            let cq_ref = cq.borrow();
112            rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
113            rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
114            rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
115            rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
116        }
117
118        let frozen_rhs = rhs.freeze();
119
120        // Allocate the search direction. If we are improving an
121        // existing one, seed it with `−delta` (per upstream).
122        let mut delta = frozen_rhs.make_new_zeroed();
123        if improve_solution {
124            let prev = {
125                let d = data.borrow();
126                let Some(p) = d.delta.clone() else {
127                    unreachable!("PdSearchDirCalc: delta cleared between is_some() and clone()")
128                };
129                p
130            };
131            delta.add_one_vector(-1.0, &prev, 0.0);
132        }
133
134        let allow_inexact = self.fast_step_computation;
135        let ok = self.pd_solver.borrow_mut().solve(
136            data,
137            cq,
138            nlp,
139            -1.0,
140            0.0,
141            &frozen_rhs,
142            &mut delta,
143            allow_inexact,
144            improve_solution,
145        );
146
147        if ok {
148            data.borrow_mut().set_delta(delta.freeze());
149        }
150        ok
151    }
152
153    /// Affine (predictor) step — port of upstream's
154    /// `IpAdaptiveMuUpdate::ComputeMuMehrotra` predictor solve. Builds
155    /// the same RHS as [`Self::compute_search_direction`] except the
156    /// z-blocks use the *unrelaxed* complementarity `s · z`
157    /// (μ-target = 0) so the resulting step targets the affine-scaling
158    /// system. The solution is stored in `data.delta_aff` for
159    /// consumption by the Probing / Quality-Function oracles.
160    ///
161    /// Returns `false` if the linear solve fails.
162    pub fn compute_affine_step(
163        &mut self,
164        data: &IpoptDataHandle,
165        cq: &IpoptCqHandle,
166        nlp: &Rc<RefCell<dyn IpoptNlp>>,
167    ) -> bool {
168        let curr = {
169            let d = data.borrow();
170            d.curr
171                .clone()
172                .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
173        };
174
175        let mut rhs = curr.make_new_zeroed();
176        {
177            let cq_ref = cq.borrow();
178            // Upstream `IpQualityFunctionMuOracle.cpp:193-200` uses the
179            // *plain* `curr_grad_lag_{x,s}` here, NOT the damped variant.
180            // The `μ·κ_d·(P_L − P_U)` damping enters the main-step RHS
181            // only — for the affine (predictor) RHS upstream wants the
182            // gradient at μ=0.
183            rhs.x.copy(&*cq_ref.curr_grad_lag_x());
184            rhs.s.copy(&*cq_ref.curr_grad_lag_s());
185            rhs.y_c.copy(&*cq_ref.curr_c());
186            rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
187            // Affine RHS: complementarity blocks use `s·z` (μ=0),
188            // not `s·z − μ`.
189            rhs.z_l.copy(&*cq_ref.curr_compl_x_l());
190            rhs.z_u.copy(&*cq_ref.curr_compl_x_u());
191            rhs.v_l.copy(&*cq_ref.curr_compl_s_l());
192            rhs.v_u.copy(&*cq_ref.curr_compl_s_u());
193        }
194
195        let frozen_rhs = rhs.freeze();
196        let mut delta_aff = frozen_rhs.make_new_zeroed();
197
198        // Upstream `IpQualityFunctionMuOracle.cpp:208` passes
199        // `allow_inexact = true`. Pounce keeps full iterative
200        // refinement here (allow_inexact=false): an earlier attempt to
201        // set this to `true` regressed TRO3X3 from Solve_Succeeded to
202        // Infeasible_Problem_Detected, because pounce's IR-driven
203        // `increase_quality()` cascade produces materially different
204        // steps than upstream's single-shot MA57. Leaving as-is until
205        // the MA57 backend lands in Phase 4.
206        let ok = self.pd_solver.borrow_mut().solve(
207            data,
208            cq,
209            nlp,
210            -1.0,
211            0.0,
212            &frozen_rhs,
213            &mut delta_aff,
214            false,
215            false,
216        );
217
218        if ok {
219            data.borrow_mut().set_delta_aff(delta_aff.freeze());
220        }
221        ok
222    }
223
224    /// Pure centering step — port of upstream
225    /// `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 218-247. RHS
226    /// is `(0, 0, 0, 0, μ̄·1, μ̄·1, μ̄·1, μ̄·1)` with μ̄ = `curr_avrg_compl`.
227    /// Solution stored on `data.delta_cen` for the quality-function
228    /// oracle's σ-bracket search.
229    ///
230    /// Returns `false` if the linear solve fails.
231    pub fn compute_centering_step(
232        &mut self,
233        data: &IpoptDataHandle,
234        cq: &IpoptCqHandle,
235        nlp: &Rc<RefCell<dyn IpoptNlp>>,
236    ) -> bool {
237        let curr = {
238            let d = data.borrow();
239            d.curr
240                .clone()
241                .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
242        };
243        let avrg_compl = cq.borrow().curr_avrg_compl();
244
245        let mut rhs = curr.make_new_zeroed();
246        // x/s blocks: -avrg_compl · grad_kappa_times_damping_{x,s}, per
247        // upstream IpQualityFunctionMuOracle.cpp:229-230. With kappa_d=0
248        // (the default) these are zero, but kappa_d=1e-5 (default) makes
249        // them nonzero on damped components and the centering direction
250        // depends on them.
251        {
252            let cq_ref = cq.borrow();
253            rhs.x
254                .add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_x(), 0.0);
255            rhs.s
256                .add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_s(), 0.0);
257        }
258        rhs.y_c.set(0.0);
259        rhs.y_d.set(0.0);
260        rhs.z_l.set(avrg_compl);
261        rhs.z_u.set(avrg_compl);
262        rhs.v_l.set(avrg_compl);
263        rhs.v_u.set(avrg_compl);
264
265        let frozen_rhs = rhs.freeze();
266        let mut delta_cen = frozen_rhs.make_new_zeroed();
267
268        // Upstream `IpQualityFunctionMuOracle.cpp:243` passes
269        // `allow_inexact = true`. Same caveat as `compute_affine_step`
270        // — flipping this on regresses TRO3X3 because the FERAL-backed
271        // IR cascade differs from MA57's single-shot. Defer until MA57.
272        let ok = self.pd_solver.borrow_mut().solve(
273            data,
274            cq,
275            nlp,
276            1.0,
277            0.0,
278            &frozen_rhs,
279            &mut delta_cen,
280            false,
281            false,
282        );
283
284        if ok {
285            data.borrow_mut().set_delta_cen(delta_cen.freeze());
286        }
287        ok
288    }
289
290    /// Solve the second-order-correction (SOC) linear system used by
291    /// the filter line search to recover full-step acceptability when
292    /// the Newton step grows the constraint violation. Mirrors the RHS
293    /// assembly + `pd_solver_->Solve(-1.0, 0.0, ...)` block in upstream
294    /// `IpFilterLSAcceptor.cpp:577-608`.
295    ///
296    /// The caller supplies the SOC right-hand sides for the equality and
297    /// inequality blocks (`c_soc`, `dms_soc`); this method assembles the
298    /// remaining six blocks using the current iterate's KKT residuals
299    /// and returns the resulting `delta_soc`. `soc_method = 0` matches
300    /// upstream's default (gradient blocks unscaled); `soc_method = 1`
301    /// scales the gradient blocks by `alpha_primal_soc` to reuse a
302    /// previously-tried correction.
303    pub fn compute_soc_step(
304        &mut self,
305        data: &IpoptDataHandle,
306        cq: &IpoptCqHandle,
307        nlp: &Rc<RefCell<dyn IpoptNlp>>,
308        c_soc: &dyn pounce_linalg::Vector,
309        dms_soc: &dyn pounce_linalg::Vector,
310        alpha_primal_soc: Number,
311        soc_method: i32,
312    ) -> Option<IteratesVector> {
313        let curr = {
314            let d = data.borrow();
315            d.curr
316                .clone()
317                .unwrap_or_else(|| panic!("PdSearchDirCalc::compute_soc_step: curr is unset"))
318        };
319        let mut rhs = curr.make_new_zeroed();
320        {
321            let cq_ref = cq.borrow();
322            rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
323            rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
324            if soc_method == 1 {
325                rhs.x.scal(alpha_primal_soc);
326                rhs.s.scal(alpha_primal_soc);
327            }
328            rhs.y_c.copy(c_soc);
329            rhs.y_d.copy(dms_soc);
330            rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
331            rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
332            rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
333            rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
334        }
335        let frozen_rhs = rhs.freeze();
336        let mut delta_soc = frozen_rhs.make_new_zeroed();
337        let ok = self.pd_solver.borrow_mut().solve(
338            data,
339            cq,
340            nlp,
341            -1.0,
342            0.0,
343            &frozen_rhs,
344            &mut delta_soc,
345            false,
346            false,
347        );
348        if ok {
349            Some(delta_soc.freeze())
350        } else {
351            None
352        }
353    }
354
355    /// Mehrotra z-block:
356    ///   tmp_zL =  P_L^T · Δx_aff;  tmp_zL ⊙= Δz_aff_L;  tmp_zL += relaxed_compl_x_L
357    /// Symmetric for the U / s blocks.
358    fn fill_mehrotra_z_blocks(
359        &self,
360        delta_aff: &IteratesVector,
361        cq: &IpoptCqHandle,
362        nlp: &Rc<RefCell<dyn IpoptNlp>>,
363        rhs: &mut IteratesVectorMut,
364    ) {
365        let n = nlp.borrow();
366        let cq_ref = cq.borrow();
367
368        // z_L
369        n.px_l()
370            .trans_mult_vector(1.0, &*delta_aff.x, 0.0, &mut *rhs.z_l);
371        rhs.z_l.element_wise_multiply(&*delta_aff.z_l);
372        rhs.z_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_l());
373
374        // z_U
375        n.px_u()
376            .trans_mult_vector(-1.0, &*delta_aff.x, 0.0, &mut *rhs.z_u);
377        rhs.z_u.element_wise_multiply(&*delta_aff.z_u);
378        rhs.z_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_u());
379
380        // v_L
381        n.pd_l()
382            .trans_mult_vector(1.0, &*delta_aff.s, 0.0, &mut *rhs.v_l);
383        rhs.v_l.element_wise_multiply(&*delta_aff.v_l);
384        rhs.v_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_l());
385
386        // v_U
387        n.pd_u()
388            .trans_mult_vector(-1.0, &*delta_aff.s, 0.0, &mut *rhs.v_u);
389        rhs.v_u.element_wise_multiply(&*delta_aff.v_u);
390        rhs.v_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_u());
391    }
392}
393
394impl SearchDirCalculator for PdSearchDirCalc {}
395
396// --- per-element helpers retained from the Phase-6 stub for
397// downstream callers (CG-penalty path, restoration RHS unit tests).
398// Not used by `compute_search_direction` itself.
399
400pub fn mehrotra_corrector_lower(
401    delta_aff_x_lo: Number,
402    delta_aff_z: Number,
403    relaxed_compl: Number,
404) -> Number {
405    delta_aff_x_lo * delta_aff_z + relaxed_compl
406}
407
408pub fn mehrotra_corrector_upper(
409    delta_aff_x_up: Number,
410    delta_aff_z: Number,
411    relaxed_compl: Number,
412) -> Number {
413    -delta_aff_x_up * delta_aff_z + relaxed_compl
414}
415
416pub fn relaxed_complementarity(x: Number, z: Number, mu: Number) -> Number {
417    x * z - mu
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn relaxed_compl_at_central_path_is_zero() {
426        assert_eq!(relaxed_complementarity(2.0, 0.5, 1.0), 0.0);
427    }
428
429    #[test]
430    fn mehrotra_lower_combines_linearly() {
431        assert_eq!(mehrotra_corrector_lower(1.0, 2.0, 0.5), 2.5);
432    }
433
434    #[test]
435    fn mehrotra_upper_negates_dx() {
436        assert_eq!(mehrotra_corrector_upper(1.0, 2.0, 0.5), -1.5);
437    }
438}