Skip to main content

wls_alloc/
solver.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use crate::linalg::check_limits_tol;
4use crate::types::{ExitCode, MatA, SolverStats, VecN, CONSTR_TOL};
5
6// ---------------------------------------------------------------------------
7// Givens helpers — minimal bounds (just element access, no QR/DimMin needed)
8// ---------------------------------------------------------------------------
9
10#[inline]
11fn givens(a: f32, b: f32) -> (f32, f32) {
12    let h = libm::hypotf(a, b);
13    let sigma = 1.0 / h;
14    (sigma * a, -(sigma * b))
15}
16
17#[inline]
18fn givens_left_apply<const N: usize>(
19    r: &mut OMatrix<f32, Const<N>, Const<N>>,
20    c: f32,
21    s: f32,
22    row1: usize,
23    row2: usize,
24    n_cols: usize,
25) where
26    Const<N>: DimName,
27    DefaultAllocator: Allocator<Const<N>, Const<N>>,
28{
29    for col in 0..n_cols {
30        let r1 = r[(row1, col)];
31        let r2 = r[(row2, col)];
32        r[(row1, col)] = c * r1 - s * r2;
33        r[(row2, col)] = s * r1 + c * r2;
34    }
35}
36
37#[inline]
38fn givens_right_apply_t<const R: usize, const C: usize>(
39    q: &mut OMatrix<f32, Const<R>, Const<C>>,
40    c: f32,
41    s: f32,
42    col1: usize,
43    col2: usize,
44    n_rows: usize,
45) where
46    Const<R>: DimName,
47    Const<C>: DimName,
48    DefaultAllocator: Allocator<Const<R>, Const<C>>,
49{
50    for i in 0..n_rows {
51        let q1 = q[(i, col1)];
52        let q2 = q[(i, col2)];
53        q[(i, col1)] = c * q1 - s * q2;
54        q[(i, col2)] = s * q1 + c * q2;
55    }
56}
57
58fn qr_shift<const NC: usize, const NU: usize>(
59    q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
60    r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
61    i: usize,
62    j: usize,
63) where
64    Const<NC>: DimName,
65    Const<NU>: DimName,
66    DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
67{
68    if i == j {
69        return;
70    }
71
72    let n_givens: usize;
73    if i > j {
74        n_givens = i - j;
75        for l in 0..NU {
76            let tmp = r[(l, j)];
77            for k in j..i {
78                r[(l, k)] = r[(l, k + 1)];
79            }
80            r[(l, i)] = tmp;
81        }
82    } else {
83        n_givens = j - i;
84        for l in 0..NU {
85            let tmp = r[(l, j)];
86            for k in (i..j).rev() {
87                r[(l, k + 1)] = r[(l, k)];
88            }
89            r[(l, i)] = tmp;
90        }
91    }
92
93    for k in 0..n_givens {
94        let (j1, i1) = if j > i {
95            (j - k - 1, i)
96        } else {
97            (j + k, j + k)
98        };
99        let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
100        givens_left_apply(r, c, s, j1, j1 + 1, NU);
101        givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
102    }
103}
104
105fn backward_tri_solve<const NU: usize>(
106    r: &OMatrix<f32, Const<NU>, Const<NU>>,
107    b: &[f32; NU],
108    x: &mut [f32; NU],
109    n: usize,
110) where
111    Const<NU>: DimName,
112    DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
113{
114    if n == 0 {
115        return;
116    }
117    x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
118    for i in (0..n.saturating_sub(1)).rev() {
119        let mut s = 0.0f32;
120        for j in (i + 1)..n {
121            s += r[(i, j)] * x[j];
122        }
123        x[i] = (b[i] - s) / r[(i, i)];
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Incremental active-set solver
129// ---------------------------------------------------------------------------
130
131/// Active-set solver with incremental QR updates via Givens rotations.
132///
133/// Uses nalgebra's Householder QR for the initial factorisation, then Givens
134/// column-shift updates when constraints activate/deactivate.
135///
136/// Translates `solveActiveSet_qr.c`.
137#[allow(clippy::needless_range_loop)] // multi-array index loops (ws, us, perm, bounds)
138pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
139    a: &MatA<NC, NU>,
140    b: &VecN<NC>,
141    umin: &VecN<NU>,
142    umax: &VecN<NU>,
143    us: &mut VecN<NU>,
144    ws: &mut [i8; NU],
145    imax: usize,
146) -> SolverStats
147where
148    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
149    Const<NU>: DimName,
150    Const<NV>: DimName,
151    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
152        + Allocator<Const<NC>, Const<NC>>
153        + Allocator<Const<NU>, Const<NU>>
154        + Allocator<Const<NC>>
155        + Allocator<Const<NU>>,
156{
157    debug_assert_eq!(NC, NU + NV);
158    let imax = if imax == 0 { 100 } else { imax };
159
160    for i in 0..NU {
161        if ws[i] == 0 {
162            if us[i] > umax[i] {
163                us[i] = umax[i];
164            } else if us[i] < umin[i] {
165                us[i] = umin[i];
166            }
167        } else {
168            us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
169        }
170    }
171
172    // Permutation: free first, bounded after
173    let mut perm = [0usize; NU];
174    let mut n_free: usize = 0;
175    for i in 0..NU {
176        if ws[i] == 0 {
177            perm[n_free] = i;
178            n_free += 1;
179        }
180    }
181    let mut i_bnd: usize = 0;
182    for i in 0..NU {
183        if ws[i] != 0 {
184            perm[n_free + i_bnd] = i;
185            i_bnd += 1;
186        }
187    }
188
189    // Permuted A → nalgebra QR → thin Q (NC×NU) and thin R (NU×NU)
190    let mut a_perm: MatA<NC, NU> = MatA::zeros();
191    for j in 0..NU {
192        for i in 0..NC {
193            a_perm[(i, j)] = a[(i, perm[j])];
194        }
195    }
196    let qr_decomp = a_perm.qr();
197    let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
198    let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
199
200    let mut z = [0.0f32; NU];
201    let mut exit_code = ExitCode::IterLimit;
202
203    let mut iter: usize = 0;
204    while {
205        iter += 1;
206        iter <= imax
207    } {
208        let mut c = [0.0f32; NU];
209        for i in 0..n_free {
210            let mut s = 0.0f32;
211            for j in 0..NC {
212                s += q[(j, i)] * b[j];
213            }
214            c[i] = s;
215        }
216
217        for i in 0..n_free {
218            for j in 0..(NU - n_free) {
219                let pi = perm[n_free + j];
220                let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
221                c[i] -= r[(i, n_free + j)] * ub;
222            }
223        }
224
225        let mut q_sol = [0.0f32; NU];
226        backward_tri_solve(&r, &c, &mut q_sol, n_free);
227
228        let mut nan_found = false;
229        for i in 0..n_free {
230            if f32::is_nan(q_sol[i]) {
231                nan_found = true;
232                break;
233            }
234            z[perm[i]] = q_sol[i];
235        }
236        if nan_found {
237            exit_code = ExitCode::NanFoundQ;
238            break;
239        }
240        for i in n_free..NU {
241            z[perm[i]] = us[perm[i]];
242        }
243
244        let mut umin_arr = [0.0f32; NU];
245        let mut umax_arr = [0.0f32; NU];
246        for i in 0..NU {
247            umin_arr[i] = umin[i];
248            umax_arr[i] = umax[i];
249        }
250        let mut w_temp = [0i8; NU];
251        let n_violated =
252            check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
253
254        if n_violated == 0 {
255            for i in 0..n_free {
256                us[perm[i]] = z[perm[i]];
257            }
258
259            if n_free == NU {
260                exit_code = ExitCode::Success;
261                break;
262            }
263
264            let mut d = [0.0f32; NU];
265            for i in n_free..NU {
266                let mut s = 0.0f32;
267                for j in 0..NC {
268                    s += q[(j, i)] * b[j];
269                }
270                d[i] = s;
271            }
272            for i in n_free..NU {
273                for j in i..NU {
274                    d[i] -= r[(i, j)] * us[perm[j]];
275                }
276            }
277
278            let mut f_free: usize = 0;
279            let mut maxlam: f32 = f32::NEG_INFINITY;
280            for i in n_free..NU {
281                let mut lam = 0.0f32;
282                for j in n_free..=i {
283                    lam += r[(j, i)] * d[j];
284                }
285                lam *= -f32::from(ws[perm[i]]);
286                if lam > maxlam {
287                    maxlam = lam;
288                    f_free = i - n_free;
289                }
290            }
291
292            if maxlam <= CONSTR_TOL {
293                exit_code = ExitCode::Success;
294                break;
295            }
296
297            qr_shift(&mut q, &mut r, n_free, n_free + f_free);
298            ws[perm[n_free + f_free]] = 0;
299            let last_val = perm[n_free + f_free];
300            for i in (1..=f_free).rev() {
301                perm[n_free + i] = perm[n_free + i - 1];
302            }
303            perm[n_free] = last_val;
304            n_free += 1;
305        } else {
306            let mut alpha: f32 = f32::INFINITY;
307            let mut i_a: usize = 0;
308            let mut f_bound: usize = 0;
309            let mut i_s: i8 = 0;
310
311            for f in 0..n_free {
312                let ii = perm[f];
313                let (tmp, ts) = if w_temp[ii] == -1 {
314                    ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
315                } else if w_temp[ii] == 1 {
316                    ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
317                } else {
318                    continue;
319                };
320                if tmp < alpha {
321                    alpha = tmp;
322                    i_a = ii;
323                    f_bound = f;
324                    i_s = ts;
325                }
326            }
327
328            let mut nan_found = false;
329            for i in 0..NU {
330                if i == i_a {
331                    us[i] = if i_s == 1 { umax[i] } else { umin[i] };
332                } else {
333                    us[i] += alpha * (z[i] - us[i]);
334                }
335                if f32::is_nan(us[i]) {
336                    nan_found = true;
337                    break;
338                }
339            }
340            if nan_found {
341                exit_code = ExitCode::NanFoundUs;
342                break;
343            }
344
345            qr_shift(&mut q, &mut r, n_free - 1, f_bound);
346            ws[i_a] = i_s;
347            let first_val = perm[f_bound];
348            for i in 0..(n_free - f_bound - 1) {
349                perm[f_bound + i] = perm[f_bound + i + 1];
350            }
351            n_free -= 1;
352            perm[n_free] = first_val;
353        }
354    }
355    if exit_code == ExitCode::IterLimit {
356        iter -= 1;
357    }
358    SolverStats {
359        exit_code,
360        iterations: iter,
361        n_free,
362    }
363}