Skip to main content

pounce_algorithm/kkt/
pd_search_dir_calc.rs

1//! PD search-direction calculator — port of
2//! `Algorithm/IpPDSearchDirCalc.{hpp,cpp}`.
3//!
4//! Builds the right-hand side from the current iterate's KKT
5//! residuals (gradient of Lagrangian, constraint values, relaxed
6//! complementarities), optionally adds a Mehrotra corrector, then
7//! calls `PdFullSpaceSolver::solve` to produce the search direction
8//! `delta`.
9//!
10//! Two RHS modes:
11//! * standard: z-blocks are the relaxed complementarities
12//!   `s_L · z_L − μ`, …
13//! * Mehrotra: z-blocks include the second-order term
14//!   `(P_L^T Δx_aff) · Δz_aff_L + (s_L · z_L − μ)`.
15
16use crate::ipopt_cq::IpoptCqHandle;
17use crate::ipopt_data::IpoptDataHandle;
18use crate::ipopt_nlp::IpoptNlp;
19use crate::iterates_vector::{IteratesVector, IteratesVectorMut};
20use crate::kkt::pd_full_space_solver::PdFullSpaceSolver;
21use crate::kkt::search_dir_calc::SearchDirCalculator;
22use pounce_common::types::Number;
23use std::cell::{RefCell, RefMut};
24use std::rc::Rc;
25
26pub struct PdSearchDirCalc {
27    /// Owned via `Rc<RefCell<…>>` so external callers (e.g. the
28    /// post-converged sensitivity callback) can retain a cloned handle
29    /// past the IPM call. During the IPM loop refcount is 1 and every
30    /// internal call goes through `borrow_mut`; the runtime borrow
31    /// check costs are negligible relative to the linear solve.
32    pd_solver: Rc<RefCell<PdFullSpaceSolver>>,
33    /// Skip the residual check on the search direction. Mirrors
34    /// `fast_step_computation` (default false).
35    pub fast_step_computation: bool,
36    /// Mehrotra-style predictor-corrector step. Mirrors
37    /// `mehrotra_algorithm` (default false in v1.0; flipped on by
38    /// the adaptive-mu wiring in Phase 10).
39    pub mehrotra_algorithm: bool,
40}
41
42impl PdSearchDirCalc {
43    pub fn new(pd_solver: PdFullSpaceSolver) -> Self {
44        Self {
45            pd_solver: Rc::new(RefCell::new(pd_solver)),
46            fast_step_computation: false,
47            mehrotra_algorithm: false,
48        }
49    }
50
51    /// Clone the shared handle to the PD solver. Used by the
52    /// post-converged sensitivity callback to retain a factor handle
53    /// past the IPM call.
54    pub fn pd_solver_rc(&self) -> Rc<RefCell<PdFullSpaceSolver>> {
55        Rc::clone(&self.pd_solver)
56    }
57
58    /// Borrow the PD solver mutably. Caller is responsible for not
59    /// holding two mutable borrows at once (single-thread, single-
60    /// borrow access pattern — matches every existing call site).
61    pub fn pd_solver_mut(&self) -> RefMut<'_, PdFullSpaceSolver> {
62        self.pd_solver.borrow_mut()
63    }
64
65    /// Compute the search direction and write it back into
66    /// `data.delta`. Returns `false` if the underlying linear solve
67    /// fails. Mirrors `PDSearchDirCalculator::ComputeSearchDirection`.
68    pub fn compute_search_direction(
69        &mut self,
70        data: &IpoptDataHandle,
71        cq: &IpoptCqHandle,
72        nlp: &Rc<RefCell<dyn IpoptNlp>>,
73    ) -> bool {
74        let improve_solution = data.borrow().delta.is_some();
75
76        if improve_solution && self.fast_step_computation {
77            return true;
78        }
79
80        let curr = {
81            let d = data.borrow();
82            d.curr
83                .clone()
84                .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
85        };
86
87        // Build RHS.
88        let mut rhs = curr.make_new_zeroed();
89        {
90            let cq_ref = cq.borrow();
91            rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
92            rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
93            rhs.y_c.copy(&*cq_ref.curr_c());
94            rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
95        }
96
97        let nbounds = {
98            let n = nlp.borrow();
99            n.x_l().dim() + n.x_u().dim() + n.d_l().dim() + n.d_u().dim()
100        };
101
102        if nbounds > 0 && self.mehrotra_algorithm {
103            let delta_aff = {
104                let d = data.borrow();
105                d.delta_aff
106                    .clone()
107                    .unwrap_or_else(|| panic!("PdSearchDirCalc: delta_aff missing for Mehrotra"))
108            };
109            self.fill_mehrotra_z_blocks(&delta_aff, cq, nlp, &mut rhs);
110        } else {
111            let cq_ref = cq.borrow();
112            rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
113            rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
114            rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
115            rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
116        }
117
118        let frozen_rhs = rhs.freeze();
119
120        // Allocate the search direction. If we are improving an
121        // existing one, seed it with `−delta` (per upstream).
122        let mut delta = frozen_rhs.make_new_zeroed();
123        if improve_solution {
124            let prev = {
125                let d = data.borrow();
126                let Some(p) = d.delta.clone() else {
127                    unreachable!("PdSearchDirCalc: delta cleared between is_some() and clone()")
128                };
129                p
130            };
131            delta.add_one_vector(-1.0, &prev, 0.0);
132        }
133
134        let allow_inexact = self.fast_step_computation;
135        let ok = self.pd_solver.borrow_mut().solve(
136            data,
137            cq,
138            nlp,
139            -1.0,
140            0.0,
141            &frozen_rhs,
142            &mut delta,
143            allow_inexact,
144            improve_solution,
145        );
146
147        if ok {
148            data.borrow_mut().set_delta(delta.freeze());
149        }
150        ok
151    }
152
153    /// Affine (predictor) step — port of upstream's
154    /// `IpAdaptiveMuUpdate::ComputeMuMehrotra` predictor solve. Builds
155    /// the same RHS as [`Self::compute_search_direction`] except the
156    /// z-blocks use the *unrelaxed* complementarity `s · z`
157    /// (μ-target = 0) so the resulting step targets the affine-scaling
158    /// system. The solution is stored in `data.delta_aff` for
159    /// consumption by the Probing / Quality-Function oracles.
160    ///
161    /// Returns `false` if the linear solve fails.
162    pub fn compute_affine_step(
163        &mut self,
164        data: &IpoptDataHandle,
165        cq: &IpoptCqHandle,
166        nlp: &Rc<RefCell<dyn IpoptNlp>>,
167    ) -> bool {
168        let curr = {
169            let d = data.borrow();
170            d.curr
171                .clone()
172                .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
173        };
174
175        let mut rhs = curr.make_new_zeroed();
176        {
177            let cq_ref = cq.borrow();
178            // Upstream `IpQualityFunctionMuOracle.cpp:193-200` uses the
179            // *plain* `curr_grad_lag_{x,s}` here, NOT the damped variant.
180            // The `μ·κ_d·(P_L − P_U)` damping enters the main-step RHS
181            // only — for the affine (predictor) RHS upstream wants the
182            // gradient at μ=0.
183            rhs.x.copy(&*cq_ref.curr_grad_lag_x());
184            rhs.s.copy(&*cq_ref.curr_grad_lag_s());
185            rhs.y_c.copy(&*cq_ref.curr_c());
186            rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
187            // Affine RHS: complementarity blocks use `s·z` (μ=0),
188            // not `s·z − μ`.
189            rhs.z_l.copy(&*cq_ref.curr_compl_x_l());
190            rhs.z_u.copy(&*cq_ref.curr_compl_x_u());
191            rhs.v_l.copy(&*cq_ref.curr_compl_s_l());
192            rhs.v_u.copy(&*cq_ref.curr_compl_s_u());
193        }
194
195        let frozen_rhs = rhs.freeze();
196        let mut delta_aff = frozen_rhs.make_new_zeroed();
197
198        // Upstream `IpQualityFunctionMuOracle.cpp:208` and
199        // `IpProbingMuOracle.cpp:79` both pass `allow_inexact = true`
200        // on the affine (predictor) solve: "we allow a somewhat
201        // inexact solution here ... iterative refinement will be done
202        // after mu is known". Skipping IR on the predictor saves
203        // ~5-10x per-iter linsol work on Mehrotra runs.
204        let ok = self.pd_solver.borrow_mut().solve(
205            data,
206            cq,
207            nlp,
208            -1.0,
209            0.0,
210            &frozen_rhs,
211            &mut delta_aff,
212            true,
213            false,
214        );
215
216        if ok {
217            data.borrow_mut().set_delta_aff(delta_aff.freeze());
218        }
219        ok
220    }
221
222    /// Pure centering step — port of upstream
223    /// `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 218-247. RHS
224    /// is `(0, 0, 0, 0, μ̄·1, μ̄·1, μ̄·1, μ̄·1)` with μ̄ = `curr_avrg_compl`.
225    /// Solution stored on `data.delta_cen` for the quality-function
226    /// oracle's σ-bracket search.
227    ///
228    /// Returns `false` if the linear solve fails.
229    pub fn compute_centering_step(
230        &mut self,
231        data: &IpoptDataHandle,
232        cq: &IpoptCqHandle,
233        nlp: &Rc<RefCell<dyn IpoptNlp>>,
234    ) -> bool {
235        let curr = {
236            let d = data.borrow();
237            d.curr
238                .clone()
239                .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
240        };
241        let avrg_compl = cq.borrow().curr_avrg_compl();
242
243        let mut rhs = curr.make_new_zeroed();
244        // x/s blocks: -avrg_compl · grad_kappa_times_damping_{x,s}, per
245        // upstream IpQualityFunctionMuOracle.cpp:229-230. With kappa_d=0
246        // (the default) these are zero, but kappa_d=1e-5 (default) makes
247        // them nonzero on damped components and the centering direction
248        // depends on them.
249        {
250            let cq_ref = cq.borrow();
251            rhs.x
252                .add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_x(), 0.0);
253            rhs.s
254                .add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_s(), 0.0);
255        }
256        rhs.y_c.set(0.0);
257        rhs.y_d.set(0.0);
258        rhs.z_l.set(avrg_compl);
259        rhs.z_u.set(avrg_compl);
260        rhs.v_l.set(avrg_compl);
261        rhs.v_u.set(avrg_compl);
262
263        let frozen_rhs = rhs.freeze();
264        let mut delta_cen = frozen_rhs.make_new_zeroed();
265
266        // Match upstream `IpQualityFunctionMuOracle.cpp:243`: IR is
267        // deferred until mu is known, so we allow a somewhat inexact
268        // centering solve here.
269        let ok = self.pd_solver.borrow_mut().solve(
270            data,
271            cq,
272            nlp,
273            1.0,
274            0.0,
275            &frozen_rhs,
276            &mut delta_cen,
277            true,
278            false,
279        );
280
281        if ok {
282            data.borrow_mut().set_delta_cen(delta_cen.freeze());
283        }
284        ok
285    }
286
287    /// Solve the second-order-correction (SOC) linear system used by
288    /// the filter line search to recover full-step acceptability when
289    /// the Newton step grows the constraint violation. Mirrors the RHS
290    /// assembly + `pd_solver_->Solve(-1.0, 0.0, ...)` block in upstream
291    /// `IpFilterLSAcceptor.cpp:577-608`.
292    ///
293    /// The caller supplies the SOC right-hand sides for the equality and
294    /// inequality blocks (`c_soc`, `dms_soc`); this method assembles the
295    /// remaining six blocks using the current iterate's KKT residuals
296    /// and returns the resulting `delta_soc`. `soc_method = 0` matches
297    /// upstream's default (gradient blocks unscaled); `soc_method = 1`
298    /// scales the gradient blocks by `alpha_primal_soc` to reuse a
299    /// previously-tried correction.
300    pub fn compute_soc_step(
301        &mut self,
302        data: &IpoptDataHandle,
303        cq: &IpoptCqHandle,
304        nlp: &Rc<RefCell<dyn IpoptNlp>>,
305        c_soc: &dyn pounce_linalg::Vector,
306        dms_soc: &dyn pounce_linalg::Vector,
307        alpha_primal_soc: Number,
308        soc_method: i32,
309    ) -> Option<IteratesVector> {
310        let curr = {
311            let d = data.borrow();
312            d.curr
313                .clone()
314                .unwrap_or_else(|| panic!("PdSearchDirCalc::compute_soc_step: curr is unset"))
315        };
316        let mut rhs = curr.make_new_zeroed();
317        {
318            let cq_ref = cq.borrow();
319            rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
320            rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
321            if soc_method == 1 {
322                rhs.x.scal(alpha_primal_soc);
323                rhs.s.scal(alpha_primal_soc);
324            }
325            rhs.y_c.copy(c_soc);
326            rhs.y_d.copy(dms_soc);
327            rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
328            rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
329            rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
330            rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
331        }
332        let frozen_rhs = rhs.freeze();
333        let mut delta_soc = frozen_rhs.make_new_zeroed();
334        let ok = self.pd_solver.borrow_mut().solve(
335            data,
336            cq,
337            nlp,
338            -1.0,
339            0.0,
340            &frozen_rhs,
341            &mut delta_soc,
342            false,
343            false,
344        );
345        if ok {
346            Some(delta_soc.freeze())
347        } else {
348            None
349        }
350    }
351
352    /// Mehrotra z-block:
353    ///   tmp_zL =  P_L^T · Δx_aff;  tmp_zL ⊙= Δz_aff_L;  tmp_zL += relaxed_compl_x_L
354    /// Symmetric for the U / s blocks.
355    fn fill_mehrotra_z_blocks(
356        &self,
357        delta_aff: &IteratesVector,
358        cq: &IpoptCqHandle,
359        nlp: &Rc<RefCell<dyn IpoptNlp>>,
360        rhs: &mut IteratesVectorMut,
361    ) {
362        let n = nlp.borrow();
363        let cq_ref = cq.borrow();
364
365        // z_L
366        n.px_l()
367            .trans_mult_vector(1.0, &*delta_aff.x, 0.0, &mut *rhs.z_l);
368        rhs.z_l.element_wise_multiply(&*delta_aff.z_l);
369        rhs.z_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_l());
370
371        // z_U
372        n.px_u()
373            .trans_mult_vector(-1.0, &*delta_aff.x, 0.0, &mut *rhs.z_u);
374        rhs.z_u.element_wise_multiply(&*delta_aff.z_u);
375        rhs.z_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_u());
376
377        // v_L
378        n.pd_l()
379            .trans_mult_vector(1.0, &*delta_aff.s, 0.0, &mut *rhs.v_l);
380        rhs.v_l.element_wise_multiply(&*delta_aff.v_l);
381        rhs.v_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_l());
382
383        // v_U
384        n.pd_u()
385            .trans_mult_vector(-1.0, &*delta_aff.s, 0.0, &mut *rhs.v_u);
386        rhs.v_u.element_wise_multiply(&*delta_aff.v_u);
387        rhs.v_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_u());
388    }
389}
390
391impl SearchDirCalculator for PdSearchDirCalc {}
392
393// --- per-element helpers retained from the Phase-6 stub for
394// downstream callers (CG-penalty path, restoration RHS unit tests).
395// Not used by `compute_search_direction` itself.
396
397pub fn mehrotra_corrector_lower(
398    delta_aff_x_lo: Number,
399    delta_aff_z: Number,
400    relaxed_compl: Number,
401) -> Number {
402    delta_aff_x_lo * delta_aff_z + relaxed_compl
403}
404
405pub fn mehrotra_corrector_upper(
406    delta_aff_x_up: Number,
407    delta_aff_z: Number,
408    relaxed_compl: Number,
409) -> Number {
410    -delta_aff_x_up * delta_aff_z + relaxed_compl
411}
412
413pub fn relaxed_complementarity(x: Number, z: Number, mu: Number) -> Number {
414    x * z - mu
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn relaxed_compl_at_central_path_is_zero() {
423        assert_eq!(relaxed_complementarity(2.0, 0.5, 1.0), 0.0);
424    }
425
426    #[test]
427    fn mehrotra_lower_combines_linearly() {
428        assert_eq!(mehrotra_corrector_lower(1.0, 2.0, 0.5), 2.5);
429    }
430
431    #[test]
432    fn mehrotra_upper_negates_dx() {
433        assert_eq!(mehrotra_corrector_upper(1.0, 2.0, 0.5), -1.5);
434    }
435}