Skip to main content

wls_alloc/
solver.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use crate::linalg::check_limits_tol;
4use crate::types::{ExitCode, MatA, SolverStats, VecN, CONSTR_TOL};
5
6// ---------------------------------------------------------------------------
7// Givens helpers — minimal bounds (just element access, no QR/DimMin needed)
8// ---------------------------------------------------------------------------
9
10#[inline]
11fn givens(a: f32, b: f32) -> (f32, f32) {
12    let h = libm::hypotf(a, b);
13    let sigma = 1.0 / h;
14    (sigma * a, -(sigma * b))
15}
16
17#[inline]
18fn givens_left_apply<const N: usize>(
19    r: &mut OMatrix<f32, Const<N>, Const<N>>,
20    c: f32,
21    s: f32,
22    row1: usize,
23    row2: usize,
24    n_cols: usize,
25) where
26    Const<N>: DimName,
27    DefaultAllocator: Allocator<Const<N>, Const<N>>,
28{
29    for col in 0..n_cols {
30        let r1 = r[(row1, col)];
31        let r2 = r[(row2, col)];
32        r[(row1, col)] = c * r1 - s * r2;
33        r[(row2, col)] = s * r1 + c * r2;
34    }
35}
36
37#[inline]
38fn givens_right_apply_t<const R: usize, const C: usize>(
39    q: &mut OMatrix<f32, Const<R>, Const<C>>,
40    c: f32,
41    s: f32,
42    col1: usize,
43    col2: usize,
44    n_rows: usize,
45) where
46    Const<R>: DimName,
47    Const<C>: DimName,
48    DefaultAllocator: Allocator<Const<R>, Const<C>>,
49{
50    for i in 0..n_rows {
51        let q1 = q[(i, col1)];
52        let q2 = q[(i, col2)];
53        q[(i, col1)] = c * q1 - s * q2;
54        q[(i, col2)] = s * q1 + c * q2;
55    }
56}
57
58fn qr_shift<const NC: usize, const NU: usize>(
59    q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
60    r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
61    qtb: &mut [f32; NU],
62    i: usize,
63    j: usize,
64) where
65    Const<NC>: DimName,
66    Const<NU>: DimName,
67    DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
68{
69    if i == j {
70        return;
71    }
72
73    let n_givens: usize;
74    if i > j {
75        n_givens = i - j;
76        for l in 0..NU {
77            let tmp = r[(l, j)];
78            for k in j..i {
79                r[(l, k)] = r[(l, k + 1)];
80            }
81            r[(l, i)] = tmp;
82        }
83    } else {
84        n_givens = j - i;
85        for l in 0..NU {
86            let tmp = r[(l, j)];
87            for k in (i..j).rev() {
88                r[(l, k + 1)] = r[(l, k)];
89            }
90            r[(l, i)] = tmp;
91        }
92    }
93
94    for k in 0..n_givens {
95        let (j1, i1) = if j > i {
96            (j - k - 1, i)
97        } else {
98            (j + k, j + k)
99        };
100        let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
101        givens_left_apply(r, c, s, j1, j1 + 1, NU);
102        givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
103        // Incrementally update cached Q^T * b
104        let t1 = qtb[j1];
105        let t2 = qtb[j1 + 1];
106        qtb[j1] = c * t1 - s * t2;
107        qtb[j1 + 1] = s * t1 + c * t2;
108    }
109}
110
111fn backward_tri_solve<const NU: usize>(
112    r: &OMatrix<f32, Const<NU>, Const<NU>>,
113    b: &[f32; NU],
114    x: &mut [f32; NU],
115    n: usize,
116) where
117    Const<NU>: DimName,
118    DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
119{
120    if n == 0 {
121        return;
122    }
123    x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
124    for i in (0..n.saturating_sub(1)).rev() {
125        let mut s = 0.0f32;
126        for j in (i + 1)..n {
127            s += r[(i, j)] * x[j];
128        }
129        x[i] = (b[i] - s) / r[(i, i)];
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Incremental active-set solver
135// ---------------------------------------------------------------------------
136
137/// Active-set solver for the regularised WLS problem.
138///
139/// This is a convenience wrapper around [`solve_cls`] that enforces
140/// `NC == NU + NV` (the augmented system produced by [`setup_a`] /
141/// [`setup_b`]).
142///
143/// [`setup_a`]: crate::setup_a
144/// [`setup_b`]: crate::setup_b
145#[allow(clippy::needless_range_loop)] // forwarding wrapper
146pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
147    a: &MatA<NC, NU>,
148    b: &VecN<NC>,
149    umin: &VecN<NU>,
150    umax: &VecN<NU>,
151    us: &mut VecN<NU>,
152    ws: &mut [i8; NU],
153    imax: usize,
154) -> SolverStats
155where
156    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
157    Const<NU>: DimName,
158    Const<NV>: DimName,
159    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
160        + Allocator<Const<NC>, Const<NC>>
161        + Allocator<Const<NU>, Const<NU>>
162        + Allocator<Const<NC>>
163        + Allocator<Const<NU>>,
164{
165    debug_assert_eq!(NC, NU + NV);
166    solve_cls(a, b, umin, umax, us, ws, imax)
167}
168
169/// General box-constrained least-squares solver.
170///
171/// Solves `min ‖Au − b‖²` subject to `umin ≤ u ≤ umax` using an active-set
172/// method with incremental QR updates (Givens rotations).
173///
174/// Unlike [`solve`], this function does **not** require `NC == NU + NV` and
175/// accepts any `NC ≥ NU`. Use it with the unregularised setup functions
176/// ([`setup_a_unreg`] / [`setup_b_unreg`]) or with a custom `A` / `b`.
177///
178/// [`setup_a_unreg`]: crate::setup_a_unreg
179/// [`setup_b_unreg`]: crate::setup_b_unreg
180#[allow(clippy::needless_range_loop)] // multi-array index loops (ws, us, perm, bounds)
181pub fn solve_cls<const NU: usize, const NC: usize>(
182    a: &MatA<NC, NU>,
183    b: &VecN<NC>,
184    umin: &VecN<NU>,
185    umax: &VecN<NU>,
186    us: &mut VecN<NU>,
187    ws: &mut [i8; NU],
188    imax: usize,
189) -> SolverStats
190where
191    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
192    Const<NU>: DimName,
193    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
194        + Allocator<Const<NC>, Const<NC>>
195        + Allocator<Const<NU>, Const<NU>>
196        + Allocator<Const<NC>>
197        + Allocator<Const<NU>>,
198{
199    let imax = if imax == 0 { 100 } else { imax };
200
201    for i in 0..NU {
202        if ws[i] == 0 {
203            if us[i] > umax[i] {
204                us[i] = umax[i];
205            } else if us[i] < umin[i] {
206                us[i] = umin[i];
207            }
208        } else {
209            us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
210        }
211    }
212
213    // Permutation: free first, bounded after
214    let mut perm = [0usize; NU];
215    let mut n_free: usize = 0;
216    for i in 0..NU {
217        if ws[i] == 0 {
218            perm[n_free] = i;
219            n_free += 1;
220        }
221    }
222    let mut i_bnd: usize = 0;
223    for i in 0..NU {
224        if ws[i] != 0 {
225            perm[n_free + i_bnd] = i;
226            i_bnd += 1;
227        }
228    }
229
230    // Permuted A → nalgebra QR → thin Q (NC×NU) and thin R (NU×NU)
231    let mut a_perm: MatA<NC, NU> = MatA::zeros();
232    for j in 0..NU {
233        for i in 0..NC {
234            a_perm[(i, j)] = a[(i, perm[j])];
235        }
236    }
237    let qr_decomp = a_perm.qr();
238    let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
239    let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
240
241    // Cache Q^T * b — updated incrementally via Givens in qr_shift
242    let mut qtb = [0.0f32; NU];
243    for i in 0..NU {
244        let mut s = 0.0f32;
245        for j in 0..NC {
246            s += q[(j, i)] * b[j];
247        }
248        qtb[i] = s;
249    }
250
251    // Hoist bound arrays and scratch space outside the loop
252    let mut umin_arr = [0.0f32; NU];
253    let mut umax_arr = [0.0f32; NU];
254    for i in 0..NU {
255        umin_arr[i] = umin[i];
256        umax_arr[i] = umax[i];
257    }
258    let mut w_temp = [0i8; NU];
259
260    let mut z = [0.0f32; NU];
261    let mut exit_code = ExitCode::IterLimit;
262
263    let mut iter: usize = 0;
264    while {
265        iter += 1;
266        iter <= imax
267    } {
268        // Use cached Q^T * b instead of recomputing from Q
269        let mut c = [0.0f32; NU];
270        c[..n_free].copy_from_slice(&qtb[..n_free]);
271
272        for i in 0..n_free {
273            for j in 0..(NU - n_free) {
274                let pi = perm[n_free + j];
275                let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
276                c[i] -= r[(i, n_free + j)] * ub;
277            }
278        }
279
280        let mut q_sol = [0.0f32; NU];
281        backward_tri_solve(&r, &c, &mut q_sol, n_free);
282
283        let mut nan_found = false;
284        for i in 0..n_free {
285            if f32::is_nan(q_sol[i]) {
286                nan_found = true;
287                break;
288            }
289            z[perm[i]] = q_sol[i];
290        }
291        if nan_found {
292            exit_code = ExitCode::NanFoundQ;
293            break;
294        }
295        for i in n_free..NU {
296            z[perm[i]] = us[perm[i]];
297        }
298
299        let n_violated =
300            check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
301
302        if n_violated == 0 {
303            for i in 0..n_free {
304                us[perm[i]] = z[perm[i]];
305            }
306
307            if n_free == NU {
308                exit_code = ExitCode::Success;
309                break;
310            }
311
312            // Dual variables — use cached qtb instead of recomputing Q^T * b
313            let mut d = [0.0f32; NU];
314            d[n_free..NU].copy_from_slice(&qtb[n_free..NU]);
315            for i in n_free..NU {
316                for j in i..NU {
317                    d[i] -= r[(i, j)] * us[perm[j]];
318                }
319            }
320
321            let mut f_free: usize = 0;
322            let mut maxlam: f32 = f32::NEG_INFINITY;
323            for i in n_free..NU {
324                let mut lam = 0.0f32;
325                for j in n_free..=i {
326                    lam += r[(j, i)] * d[j];
327                }
328                lam *= -f32::from(ws[perm[i]]);
329                if lam > maxlam {
330                    maxlam = lam;
331                    f_free = i - n_free;
332                }
333            }
334
335            if maxlam <= CONSTR_TOL {
336                exit_code = ExitCode::Success;
337                break;
338            }
339
340            qr_shift(&mut q, &mut r, &mut qtb, n_free, n_free + f_free);
341            ws[perm[n_free + f_free]] = 0;
342            let last_val = perm[n_free + f_free];
343            for i in (1..=f_free).rev() {
344                perm[n_free + i] = perm[n_free + i - 1];
345            }
346            perm[n_free] = last_val;
347            n_free += 1;
348        } else {
349            let mut alpha: f32 = f32::INFINITY;
350            let mut i_a: usize = 0;
351            let mut f_bound: usize = 0;
352            let mut i_s: i8 = 0;
353
354            for f in 0..n_free {
355                let ii = perm[f];
356                let (tmp, ts) = if w_temp[ii] == -1 {
357                    ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
358                } else if w_temp[ii] == 1 {
359                    ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
360                } else {
361                    continue;
362                };
363                if tmp < alpha {
364                    alpha = tmp;
365                    i_a = ii;
366                    f_bound = f;
367                    i_s = ts;
368                }
369            }
370
371            let mut nan_found = false;
372            for i in 0..NU {
373                if i == i_a {
374                    us[i] = if i_s == 1 { umax[i] } else { umin[i] };
375                } else {
376                    us[i] += alpha * (z[i] - us[i]);
377                }
378                if f32::is_nan(us[i]) {
379                    nan_found = true;
380                    break;
381                }
382            }
383            if nan_found {
384                exit_code = ExitCode::NanFoundUs;
385                break;
386            }
387
388            qr_shift(&mut q, &mut r, &mut qtb, n_free - 1, f_bound);
389            ws[i_a] = i_s;
390            let first_val = perm[f_bound];
391            for i in 0..(n_free - f_bound - 1) {
392                perm[f_bound + i] = perm[f_bound + i + 1];
393            }
394            n_free -= 1;
395            perm[n_free] = first_val;
396        }
397    }
398    if exit_code == ExitCode::IterLimit {
399        iter -= 1;
400    }
401    SolverStats {
402        exit_code,
403        iterations: iter,
404        n_free,
405    }
406}