Skip to main content

wls_alloc/
solver.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use crate::linalg::check_limits_tol;
4use crate::types::{ExitCode, MatA, SolverStats, VecN, CONSTR_TOL};
5
6// ---------------------------------------------------------------------------
7// Givens helpers — minimal bounds (just element access, no QR/DimMin needed)
8// ---------------------------------------------------------------------------
9
10#[inline]
11fn givens(a: f32, b: f32) -> (f32, f32) {
12    let h = libm::hypotf(a, b);
13    let sigma = 1.0 / h;
14    (sigma * a, -(sigma * b))
15}
16
17#[inline]
18fn givens_left_apply<const N: usize>(
19    r: &mut OMatrix<f32, Const<N>, Const<N>>,
20    c: f32,
21    s: f32,
22    row1: usize,
23    row2: usize,
24    n_cols: usize,
25) where
26    Const<N>: DimName,
27    DefaultAllocator: Allocator<Const<N>, Const<N>>,
28{
29    for col in 0..n_cols {
30        let r1 = r[(row1, col)];
31        let r2 = r[(row2, col)];
32        r[(row1, col)] = c * r1 - s * r2;
33        r[(row2, col)] = s * r1 + c * r2;
34    }
35}
36
37#[inline]
38fn givens_right_apply_t<const R: usize, const C: usize>(
39    q: &mut OMatrix<f32, Const<R>, Const<C>>,
40    c: f32,
41    s: f32,
42    col1: usize,
43    col2: usize,
44    n_rows: usize,
45) where
46    Const<R>: DimName,
47    Const<C>: DimName,
48    DefaultAllocator: Allocator<Const<R>, Const<C>>,
49{
50    for i in 0..n_rows {
51        let q1 = q[(i, col1)];
52        let q2 = q[(i, col2)];
53        q[(i, col1)] = c * q1 - s * q2;
54        q[(i, col2)] = s * q1 + c * q2;
55    }
56}
57
58fn qr_shift<const NC: usize, const NU: usize>(
59    q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
60    r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
61    qtb: &mut [f32; NU],
62    i: usize,
63    j: usize,
64) where
65    Const<NC>: DimName,
66    Const<NU>: DimName,
67    DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
68{
69    if i == j {
70        return;
71    }
72
73    let n_givens: usize;
74    if i > j {
75        n_givens = i - j;
76        for l in 0..NU {
77            let tmp = r[(l, j)];
78            for k in j..i {
79                r[(l, k)] = r[(l, k + 1)];
80            }
81            r[(l, i)] = tmp;
82        }
83    } else {
84        n_givens = j - i;
85        for l in 0..NU {
86            let tmp = r[(l, j)];
87            for k in (i..j).rev() {
88                r[(l, k + 1)] = r[(l, k)];
89            }
90            r[(l, i)] = tmp;
91        }
92    }
93
94    for k in 0..n_givens {
95        let (j1, i1) = if j > i {
96            (j - k - 1, i)
97        } else {
98            (j + k, j + k)
99        };
100        let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
101        givens_left_apply(r, c, s, j1, j1 + 1, NU);
102        givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
103        // Incrementally update cached Q^T * b
104        let t1 = qtb[j1];
105        let t2 = qtb[j1 + 1];
106        qtb[j1] = c * t1 - s * t2;
107        qtb[j1 + 1] = s * t1 + c * t2;
108    }
109}
110
111fn backward_tri_solve<const NU: usize>(
112    r: &OMatrix<f32, Const<NU>, Const<NU>>,
113    b: &[f32; NU],
114    x: &mut [f32; NU],
115    n: usize,
116) where
117    Const<NU>: DimName,
118    DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
119{
120    if n == 0 {
121        return;
122    }
123    x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
124    for i in (0..n.saturating_sub(1)).rev() {
125        let mut s = 0.0f32;
126        for j in (i + 1)..n {
127            s += r[(i, j)] * x[j];
128        }
129        x[i] = (b[i] - s) / r[(i, i)];
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Incremental active-set solver
135// ---------------------------------------------------------------------------
136
137/// Active-set solver with incremental QR updates via Givens rotations.
138///
139/// Uses nalgebra's Householder QR for the initial factorisation, then Givens
140/// column-shift updates when constraints activate/deactivate.
141///
142/// Translates `solveActiveSet_qr.c`.
143#[allow(clippy::needless_range_loop)] // multi-array index loops (ws, us, perm, bounds)
144pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
145    a: &MatA<NC, NU>,
146    b: &VecN<NC>,
147    umin: &VecN<NU>,
148    umax: &VecN<NU>,
149    us: &mut VecN<NU>,
150    ws: &mut [i8; NU],
151    imax: usize,
152) -> SolverStats
153where
154    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
155    Const<NU>: DimName,
156    Const<NV>: DimName,
157    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
158        + Allocator<Const<NC>, Const<NC>>
159        + Allocator<Const<NU>, Const<NU>>
160        + Allocator<Const<NC>>
161        + Allocator<Const<NU>>,
162{
163    debug_assert_eq!(NC, NU + NV);
164    let imax = if imax == 0 { 100 } else { imax };
165
166    for i in 0..NU {
167        if ws[i] == 0 {
168            if us[i] > umax[i] {
169                us[i] = umax[i];
170            } else if us[i] < umin[i] {
171                us[i] = umin[i];
172            }
173        } else {
174            us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
175        }
176    }
177
178    // Permutation: free first, bounded after
179    let mut perm = [0usize; NU];
180    let mut n_free: usize = 0;
181    for i in 0..NU {
182        if ws[i] == 0 {
183            perm[n_free] = i;
184            n_free += 1;
185        }
186    }
187    let mut i_bnd: usize = 0;
188    for i in 0..NU {
189        if ws[i] != 0 {
190            perm[n_free + i_bnd] = i;
191            i_bnd += 1;
192        }
193    }
194
195    // Permuted A → nalgebra QR → thin Q (NC×NU) and thin R (NU×NU)
196    let mut a_perm: MatA<NC, NU> = MatA::zeros();
197    for j in 0..NU {
198        for i in 0..NC {
199            a_perm[(i, j)] = a[(i, perm[j])];
200        }
201    }
202    let qr_decomp = a_perm.qr();
203    let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
204    let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
205
206    // Cache Q^T * b — updated incrementally via Givens in qr_shift
207    let mut qtb = [0.0f32; NU];
208    for i in 0..NU {
209        let mut s = 0.0f32;
210        for j in 0..NC {
211            s += q[(j, i)] * b[j];
212        }
213        qtb[i] = s;
214    }
215
216    // Hoist bound arrays and scratch space outside the loop
217    let mut umin_arr = [0.0f32; NU];
218    let mut umax_arr = [0.0f32; NU];
219    for i in 0..NU {
220        umin_arr[i] = umin[i];
221        umax_arr[i] = umax[i];
222    }
223    let mut w_temp = [0i8; NU];
224
225    let mut z = [0.0f32; NU];
226    let mut exit_code = ExitCode::IterLimit;
227
228    let mut iter: usize = 0;
229    while {
230        iter += 1;
231        iter <= imax
232    } {
233        // Use cached Q^T * b instead of recomputing from Q
234        let mut c = [0.0f32; NU];
235        c[..n_free].copy_from_slice(&qtb[..n_free]);
236
237        for i in 0..n_free {
238            for j in 0..(NU - n_free) {
239                let pi = perm[n_free + j];
240                let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
241                c[i] -= r[(i, n_free + j)] * ub;
242            }
243        }
244
245        let mut q_sol = [0.0f32; NU];
246        backward_tri_solve(&r, &c, &mut q_sol, n_free);
247
248        let mut nan_found = false;
249        for i in 0..n_free {
250            if f32::is_nan(q_sol[i]) {
251                nan_found = true;
252                break;
253            }
254            z[perm[i]] = q_sol[i];
255        }
256        if nan_found {
257            exit_code = ExitCode::NanFoundQ;
258            break;
259        }
260        for i in n_free..NU {
261            z[perm[i]] = us[perm[i]];
262        }
263
264        let n_violated =
265            check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
266
267        if n_violated == 0 {
268            for i in 0..n_free {
269                us[perm[i]] = z[perm[i]];
270            }
271
272            if n_free == NU {
273                exit_code = ExitCode::Success;
274                break;
275            }
276
277            // Dual variables — use cached qtb instead of recomputing Q^T * b
278            let mut d = [0.0f32; NU];
279            d[n_free..NU].copy_from_slice(&qtb[n_free..NU]);
280            for i in n_free..NU {
281                for j in i..NU {
282                    d[i] -= r[(i, j)] * us[perm[j]];
283                }
284            }
285
286            let mut f_free: usize = 0;
287            let mut maxlam: f32 = f32::NEG_INFINITY;
288            for i in n_free..NU {
289                let mut lam = 0.0f32;
290                for j in n_free..=i {
291                    lam += r[(j, i)] * d[j];
292                }
293                lam *= -f32::from(ws[perm[i]]);
294                if lam > maxlam {
295                    maxlam = lam;
296                    f_free = i - n_free;
297                }
298            }
299
300            if maxlam <= CONSTR_TOL {
301                exit_code = ExitCode::Success;
302                break;
303            }
304
305            qr_shift(&mut q, &mut r, &mut qtb, n_free, n_free + f_free);
306            ws[perm[n_free + f_free]] = 0;
307            let last_val = perm[n_free + f_free];
308            for i in (1..=f_free).rev() {
309                perm[n_free + i] = perm[n_free + i - 1];
310            }
311            perm[n_free] = last_val;
312            n_free += 1;
313        } else {
314            let mut alpha: f32 = f32::INFINITY;
315            let mut i_a: usize = 0;
316            let mut f_bound: usize = 0;
317            let mut i_s: i8 = 0;
318
319            for f in 0..n_free {
320                let ii = perm[f];
321                let (tmp, ts) = if w_temp[ii] == -1 {
322                    ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
323                } else if w_temp[ii] == 1 {
324                    ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
325                } else {
326                    continue;
327                };
328                if tmp < alpha {
329                    alpha = tmp;
330                    i_a = ii;
331                    f_bound = f;
332                    i_s = ts;
333                }
334            }
335
336            let mut nan_found = false;
337            for i in 0..NU {
338                if i == i_a {
339                    us[i] = if i_s == 1 { umax[i] } else { umin[i] };
340                } else {
341                    us[i] += alpha * (z[i] - us[i]);
342                }
343                if f32::is_nan(us[i]) {
344                    nan_found = true;
345                    break;
346                }
347            }
348            if nan_found {
349                exit_code = ExitCode::NanFoundUs;
350                break;
351            }
352
353            qr_shift(&mut q, &mut r, &mut qtb, n_free - 1, f_bound);
354            ws[i_a] = i_s;
355            let first_val = perm[f_bound];
356            for i in 0..(n_free - f_bound - 1) {
357                perm[f_bound + i] = perm[f_bound + i + 1];
358            }
359            n_free -= 1;
360            perm[n_free] = first_val;
361        }
362    }
363    if exit_code == ExitCode::IterLimit {
364        iter -= 1;
365    }
366    SolverStats {
367        exit_code,
368        iterations: iter,
369        n_free,
370    }
371}