1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use crate::linalg::check_limits_tol;
4use crate::types::{ExitCode, MatA, SolverStats, VecN, CONSTR_TOL};
5
6#[inline]
11fn givens(a: f32, b: f32) -> (f32, f32) {
12 let h = libm::hypotf(a, b);
13 let sigma = 1.0 / h;
14 (sigma * a, -(sigma * b))
15}
16
17#[inline]
18fn givens_left_apply<const N: usize>(
19 r: &mut OMatrix<f32, Const<N>, Const<N>>,
20 c: f32,
21 s: f32,
22 row1: usize,
23 row2: usize,
24 n_cols: usize,
25) where
26 Const<N>: DimName,
27 DefaultAllocator: Allocator<Const<N>, Const<N>>,
28{
29 for col in 0..n_cols {
30 let r1 = r[(row1, col)];
31 let r2 = r[(row2, col)];
32 r[(row1, col)] = c * r1 - s * r2;
33 r[(row2, col)] = s * r1 + c * r2;
34 }
35}
36
37#[inline]
38fn givens_right_apply_t<const R: usize, const C: usize>(
39 q: &mut OMatrix<f32, Const<R>, Const<C>>,
40 c: f32,
41 s: f32,
42 col1: usize,
43 col2: usize,
44 n_rows: usize,
45) where
46 Const<R>: DimName,
47 Const<C>: DimName,
48 DefaultAllocator: Allocator<Const<R>, Const<C>>,
49{
50 for i in 0..n_rows {
51 let q1 = q[(i, col1)];
52 let q2 = q[(i, col2)];
53 q[(i, col1)] = c * q1 - s * q2;
54 q[(i, col2)] = s * q1 + c * q2;
55 }
56}
57
58fn qr_shift<const NC: usize, const NU: usize>(
59 q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
60 r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
61 qtb: &mut [f32; NU],
62 i: usize,
63 j: usize,
64) where
65 Const<NC>: DimName,
66 Const<NU>: DimName,
67 DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
68{
69 if i == j {
70 return;
71 }
72
73 let n_givens: usize;
74 if i > j {
75 n_givens = i - j;
76 for l in 0..NU {
77 let tmp = r[(l, j)];
78 for k in j..i {
79 r[(l, k)] = r[(l, k + 1)];
80 }
81 r[(l, i)] = tmp;
82 }
83 } else {
84 n_givens = j - i;
85 for l in 0..NU {
86 let tmp = r[(l, j)];
87 for k in (i..j).rev() {
88 r[(l, k + 1)] = r[(l, k)];
89 }
90 r[(l, i)] = tmp;
91 }
92 }
93
94 for k in 0..n_givens {
95 let (j1, i1) = if j > i {
96 (j - k - 1, i)
97 } else {
98 (j + k, j + k)
99 };
100 let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
101 givens_left_apply(r, c, s, j1, j1 + 1, NU);
102 givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
103 let t1 = qtb[j1];
105 let t2 = qtb[j1 + 1];
106 qtb[j1] = c * t1 - s * t2;
107 qtb[j1 + 1] = s * t1 + c * t2;
108 }
109}
110
111fn backward_tri_solve<const NU: usize>(
112 r: &OMatrix<f32, Const<NU>, Const<NU>>,
113 b: &[f32; NU],
114 x: &mut [f32; NU],
115 n: usize,
116) where
117 Const<NU>: DimName,
118 DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
119{
120 if n == 0 {
121 return;
122 }
123 x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
124 for i in (0..n.saturating_sub(1)).rev() {
125 let mut s = 0.0f32;
126 for j in (i + 1)..n {
127 s += r[(i, j)] * x[j];
128 }
129 x[i] = (b[i] - s) / r[(i, i)];
130 }
131}
132
133#[allow(clippy::needless_range_loop)] pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
145 a: &MatA<NC, NU>,
146 b: &VecN<NC>,
147 umin: &VecN<NU>,
148 umax: &VecN<NU>,
149 us: &mut VecN<NU>,
150 ws: &mut [i8; NU],
151 imax: usize,
152) -> SolverStats
153where
154 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
155 Const<NU>: DimName,
156 Const<NV>: DimName,
157 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
158 + Allocator<Const<NC>, Const<NC>>
159 + Allocator<Const<NU>, Const<NU>>
160 + Allocator<Const<NC>>
161 + Allocator<Const<NU>>,
162{
163 debug_assert_eq!(NC, NU + NV);
164 let imax = if imax == 0 { 100 } else { imax };
165
166 for i in 0..NU {
167 if ws[i] == 0 {
168 if us[i] > umax[i] {
169 us[i] = umax[i];
170 } else if us[i] < umin[i] {
171 us[i] = umin[i];
172 }
173 } else {
174 us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
175 }
176 }
177
178 let mut perm = [0usize; NU];
180 let mut n_free: usize = 0;
181 for i in 0..NU {
182 if ws[i] == 0 {
183 perm[n_free] = i;
184 n_free += 1;
185 }
186 }
187 let mut i_bnd: usize = 0;
188 for i in 0..NU {
189 if ws[i] != 0 {
190 perm[n_free + i_bnd] = i;
191 i_bnd += 1;
192 }
193 }
194
195 let mut a_perm: MatA<NC, NU> = MatA::zeros();
197 for j in 0..NU {
198 for i in 0..NC {
199 a_perm[(i, j)] = a[(i, perm[j])];
200 }
201 }
202 let qr_decomp = a_perm.qr();
203 let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
204 let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
205
206 let mut qtb = [0.0f32; NU];
208 for i in 0..NU {
209 let mut s = 0.0f32;
210 for j in 0..NC {
211 s += q[(j, i)] * b[j];
212 }
213 qtb[i] = s;
214 }
215
216 let mut umin_arr = [0.0f32; NU];
218 let mut umax_arr = [0.0f32; NU];
219 for i in 0..NU {
220 umin_arr[i] = umin[i];
221 umax_arr[i] = umax[i];
222 }
223 let mut w_temp = [0i8; NU];
224
225 let mut z = [0.0f32; NU];
226 let mut exit_code = ExitCode::IterLimit;
227
228 let mut iter: usize = 0;
229 while {
230 iter += 1;
231 iter <= imax
232 } {
233 let mut c = [0.0f32; NU];
235 c[..n_free].copy_from_slice(&qtb[..n_free]);
236
237 for i in 0..n_free {
238 for j in 0..(NU - n_free) {
239 let pi = perm[n_free + j];
240 let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
241 c[i] -= r[(i, n_free + j)] * ub;
242 }
243 }
244
245 let mut q_sol = [0.0f32; NU];
246 backward_tri_solve(&r, &c, &mut q_sol, n_free);
247
248 let mut nan_found = false;
249 for i in 0..n_free {
250 if f32::is_nan(q_sol[i]) {
251 nan_found = true;
252 break;
253 }
254 z[perm[i]] = q_sol[i];
255 }
256 if nan_found {
257 exit_code = ExitCode::NanFoundQ;
258 break;
259 }
260 for i in n_free..NU {
261 z[perm[i]] = us[perm[i]];
262 }
263
264 let n_violated =
265 check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
266
267 if n_violated == 0 {
268 for i in 0..n_free {
269 us[perm[i]] = z[perm[i]];
270 }
271
272 if n_free == NU {
273 exit_code = ExitCode::Success;
274 break;
275 }
276
277 let mut d = [0.0f32; NU];
279 d[n_free..NU].copy_from_slice(&qtb[n_free..NU]);
280 for i in n_free..NU {
281 for j in i..NU {
282 d[i] -= r[(i, j)] * us[perm[j]];
283 }
284 }
285
286 let mut f_free: usize = 0;
287 let mut maxlam: f32 = f32::NEG_INFINITY;
288 for i in n_free..NU {
289 let mut lam = 0.0f32;
290 for j in n_free..=i {
291 lam += r[(j, i)] * d[j];
292 }
293 lam *= -f32::from(ws[perm[i]]);
294 if lam > maxlam {
295 maxlam = lam;
296 f_free = i - n_free;
297 }
298 }
299
300 if maxlam <= CONSTR_TOL {
301 exit_code = ExitCode::Success;
302 break;
303 }
304
305 qr_shift(&mut q, &mut r, &mut qtb, n_free, n_free + f_free);
306 ws[perm[n_free + f_free]] = 0;
307 let last_val = perm[n_free + f_free];
308 for i in (1..=f_free).rev() {
309 perm[n_free + i] = perm[n_free + i - 1];
310 }
311 perm[n_free] = last_val;
312 n_free += 1;
313 } else {
314 let mut alpha: f32 = f32::INFINITY;
315 let mut i_a: usize = 0;
316 let mut f_bound: usize = 0;
317 let mut i_s: i8 = 0;
318
319 for f in 0..n_free {
320 let ii = perm[f];
321 let (tmp, ts) = if w_temp[ii] == -1 {
322 ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
323 } else if w_temp[ii] == 1 {
324 ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
325 } else {
326 continue;
327 };
328 if tmp < alpha {
329 alpha = tmp;
330 i_a = ii;
331 f_bound = f;
332 i_s = ts;
333 }
334 }
335
336 let mut nan_found = false;
337 for i in 0..NU {
338 if i == i_a {
339 us[i] = if i_s == 1 { umax[i] } else { umin[i] };
340 } else {
341 us[i] += alpha * (z[i] - us[i]);
342 }
343 if f32::is_nan(us[i]) {
344 nan_found = true;
345 break;
346 }
347 }
348 if nan_found {
349 exit_code = ExitCode::NanFoundUs;
350 break;
351 }
352
353 qr_shift(&mut q, &mut r, &mut qtb, n_free - 1, f_bound);
354 ws[i_a] = i_s;
355 let first_val = perm[f_bound];
356 for i in 0..(n_free - f_bound - 1) {
357 perm[f_bound + i] = perm[f_bound + i + 1];
358 }
359 n_free -= 1;
360 perm[n_free] = first_val;
361 }
362 }
363 if exit_code == ExitCode::IterLimit {
364 iter -= 1;
365 }
366 SolverStats {
367 exit_code,
368 iterations: iter,
369 n_free,
370 }
371}