1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use crate::linalg::check_limits_tol;
4use crate::types::{ExitCode, MatA, SolverStats, VecN, CONSTR_TOL};
5
6#[inline]
11fn givens(a: f32, b: f32) -> (f32, f32) {
12 let h = libm::hypotf(a, b);
13 let sigma = 1.0 / h;
14 (sigma * a, -(sigma * b))
15}
16
17#[inline]
18fn givens_left_apply<const N: usize>(
19 r: &mut OMatrix<f32, Const<N>, Const<N>>,
20 c: f32,
21 s: f32,
22 row1: usize,
23 row2: usize,
24 n_cols: usize,
25) where
26 Const<N>: DimName,
27 DefaultAllocator: Allocator<Const<N>, Const<N>>,
28{
29 for col in 0..n_cols {
30 let r1 = r[(row1, col)];
31 let r2 = r[(row2, col)];
32 r[(row1, col)] = c * r1 - s * r2;
33 r[(row2, col)] = s * r1 + c * r2;
34 }
35}
36
37#[inline]
38fn givens_right_apply_t<const R: usize, const C: usize>(
39 q: &mut OMatrix<f32, Const<R>, Const<C>>,
40 c: f32,
41 s: f32,
42 col1: usize,
43 col2: usize,
44 n_rows: usize,
45) where
46 Const<R>: DimName,
47 Const<C>: DimName,
48 DefaultAllocator: Allocator<Const<R>, Const<C>>,
49{
50 for i in 0..n_rows {
51 let q1 = q[(i, col1)];
52 let q2 = q[(i, col2)];
53 q[(i, col1)] = c * q1 - s * q2;
54 q[(i, col2)] = s * q1 + c * q2;
55 }
56}
57
58fn qr_shift<const NC: usize, const NU: usize>(
59 q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
60 r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
61 qtb: &mut [f32; NU],
62 i: usize,
63 j: usize,
64) where
65 Const<NC>: DimName,
66 Const<NU>: DimName,
67 DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
68{
69 if i == j {
70 return;
71 }
72
73 let n_givens: usize;
74 if i > j {
75 n_givens = i - j;
76 for l in 0..NU {
77 let tmp = r[(l, j)];
78 for k in j..i {
79 r[(l, k)] = r[(l, k + 1)];
80 }
81 r[(l, i)] = tmp;
82 }
83 } else {
84 n_givens = j - i;
85 for l in 0..NU {
86 let tmp = r[(l, j)];
87 for k in (i..j).rev() {
88 r[(l, k + 1)] = r[(l, k)];
89 }
90 r[(l, i)] = tmp;
91 }
92 }
93
94 for k in 0..n_givens {
95 let (j1, i1) = if j > i {
96 (j - k - 1, i)
97 } else {
98 (j + k, j + k)
99 };
100 let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
101 givens_left_apply(r, c, s, j1, j1 + 1, NU);
102 givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
103 let t1 = qtb[j1];
105 let t2 = qtb[j1 + 1];
106 qtb[j1] = c * t1 - s * t2;
107 qtb[j1 + 1] = s * t1 + c * t2;
108 }
109}
110
111fn backward_tri_solve<const NU: usize>(
112 r: &OMatrix<f32, Const<NU>, Const<NU>>,
113 b: &[f32; NU],
114 x: &mut [f32; NU],
115 n: usize,
116) where
117 Const<NU>: DimName,
118 DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
119{
120 if n == 0 {
121 return;
122 }
123 x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
124 for i in (0..n.saturating_sub(1)).rev() {
125 let mut s = 0.0f32;
126 for j in (i + 1)..n {
127 s += r[(i, j)] * x[j];
128 }
129 x[i] = (b[i] - s) / r[(i, i)];
130 }
131}
132
133#[allow(clippy::needless_range_loop)] pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
147 a: &MatA<NC, NU>,
148 b: &VecN<NC>,
149 umin: &VecN<NU>,
150 umax: &VecN<NU>,
151 us: &mut VecN<NU>,
152 ws: &mut [i8; NU],
153 imax: usize,
154) -> SolverStats
155where
156 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
157 Const<NU>: DimName,
158 Const<NV>: DimName,
159 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
160 + Allocator<Const<NC>, Const<NC>>
161 + Allocator<Const<NU>, Const<NU>>
162 + Allocator<Const<NC>>
163 + Allocator<Const<NU>>,
164{
165 debug_assert_eq!(NC, NU + NV);
166 solve_cls(a, b, umin, umax, us, ws, imax)
167}
168
169#[allow(clippy::needless_range_loop)] pub fn solve_cls<const NU: usize, const NC: usize>(
182 a: &MatA<NC, NU>,
183 b: &VecN<NC>,
184 umin: &VecN<NU>,
185 umax: &VecN<NU>,
186 us: &mut VecN<NU>,
187 ws: &mut [i8; NU],
188 imax: usize,
189) -> SolverStats
190where
191 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
192 Const<NU>: DimName,
193 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
194 + Allocator<Const<NC>, Const<NC>>
195 + Allocator<Const<NU>, Const<NU>>
196 + Allocator<Const<NC>>
197 + Allocator<Const<NU>>,
198{
199 let imax = if imax == 0 { 100 } else { imax };
200
201 for i in 0..NU {
202 if ws[i] == 0 {
203 if us[i] > umax[i] {
204 us[i] = umax[i];
205 } else if us[i] < umin[i] {
206 us[i] = umin[i];
207 }
208 } else {
209 us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
210 }
211 }
212
213 let mut perm = [0usize; NU];
215 let mut n_free: usize = 0;
216 for i in 0..NU {
217 if ws[i] == 0 {
218 perm[n_free] = i;
219 n_free += 1;
220 }
221 }
222 let mut i_bnd: usize = 0;
223 for i in 0..NU {
224 if ws[i] != 0 {
225 perm[n_free + i_bnd] = i;
226 i_bnd += 1;
227 }
228 }
229
230 let mut a_perm: MatA<NC, NU> = MatA::zeros();
232 for j in 0..NU {
233 for i in 0..NC {
234 a_perm[(i, j)] = a[(i, perm[j])];
235 }
236 }
237 let qr_decomp = a_perm.qr();
238 let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
239 let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
240
241 let mut qtb = [0.0f32; NU];
243 for i in 0..NU {
244 let mut s = 0.0f32;
245 for j in 0..NC {
246 s += q[(j, i)] * b[j];
247 }
248 qtb[i] = s;
249 }
250
251 let mut umin_arr = [0.0f32; NU];
253 let mut umax_arr = [0.0f32; NU];
254 for i in 0..NU {
255 umin_arr[i] = umin[i];
256 umax_arr[i] = umax[i];
257 }
258 let mut w_temp = [0i8; NU];
259
260 let mut z = [0.0f32; NU];
261 let mut exit_code = ExitCode::IterLimit;
262
263 let mut iter: usize = 0;
264 while {
265 iter += 1;
266 iter <= imax
267 } {
268 let mut c = [0.0f32; NU];
270 c[..n_free].copy_from_slice(&qtb[..n_free]);
271
272 for i in 0..n_free {
273 for j in 0..(NU - n_free) {
274 let pi = perm[n_free + j];
275 let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
276 c[i] -= r[(i, n_free + j)] * ub;
277 }
278 }
279
280 let mut q_sol = [0.0f32; NU];
281 backward_tri_solve(&r, &c, &mut q_sol, n_free);
282
283 let mut nan_found = false;
284 for i in 0..n_free {
285 if f32::is_nan(q_sol[i]) {
286 nan_found = true;
287 break;
288 }
289 z[perm[i]] = q_sol[i];
290 }
291 if nan_found {
292 exit_code = ExitCode::NanFoundQ;
293 break;
294 }
295 for i in n_free..NU {
296 z[perm[i]] = us[perm[i]];
297 }
298
299 let n_violated =
300 check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
301
302 if n_violated == 0 {
303 for i in 0..n_free {
304 us[perm[i]] = z[perm[i]];
305 }
306
307 if n_free == NU {
308 exit_code = ExitCode::Success;
309 break;
310 }
311
312 let mut d = [0.0f32; NU];
314 d[n_free..NU].copy_from_slice(&qtb[n_free..NU]);
315 for i in n_free..NU {
316 for j in i..NU {
317 d[i] -= r[(i, j)] * us[perm[j]];
318 }
319 }
320
321 let mut f_free: usize = 0;
322 let mut maxlam: f32 = f32::NEG_INFINITY;
323 for i in n_free..NU {
324 let mut lam = 0.0f32;
325 for j in n_free..=i {
326 lam += r[(j, i)] * d[j];
327 }
328 lam *= -f32::from(ws[perm[i]]);
329 if lam > maxlam {
330 maxlam = lam;
331 f_free = i - n_free;
332 }
333 }
334
335 if maxlam <= CONSTR_TOL {
336 exit_code = ExitCode::Success;
337 break;
338 }
339
340 qr_shift(&mut q, &mut r, &mut qtb, n_free, n_free + f_free);
341 ws[perm[n_free + f_free]] = 0;
342 let last_val = perm[n_free + f_free];
343 for i in (1..=f_free).rev() {
344 perm[n_free + i] = perm[n_free + i - 1];
345 }
346 perm[n_free] = last_val;
347 n_free += 1;
348 } else {
349 let mut alpha: f32 = f32::INFINITY;
350 let mut i_a: usize = 0;
351 let mut f_bound: usize = 0;
352 let mut i_s: i8 = 0;
353
354 for f in 0..n_free {
355 let ii = perm[f];
356 let (tmp, ts) = if w_temp[ii] == -1 {
357 ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
358 } else if w_temp[ii] == 1 {
359 ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
360 } else {
361 continue;
362 };
363 if tmp < alpha {
364 alpha = tmp;
365 i_a = ii;
366 f_bound = f;
367 i_s = ts;
368 }
369 }
370
371 let mut nan_found = false;
372 for i in 0..NU {
373 if i == i_a {
374 us[i] = if i_s == 1 { umax[i] } else { umin[i] };
375 } else {
376 us[i] += alpha * (z[i] - us[i]);
377 }
378 if f32::is_nan(us[i]) {
379 nan_found = true;
380 break;
381 }
382 }
383 if nan_found {
384 exit_code = ExitCode::NanFoundUs;
385 break;
386 }
387
388 qr_shift(&mut q, &mut r, &mut qtb, n_free - 1, f_bound);
389 ws[i_a] = i_s;
390 let first_val = perm[f_bound];
391 for i in 0..(n_free - f_bound - 1) {
392 perm[f_bound + i] = perm[f_bound + i + 1];
393 }
394 n_free -= 1;
395 perm[n_free] = first_val;
396 }
397 }
398 if exit_code == ExitCode::IterLimit {
399 iter -= 1;
400 }
401 SolverStats {
402 exit_code,
403 iterations: iter,
404 n_free,
405 }
406}