1use pounce_common::types::Number;
23
24use crate::block_solve::{lu_factor_partial_pivot, lu_solve, BlockSolveError};
25
26#[derive(Debug, Default, Clone)]
29pub struct ReductionFrame {
30 pub fixed_vars: Vec<usize>,
32 pub fixed_values: Vec<Number>,
34 pub dropped_rows: Vec<usize>,
37 pub var_map: Vec<Option<usize>>,
40 pub row_map: Vec<Option<usize>>,
42}
43
44impl ReductionFrame {
45 pub fn new(
48 n_vars: usize,
49 n_rows: usize,
50 fixed_vars: Vec<usize>,
51 fixed_values: Vec<Number>,
52 dropped_rows: Vec<usize>,
53 ) -> Self {
54 assert_eq!(
55 fixed_vars.len(),
56 fixed_values.len(),
57 "fixed_vars and fixed_values must be the same length"
58 );
59 assert_eq!(
60 fixed_vars.len(),
61 dropped_rows.len(),
62 "fixed_vars and dropped_rows must be the same length (square block)"
63 );
64
65 let mut is_fixed_var = vec![false; n_vars];
68 for &i in &fixed_vars {
69 is_fixed_var[i] = true;
70 }
71 let mut is_dropped_row = vec![false; n_rows];
72 for &i in &dropped_rows {
73 is_dropped_row[i] = true;
74 }
75
76 let mut var_map = vec![None; n_vars];
77 let mut next_reduced = 0;
78 for (i, slot) in var_map.iter_mut().enumerate().take(n_vars) {
79 if is_fixed_var[i] {
80 continue;
81 }
82 *slot = Some(next_reduced);
83 next_reduced += 1;
84 }
85
86 let mut row_map = vec![None; n_rows];
87 let mut next_reduced_row = 0;
88 for (i, slot) in row_map.iter_mut().enumerate().take(n_rows) {
89 if is_dropped_row[i] {
90 continue;
91 }
92 *slot = Some(next_reduced_row);
93 next_reduced_row += 1;
94 }
95
96 Self {
97 fixed_vars,
98 fixed_values,
99 dropped_rows,
100 var_map,
101 row_map,
102 }
103 }
104
105 pub fn n_full_vars(&self) -> usize {
106 self.var_map.len()
107 }
108
109 pub fn n_full_rows(&self) -> usize {
110 self.row_map.len()
111 }
112
113 pub fn n_reduced_vars(&self) -> usize {
114 self.n_full_vars() - self.fixed_vars.len()
115 }
116
117 pub fn n_reduced_rows(&self) -> usize {
118 self.n_full_rows() - self.dropped_rows.len()
119 }
120
121 pub fn project_x(&self, x_full: &[Number]) -> Vec<Number> {
124 assert_eq!(x_full.len(), self.n_full_vars());
125 self.var_map
126 .iter()
127 .zip(x_full.iter())
128 .filter_map(|(slot, &v)| slot.map(|_| v))
129 .collect()
130 }
131
132 pub fn lift_x(&self, x_reduced: &[Number]) -> Vec<Number> {
135 assert_eq!(x_reduced.len(), self.n_reduced_vars());
136 let mut out = vec![0.0; self.n_full_vars()];
137 for (i, slot) in self.var_map.iter().enumerate() {
138 if let Some(r) = slot {
139 out[i] = x_reduced[*r];
140 }
141 }
142 for (k, &i) in self.fixed_vars.iter().enumerate() {
143 out[i] = self.fixed_values[k];
144 }
145 out
146 }
147
148 pub fn project_lambda(&self, lambda_full: &[Number]) -> Vec<Number> {
150 assert_eq!(lambda_full.len(), self.n_full_rows());
151 self.row_map
152 .iter()
153 .zip(lambda_full.iter())
154 .filter_map(|(slot, &v)| slot.map(|_| v))
155 .collect()
156 }
157
158 pub fn lift_lambda(&self, lambda_reduced: &[Number]) -> Vec<Number> {
162 assert_eq!(lambda_reduced.len(), self.n_reduced_rows());
163 let mut out = vec![0.0; self.n_full_rows()];
164 for (i, slot) in self.row_map.iter().enumerate() {
165 if let Some(r) = slot {
166 out[i] = lambda_reduced[*r];
167 }
168 }
169 out
170 }
171
172 pub fn recover_dropped_multipliers(
206 &self,
207 grad_f: &[Number],
208 jac_full_row_major: &[Number],
209 lambda_full: &[Number],
210 ) -> Result<Vec<Number>, BlockSolveError> {
211 let n_vars = self.n_full_vars();
212 let n_rows = self.n_full_rows();
213 let k = self.fixed_vars.len();
214 assert_eq!(grad_f.len(), n_vars, "grad_f length mismatch");
215 assert_eq!(
216 jac_full_row_major.len(),
217 n_rows * n_vars,
218 "jac_full_row_major length mismatch"
219 );
220 assert_eq!(lambda_full.len(), n_rows, "lambda_full length mismatch");
221
222 if k == 0 {
223 return Ok(Vec::new());
224 }
225
226 let mut matrix = vec![0.0; k * k];
233 for (i_idx, &i) in self.fixed_vars.iter().enumerate() {
234 for (j_idx, &dr) in self.dropped_rows.iter().enumerate() {
235 matrix[i_idx * k + j_idx] = jac_full_row_major[dr * n_vars + i];
236 }
237 }
238
239 let mut rhs = vec![0.0; k];
240 for (i_idx, &i) in self.fixed_vars.iter().enumerate() {
241 let mut sum = 0.0;
242 for r in 0..n_rows {
243 if self.row_map[r].is_none() {
244 continue;
246 }
247 sum += jac_full_row_major[r * n_vars + i] * lambda_full[r];
248 }
249 rhs[i_idx] = grad_f[i] - sum;
250 }
251
252 let piv = lu_factor_partial_pivot(&mut matrix, k).map_err(|_| BlockSolveError::Singular)?;
253 lu_solve(&matrix, &piv, &mut rhs, k);
254 Ok(rhs)
255 }
256}
257
258#[derive(Debug, Default, Clone)]
262pub struct ReductionStack {
263 frames: Vec<ReductionFrame>,
264}
265
266impl ReductionStack {
267 pub fn is_empty(&self) -> bool {
269 self.frames.is_empty()
270 }
271
272 pub fn len(&self) -> usize {
274 self.frames.len()
275 }
276
277 pub fn push(&mut self, frame: ReductionFrame) {
279 self.frames.push(frame);
280 }
281
282 pub fn top(&self) -> Option<&ReductionFrame> {
284 self.frames.last()
285 }
286
287 pub fn iter_top_down(&self) -> impl Iterator<Item = &ReductionFrame> {
291 self.frames.iter().rev()
292 }
293
294 pub fn iter_bottom_up(&self) -> impl Iterator<Item = &ReductionFrame> {
298 self.frames.iter()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn frame_new_builds_maps_correctly() {
308 let frame = ReductionFrame::new(4, 3, vec![1], vec![42.0], vec![0]);
310 assert_eq!(frame.var_map, vec![Some(0), None, Some(1), Some(2)]);
312 assert_eq!(frame.row_map, vec![None, Some(0), Some(1)]);
314 assert_eq!(frame.n_reduced_vars(), 3);
315 assert_eq!(frame.n_reduced_rows(), 2);
316 }
317
318 #[test]
319 fn frame_project_x_drops_fixed() {
320 let frame = ReductionFrame::new(3, 1, vec![1], vec![20.0], vec![0]);
321 let x_full = [10.0, 20.0, 30.0];
322 assert_eq!(frame.project_x(&x_full), vec![10.0, 30.0]);
323 }
324
325 #[test]
326 fn frame_lift_x_splices_fixed_values() {
327 let frame = ReductionFrame::new(3, 1, vec![1], vec![20.0], vec![0]);
328 let x_reduced = [10.0, 30.0];
329 assert_eq!(frame.lift_x(&x_reduced), vec![10.0, 20.0, 30.0]);
330 }
331
332 #[test]
333 fn frame_project_lift_x_roundtrip() {
334 let frame = ReductionFrame::new(4, 2, vec![0, 2], vec![1.0, 9.0], vec![0, 1]);
335 let x_full = [1.0, 5.0, 9.0, 7.0];
336 let reduced = frame.project_x(&x_full);
337 let lifted = frame.lift_x(&reduced);
338 assert_eq!(lifted, x_full);
339 }
340
341 #[test]
342 fn frame_project_lambda_drops_dropped() {
343 let frame = ReductionFrame::new(3, 3, vec![1], vec![20.0], vec![0]);
344 let lambda_full = [1.0, 2.0, 3.0];
345 assert_eq!(frame.project_lambda(&lambda_full), vec![2.0, 3.0]);
346 }
347
348 #[test]
349 fn frame_lift_lambda_zeros_dropped() {
350 let frame = ReductionFrame::new(3, 3, vec![1], vec![20.0], vec![0]);
351 let lambda_reduced = [2.0, 3.0];
352 assert_eq!(frame.lift_lambda(&lambda_reduced), vec![0.0, 2.0, 3.0]);
353 }
354
355 #[test]
356 fn recover_multipliers_singleton_linear() {
357 let frame = ReductionFrame::new(1, 1, vec![0], vec![3.0], vec![0]);
360 let lam = frame
361 .recover_dropped_multipliers(&[4.0], &[1.0], &[0.0])
362 .unwrap();
363 assert_eq!(lam.len(), 1);
364 assert!((lam[0] - 4.0).abs() < 1e-12);
365 }
366
367 #[test]
368 fn recover_multipliers_2x2_linear() {
369 let frame = ReductionFrame::new(2, 2, vec![0, 1], vec![1.0, 2.0], vec![0, 1]);
386 let jac = [1.0, 0.0, 1.0, 1.0]; let grad_f = [2.0, 5.0];
388 let lam = frame
389 .recover_dropped_multipliers(&grad_f, &jac, &[0.0, 0.0])
390 .unwrap();
391 assert!((lam[0] - (-3.0)).abs() < 1e-12, "λ0 was {}", lam[0]);
392 assert!((lam[1] - 5.0).abs() < 1e-12, "λ1 was {}", lam[1]);
393 }
394
395 #[test]
396 fn recover_multipliers_with_kept_rows() {
397 let frame = ReductionFrame::new(2, 2, vec![0], vec![1.0], vec![0]);
405 let jac = [2.0, 3.0, 4.0, 5.0];
406 let grad_f = [10.0, 0.0];
407 let lambda_full = [0.0, 0.5]; let lam = frame
409 .recover_dropped_multipliers(&grad_f, &jac, &lambda_full)
410 .unwrap();
411 assert_eq!(lam.len(), 1);
412 assert!((lam[0] - 4.0).abs() < 1e-12);
413 }
414
415 #[test]
416 fn recover_multipliers_singular_block_jacobian() {
417 let frame = ReductionFrame::new(2, 2, vec![0, 1], vec![0.0, 0.0], vec![0, 1]);
419 let jac = [1.0, 2.0, 2.0, 4.0]; let grad_f = [1.0, 2.0];
421 let err = frame
422 .recover_dropped_multipliers(&grad_f, &jac, &[0.0, 0.0])
423 .unwrap_err();
424 assert_eq!(err, BlockSolveError::Singular);
425 }
426
427 #[test]
428 fn recover_multipliers_empty_frame() {
429 let frame = ReductionFrame::new(2, 2, vec![], vec![], vec![]);
430 let lam = frame
431 .recover_dropped_multipliers(&[0.0; 2], &[0.0; 4], &[0.0; 2])
432 .unwrap();
433 assert!(lam.is_empty());
434 }
435
436 #[test]
437 fn kkt_residual_after_recovery_to_1e_minus_12() {
438 let frame = ReductionFrame::new(3, 3, vec![0, 1], vec![2.0 / 3.0, 5.0 / 3.0], vec![0, 1]);
452 let jac = [
454 2.0, 1.0, 0.0, 1.0, -1.0, 0.0, 1.0, 1.0, 1.0, ];
458 let y_star = 8.0 / 3.0;
460 let grad_f = [10.0, 4.0, 2.0 * y_star];
461 let lambda_kept_2 = 2.0 * y_star;
464 let lambda_full = [0.0, 0.0, lambda_kept_2];
465
466 let lam_dropped = frame
467 .recover_dropped_multipliers(&grad_f, &jac, &lambda_full)
468 .unwrap();
469 let mut lambda_recovered = lambda_full;
471 for (k, &r) in frame.dropped_rows.iter().enumerate() {
472 lambda_recovered[r] = lam_dropped[k];
473 }
474 for &i in &frame.fixed_vars {
476 let mut s = grad_f[i];
477 for r in 0..3 {
478 s -= jac[r * 3 + i] * lambda_recovered[r];
479 }
480 assert!(s.abs() < 1e-12, "stationarity at var {i} = {s}");
481 }
482 }
483
484 struct FuzzRng(u64);
490 impl FuzzRng {
491 fn new(seed: u64) -> Self {
492 Self(seed)
493 }
494 fn next_u64(&mut self) -> u64 {
495 self.0 = self
496 .0
497 .wrapping_mul(6364136223846793005)
498 .wrapping_add(1442695040888963407);
499 self.0 >> 32
500 }
501 fn unit(&mut self) -> Number {
502 let raw = (self.next_u64() & 0x3fff_ffff) as Number;
503 raw / (1u64 << 29) as Number - 1.0
504 }
505 }
506
507 #[test]
508 fn frame_fuzz_recover_reproduces_synthetic_lambda() {
509 let mut rng = FuzzRng::new(0xface_b00c_baad_f00d);
510
511 for trial in 0..30 {
512 let n_vars = 2 + (rng.next_u64() % 3) as usize; let n_rows = n_vars;
514 let k = 1 + (rng.next_u64() % n_vars as u64) as usize;
515
516 let mut perm_v: Vec<usize> = (0..n_vars).collect();
517 for i in (1..n_vars).rev() {
518 let j = (rng.next_u64() as usize) % (i + 1);
519 perm_v.swap(i, j);
520 }
521 let mut fixed_vars: Vec<usize> = perm_v[..k].to_vec();
522 fixed_vars.sort_unstable();
523
524 let mut perm_r: Vec<usize> = (0..n_rows).collect();
525 for i in (1..n_rows).rev() {
526 let j = (rng.next_u64() as usize) % (i + 1);
527 perm_r.swap(i, j);
528 }
529 let mut dropped_rows: Vec<usize> = perm_r[..k].to_vec();
530 dropped_rows.sort_unstable();
531
532 let mut jac = vec![0.0; n_rows * n_vars];
533 for r in 0..n_rows {
534 for c in 0..n_vars {
535 jac[r * n_vars + c] = 0.2 * rng.unit();
536 }
537 }
538 for (&r, &c) in dropped_rows.iter().zip(fixed_vars.iter()) {
539 jac[r * n_vars + c] += 2.5;
540 }
541
542 let lambda_star: Vec<Number> = (0..n_rows).map(|_| rng.unit()).collect();
543 let mut grad_f = vec![0.0; n_vars];
544 let fixed_set: std::collections::BTreeSet<usize> = fixed_vars.iter().copied().collect();
545 for i in 0..n_vars {
546 if fixed_set.contains(&i) {
547 let mut s = 0.0;
548 for r in 0..n_rows {
549 s += jac[r * n_vars + i] * lambda_star[r];
550 }
551 grad_f[i] = s;
552 } else {
553 grad_f[i] = rng.unit();
554 }
555 }
556
557 let dropped_set: std::collections::BTreeSet<usize> =
558 dropped_rows.iter().copied().collect();
559 let mut lambda_given = vec![0.0; n_rows];
560 for r in 0..n_rows {
561 if !dropped_set.contains(&r) {
562 lambda_given[r] = lambda_star[r];
563 }
564 }
565
566 let frame = ReductionFrame::new(
567 n_vars,
568 n_rows,
569 fixed_vars.clone(),
570 vec![0.0; k],
571 dropped_rows.clone(),
572 );
573
574 let lam_dropped = frame
575 .recover_dropped_multipliers(&grad_f, &jac, &lambda_given)
576 .unwrap_or_else(|e| panic!("trial {trial}: {e:?}"));
577
578 for (idx, &r) in dropped_rows.iter().enumerate() {
579 let expected = lambda_star[r];
580 let got = lam_dropped[idx];
581 assert!(
582 (expected - got).abs() < 1e-10,
583 "trial {trial}: λ[{r}] expected {expected:.6}, got {got:.6}"
584 );
585 }
586 }
587 }
588
589 #[test]
590 fn reduction_stack_push_top_iter() {
591 let mut stack = ReductionStack::default();
592 assert!(stack.is_empty());
593 let f1 = ReductionFrame::new(2, 2, vec![0], vec![1.0], vec![0]);
594 let f2 = ReductionFrame::new(2, 2, vec![1], vec![2.0], vec![1]);
595 stack.push(f1.clone());
596 stack.push(f2.clone());
597 assert_eq!(stack.len(), 2);
598 let top = stack.top().expect("non-empty");
599 assert_eq!(top.fixed_vars, f2.fixed_vars);
600 let order: Vec<_> = stack.iter_top_down().map(|f| f.fixed_vars[0]).collect();
602 assert_eq!(order, vec![1, 0]);
603 let order_up: Vec<_> = stack.iter_bottom_up().map(|f| f.fixed_vars[0]).collect();
604 assert_eq!(order_up, vec![0, 1]);
605 }
606
607 #[test]
613 fn frame_project_lift_lambda_roundtrip() {
614 let frame = ReductionFrame::new(4, 3, vec![0, 2], vec![1.0, 9.0], vec![0, 1]);
615 let lambda_full = [4.0, 5.0, 6.0];
617 let reduced = frame.project_lambda(&lambda_full);
618 assert_eq!(reduced, vec![6.0]);
620 let lifted = frame.lift_lambda(&reduced);
622 assert_eq!(lifted, vec![0.0, 0.0, 6.0]);
623 let reduced_again = frame.project_lambda(&lifted);
626 assert_eq!(reduced_again, reduced);
627 }
628
629 #[test]
634 fn reduction_stack_multi_frame_roundtrip() {
635 let f1 = ReductionFrame::new(4, 4, vec![0], vec![10.0], vec![0]);
639 let f2 = ReductionFrame::new(4, 4, vec![2], vec![30.0], vec![2]);
640 let mut stack = ReductionStack::default();
641 stack.push(f1.clone());
642 stack.push(f2.clone());
643
644 let x_full_expected = vec![10.0, 7.0, 30.0, 5.0];
647 let lambda_full_expected = vec![0.0, 8.0, 0.0, 6.0];
648
649 for frame in stack.iter_top_down() {
659 let reduced_x = frame.project_x(&x_full_expected);
660 let lifted_x = frame.lift_x(&reduced_x);
661 assert_eq!(lifted_x, x_full_expected);
662 let reduced_l = frame.project_lambda(&lambda_full_expected);
663 let lifted_l = frame.lift_lambda(&reduced_l);
664 for r in 0..4 {
667 if frame.row_map[r].is_some() {
668 assert_eq!(lifted_l[r], lambda_full_expected[r]);
669 } else {
670 assert_eq!(lifted_l[r], 0.0);
671 }
672 }
673 }
674 }
675}