1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use super::linalg::check_limits_tol;
4use super::types::{ExitCode, Mat, SolverStats, VecN, CONSTR_TOL};
5use crate::givens::{givens, givens_left_apply, givens_right_apply_t};
6
7fn qr_shift<const NC: usize, const NU: usize>(
8 q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
9 r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
10 qtb: &mut [f32; NU],
11 i: usize,
12 j: usize,
13) where
14 Const<NC>: DimName,
15 Const<NU>: DimName,
16 DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
17{
18 if i == j {
19 return;
20 }
21
22 let n_givens: usize;
23 if i > j {
24 n_givens = i - j;
25 for l in 0..NU {
26 let tmp = r[(l, j)];
27 for k in j..i {
28 r[(l, k)] = r[(l, k + 1)];
29 }
30 r[(l, i)] = tmp;
31 }
32 } else {
33 n_givens = j - i;
34 for l in 0..NU {
35 let tmp = r[(l, j)];
36 for k in (i..j).rev() {
37 r[(l, k + 1)] = r[(l, k)];
38 }
39 r[(l, i)] = tmp;
40 }
41 }
42
43 for k in 0..n_givens {
44 let (j1, i1) = if j > i {
45 (j - k - 1, i)
46 } else {
47 (j + k, j + k)
48 };
49 let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
50 givens_left_apply(r, c, s, j1, j1 + 1, NU);
51 givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
52 let t1 = qtb[j1];
54 let t2 = qtb[j1 + 1];
55 qtb[j1] = c * t1 - s * t2;
56 qtb[j1 + 1] = s * t1 + c * t2;
57 }
58}
59
60fn backward_tri_solve<const NU: usize>(
61 r: &OMatrix<f32, Const<NU>, Const<NU>>,
62 b: &[f32; NU],
63 x: &mut [f32; NU],
64 n: usize,
65) where
66 Const<NU>: DimName,
67 DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
68{
69 if n == 0 {
70 return;
71 }
72 x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
73 for i in (0..n.saturating_sub(1)).rev() {
74 let mut s = 0.0f32;
75 for j in (i + 1)..n {
76 s += r[(i, j)] * x[j];
77 }
78 x[i] = (b[i] - s) / r[(i, i)];
79 }
80}
81
82#[allow(clippy::needless_range_loop)] pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
94 a: &Mat<NC, NU>,
95 b: &VecN<NC>,
96 umin: &VecN<NU>,
97 umax: &VecN<NU>,
98 us: &mut VecN<NU>,
99 ws: &mut [i8; NU],
100 imax: usize,
101) -> SolverStats
102where
103 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
104 Const<NU>: DimName,
105 Const<NV>: DimName,
106 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
107 + Allocator<Const<NC>, Const<NC>>
108 + Allocator<Const<NU>, Const<NU>>
109 + Allocator<Const<NC>>
110 + Allocator<Const<NU>>,
111{
112 debug_assert_eq!(NC, NU + NV);
113 solve_cls(a, b, umin, umax, us, ws, imax)
114}
115
116#[allow(clippy::needless_range_loop)] pub fn solve_cls<const NU: usize, const NC: usize>(
126 a: &Mat<NC, NU>,
127 b: &VecN<NC>,
128 umin: &VecN<NU>,
129 umax: &VecN<NU>,
130 us: &mut VecN<NU>,
131 ws: &mut [i8; NU],
132 imax: usize,
133) -> SolverStats
134where
135 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
136 Const<NU>: DimName,
137 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
138 + Allocator<Const<NC>, Const<NC>>
139 + Allocator<Const<NU>, Const<NU>>
140 + Allocator<Const<NC>>
141 + Allocator<Const<NU>>,
142{
143 let imax = if imax == 0 { 100 } else { imax };
144
145 for i in 0..NU {
146 if ws[i] == 0 {
147 if us[i] > umax[i] {
148 us[i] = umax[i];
149 } else if us[i] < umin[i] {
150 us[i] = umin[i];
151 }
152 } else {
153 us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
154 }
155 }
156
157 let mut perm = [0usize; NU];
159 let mut n_free: usize = 0;
160 for i in 0..NU {
161 if ws[i] == 0 {
162 perm[n_free] = i;
163 n_free += 1;
164 }
165 }
166 let mut i_bnd: usize = 0;
167 for i in 0..NU {
168 if ws[i] != 0 {
169 perm[n_free + i_bnd] = i;
170 i_bnd += 1;
171 }
172 }
173
174 let mut a_perm: Mat<NC, NU> = Mat::zeros();
176 for j in 0..NU {
177 for i in 0..NC {
178 a_perm[(i, j)] = a[(i, perm[j])];
179 }
180 }
181 let qr_decomp = a_perm.qr();
182 let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
183 let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
184
185 let mut qtb = [0.0f32; NU];
187 for i in 0..NU {
188 let mut s = 0.0f32;
189 for j in 0..NC {
190 s += q[(j, i)] * b[j];
191 }
192 qtb[i] = s;
193 }
194
195 let mut umin_arr = [0.0f32; NU];
197 let mut umax_arr = [0.0f32; NU];
198 for i in 0..NU {
199 umin_arr[i] = umin[i];
200 umax_arr[i] = umax[i];
201 }
202 let mut w_temp = [0i8; NU];
203
204 let mut z = [0.0f32; NU];
205 let mut exit_code = ExitCode::IterLimit;
206
207 let mut iter: usize = 0;
208 while {
209 iter += 1;
210 iter <= imax
211 } {
212 let mut c = [0.0f32; NU];
214 c[..n_free].copy_from_slice(&qtb[..n_free]);
215
216 for i in 0..n_free {
217 for j in 0..(NU - n_free) {
218 let pi = perm[n_free + j];
219 let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
220 c[i] -= r[(i, n_free + j)] * ub;
221 }
222 }
223
224 let mut q_sol = [0.0f32; NU];
225 backward_tri_solve(&r, &c, &mut q_sol, n_free);
226
227 let mut nan_found = false;
228 for i in 0..n_free {
229 if f32::is_nan(q_sol[i]) {
230 nan_found = true;
231 break;
232 }
233 z[perm[i]] = q_sol[i];
234 }
235 if nan_found {
236 exit_code = ExitCode::NanFoundQ;
237 break;
238 }
239 for i in n_free..NU {
240 z[perm[i]] = us[perm[i]];
241 }
242
243 let n_violated =
244 check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
245
246 if n_violated == 0 {
247 for i in 0..n_free {
248 us[perm[i]] = z[perm[i]];
249 }
250
251 if n_free == NU {
252 exit_code = ExitCode::Success;
253 break;
254 }
255
256 let mut d = [0.0f32; NU];
258 d[n_free..NU].copy_from_slice(&qtb[n_free..NU]);
259 for i in n_free..NU {
260 for j in i..NU {
261 d[i] -= r[(i, j)] * us[perm[j]];
262 }
263 }
264
265 let mut f_free: usize = 0;
266 let mut maxlam: f32 = f32::NEG_INFINITY;
267 for i in n_free..NU {
268 let mut lam = 0.0f32;
269 for j in n_free..=i {
270 lam += r[(j, i)] * d[j];
271 }
272 lam *= -f32::from(ws[perm[i]]);
273 if lam > maxlam {
274 maxlam = lam;
275 f_free = i - n_free;
276 }
277 }
278
279 if maxlam <= CONSTR_TOL {
280 exit_code = ExitCode::Success;
281 break;
282 }
283
284 qr_shift(&mut q, &mut r, &mut qtb, n_free, n_free + f_free);
285 ws[perm[n_free + f_free]] = 0;
286 let last_val = perm[n_free + f_free];
287 for i in (1..=f_free).rev() {
288 perm[n_free + i] = perm[n_free + i - 1];
289 }
290 perm[n_free] = last_val;
291 n_free += 1;
292 } else {
293 let mut alpha: f32 = f32::INFINITY;
294 let mut i_a: usize = 0;
295 let mut f_bound: usize = 0;
296 let mut i_s: i8 = 0;
297
298 for f in 0..n_free {
299 let ii = perm[f];
300 let (tmp, ts) = if w_temp[ii] == -1 {
301 ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
302 } else if w_temp[ii] == 1 {
303 ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
304 } else {
305 continue;
306 };
307 if tmp < alpha {
308 alpha = tmp;
309 i_a = ii;
310 f_bound = f;
311 i_s = ts;
312 }
313 }
314
315 let mut nan_found = false;
316 for i in 0..NU {
317 if i == i_a {
318 us[i] = if i_s == 1 { umax[i] } else { umin[i] };
319 } else {
320 us[i] += alpha * (z[i] - us[i]);
321 }
322 if f32::is_nan(us[i]) {
323 nan_found = true;
324 break;
325 }
326 }
327 if nan_found {
328 exit_code = ExitCode::NanFoundUs;
329 break;
330 }
331
332 qr_shift(&mut q, &mut r, &mut qtb, n_free - 1, f_bound);
333 ws[i_a] = i_s;
334 let first_val = perm[f_bound];
335 for i in 0..(n_free - f_bound - 1) {
336 perm[f_bound + i] = perm[f_bound + i + 1];
337 }
338 n_free -= 1;
339 perm[n_free] = first_val;
340 }
341 }
342 if exit_code == ExitCode::IterLimit {
343 iter -= 1;
344 }
345 SolverStats {
346 exit_code,
347 iterations: iter,
348 n_free,
349 }
350}