Skip to main content

pounce_algorithm/sqp/
bfgs.rs

1//! Powell-damped BFGS Hessian approximation for SQP (Powell
2//! 1978, *Numerical Analysis Dundee 1977*). Used when
3//! `SqpOptions::hessian = DampedBfgs` — the QP subproblem's
4//! Hessian comes from this rank-2-updated matrix instead of
5//! `nlp.eval_hess_lag`.
6//!
7//! Powell's damping rule guarantees positive-definiteness of
8//! every iterate, so the QP solver doesn't have to engage
9//! inertia control to keep `∇²L`-quadratic models PD. The
10//! damping factor `θ ∈ [0, 1]` interpolates between the raw
11//! BFGS `y = ∇L_new − ∇L_old` and the conservative `B·s`:
12//!
13//! ```text
14//!     if sᵀy ≥ 0.2 · sᵀ B s :  θ = 1            (standard BFGS)
15//!     else                  :  θ = 0.8 · sᵀ B s / (sᵀ B s − sᵀy)
16//!     y_damp = θ y + (1 − θ) B s
17//!     B_new = B − (Bs · sᵀB) / (sᵀ B s)
18//!                  + (y_damp · y_dampᵀ) / (sᵀ y_damp)
19//! ```
20//!
21//! Storage is dense `n × n` (lower-triangle row-major); exposed
22//! to `pounce-qp` as a fully-populated [`Triplet`] over the upper
23//! triangle (1-based row/col).
24
25use crate::sqp::qp_assembly::Triplet;
26use pounce_common::types::{Index, Number};
27
28pub struct DampedBfgs {
29    n: usize,
30    /// Lower-triangle row-major storage:
31    /// `b[i*(i+1)/2 + j] = B[i, j]` for `i ≥ j`.
32    b: Vec<Number>,
33    /// Previous `x` and ∇L; updated at the end of each `update` call.
34    prev_x: Option<Vec<Number>>,
35    prev_grad_lag: Option<Vec<Number>>,
36    /// Pre-computed sparsity pattern for `as_triplet`. Fixed:
37    /// every (i, j) with `i ≥ j`. 1-based.
38    h_irow: Vec<Index>,
39    h_jcol: Vec<Index>,
40}
41
42impl DampedBfgs {
43    pub fn new(n: usize) -> Self {
44        let nz = n * (n + 1) / 2;
45        let mut b = vec![0.0; nz];
46        let mut h_irow = Vec::with_capacity(nz);
47        let mut h_jcol = Vec::with_capacity(nz);
48        for i in 0..n {
49            for j in 0..=i {
50                if i == j {
51                    b[i * (i + 1) / 2 + j] = 1.0;
52                }
53                h_irow.push((i + 1) as Index);
54                h_jcol.push((j + 1) as Index);
55            }
56        }
57        Self {
58            n,
59            b,
60            prev_x: None,
61            prev_grad_lag: None,
62            h_irow,
63            h_jcol,
64        }
65    }
66
67    /// Have we recorded a previous `(x, ∇L)`? `false` until the
68    /// first call to [`Self::update`].
69    pub fn has_prev(&self) -> bool {
70        self.prev_x.is_some()
71    }
72
73    fn idx(&self, i: usize, j: usize) -> usize {
74        debug_assert!(i < self.n && j < self.n);
75        let (lo, hi) = if i >= j { (j, i) } else { (i, j) };
76        hi * (hi + 1) / 2 + lo
77    }
78
79    fn get(&self, i: usize, j: usize) -> Number {
80        self.b[self.idx(i, j)]
81    }
82
83    fn set(&mut self, i: usize, j: usize, v: Number) {
84        let k = self.idx(i, j);
85        self.b[k] = v;
86    }
87
88    /// Apply the Powell-damped BFGS update from the previous
89    /// `(x_old, ∇L_old)` to the supplied `(x_new, ∇L_new)`. The
90    /// first call just stores the pair; subsequent calls also
91    /// modify `B`.
92    pub fn update(&mut self, x_new: &[Number], grad_lag_new: &[Number]) {
93        // Hard assert (PR #50 review S5): a length mismatch here
94        // would silently mis-compute the rank-2 update in release
95        // builds with debug_assert.
96        assert_eq!(x_new.len(), self.n, "BFGS::update: x_new.len() != n");
97        assert_eq!(
98            grad_lag_new.len(),
99            self.n,
100            "BFGS::update: grad_lag_new.len() != n"
101        );
102
103        if let (Some(prev_x), Some(prev_grad_lag)) = (self.prev_x.take(), self.prev_grad_lag.take())
104        {
105            let s: Vec<Number> = x_new
106                .iter()
107                .zip(prev_x.iter())
108                .map(|(a, b)| a - b)
109                .collect();
110            let y: Vec<Number> = grad_lag_new
111                .iter()
112                .zip(prev_grad_lag.iter())
113                .map(|(a, b)| a - b)
114                .collect();
115
116            // bs = B · s
117            let bs: Vec<Number> = (0..self.n)
118                .map(|i| (0..self.n).map(|j| self.get(i, j) * s[j]).sum())
119                .collect();
120
121            let s_bs: Number = s.iter().zip(bs.iter()).map(|(a, b)| a * b).sum();
122            let s_y: Number = s.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
123
124            // Powell damping.
125            let theta = if s_y >= 0.2 * s_bs {
126                1.0
127            } else if s_bs - s_y > 1e-14 {
128                0.8 * s_bs / (s_bs - s_y)
129            } else {
130                // Pathological — fall back to the unmodified
131                // identity update (no harm done).
132                1.0
133            };
134            let y_damp: Vec<Number> = y
135                .iter()
136                .zip(bs.iter())
137                .map(|(yi, bsi)| theta * yi + (1.0 - theta) * bsi)
138                .collect();
139            let s_y_damp: Number = s.iter().zip(y_damp.iter()).map(|(a, b)| a * b).sum();
140
141            if s_bs > 1e-14 && s_y_damp > 1e-14 {
142                for i in 0..self.n {
143                    for j in 0..=i {
144                        let new_val = self.get(i, j) - (bs[i] * bs[j]) / s_bs
145                            + (y_damp[i] * y_damp[j]) / s_y_damp;
146                        self.set(i, j, new_val);
147                    }
148                }
149            }
150        }
151
152        self.prev_x = Some(x_new.to_vec());
153        self.prev_grad_lag = Some(grad_lag_new.to_vec());
154    }
155
156    /// Produce the current B as a `Triplet` over the upper
157    /// triangle (1-based), ready to feed into `SqpQpData::build`.
158    pub fn as_triplet(&self) -> Triplet {
159        let mut vals = Vec::with_capacity(self.h_irow.len());
160        for i in 0..self.n {
161            for j in 0..=i {
162                vals.push(self.get(i, j));
163            }
164        }
165        Triplet {
166            n_rows: self.n,
167            n_cols: self.n,
168            irow: self.h_irow.clone(),
169            jcol: self.h_jcol.clone(),
170            vals,
171        }
172    }
173}