Skip to main content

pounce_presolve/fbbt/
forward.rs

1//! Forward interval pass over an [`FbbtTape`].
2//!
3//! Given current variable bounds `x_lo[i] ≤ x_i ≤ x_hi[i]` and a
4//! constraint expression tape, computes an interval over-approximation
5//! of the value at every slot of the tape. The output is a parallel
6//! `Vec<Interval>` whose last entry is the over-approximation of the
7//! whole constraint expression. The next phase (reverse pass, commit
8//! 3 of [#62]) consumes this buffer.
9//!
10//! The pass is one linear scan over `ops`; each op consults the
11//! intervals already computed for its operand slots, so the result is
12//! `O(n)` in the tape length.
13//!
14//! [`FbbtOp::Opaque`] slots collapse to [`Interval::ENTIRE`]: we have
15//! no structural information about that subexpression, so we cannot
16//! tighten anything through it.
17//!
18//! [#62]: https://github.com/jkitchin/pounce/issues/62
19//! [`FbbtTape`]: pounce_nlp::FbbtTape
20//! [`FbbtOp::Opaque`]: pounce_nlp::FbbtOp::Opaque
21
22use pounce_common::types::Number;
23use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
24
25use crate::fbbt::interval::Interval;
26
27/// Why a forward pass might fail.
28#[derive(Debug, Clone, PartialEq)]
29pub enum ForwardError {
30    /// Tape failed [`FbbtTape::first_invalid_slot`] validation. The
31    /// payload is the slot index that referenced something it
32    /// shouldn't have.
33    MalformedTape(usize),
34    /// Tape mentioned `Var(j)` with `j >= x_lo.len()`. The payload is
35    /// the offending variable index.
36    VariableIndexOutOfRange(usize),
37    /// `x_lo` and `x_hi` had different lengths.
38    BoundsLengthMismatch { lo: usize, hi: usize },
39}
40
41/// Compute the per-slot interval bag for `tape`, given current
42/// variable bounds `x_lo` and `x_hi` (parallel arrays). Returns a
43/// `Vec<Interval>` of the same length as `tape.ops`.
44///
45/// On a well-formed tape this never panics — domain violations
46/// (e.g. `ln` of a fully-negative interval) produce
47/// [`Interval::EMPTY`] in the corresponding slot, which the
48/// orchestrator interprets as "infeasibility candidate; skip
49/// tightening from this constraint."
50pub fn forward_pass(
51    tape: &FbbtTape,
52    x_lo: &[Number],
53    x_hi: &[Number],
54) -> Result<Vec<Interval>, ForwardError> {
55    if x_lo.len() != x_hi.len() {
56        return Err(ForwardError::BoundsLengthMismatch {
57            lo: x_lo.len(),
58            hi: x_hi.len(),
59        });
60    }
61    if let Some(bad) = tape.first_invalid_slot() {
62        return Err(ForwardError::MalformedTape(bad));
63    }
64
65    let n_vars = x_lo.len();
66    let mut vals: Vec<Interval> = Vec::with_capacity(tape.ops.len());
67    for op in &tape.ops {
68        let v = match *op {
69            FbbtOp::Const(c) => Interval::point(c),
70            FbbtOp::Var(i) => {
71                if i >= n_vars {
72                    return Err(ForwardError::VariableIndexOutOfRange(i));
73                }
74                Interval::new(x_lo[i], x_hi[i])
75            }
76            FbbtOp::Opaque => Interval::ENTIRE,
77            FbbtOp::Add(a, b) => vals[a].add(vals[b]),
78            FbbtOp::Sub(a, b) => vals[a].sub(vals[b]),
79            FbbtOp::Mul(a, b) => vals[a].mul(vals[b]),
80            FbbtOp::Div(a, b) => vals[a].div(vals[b]),
81            FbbtOp::PowInt(a, n) => vals[a].pow_uint(n),
82            FbbtOp::Neg(a) => vals[a].neg(),
83            FbbtOp::Sqrt(a) => vals[a].sqrt(),
84            FbbtOp::Exp(a) => vals[a].exp(),
85            FbbtOp::Ln(a) => vals[a].ln(),
86            FbbtOp::Abs(a) => vals[a].abs(),
87            FbbtOp::Sin(a) => vals[a].sin(),
88            FbbtOp::Cos(a) => vals[a].cos(),
89        };
90        vals.push(v);
91    }
92    Ok(vals)
93}
94
95/// Final result of [`forward_pass`] — the interval enclosing the
96/// whole constraint expression. Empty tape returns
97/// [`Interval::ENTIRE`].
98pub fn forward_result(vals: &[Interval]) -> Interval {
99    vals.last().copied().unwrap_or(Interval::ENTIRE)
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    /// `2 * x` for `x ∈ [-1, 3]`. Result: `[-2, 6]`.
107    #[test]
108    fn simple_linear_combination() {
109        let tape = FbbtTape {
110            ops: vec![FbbtOp::Const(2.0), FbbtOp::Var(0), FbbtOp::Mul(0, 1)],
111        };
112        let vals = forward_pass(&tape, &[-1.0], &[3.0]).unwrap();
113        let res = forward_result(&vals);
114        assert!(res.contains(-2.0));
115        assert!(res.contains(6.0));
116    }
117
118    /// `x^2 + y^2` for `x ∈ [-2, 1]`, `y ∈ [0, 3]`.
119    /// `x^2 ∈ [0, 4]`, `y^2 ∈ [0, 9]` → sum `[0, 13]`.
120    #[test]
121    fn quadratic_sum() {
122        let tape = FbbtTape {
123            ops: vec![
124                FbbtOp::Var(0),       // x
125                FbbtOp::PowInt(0, 2), // x^2
126                FbbtOp::Var(1),       // y
127                FbbtOp::PowInt(2, 2), // y^2
128                FbbtOp::Add(1, 3),    // x^2 + y^2
129            ],
130        };
131        let vals = forward_pass(&tape, &[-2.0, 0.0], &[1.0, 3.0]).unwrap();
132        let res = forward_result(&vals);
133        assert!(res.contains(0.0), "should contain min");
134        assert!(res.contains(13.0), "should contain max");
135        // Outward rounding may give slightly looser bounds.
136        assert!(res.lo <= 0.0);
137        assert!(res.hi >= 13.0);
138    }
139
140    /// `exp(x)` for `x ∈ [0, 1]` → `[1, e]`.
141    #[test]
142    fn exp_monotone() {
143        let tape = FbbtTape {
144            ops: vec![FbbtOp::Var(0), FbbtOp::Exp(0)],
145        };
146        let vals = forward_pass(&tape, &[0.0], &[1.0]).unwrap();
147        let res = forward_result(&vals);
148        assert!(res.contains(1.0));
149        assert!(res.contains(std::f64::consts::E));
150    }
151
152    /// `ln(x)` for `x ∈ [-1, 4]` — domain straddles zero.
153    /// Forward pass should produce `[-∞, ln(4)]` (we clip `lo` to 0
154    /// inside `ln`).
155    #[test]
156    fn ln_domain_clip() {
157        let tape = FbbtTape {
158            ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
159        };
160        let vals = forward_pass(&tape, &[-1.0], &[4.0]).unwrap();
161        let res = forward_result(&vals);
162        assert_eq!(res.lo, Number::NEG_INFINITY);
163        assert!(res.hi >= std::f64::consts::LN_2 * 2.0);
164    }
165
166    /// `ln(x)` for `x ∈ [-3, -1]` — fully outside the domain. The
167    /// interval is EMPTY, signalling infeasibility from this branch.
168    #[test]
169    fn ln_fully_outside_domain_is_empty() {
170        let tape = FbbtTape {
171            ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
172        };
173        let vals = forward_pass(&tape, &[-3.0], &[-1.0]).unwrap();
174        let res = forward_result(&vals);
175        assert!(res.is_empty());
176    }
177
178    /// CSE-style sharing: a single `x` slot reused by two ops should
179    /// still produce the correct interval (the tape representation
180    /// inherently shares — every reference to slot 0 sees the same
181    /// `[x_lo, x_hi]` interval).
182    #[test]
183    fn cse_via_tape_slot_sharing() {
184        // x * x for x ∈ [-2, 3] should give [-6, 9] via Mul (not
185        // [0, 9] like PowInt(2) would). This is the natural over-
186        // approximation when sharing is structural, not symbolic.
187        let tape = FbbtTape {
188            ops: vec![FbbtOp::Var(0), FbbtOp::Mul(0, 0)],
189        };
190        let vals = forward_pass(&tape, &[-2.0], &[3.0]).unwrap();
191        let res = forward_result(&vals);
192        // Looser than PowInt: this is expected — the interval
193        // arithmetic forgets the shared-operand constraint.
194        assert!(res.contains(-6.0));
195        assert!(res.contains(9.0));
196    }
197
198    /// Opaque slot → ENTIRE.
199    #[test]
200    fn opaque_yields_entire() {
201        let tape = FbbtTape {
202            ops: vec![FbbtOp::Var(0), FbbtOp::Opaque, FbbtOp::Add(0, 1)],
203        };
204        let vals = forward_pass(&tape, &[1.0], &[2.0]).unwrap();
205        let res = forward_result(&vals);
206        // [1, 2] + ENTIRE = ENTIRE.
207        assert_eq!(res.lo, Number::NEG_INFINITY);
208        assert_eq!(res.hi, Number::INFINITY);
209    }
210
211    #[test]
212    fn empty_tape_yields_entire() {
213        let tape = FbbtTape::new();
214        let vals = forward_pass(&tape, &[], &[]).unwrap();
215        assert!(vals.is_empty());
216        let res = forward_result(&vals);
217        assert!(res.is_entire());
218    }
219
220    #[test]
221    fn malformed_tape_rejected() {
222        let tape = FbbtTape {
223            ops: vec![FbbtOp::Add(0, 1), FbbtOp::Const(0.0)],
224        };
225        let err = forward_pass(&tape, &[], &[]).unwrap_err();
226        assert_eq!(err, ForwardError::MalformedTape(0));
227    }
228
229    #[test]
230    fn out_of_range_var_rejected() {
231        let tape = FbbtTape {
232            ops: vec![FbbtOp::Var(2)],
233        };
234        let err = forward_pass(&tape, &[0.0], &[1.0]).unwrap_err();
235        assert_eq!(err, ForwardError::VariableIndexOutOfRange(2));
236    }
237
238    #[test]
239    fn mismatched_bounds_lengths_rejected() {
240        let tape = FbbtTape {
241            ops: vec![FbbtOp::Const(0.0)],
242        };
243        let err = forward_pass(&tape, &[0.0], &[1.0, 2.0]).unwrap_err();
244        assert!(matches!(err, ForwardError::BoundsLengthMismatch { .. }));
245    }
246
247    /// Soundness: for many sample points inside the variable box,
248    /// the constraint value must fall inside the forward-pass result.
249    #[test]
250    fn fuzz_soundness_pointwise() {
251        // f(x, y) = (x - 1) * (y + 2) + sqrt(x + 10)
252        let tape = FbbtTape {
253            ops: vec![
254                FbbtOp::Var(0), // x
255                FbbtOp::Const(1.0),
256                FbbtOp::Sub(0, 1), // x - 1
257                FbbtOp::Var(1),    // y
258                FbbtOp::Const(2.0),
259                FbbtOp::Add(3, 4), // y + 2
260                FbbtOp::Mul(2, 5), // (x-1)*(y+2)
261                FbbtOp::Const(10.0),
262                FbbtOp::Add(0, 7), // x + 10
263                FbbtOp::Sqrt(8),   // sqrt(x+10)
264                FbbtOp::Add(6, 9), // f
265            ],
266        };
267        let x_lo = [-2.0, -1.0];
268        let x_hi = [3.0, 5.0];
269        let res = forward_result(&forward_pass(&tape, &x_lo, &x_hi).unwrap());
270
271        // 25 sample points on a 5x5 grid.
272        for ix in 0..5 {
273            for iy in 0..5 {
274                let x = x_lo[0] + (x_hi[0] - x_lo[0]) * (ix as f64) / 4.0;
275                let y = x_lo[1] + (x_hi[1] - x_lo[1]) * (iy as f64) / 4.0;
276                let f = (x - 1.0) * (y + 2.0) + (x + 10.0).sqrt();
277                assert!(res.contains(f), "x={x}, y={y}, f={f} not in {:?}", res);
278            }
279        }
280    }
281}