Skip to main content

wls_alloc/
linalg.rs

1//! Shared linear algebra utilities: constraint checking, Householder QR, and
2//! triangular back-substitution.
3//!
4//! These operate on raw `[[f32; R]; C]` arrays (column-major: `arr[col][row]`)
5//! because the naive solver's subproblem has a runtime column count that
6//! nalgebra's static QR cannot express without `alloc`.
7
8use crate::types::CONSTR_TOL;
9
10/// Check which elements of `x` violate bounds, with relative + absolute tolerance.
11///
12/// For each `i` in `0..n_check`, examines `x[perm[i]]` (or `x[i]` when `perm`
13/// is `None`) against `xmin` / `xmax`. Returns the number of violations, and
14/// writes `+1` (upper), `-1` (lower), or `0` (feasible) into `output`.
15pub fn check_limits_tol<const N: usize>(
16    n_check: usize,
17    x: &[f32; N],
18    xmin: &[f32; N],
19    xmax: &[f32; N],
20    output: &mut [i8; N],
21    perm: Option<&[usize; N]>,
22) -> usize {
23    let tol = CONSTR_TOL;
24    let mut count = 0usize;
25    for i in 0..n_check {
26        let ind = match perm {
27            Some(p) => p[i],
28            None => i,
29        };
30        let sign_max: f32 = if xmax[ind] > 0.0 { 1.0 } else { -1.0 };
31        let upper = xmax[ind] * (1.0 + sign_max * tol) + tol;
32        let sign_min: f32 = if xmin[ind] < 0.0 { 1.0 } else { -1.0 };
33        let lower = xmin[ind] * (1.0 + sign_min * tol) - tol;
34
35        if x[ind] >= upper {
36            output[ind] = 1;
37            count += 1;
38        } else if x[ind] <= lower {
39            output[ind] = -1;
40            count += 1;
41        } else {
42            output[ind] = 0;
43        }
44    }
45    count
46}
47
48// ---------------------------------------------------------------------------
49// Householder QR with explicit Q recovery
50// ---------------------------------------------------------------------------
51
52/// Compute full Householder QR factorisation: A (m x n) = Q (m x m) * R (m x n).
53///
54/// Arrays are column-major: `mat[col][row]`.
55/// `work` is the input; `q` and `r` are written on output.
56#[allow(clippy::needless_range_loop)] // cross-column 2D array access prevents iterator use
57pub fn householder_qr<const M: usize, const N: usize>(
58    work: &[[f32; M]; N],
59    q: &mut [[f32; M]; M],
60    r: &mut [[f32; M]; N],
61    m: usize,
62    n: usize,
63) {
64    for j in 0..n {
65        for i in 0..m {
66            r[j][i] = work[j][i];
67        }
68    }
69    for j in n..N {
70        for i in 0..M {
71            r[j][i] = 0.0;
72        }
73    }
74
75    let mut tau = [0.0f32; N];
76    let kmax = if m < n { m } else { n };
77
78    for k in 0..kmax {
79        let mut nu = 0.0f32;
80        for i in (k + 1)..m {
81            nu += r[k][i] * r[k][i];
82        }
83        nu = libm::sqrtf(nu);
84
85        if nu < 1e-12 {
86            tau[k] = 0.0;
87            continue;
88        }
89
90        let beta = (if r[k][k] >= 0.0 { -1.0f32 } else { 1.0 }) * libm::hypotf(r[k][k], nu);
91        tau[k] = (beta - r[k][k]) / beta;
92        let scale = 1.0 / (r[k][k] - beta);
93        for i in (k + 1)..m {
94            r[k][i] *= scale;
95        }
96        r[k][k] = beta;
97
98        for j in (k + 1)..n {
99            let mut w = r[j][k];
100            for i in (k + 1)..m {
101                w += r[k][i] * r[j][i];
102            }
103            r[j][k] -= tau[k] * w;
104            for i in (k + 1)..m {
105                r[j][i] -= tau[k] * r[k][i] * w;
106            }
107        }
108    }
109
110    // Recover explicit Q from stored Householder vectors (reverse order)
111    for j in 0..M {
112        for i in 0..M {
113            q[j][i] = if i == j { 1.0 } else { 0.0 };
114        }
115    }
116    for k in (0..kmax).rev() {
117        if tau[k] == 0.0 {
118            continue;
119        }
120        for j in (k..m).rev() {
121            let mut w = q[j][k];
122            for i in (k + 1)..m {
123                w += r[k][i] * q[j][i];
124            }
125            q[j][k] -= tau[k] * w;
126            for i in (k + 1)..m {
127                q[j][i] -= tau[k] * r[k][i] * w;
128            }
129        }
130    }
131
132    // Zero lower triangle of R (was Householder vector storage)
133    for j in 0..n {
134        for i in (j + 1)..m {
135            r[j][i] = 0.0;
136        }
137    }
138}
139
140/// Back-substitute `Rx = b` where R is upper-triangular (first `n` rows/cols).
141pub fn backward_tri_solve<const M: usize, const N: usize>(
142    r: &[[f32; M]; N],
143    b: &[f32; N],
144    x: &mut [f32; N],
145    n: usize,
146) {
147    if n == 0 {
148        return;
149    }
150    x[n - 1] = b[n - 1] / r[n - 1][n - 1];
151    for i in (0..n.saturating_sub(1)).rev() {
152        let mut s = 0.0f32;
153        for j in (i + 1)..n {
154            s += r[j][i] * x[j];
155        }
156        x[i] = (b[i] - s) / r[i][i];
157    }
158}