Skip to main content

pounce_presolve/
reduction_frame.rs

1//! Postsolve frame stack for the auxiliary-equality preprocessing
2//! pass.
3//!
4//! PR 7 of the auxiliary-presolve port (issue #53). A
5//! [`ReductionFrame`] captures one layer of variable + row
6//! elimination:
7//!
8//! - `fixed_vars` — block variables fixed by the block solve.
9//! - `fixed_values` — their values at the fixed point.
10//! - `dropped_rows` — equality rows used to determine them.
11//! - `var_map / row_map` — index maps between full and reduced space.
12//!
13//! The headline method is
14//! [`ReductionFrame::recover_dropped_multipliers`], which solves the
15//! full-space KKT stationarity equations at the fixed variables for
16//! the missing multipliers. Assumption (matching ripopt v1): fixed
17//! variables are interior to their original bounds at the optimum,
18//! so `z_l = z_u = 0` for them.
19//!
20//! ripopt anchor: `src/reduction_frame.rs:101-231`.
21
22use pounce_common::types::Number;
23
24use crate::block_solve::{lu_factor_partial_pivot, lu_solve, BlockSolveError};
25
26/// One layer of the postsolve stack. Built once per accepted block
27/// elimination by PR 8's orchestrator.
28#[derive(Debug, Default, Clone)]
29pub struct ReductionFrame {
30    /// Inner-variable indices fixed by this layer, in ascending order.
31    pub fixed_vars: Vec<usize>,
32    /// Their values at the block-solve fixed point.
33    pub fixed_values: Vec<Number>,
34    /// Inner equality-row indices dropped by this layer, in
35    /// ascending order. `dropped_rows.len() == fixed_vars.len()`.
36    pub dropped_rows: Vec<usize>,
37    /// `var_map[i] = Some(reduced_idx)` if `i` survives this layer,
38    /// `None` if `i` is in `fixed_vars`.
39    pub var_map: Vec<Option<usize>>,
40    /// Same for rows.
41    pub row_map: Vec<Option<usize>>,
42}
43
44impl ReductionFrame {
45    /// Build a frame from the (sorted) lists of fixed variables /
46    /// values / dropped rows and the **full-space** problem shape.
47    pub fn new(
48        n_vars: usize,
49        n_rows: usize,
50        fixed_vars: Vec<usize>,
51        fixed_values: Vec<Number>,
52        dropped_rows: Vec<usize>,
53    ) -> Self {
54        assert_eq!(
55            fixed_vars.len(),
56            fixed_values.len(),
57            "fixed_vars and fixed_values must be the same length"
58        );
59        assert_eq!(
60            fixed_vars.len(),
61            dropped_rows.len(),
62            "fixed_vars and dropped_rows must be the same length (square block)"
63        );
64
65        // Mark fixed positions on flat `bool` vectors (O(1) lookup);
66        // BTreeSet would cost O(log k) per probe. PR review #60.
67        let mut is_fixed_var = vec![false; n_vars];
68        for &i in &fixed_vars {
69            is_fixed_var[i] = true;
70        }
71        let mut is_dropped_row = vec![false; n_rows];
72        for &i in &dropped_rows {
73            is_dropped_row[i] = true;
74        }
75
76        let mut var_map = vec![None; n_vars];
77        let mut next_reduced = 0;
78        for (i, slot) in var_map.iter_mut().enumerate().take(n_vars) {
79            if is_fixed_var[i] {
80                continue;
81            }
82            *slot = Some(next_reduced);
83            next_reduced += 1;
84        }
85
86        let mut row_map = vec![None; n_rows];
87        let mut next_reduced_row = 0;
88        for (i, slot) in row_map.iter_mut().enumerate().take(n_rows) {
89            if is_dropped_row[i] {
90                continue;
91            }
92            *slot = Some(next_reduced_row);
93            next_reduced_row += 1;
94        }
95
96        Self {
97            fixed_vars,
98            fixed_values,
99            dropped_rows,
100            var_map,
101            row_map,
102        }
103    }
104
105    pub fn n_full_vars(&self) -> usize {
106        self.var_map.len()
107    }
108
109    pub fn n_full_rows(&self) -> usize {
110        self.row_map.len()
111    }
112
113    pub fn n_reduced_vars(&self) -> usize {
114        self.n_full_vars() - self.fixed_vars.len()
115    }
116
117    pub fn n_reduced_rows(&self) -> usize {
118        self.n_full_rows() - self.dropped_rows.len()
119    }
120
121    /// Project a full-space `x` vector into reduced space (drop the
122    /// fixed entries).
123    pub fn project_x(&self, x_full: &[Number]) -> Vec<Number> {
124        assert_eq!(x_full.len(), self.n_full_vars());
125        self.var_map
126            .iter()
127            .zip(x_full.iter())
128            .filter_map(|(slot, &v)| slot.map(|_| v))
129            .collect()
130    }
131
132    /// Lift a reduced `x` back to full space, splicing the fixed
133    /// values back into their original positions.
134    pub fn lift_x(&self, x_reduced: &[Number]) -> Vec<Number> {
135        assert_eq!(x_reduced.len(), self.n_reduced_vars());
136        let mut out = vec![0.0; self.n_full_vars()];
137        for (i, slot) in self.var_map.iter().enumerate() {
138            if let Some(r) = slot {
139                out[i] = x_reduced[*r];
140            }
141        }
142        for (k, &i) in self.fixed_vars.iter().enumerate() {
143            out[i] = self.fixed_values[k];
144        }
145        out
146    }
147
148    /// Project a full-space λ vector into reduced space.
149    pub fn project_lambda(&self, lambda_full: &[Number]) -> Vec<Number> {
150        assert_eq!(lambda_full.len(), self.n_full_rows());
151        self.row_map
152            .iter()
153            .zip(lambda_full.iter())
154            .filter_map(|(slot, &v)| slot.map(|_| v))
155            .collect()
156    }
157
158    /// Lift a reduced λ back to full space, with zeros at dropped
159    /// row indices. (Real values for dropped rows come from
160    /// [`Self::recover_dropped_multipliers`].)
161    pub fn lift_lambda(&self, lambda_reduced: &[Number]) -> Vec<Number> {
162        assert_eq!(lambda_reduced.len(), self.n_reduced_rows());
163        let mut out = vec![0.0; self.n_full_rows()];
164        for (i, slot) in self.row_map.iter().enumerate() {
165            if let Some(r) = slot {
166                out[i] = lambda_reduced[*r];
167            }
168        }
169        out
170    }
171
172    /// Recover the `k = fixed_vars.len()` dropped-row multipliers
173    /// via dense LU on the full-space KKT stationarity equations at
174    /// the fixed variables. Returns one entry per `self.dropped_rows`
175    /// (in the same order).
176    ///
177    /// Assumption: fixed variables are interior to their original
178    /// bounds at the optimum (so `z_l = z_u = 0` for them).
179    ///
180    /// # Inputs
181    ///
182    /// - `grad_f` — objective gradient at the full-space optimum
183    ///   (length `n_full_vars`).
184    /// - `jac_full_row_major` — dense full-space Jacobian
185    ///   `(n_full_rows × n_full_vars)` at the optimum.
186    /// - `lambda_full` — multipliers for kept rows; entries at
187    ///   dropped-row positions are ignored.
188    ///
189    /// # Example
190    ///
191    /// ```
192    /// use pounce_presolve::reduction_frame::ReductionFrame;
193    ///
194    /// // 1 var, 1 row, dropped:  c(x) = x - 3 = 0, obj f = 4 x.
195    /// // Stationarity:  4 - 1 * λ = 0  →  λ = 4.
196    /// let frame = ReductionFrame::new(1, 1, vec![0], vec![3.0], vec![0]);
197    /// let grad_f = [4.0];
198    /// let jac = [1.0];
199    /// let lambda_full = [0.0]; // dropped, ignored
200    /// let lam = frame
201    ///     .recover_dropped_multipliers(&grad_f, &jac, &lambda_full)
202    ///     .unwrap();
203    /// assert!((lam[0] - 4.0).abs() < 1e-12);
204    /// ```
205    pub fn recover_dropped_multipliers(
206        &self,
207        grad_f: &[Number],
208        jac_full_row_major: &[Number],
209        lambda_full: &[Number],
210    ) -> Result<Vec<Number>, BlockSolveError> {
211        let n_vars = self.n_full_vars();
212        let n_rows = self.n_full_rows();
213        let k = self.fixed_vars.len();
214        assert_eq!(grad_f.len(), n_vars, "grad_f length mismatch");
215        assert_eq!(
216            jac_full_row_major.len(),
217            n_rows * n_vars,
218            "jac_full_row_major length mismatch"
219        );
220        assert_eq!(lambda_full.len(), n_rows, "lambda_full length mismatch");
221
222        if k == 0 {
223            return Ok(Vec::new());
224        }
225
226        // Use `row_map` for O(1) "is row r dropped?" — set in
227        // `new()`, no BTreeSet needed (PR review #60).
228        // Build the k×k system M λ_dropped = rhs.
229        //   M[i_idx][j_idx] = J[dropped_rows[j_idx]][fixed_vars[i_idx]]
230        //   rhs[i_idx] = grad_f[fixed_vars[i_idx]]
231        //              - Σ_{r kept} J[r][fixed_vars[i_idx]] * lambda_full[r]
232        let mut matrix = vec![0.0; k * k];
233        for (i_idx, &i) in self.fixed_vars.iter().enumerate() {
234            for (j_idx, &dr) in self.dropped_rows.iter().enumerate() {
235                matrix[i_idx * k + j_idx] = jac_full_row_major[dr * n_vars + i];
236            }
237        }
238
239        let mut rhs = vec![0.0; k];
240        for (i_idx, &i) in self.fixed_vars.iter().enumerate() {
241            let mut sum = 0.0;
242            for r in 0..n_rows {
243                if self.row_map[r].is_none() {
244                    // Row was dropped.
245                    continue;
246                }
247                sum += jac_full_row_major[r * n_vars + i] * lambda_full[r];
248            }
249            rhs[i_idx] = grad_f[i] - sum;
250        }
251
252        let piv = lu_factor_partial_pivot(&mut matrix, k).map_err(|_| BlockSolveError::Singular)?;
253        lu_solve(&matrix, &piv, &mut rhs, k);
254        Ok(rhs)
255    }
256}
257
258/// LIFO stack of `ReductionFrame`s. Bottom-most frame represents the
259/// first elimination layer applied; top-most is the most recent.
260/// `finalize_solution` lifts from top to bottom.
261#[derive(Debug, Default, Clone)]
262pub struct ReductionStack {
263    frames: Vec<ReductionFrame>,
264}
265
266impl ReductionStack {
267    /// True when no reduction has been pushed (the no-op fast path).
268    pub fn is_empty(&self) -> bool {
269        self.frames.is_empty()
270    }
271
272    /// Number of layers currently on the stack.
273    pub fn len(&self) -> usize {
274        self.frames.len()
275    }
276
277    /// Push a frame onto the stack (most-recent end).
278    pub fn push(&mut self, frame: ReductionFrame) {
279        self.frames.push(frame);
280    }
281
282    /// Reference to the most-recently-pushed frame, if any.
283    pub fn top(&self) -> Option<&ReductionFrame> {
284        self.frames.last()
285    }
286
287    /// Iterate frames from top (most recent) to bottom (first). PR 8
288    /// uses this order when lifting a reduced solution back to the
289    /// original full space.
290    pub fn iter_top_down(&self) -> impl Iterator<Item = &ReductionFrame> {
291        self.frames.iter().rev()
292    }
293
294    /// Iterate frames in push order (bottom to top). Useful when
295    /// projecting full → reduced through the layers in the same
296    /// order they were applied.
297    pub fn iter_bottom_up(&self) -> impl Iterator<Item = &ReductionFrame> {
298        self.frames.iter()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn frame_new_builds_maps_correctly() {
308        // 4 vars, 3 rows. fixed_vars=[1], dropped_rows=[0].
309        let frame = ReductionFrame::new(4, 3, vec![1], vec![42.0], vec![0]);
310        // var_map: [Some(0), None, Some(1), Some(2)]
311        assert_eq!(frame.var_map, vec![Some(0), None, Some(1), Some(2)]);
312        // row_map: [None, Some(0), Some(1)]
313        assert_eq!(frame.row_map, vec![None, Some(0), Some(1)]);
314        assert_eq!(frame.n_reduced_vars(), 3);
315        assert_eq!(frame.n_reduced_rows(), 2);
316    }
317
318    #[test]
319    fn frame_project_x_drops_fixed() {
320        let frame = ReductionFrame::new(3, 1, vec![1], vec![20.0], vec![0]);
321        let x_full = [10.0, 20.0, 30.0];
322        assert_eq!(frame.project_x(&x_full), vec![10.0, 30.0]);
323    }
324
325    #[test]
326    fn frame_lift_x_splices_fixed_values() {
327        let frame = ReductionFrame::new(3, 1, vec![1], vec![20.0], vec![0]);
328        let x_reduced = [10.0, 30.0];
329        assert_eq!(frame.lift_x(&x_reduced), vec![10.0, 20.0, 30.0]);
330    }
331
332    #[test]
333    fn frame_project_lift_x_roundtrip() {
334        let frame = ReductionFrame::new(4, 2, vec![0, 2], vec![1.0, 9.0], vec![0, 1]);
335        let x_full = [1.0, 5.0, 9.0, 7.0];
336        let reduced = frame.project_x(&x_full);
337        let lifted = frame.lift_x(&reduced);
338        assert_eq!(lifted, x_full);
339    }
340
341    #[test]
342    fn frame_project_lambda_drops_dropped() {
343        let frame = ReductionFrame::new(3, 3, vec![1], vec![20.0], vec![0]);
344        let lambda_full = [1.0, 2.0, 3.0];
345        assert_eq!(frame.project_lambda(&lambda_full), vec![2.0, 3.0]);
346    }
347
348    #[test]
349    fn frame_lift_lambda_zeros_dropped() {
350        let frame = ReductionFrame::new(3, 3, vec![1], vec![20.0], vec![0]);
351        let lambda_reduced = [2.0, 3.0];
352        assert_eq!(frame.lift_lambda(&lambda_reduced), vec![0.0, 2.0, 3.0]);
353    }
354
355    #[test]
356    fn recover_multipliers_singleton_linear() {
357        // 1 var, 1 row. c(x) = x - 3 = 0, f = 4 x.
358        // Stationarity: 4 - 1 * λ = 0 → λ = 4.
359        let frame = ReductionFrame::new(1, 1, vec![0], vec![3.0], vec![0]);
360        let lam = frame
361            .recover_dropped_multipliers(&[4.0], &[1.0], &[0.0])
362            .unwrap();
363        assert_eq!(lam.len(), 1);
364        assert!((lam[0] - 4.0).abs() < 1e-12);
365    }
366
367    #[test]
368    fn recover_multipliers_2x2_linear() {
369        // 2 vars, 2 rows, both dropped.
370        // J = [[1, 0], [1, 1]]
371        // grad_f = [2, 5]
372        // Stationarity (per fixed var i):
373        //   i=0: 2 - 1*λ0 - 1*λ1 = 0
374        //   i=1: 5 - 0*λ0 - 1*λ1 = 0
375        // → λ1 = 5, then λ0 = 2 - 5 = -3.
376        //
377        // Note our system is M λ = rhs with
378        //   M[i][j] = J[dropped[j]][fixed[i]]
379        //   M = [[1, 1], [0, 1]]
380        //   rhs = grad_f - 0 (no kept rows) = [2, 5]
381        // Solving M λ = rhs:
382        //   row 0: λ0 + λ1 = 2
383        //   row 1:         λ1 = 5
384        //   → λ1 = 5, λ0 = -3. ✓
385        let frame = ReductionFrame::new(2, 2, vec![0, 1], vec![1.0, 2.0], vec![0, 1]);
386        let jac = [1.0, 0.0, 1.0, 1.0]; // row-major
387        let grad_f = [2.0, 5.0];
388        let lam = frame
389            .recover_dropped_multipliers(&grad_f, &jac, &[0.0, 0.0])
390            .unwrap();
391        assert!((lam[0] - (-3.0)).abs() < 1e-12, "λ0 was {}", lam[0]);
392        assert!((lam[1] - 5.0).abs() < 1e-12, "λ1 was {}", lam[1]);
393    }
394
395    #[test]
396    fn recover_multipliers_with_kept_rows() {
397        // 2 vars, 2 rows. Row 0 dropped, row 1 kept.
398        // fixed_vars = [0]. Only x_0 is fixed.
399        // J = [[2,  3],   ← dropped (touches fixed var x_0 with J=2)
400        //      [4,  5]]   ← kept (touches x_0 with J=4, λ_kept = 0.5)
401        // grad_f[0] = 10.
402        // Stationarity at x_0: 10 - 2 * λ_dropped - 4 * 0.5 = 0
403        //   → 10 - 2 λ_dropped - 2 = 0 → λ_dropped = 4.
404        let frame = ReductionFrame::new(2, 2, vec![0], vec![1.0], vec![0]);
405        let jac = [2.0, 3.0, 4.0, 5.0];
406        let grad_f = [10.0, 0.0];
407        let lambda_full = [0.0, 0.5]; // entry 0 ignored
408        let lam = frame
409            .recover_dropped_multipliers(&grad_f, &jac, &lambda_full)
410            .unwrap();
411        assert_eq!(lam.len(), 1);
412        assert!((lam[0] - 4.0).abs() < 1e-12);
413    }
414
415    #[test]
416    fn recover_multipliers_singular_block_jacobian() {
417        // 2x2 with rank-1 block Jacobian.
418        let frame = ReductionFrame::new(2, 2, vec![0, 1], vec![0.0, 0.0], vec![0, 1]);
419        let jac = [1.0, 2.0, 2.0, 4.0]; // rank-1
420        let grad_f = [1.0, 2.0];
421        let err = frame
422            .recover_dropped_multipliers(&grad_f, &jac, &[0.0, 0.0])
423            .unwrap_err();
424        assert_eq!(err, BlockSolveError::Singular);
425    }
426
427    #[test]
428    fn recover_multipliers_empty_frame() {
429        let frame = ReductionFrame::new(2, 2, vec![], vec![], vec![]);
430        let lam = frame
431            .recover_dropped_multipliers(&[0.0; 2], &[0.0; 4], &[0.0; 2])
432            .unwrap();
433        assert!(lam.is_empty());
434    }
435
436    #[test]
437    fn kkt_residual_after_recovery_to_1e_minus_12() {
438        // 3 vars (b1, b2, y), 3 rows.
439        //   row 0 (dropped):     2 b1 + b2     - 3       = 0  → at fixed point.
440        //   row 1 (dropped):     b1   - b2     + 1       = 0  → at fixed point.
441        //   row 2 (kept):        b1   + b2 + y - 5       = 0
442        // Solving the two dropped rows: b1 = 2/3, b2 = 5/3.
443        // Then row 2: y = 5 - 7/3 = 8/3 ≈ 2.667.
444        // Objective f = 10 b1 + 4 b2 + y².
445        // grad_f at (b1, b2, y) = (10, 4, 2y).
446        //
447        // The IPM-style reduced problem keeps row 2 active and var y
448        // free. We need to recover λ_0, λ_1 (for the dropped rows)
449        // and verify full-space stationarity at b1, b2 (and y holds
450        // automatically from the reduced KKT).
451        let frame = ReductionFrame::new(3, 3, vec![0, 1], vec![2.0 / 3.0, 5.0 / 3.0], vec![0, 1]);
452        // Build the full row-major Jacobian at the optimum.
453        let jac = [
454            2.0, 1.0, 0.0, // row 0
455            1.0, -1.0, 0.0, // row 1
456            1.0, 1.0, 1.0, // row 2
457        ];
458        // Objective gradient at the optimum.
459        let y_star = 8.0 / 3.0;
460        let grad_f = [10.0, 4.0, 2.0 * y_star];
461        // Reduced problem's kept-row multipliers: at the optimum,
462        // stationarity at y is 2y - λ_2 = 0 → λ_2 = 2y = 16/3.
463        let lambda_kept_2 = 2.0 * y_star;
464        let lambda_full = [0.0, 0.0, lambda_kept_2];
465
466        let lam_dropped = frame
467            .recover_dropped_multipliers(&grad_f, &jac, &lambda_full)
468            .unwrap();
469        // Reconstruct the full λ.
470        let mut lambda_recovered = lambda_full;
471        for (k, &r) in frame.dropped_rows.iter().enumerate() {
472            lambda_recovered[r] = lam_dropped[k];
473        }
474        // Verify stationarity at b1, b2 to high precision.
475        for &i in &frame.fixed_vars {
476            let mut s = grad_f[i];
477            for r in 0..3 {
478                s -= jac[r * 3 + i] * lambda_recovered[r];
479            }
480            assert!(s.abs() < 1e-12, "stationarity at var {i} = {s}");
481        }
482    }
483
484    /// Fuzz: build a synthetic full-space KKT solution `(x*, λ*)`,
485    /// declare a random subset of variables "fixed" and the
486    /// matching subset of rows "dropped", then verify the multiplier
487    /// recovery reproduces the original λ at the dropped indices to
488    /// within 1e-10.
489    struct FuzzRng(u64);
490    impl FuzzRng {
491        fn new(seed: u64) -> Self {
492            Self(seed)
493        }
494        fn next_u64(&mut self) -> u64 {
495            self.0 = self
496                .0
497                .wrapping_mul(6364136223846793005)
498                .wrapping_add(1442695040888963407);
499            self.0 >> 32
500        }
501        fn unit(&mut self) -> Number {
502            let raw = (self.next_u64() & 0x3fff_ffff) as Number;
503            raw / (1u64 << 29) as Number - 1.0
504        }
505    }
506
507    #[test]
508    fn frame_fuzz_recover_reproduces_synthetic_lambda() {
509        let mut rng = FuzzRng::new(0xface_b00c_baad_f00d);
510
511        for trial in 0..30 {
512            let n_vars = 2 + (rng.next_u64() % 3) as usize; // 2..=4
513            let n_rows = n_vars;
514            let k = 1 + (rng.next_u64() % n_vars as u64) as usize;
515
516            let mut perm_v: Vec<usize> = (0..n_vars).collect();
517            for i in (1..n_vars).rev() {
518                let j = (rng.next_u64() as usize) % (i + 1);
519                perm_v.swap(i, j);
520            }
521            let mut fixed_vars: Vec<usize> = perm_v[..k].to_vec();
522            fixed_vars.sort_unstable();
523
524            let mut perm_r: Vec<usize> = (0..n_rows).collect();
525            for i in (1..n_rows).rev() {
526                let j = (rng.next_u64() as usize) % (i + 1);
527                perm_r.swap(i, j);
528            }
529            let mut dropped_rows: Vec<usize> = perm_r[..k].to_vec();
530            dropped_rows.sort_unstable();
531
532            let mut jac = vec![0.0; n_rows * n_vars];
533            for r in 0..n_rows {
534                for c in 0..n_vars {
535                    jac[r * n_vars + c] = 0.2 * rng.unit();
536                }
537            }
538            for (&r, &c) in dropped_rows.iter().zip(fixed_vars.iter()) {
539                jac[r * n_vars + c] += 2.5;
540            }
541
542            let lambda_star: Vec<Number> = (0..n_rows).map(|_| rng.unit()).collect();
543            let mut grad_f = vec![0.0; n_vars];
544            let fixed_set: std::collections::BTreeSet<usize> = fixed_vars.iter().copied().collect();
545            for i in 0..n_vars {
546                if fixed_set.contains(&i) {
547                    let mut s = 0.0;
548                    for r in 0..n_rows {
549                        s += jac[r * n_vars + i] * lambda_star[r];
550                    }
551                    grad_f[i] = s;
552                } else {
553                    grad_f[i] = rng.unit();
554                }
555            }
556
557            let dropped_set: std::collections::BTreeSet<usize> =
558                dropped_rows.iter().copied().collect();
559            let mut lambda_given = vec![0.0; n_rows];
560            for r in 0..n_rows {
561                if !dropped_set.contains(&r) {
562                    lambda_given[r] = lambda_star[r];
563                }
564            }
565
566            let frame = ReductionFrame::new(
567                n_vars,
568                n_rows,
569                fixed_vars.clone(),
570                vec![0.0; k],
571                dropped_rows.clone(),
572            );
573
574            let lam_dropped = frame
575                .recover_dropped_multipliers(&grad_f, &jac, &lambda_given)
576                .unwrap_or_else(|e| panic!("trial {trial}: {e:?}"));
577
578            for (idx, &r) in dropped_rows.iter().enumerate() {
579                let expected = lambda_star[r];
580                let got = lam_dropped[idx];
581                assert!(
582                    (expected - got).abs() < 1e-10,
583                    "trial {trial}: λ[{r}] expected {expected:.6}, got {got:.6}"
584                );
585            }
586        }
587    }
588
589    #[test]
590    fn reduction_stack_push_top_iter() {
591        let mut stack = ReductionStack::default();
592        assert!(stack.is_empty());
593        let f1 = ReductionFrame::new(2, 2, vec![0], vec![1.0], vec![0]);
594        let f2 = ReductionFrame::new(2, 2, vec![1], vec![2.0], vec![1]);
595        stack.push(f1.clone());
596        stack.push(f2.clone());
597        assert_eq!(stack.len(), 2);
598        let top = stack.top().expect("non-empty");
599        assert_eq!(top.fixed_vars, f2.fixed_vars);
600        // top-down: f2, then f1.
601        let order: Vec<_> = stack.iter_top_down().map(|f| f.fixed_vars[0]).collect();
602        assert_eq!(order, vec![1, 0]);
603        let order_up: Vec<_> = stack.iter_bottom_up().map(|f| f.fixed_vars[0]).collect();
604        assert_eq!(order_up, vec![0, 1]);
605    }
606
607    /// PR #60 review nit: there was no paired test for `project_lambda`
608    /// + `lift_lambda` (only the directional tests). Confirm
609    /// project(lift(x)) is the identity on the reduced shape AND
610    /// lift(project(x)) zeroes the dropped indices but preserves
611    /// kept ones.
612    #[test]
613    fn frame_project_lift_lambda_roundtrip() {
614        let frame = ReductionFrame::new(4, 3, vec![0, 2], vec![1.0, 9.0], vec![0, 1]);
615        // Full lambda with arbitrary values; project then lift.
616        let lambda_full = [4.0, 5.0, 6.0];
617        let reduced = frame.project_lambda(&lambda_full);
618        // Reduced has one row (the only kept row, row 2).
619        assert_eq!(reduced, vec![6.0]);
620        // Lifting back zeroes the dropped row entries.
621        let lifted = frame.lift_lambda(&reduced);
622        assert_eq!(lifted, vec![0.0, 0.0, 6.0]);
623        // Now the other direction: project the lifted lambda back
624        // to reduced — should be the identity on reduced shape.
625        let reduced_again = frame.project_lambda(&lifted);
626        assert_eq!(reduced_again, reduced);
627    }
628
629    /// Multi-frame `ReductionStack` round-trip. Push two frames
630    /// (mutually compatible — they fix disjoint vars and drop
631    /// disjoint rows). Verify lift_x and lift_lambda compose
632    /// consistently when walked through both frames.
633    #[test]
634    fn reduction_stack_multi_frame_roundtrip() {
635        // Full shape: 4 vars, 4 rows.
636        // Frame 1 (bottom): fixes var 0 (= 10), drops row 0.
637        // Frame 2 (top):    fixes var 2 (= 30), drops row 2.
638        let f1 = ReductionFrame::new(4, 4, vec![0], vec![10.0], vec![0]);
639        let f2 = ReductionFrame::new(4, 4, vec![2], vec![30.0], vec![2]);
640        let mut stack = ReductionStack::default();
641        stack.push(f1.clone());
642        stack.push(f2.clone());
643
644        // Synthesize a "fully-lifted" x_full where the survivors
645        // (vars 1, 3) and rows (1, 3) carry known values.
646        let x_full_expected = vec![10.0, 7.0, 30.0, 5.0];
647        let lambda_full_expected = vec![0.0, 8.0, 0.0, 6.0];
648
649        // Project through both frames in bottom-up order, then
650        // lift back top-down. Result must equal original at the
651        // surviving entries (and frame-supplied values at fixed
652        // entries / zeros at dropped row indices).
653        //
654        // For this test we don't have stacked reduced shapes
655        // (each frame is independently 4-var/4-row); we just
656        // confirm each frame's lift drops the expected values
657        // back when walked individually via the stack's iterator.
658        for frame in stack.iter_top_down() {
659            let reduced_x = frame.project_x(&x_full_expected);
660            let lifted_x = frame.lift_x(&reduced_x);
661            assert_eq!(lifted_x, x_full_expected);
662            let reduced_l = frame.project_lambda(&lambda_full_expected);
663            let lifted_l = frame.lift_lambda(&reduced_l);
664            // Dropped index should be 0 in the lift; survivors
665            // preserve their values.
666            for r in 0..4 {
667                if frame.row_map[r].is_some() {
668                    assert_eq!(lifted_l[r], lambda_full_expected[r]);
669                } else {
670                    assert_eq!(lifted_l[r], 0.0);
671                }
672            }
673        }
674    }
675}