Skip to main content

flight_solver/cls/
solver.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use super::linalg::check_limits_tol;
4use super::types::{ExitCode, Mat, SolverStats, VecN, CONSTR_TOL};
5use crate::givens::{givens, givens_left_apply, givens_right_apply_t};
6
7fn qr_shift<const NC: usize, const NU: usize>(
8    q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
9    r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
10    qtb: &mut [f32; NU],
11    i: usize,
12    j: usize,
13) where
14    Const<NC>: DimName,
15    Const<NU>: DimName,
16    DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
17{
18    if i == j {
19        return;
20    }
21
22    let n_givens: usize;
23    if i > j {
24        n_givens = i - j;
25        for l in 0..NU {
26            let tmp = r[(l, j)];
27            for k in j..i {
28                r[(l, k)] = r[(l, k + 1)];
29            }
30            r[(l, i)] = tmp;
31        }
32    } else {
33        n_givens = j - i;
34        for l in 0..NU {
35            let tmp = r[(l, j)];
36            for k in (i..j).rev() {
37                r[(l, k + 1)] = r[(l, k)];
38            }
39            r[(l, i)] = tmp;
40        }
41    }
42
43    for k in 0..n_givens {
44        let (j1, i1) = if j > i {
45            (j - k - 1, i)
46        } else {
47            (j + k, j + k)
48        };
49        let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
50        givens_left_apply(r, c, s, j1, j1 + 1, NU);
51        givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
52        // Incrementally update cached Q^T * b
53        let t1 = qtb[j1];
54        let t2 = qtb[j1 + 1];
55        qtb[j1] = c * t1 - s * t2;
56        qtb[j1 + 1] = s * t1 + c * t2;
57    }
58}
59
60fn backward_tri_solve<const NU: usize>(
61    r: &OMatrix<f32, Const<NU>, Const<NU>>,
62    b: &[f32; NU],
63    x: &mut [f32; NU],
64    n: usize,
65) where
66    Const<NU>: DimName,
67    DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
68{
69    if n == 0 {
70        return;
71    }
72    x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
73    for i in (0..n.saturating_sub(1)).rev() {
74        let mut s = 0.0f32;
75        for j in (i + 1)..n {
76            s += r[(i, j)] * x[j];
77        }
78        x[i] = (b[i] - s) / r[(i, i)];
79    }
80}
81
82// ---------------------------------------------------------------------------
83// Incremental active-set solver
84// ---------------------------------------------------------------------------
85
86/// Active-set solver for the regularised WLS problem.
87///
88/// This is a convenience wrapper around [`solve_cls`] that enforces
89/// `NC == NU + NV` (the augmented system produced by
90/// [`setup::wls::setup_a`](crate::cls::setup::wls::setup_a) /
91/// [`setup::wls::setup_b`](crate::cls::setup::wls::setup_b)).
92#[allow(clippy::needless_range_loop)] // forwarding wrapper
93pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
94    a: &Mat<NC, NU>,
95    b: &VecN<NC>,
96    umin: &VecN<NU>,
97    umax: &VecN<NU>,
98    us: &mut VecN<NU>,
99    ws: &mut [i8; NU],
100    imax: usize,
101) -> SolverStats
102where
103    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
104    Const<NU>: DimName,
105    Const<NV>: DimName,
106    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
107        + Allocator<Const<NC>, Const<NC>>
108        + Allocator<Const<NU>, Const<NU>>
109        + Allocator<Const<NC>>
110        + Allocator<Const<NU>>,
111{
112    debug_assert_eq!(NC, NU + NV);
113    solve_cls(a, b, umin, umax, us, ws, imax)
114}
115
116/// General box-constrained least-squares solver.
117///
118/// Solves `min ‖Au − b‖²` subject to `umin ≤ u ≤ umax` using an active-set
119/// method with incremental QR updates (Givens rotations).
120///
121/// Unlike [`solve`], this function does **not** require `NC == NU + NV` and
122/// accepts any `NC ≥ NU`. Use it with the unregularised setup functions
123/// ([`setup::ls`](crate::cls::setup::ls)) or with a custom `A` / `b`.
124#[allow(clippy::needless_range_loop)] // multi-array index loops (ws, us, perm, bounds)
125pub fn solve_cls<const NU: usize, const NC: usize>(
126    a: &Mat<NC, NU>,
127    b: &VecN<NC>,
128    umin: &VecN<NU>,
129    umax: &VecN<NU>,
130    us: &mut VecN<NU>,
131    ws: &mut [i8; NU],
132    imax: usize,
133) -> SolverStats
134where
135    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
136    Const<NU>: DimName,
137    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
138        + Allocator<Const<NC>, Const<NC>>
139        + Allocator<Const<NU>, Const<NU>>
140        + Allocator<Const<NC>>
141        + Allocator<Const<NU>>,
142{
143    let imax = if imax == 0 { 100 } else { imax };
144
145    for i in 0..NU {
146        if ws[i] == 0 {
147            if us[i] > umax[i] {
148                us[i] = umax[i];
149            } else if us[i] < umin[i] {
150                us[i] = umin[i];
151            }
152        } else {
153            us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
154        }
155    }
156
157    // Permutation: free first, bounded after
158    let mut perm = [0usize; NU];
159    let mut n_free: usize = 0;
160    for i in 0..NU {
161        if ws[i] == 0 {
162            perm[n_free] = i;
163            n_free += 1;
164        }
165    }
166    let mut i_bnd: usize = 0;
167    for i in 0..NU {
168        if ws[i] != 0 {
169            perm[n_free + i_bnd] = i;
170            i_bnd += 1;
171        }
172    }
173
174    // Permuted A → nalgebra QR → thin Q (NC×NU) and thin R (NU×NU)
175    let mut a_perm: Mat<NC, NU> = Mat::zeros();
176    for j in 0..NU {
177        for i in 0..NC {
178            a_perm[(i, j)] = a[(i, perm[j])];
179        }
180    }
181    let qr_decomp = a_perm.qr();
182    let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
183    let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
184
185    // Cache Q^T * b — updated incrementally via Givens in qr_shift
186    let mut qtb = [0.0f32; NU];
187    for i in 0..NU {
188        let mut s = 0.0f32;
189        for j in 0..NC {
190            s += q[(j, i)] * b[j];
191        }
192        qtb[i] = s;
193    }
194
195    // Hoist bound arrays and scratch space outside the loop
196    let mut umin_arr = [0.0f32; NU];
197    let mut umax_arr = [0.0f32; NU];
198    for i in 0..NU {
199        umin_arr[i] = umin[i];
200        umax_arr[i] = umax[i];
201    }
202    let mut w_temp = [0i8; NU];
203
204    let mut z = [0.0f32; NU];
205    let mut exit_code = ExitCode::IterLimit;
206
207    let mut iter: usize = 0;
208    while {
209        iter += 1;
210        iter <= imax
211    } {
212        // Use cached Q^T * b instead of recomputing from Q
213        let mut c = [0.0f32; NU];
214        c[..n_free].copy_from_slice(&qtb[..n_free]);
215
216        for i in 0..n_free {
217            for j in 0..(NU - n_free) {
218                let pi = perm[n_free + j];
219                let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
220                c[i] -= r[(i, n_free + j)] * ub;
221            }
222        }
223
224        let mut q_sol = [0.0f32; NU];
225        backward_tri_solve(&r, &c, &mut q_sol, n_free);
226
227        let mut nan_found = false;
228        for i in 0..n_free {
229            if f32::is_nan(q_sol[i]) {
230                nan_found = true;
231                break;
232            }
233            z[perm[i]] = q_sol[i];
234        }
235        if nan_found {
236            exit_code = ExitCode::NanFoundQ;
237            break;
238        }
239        for i in n_free..NU {
240            z[perm[i]] = us[perm[i]];
241        }
242
243        let n_violated =
244            check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
245
246        if n_violated == 0 {
247            for i in 0..n_free {
248                us[perm[i]] = z[perm[i]];
249            }
250
251            if n_free == NU {
252                exit_code = ExitCode::Success;
253                break;
254            }
255
256            // Dual variables — use cached qtb instead of recomputing Q^T * b
257            let mut d = [0.0f32; NU];
258            d[n_free..NU].copy_from_slice(&qtb[n_free..NU]);
259            for i in n_free..NU {
260                for j in i..NU {
261                    d[i] -= r[(i, j)] * us[perm[j]];
262                }
263            }
264
265            let mut f_free: usize = 0;
266            let mut maxlam: f32 = f32::NEG_INFINITY;
267            for i in n_free..NU {
268                let mut lam = 0.0f32;
269                for j in n_free..=i {
270                    lam += r[(j, i)] * d[j];
271                }
272                lam *= -f32::from(ws[perm[i]]);
273                if lam > maxlam {
274                    maxlam = lam;
275                    f_free = i - n_free;
276                }
277            }
278
279            if maxlam <= CONSTR_TOL {
280                exit_code = ExitCode::Success;
281                break;
282            }
283
284            qr_shift(&mut q, &mut r, &mut qtb, n_free, n_free + f_free);
285            ws[perm[n_free + f_free]] = 0;
286            let last_val = perm[n_free + f_free];
287            for i in (1..=f_free).rev() {
288                perm[n_free + i] = perm[n_free + i - 1];
289            }
290            perm[n_free] = last_val;
291            n_free += 1;
292        } else {
293            let mut alpha: f32 = f32::INFINITY;
294            let mut i_a: usize = 0;
295            let mut f_bound: usize = 0;
296            let mut i_s: i8 = 0;
297
298            for f in 0..n_free {
299                let ii = perm[f];
300                let (tmp, ts) = if w_temp[ii] == -1 {
301                    ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
302                } else if w_temp[ii] == 1 {
303                    ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
304                } else {
305                    continue;
306                };
307                if tmp < alpha {
308                    alpha = tmp;
309                    i_a = ii;
310                    f_bound = f;
311                    i_s = ts;
312                }
313            }
314
315            let mut nan_found = false;
316            for i in 0..NU {
317                if i == i_a {
318                    us[i] = if i_s == 1 { umax[i] } else { umin[i] };
319                } else {
320                    us[i] += alpha * (z[i] - us[i]);
321                }
322                if f32::is_nan(us[i]) {
323                    nan_found = true;
324                    break;
325                }
326            }
327            if nan_found {
328                exit_code = ExitCode::NanFoundUs;
329                break;
330            }
331
332            qr_shift(&mut q, &mut r, &mut qtb, n_free - 1, f_bound);
333            ws[i_a] = i_s;
334            let first_val = perm[f_bound];
335            for i in 0..(n_free - f_bound - 1) {
336                perm[f_bound + i] = perm[f_bound + i + 1];
337            }
338            n_free -= 1;
339            perm[n_free] = first_val;
340        }
341    }
342    if exit_code == ExitCode::IterLimit {
343        iter -= 1;
344    }
345    SolverStats {
346        exit_code,
347        iterations: iter,
348        n_free,
349    }
350}