1use pounce_common::types::Number;
23use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
24
25use crate::fbbt::interval::Interval;
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum ForwardError {
30 MalformedTape(usize),
34 VariableIndexOutOfRange(usize),
37 BoundsLengthMismatch { lo: usize, hi: usize },
39}
40
41pub fn forward_pass(
51 tape: &FbbtTape,
52 x_lo: &[Number],
53 x_hi: &[Number],
54) -> Result<Vec<Interval>, ForwardError> {
55 if x_lo.len() != x_hi.len() {
56 return Err(ForwardError::BoundsLengthMismatch {
57 lo: x_lo.len(),
58 hi: x_hi.len(),
59 });
60 }
61 if let Some(bad) = tape.first_invalid_slot() {
62 return Err(ForwardError::MalformedTape(bad));
63 }
64
65 let n_vars = x_lo.len();
66 let mut vals: Vec<Interval> = Vec::with_capacity(tape.ops.len());
67 for op in &tape.ops {
68 let v = match *op {
69 FbbtOp::Const(c) => Interval::point(c),
70 FbbtOp::Var(i) => {
71 if i >= n_vars {
72 return Err(ForwardError::VariableIndexOutOfRange(i));
73 }
74 Interval::new(x_lo[i], x_hi[i])
75 }
76 FbbtOp::Opaque => Interval::ENTIRE,
77 FbbtOp::Add(a, b) => vals[a].add(vals[b]),
78 FbbtOp::Sub(a, b) => vals[a].sub(vals[b]),
79 FbbtOp::Mul(a, b) => vals[a].mul(vals[b]),
80 FbbtOp::Div(a, b) => vals[a].div(vals[b]),
81 FbbtOp::PowInt(a, n) => vals[a].pow_uint(n),
82 FbbtOp::Neg(a) => vals[a].neg(),
83 FbbtOp::Sqrt(a) => vals[a].sqrt(),
84 FbbtOp::Exp(a) => vals[a].exp(),
85 FbbtOp::Ln(a) => vals[a].ln(),
86 FbbtOp::Abs(a) => vals[a].abs(),
87 FbbtOp::Sin(a) => vals[a].sin(),
88 FbbtOp::Cos(a) => vals[a].cos(),
89 };
90 vals.push(v);
91 }
92 Ok(vals)
93}
94
95pub fn forward_result(vals: &[Interval]) -> Interval {
99 vals.last().copied().unwrap_or(Interval::ENTIRE)
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
108 fn simple_linear_combination() {
109 let tape = FbbtTape {
110 ops: vec![FbbtOp::Const(2.0), FbbtOp::Var(0), FbbtOp::Mul(0, 1)],
111 };
112 let vals = forward_pass(&tape, &[-1.0], &[3.0]).unwrap();
113 let res = forward_result(&vals);
114 assert!(res.contains(-2.0));
115 assert!(res.contains(6.0));
116 }
117
118 #[test]
121 fn quadratic_sum() {
122 let tape = FbbtTape {
123 ops: vec![
124 FbbtOp::Var(0), FbbtOp::PowInt(0, 2), FbbtOp::Var(1), FbbtOp::PowInt(2, 2), FbbtOp::Add(1, 3), ],
130 };
131 let vals = forward_pass(&tape, &[-2.0, 0.0], &[1.0, 3.0]).unwrap();
132 let res = forward_result(&vals);
133 assert!(res.contains(0.0), "should contain min");
134 assert!(res.contains(13.0), "should contain max");
135 assert!(res.lo <= 0.0);
137 assert!(res.hi >= 13.0);
138 }
139
140 #[test]
142 fn exp_monotone() {
143 let tape = FbbtTape {
144 ops: vec![FbbtOp::Var(0), FbbtOp::Exp(0)],
145 };
146 let vals = forward_pass(&tape, &[0.0], &[1.0]).unwrap();
147 let res = forward_result(&vals);
148 assert!(res.contains(1.0));
149 assert!(res.contains(std::f64::consts::E));
150 }
151
152 #[test]
156 fn ln_domain_clip() {
157 let tape = FbbtTape {
158 ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
159 };
160 let vals = forward_pass(&tape, &[-1.0], &[4.0]).unwrap();
161 let res = forward_result(&vals);
162 assert_eq!(res.lo, Number::NEG_INFINITY);
163 assert!(res.hi >= std::f64::consts::LN_2 * 2.0);
164 }
165
166 #[test]
169 fn ln_fully_outside_domain_is_empty() {
170 let tape = FbbtTape {
171 ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
172 };
173 let vals = forward_pass(&tape, &[-3.0], &[-1.0]).unwrap();
174 let res = forward_result(&vals);
175 assert!(res.is_empty());
176 }
177
178 #[test]
183 fn cse_via_tape_slot_sharing() {
184 let tape = FbbtTape {
188 ops: vec![FbbtOp::Var(0), FbbtOp::Mul(0, 0)],
189 };
190 let vals = forward_pass(&tape, &[-2.0], &[3.0]).unwrap();
191 let res = forward_result(&vals);
192 assert!(res.contains(-6.0));
195 assert!(res.contains(9.0));
196 }
197
198 #[test]
200 fn opaque_yields_entire() {
201 let tape = FbbtTape {
202 ops: vec![FbbtOp::Var(0), FbbtOp::Opaque, FbbtOp::Add(0, 1)],
203 };
204 let vals = forward_pass(&tape, &[1.0], &[2.0]).unwrap();
205 let res = forward_result(&vals);
206 assert_eq!(res.lo, Number::NEG_INFINITY);
208 assert_eq!(res.hi, Number::INFINITY);
209 }
210
211 #[test]
212 fn empty_tape_yields_entire() {
213 let tape = FbbtTape::new();
214 let vals = forward_pass(&tape, &[], &[]).unwrap();
215 assert!(vals.is_empty());
216 let res = forward_result(&vals);
217 assert!(res.is_entire());
218 }
219
220 #[test]
221 fn malformed_tape_rejected() {
222 let tape = FbbtTape {
223 ops: vec![FbbtOp::Add(0, 1), FbbtOp::Const(0.0)],
224 };
225 let err = forward_pass(&tape, &[], &[]).unwrap_err();
226 assert_eq!(err, ForwardError::MalformedTape(0));
227 }
228
229 #[test]
230 fn out_of_range_var_rejected() {
231 let tape = FbbtTape {
232 ops: vec![FbbtOp::Var(2)],
233 };
234 let err = forward_pass(&tape, &[0.0], &[1.0]).unwrap_err();
235 assert_eq!(err, ForwardError::VariableIndexOutOfRange(2));
236 }
237
238 #[test]
239 fn mismatched_bounds_lengths_rejected() {
240 let tape = FbbtTape {
241 ops: vec![FbbtOp::Const(0.0)],
242 };
243 let err = forward_pass(&tape, &[0.0], &[1.0, 2.0]).unwrap_err();
244 assert!(matches!(err, ForwardError::BoundsLengthMismatch { .. }));
245 }
246
247 #[test]
250 fn fuzz_soundness_pointwise() {
251 let tape = FbbtTape {
253 ops: vec![
254 FbbtOp::Var(0), FbbtOp::Const(1.0),
256 FbbtOp::Sub(0, 1), FbbtOp::Var(1), FbbtOp::Const(2.0),
259 FbbtOp::Add(3, 4), FbbtOp::Mul(2, 5), FbbtOp::Const(10.0),
262 FbbtOp::Add(0, 7), FbbtOp::Sqrt(8), FbbtOp::Add(6, 9), ],
266 };
267 let x_lo = [-2.0, -1.0];
268 let x_hi = [3.0, 5.0];
269 let res = forward_result(&forward_pass(&tape, &x_lo, &x_hi).unwrap());
270
271 for ix in 0..5 {
273 for iy in 0..5 {
274 let x = x_lo[0] + (x_hi[0] - x_lo[0]) * (ix as f64) / 4.0;
275 let y = x_lo[1] + (x_hi[1] - x_lo[1]) * (iy as f64) / 4.0;
276 let f = (x - 1.0) * (y + 2.0) + (x + 10.0).sqrt();
277 assert!(res.contains(f), "x={x}, y={y}, f={f} not in {:?}", res);
278 }
279 }
280 }
281}