Skip to main content

pounce_algorithm/eq_mult/
least_square.rs

1//! Least-squares multiplier estimate — port of
2//! `Algorithm/IpLeastSquareMults.{hpp,cpp}`. Solves the W=0
3//! augmented system to get an initial `y_c`/`y_d`.
4//!
5//! The system, with `delta_x = delta_s = 1.0` and all other
6//! perturbations / weights zero (matching upstream `IpLeastSquareMults.cpp:60`):
7//!
8//! ```text
9//!   [ I    0   J_c^T  J_d^T ] [dx ]   [ −∇f + Pₗ z_L − Pᵤ z_U ]
10//!   [ 0    I    0      −I   ] [ds ] = [    Pₗ v_L − Pᵤ v_U    ]
11//!   [ J_c  0    0       0   ] [dyc]   [          0            ]
12//!   [ J_d −I    0       0   ] [dyd]   [          0            ]
13//! ```
14//!
15//! Sign convention from `IpLeastSquareMults.cpp:54-61`. `dyc`, `dyd`
16//! are the least-squares estimates we keep as `y_c`, `y_d`; `dx`,
17//! `ds` are discarded.
18
19use crate::eq_mult::r#trait::EqMultCalculator;
20use crate::ipopt_cq::IpoptCqHandle;
21use crate::ipopt_data::IpoptDataHandle;
22use crate::ipopt_nlp::IpoptNlp;
23use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
24use pounce_linalg::Vector;
25use pounce_linsol::ESymSolverStatus;
26use std::cell::RefCell;
27use std::rc::Rc;
28
29pub struct LeastSquareMults;
30
31impl LeastSquareMults {
32    pub fn new() -> Self {
33        Self
34    }
35}
36
37impl Default for LeastSquareMults {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl EqMultCalculator for LeastSquareMults {
44    fn calculate_y_eq(
45        &mut self,
46        data: &IpoptDataHandle,
47        cq: &IpoptCqHandle,
48        nlp: &Rc<RefCell<dyn IpoptNlp>>,
49        aug_solver: &mut dyn AugSystemSolver,
50        y_c: &mut dyn Vector,
51        y_d: &mut dyn Vector,
52    ) -> bool {
53        let curr = match data.borrow().curr.clone() {
54            Some(c) => c,
55            None => return false,
56        };
57
58        // Pull NLP-evaluated quantities first so the `nlp.borrow_mut()`
59        // inside CQ's eval helpers can complete before we take the
60        // shared `nlp.borrow()` for the bound-selection matrices.
61        let cq_ref = cq.borrow();
62        let grad_f = cq_ref.curr_grad_f();
63        let j_c = cq_ref.curr_jac_c();
64        let j_d = cq_ref.curr_jac_d();
65        // Upstream `IpLeastSquareMults.cpp:80` passes a `zeroW` SymMatrix
66        // (same sparsity as the real Hessian) with `W_factor=0.0`. This
67        // ensures `StdAugSystemSolver` pins its triplet structure with
68        // the W slots present, so subsequent calls (with the actual
69        // Hessian) write into those slots rather than skipping them.
70        let zero_w = cq_ref.curr_exact_hessian();
71        drop(cq_ref);
72
73        let nlp_ref = nlp.borrow();
74
75        // rhs_x = −∇f + Pₗ z_L − Pᵤ z_U  (mirrors
76        // `IpLeastSquareMults.cpp:54-57` exactly).
77        let mut rhs_x = grad_f.make_new();
78        rhs_x.copy(&*grad_f);
79        nlp_ref
80            .px_l()
81            .mult_vector(1.0, &*curr.z_l, -1.0, &mut *rhs_x);
82        nlp_ref
83            .px_u()
84            .mult_vector(-1.0, &*curr.z_u, 1.0, &mut *rhs_x);
85
86        // rhs_s = Pₗ v_L − Pᵤ v_U  (zero-init then mult; mirrors
87        // `IpLeastSquareMults.cpp:60-61`).
88        let mut rhs_s = curr.s.make_new();
89        nlp_ref
90            .pd_l()
91            .mult_vector(1.0, &*curr.v_l, 0.0, &mut *rhs_s);
92        nlp_ref
93            .pd_u()
94            .mult_vector(-1.0, &*curr.v_u, 1.0, &mut *rhs_s);
95
96        // rhs_c = 0, rhs_d = 0.
97        let mut rhs_c = curr.y_c.make_new();
98        rhs_c.set(0.0);
99        let mut rhs_d = curr.y_d.make_new();
100        rhs_d.set(0.0);
101
102        // sol_x, sol_s scratch (discarded after solve).
103        let mut sol_x = rhs_x.make_new();
104        let mut sol_s = rhs_s.make_new();
105
106        let coeffs = AugSysCoeffs {
107            w: Some(&*zero_w),
108            w_factor: 0.0,
109            d_x: None,
110            delta_x: 1.0,
111            d_s: None,
112            delta_s: 1.0,
113            j_c: &*j_c,
114            d_c: None,
115            delta_c: 0.0,
116            j_d: &*j_d,
117            d_d: None,
118            delta_d: 0.0,
119        };
120        let aug_rhs = AugSysRhs {
121            rhs_x: &*rhs_x,
122            rhs_s: &*rhs_s,
123            rhs_c: &*rhs_c,
124            rhs_d: &*rhs_d,
125        };
126        let mut sol = AugSysSol {
127            sol_x: &mut *sol_x,
128            sol_s: &mut *sol_s,
129            sol_c: y_c,
130            sol_d: y_d,
131        };
132
133        let num_eq = aug_rhs.rhs_c.dim() + aug_rhs.rhs_d.dim();
134        let check_neg = aug_solver.provides_inertia();
135        let status = aug_solver.solve(&coeffs, &aug_rhs, &mut sol, check_neg, num_eq);
136        matches!(status, ESymSolverStatus::Success)
137    }
138}