1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix};
2
3use crate::linalg::check_limits_tol;
4use crate::types::{ExitCode, MatA, SolverStats, VecN, CONSTR_TOL};
5
6#[inline]
11fn givens(a: f32, b: f32) -> (f32, f32) {
12 let h = libm::hypotf(a, b);
13 let sigma = 1.0 / h;
14 (sigma * a, -(sigma * b))
15}
16
17#[inline]
18fn givens_left_apply<const N: usize>(
19 r: &mut OMatrix<f32, Const<N>, Const<N>>,
20 c: f32,
21 s: f32,
22 row1: usize,
23 row2: usize,
24 n_cols: usize,
25) where
26 Const<N>: DimName,
27 DefaultAllocator: Allocator<Const<N>, Const<N>>,
28{
29 for col in 0..n_cols {
30 let r1 = r[(row1, col)];
31 let r2 = r[(row2, col)];
32 r[(row1, col)] = c * r1 - s * r2;
33 r[(row2, col)] = s * r1 + c * r2;
34 }
35}
36
37#[inline]
38fn givens_right_apply_t<const R: usize, const C: usize>(
39 q: &mut OMatrix<f32, Const<R>, Const<C>>,
40 c: f32,
41 s: f32,
42 col1: usize,
43 col2: usize,
44 n_rows: usize,
45) where
46 Const<R>: DimName,
47 Const<C>: DimName,
48 DefaultAllocator: Allocator<Const<R>, Const<C>>,
49{
50 for i in 0..n_rows {
51 let q1 = q[(i, col1)];
52 let q2 = q[(i, col2)];
53 q[(i, col1)] = c * q1 - s * q2;
54 q[(i, col2)] = s * q1 + c * q2;
55 }
56}
57
58fn qr_shift<const NC: usize, const NU: usize>(
59 q: &mut OMatrix<f32, Const<NC>, Const<NU>>,
60 r: &mut OMatrix<f32, Const<NU>, Const<NU>>,
61 i: usize,
62 j: usize,
63) where
64 Const<NC>: DimName,
65 Const<NU>: DimName,
66 DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>, Const<NU>>,
67{
68 if i == j {
69 return;
70 }
71
72 let n_givens: usize;
73 if i > j {
74 n_givens = i - j;
75 for l in 0..NU {
76 let tmp = r[(l, j)];
77 for k in j..i {
78 r[(l, k)] = r[(l, k + 1)];
79 }
80 r[(l, i)] = tmp;
81 }
82 } else {
83 n_givens = j - i;
84 for l in 0..NU {
85 let tmp = r[(l, j)];
86 for k in (i..j).rev() {
87 r[(l, k + 1)] = r[(l, k)];
88 }
89 r[(l, i)] = tmp;
90 }
91 }
92
93 for k in 0..n_givens {
94 let (j1, i1) = if j > i {
95 (j - k - 1, i)
96 } else {
97 (j + k, j + k)
98 };
99 let (c, s) = givens(r[(j1, i1)], r[(j1 + 1, i1)]);
100 givens_left_apply(r, c, s, j1, j1 + 1, NU);
101 givens_right_apply_t(q, c, s, j1, j1 + 1, NC);
102 }
103}
104
105fn backward_tri_solve<const NU: usize>(
106 r: &OMatrix<f32, Const<NU>, Const<NU>>,
107 b: &[f32; NU],
108 x: &mut [f32; NU],
109 n: usize,
110) where
111 Const<NU>: DimName,
112 DefaultAllocator: Allocator<Const<NU>, Const<NU>>,
113{
114 if n == 0 {
115 return;
116 }
117 x[n - 1] = b[n - 1] / r[(n - 1, n - 1)];
118 for i in (0..n.saturating_sub(1)).rev() {
119 let mut s = 0.0f32;
120 for j in (i + 1)..n {
121 s += r[(i, j)] * x[j];
122 }
123 x[i] = (b[i] - s) / r[(i, i)];
124 }
125}
126
127#[allow(clippy::needless_range_loop)] pub fn solve<const NU: usize, const NV: usize, const NC: usize>(
139 a: &MatA<NC, NU>,
140 b: &VecN<NC>,
141 umin: &VecN<NU>,
142 umax: &VecN<NU>,
143 us: &mut VecN<NU>,
144 ws: &mut [i8; NU],
145 imax: usize,
146) -> SolverStats
147where
148 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
149 Const<NU>: DimName,
150 Const<NV>: DimName,
151 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
152 + Allocator<Const<NC>, Const<NC>>
153 + Allocator<Const<NU>, Const<NU>>
154 + Allocator<Const<NC>>
155 + Allocator<Const<NU>>,
156{
157 debug_assert_eq!(NC, NU + NV);
158 let imax = if imax == 0 { 100 } else { imax };
159
160 for i in 0..NU {
161 if ws[i] == 0 {
162 if us[i] > umax[i] {
163 us[i] = umax[i];
164 } else if us[i] < umin[i] {
165 us[i] = umin[i];
166 }
167 } else {
168 us[i] = if ws[i] > 0 { umax[i] } else { umin[i] };
169 }
170 }
171
172 let mut perm = [0usize; NU];
174 let mut n_free: usize = 0;
175 for i in 0..NU {
176 if ws[i] == 0 {
177 perm[n_free] = i;
178 n_free += 1;
179 }
180 }
181 let mut i_bnd: usize = 0;
182 for i in 0..NU {
183 if ws[i] != 0 {
184 perm[n_free + i_bnd] = i;
185 i_bnd += 1;
186 }
187 }
188
189 let mut a_perm: MatA<NC, NU> = MatA::zeros();
191 for j in 0..NU {
192 for i in 0..NC {
193 a_perm[(i, j)] = a[(i, perm[j])];
194 }
195 }
196 let qr_decomp = a_perm.qr();
197 let mut q: OMatrix<f32, Const<NC>, Const<NU>> = qr_decomp.q();
198 let mut r: OMatrix<f32, Const<NU>, Const<NU>> = qr_decomp.r();
199
200 let mut z = [0.0f32; NU];
201 let mut exit_code = ExitCode::IterLimit;
202
203 let mut iter: usize = 0;
204 while {
205 iter += 1;
206 iter <= imax
207 } {
208 let mut c = [0.0f32; NU];
209 for i in 0..n_free {
210 let mut s = 0.0f32;
211 for j in 0..NC {
212 s += q[(j, i)] * b[j];
213 }
214 c[i] = s;
215 }
216
217 for i in 0..n_free {
218 for j in 0..(NU - n_free) {
219 let pi = perm[n_free + j];
220 let ub = if ws[pi] > 0 { umax[pi] } else { umin[pi] };
221 c[i] -= r[(i, n_free + j)] * ub;
222 }
223 }
224
225 let mut q_sol = [0.0f32; NU];
226 backward_tri_solve(&r, &c, &mut q_sol, n_free);
227
228 let mut nan_found = false;
229 for i in 0..n_free {
230 if f32::is_nan(q_sol[i]) {
231 nan_found = true;
232 break;
233 }
234 z[perm[i]] = q_sol[i];
235 }
236 if nan_found {
237 exit_code = ExitCode::NanFoundQ;
238 break;
239 }
240 for i in n_free..NU {
241 z[perm[i]] = us[perm[i]];
242 }
243
244 let mut umin_arr = [0.0f32; NU];
245 let mut umax_arr = [0.0f32; NU];
246 for i in 0..NU {
247 umin_arr[i] = umin[i];
248 umax_arr[i] = umax[i];
249 }
250 let mut w_temp = [0i8; NU];
251 let n_violated =
252 check_limits_tol(n_free, &z, &umin_arr, &umax_arr, &mut w_temp, Some(&perm));
253
254 if n_violated == 0 {
255 for i in 0..n_free {
256 us[perm[i]] = z[perm[i]];
257 }
258
259 if n_free == NU {
260 exit_code = ExitCode::Success;
261 break;
262 }
263
264 let mut d = [0.0f32; NU];
265 for i in n_free..NU {
266 let mut s = 0.0f32;
267 for j in 0..NC {
268 s += q[(j, i)] * b[j];
269 }
270 d[i] = s;
271 }
272 for i in n_free..NU {
273 for j in i..NU {
274 d[i] -= r[(i, j)] * us[perm[j]];
275 }
276 }
277
278 let mut f_free: usize = 0;
279 let mut maxlam: f32 = f32::NEG_INFINITY;
280 for i in n_free..NU {
281 let mut lam = 0.0f32;
282 for j in n_free..=i {
283 lam += r[(j, i)] * d[j];
284 }
285 lam *= -f32::from(ws[perm[i]]);
286 if lam > maxlam {
287 maxlam = lam;
288 f_free = i - n_free;
289 }
290 }
291
292 if maxlam <= CONSTR_TOL {
293 exit_code = ExitCode::Success;
294 break;
295 }
296
297 qr_shift(&mut q, &mut r, n_free, n_free + f_free);
298 ws[perm[n_free + f_free]] = 0;
299 let last_val = perm[n_free + f_free];
300 for i in (1..=f_free).rev() {
301 perm[n_free + i] = perm[n_free + i - 1];
302 }
303 perm[n_free] = last_val;
304 n_free += 1;
305 } else {
306 let mut alpha: f32 = f32::INFINITY;
307 let mut i_a: usize = 0;
308 let mut f_bound: usize = 0;
309 let mut i_s: i8 = 0;
310
311 for f in 0..n_free {
312 let ii = perm[f];
313 let (tmp, ts) = if w_temp[ii] == -1 {
314 ((us[ii] - umin[ii]) / (us[ii] - z[ii]), -1i8)
315 } else if w_temp[ii] == 1 {
316 ((umax[ii] - us[ii]) / (z[ii] - us[ii]), 1i8)
317 } else {
318 continue;
319 };
320 if tmp < alpha {
321 alpha = tmp;
322 i_a = ii;
323 f_bound = f;
324 i_s = ts;
325 }
326 }
327
328 let mut nan_found = false;
329 for i in 0..NU {
330 if i == i_a {
331 us[i] = if i_s == 1 { umax[i] } else { umin[i] };
332 } else {
333 us[i] += alpha * (z[i] - us[i]);
334 }
335 if f32::is_nan(us[i]) {
336 nan_found = true;
337 break;
338 }
339 }
340 if nan_found {
341 exit_code = ExitCode::NanFoundUs;
342 break;
343 }
344
345 qr_shift(&mut q, &mut r, n_free - 1, f_bound);
346 ws[i_a] = i_s;
347 let first_val = perm[f_bound];
348 for i in 0..(n_free - f_bound - 1) {
349 perm[f_bound + i] = perm[f_bound + i + 1];
350 }
351 n_free -= 1;
352 perm[n_free] = first_val;
353 }
354 }
355 if exit_code == ExitCode::IterLimit {
356 iter -= 1;
357 }
358 SolverStats {
359 exit_code,
360 iterations: iter,
361 n_free,
362 }
363}