use pounce_common::types::Number;
use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
use crate::fbbt::interval::Interval;
#[derive(Debug, Clone, PartialEq)]
pub enum ForwardError {
MalformedTape(usize),
VariableIndexOutOfRange(usize),
BoundsLengthMismatch { lo: usize, hi: usize },
}
pub fn forward_pass(
tape: &FbbtTape,
x_lo: &[Number],
x_hi: &[Number],
) -> Result<Vec<Interval>, ForwardError> {
if x_lo.len() != x_hi.len() {
return Err(ForwardError::BoundsLengthMismatch {
lo: x_lo.len(),
hi: x_hi.len(),
});
}
if let Some(bad) = tape.first_invalid_slot() {
return Err(ForwardError::MalformedTape(bad));
}
let n_vars = x_lo.len();
let mut vals: Vec<Interval> = Vec::with_capacity(tape.ops.len());
for op in &tape.ops {
let v = match *op {
FbbtOp::Const(c) => Interval::point(c),
FbbtOp::Var(i) => {
if i >= n_vars {
return Err(ForwardError::VariableIndexOutOfRange(i));
}
Interval::new(x_lo[i], x_hi[i])
}
FbbtOp::Opaque => Interval::ENTIRE,
FbbtOp::Add(a, b) => vals[a].add(vals[b]),
FbbtOp::Sub(a, b) => vals[a].sub(vals[b]),
FbbtOp::Mul(a, b) => vals[a].mul(vals[b]),
FbbtOp::Div(a, b) => vals[a].div(vals[b]),
FbbtOp::PowInt(a, n) => vals[a].pow_uint(n),
FbbtOp::Neg(a) => vals[a].neg(),
FbbtOp::Sqrt(a) => vals[a].sqrt(),
FbbtOp::Exp(a) => vals[a].exp(),
FbbtOp::Ln(a) => vals[a].ln(),
FbbtOp::Abs(a) => vals[a].abs(),
FbbtOp::Sin(a) => vals[a].sin(),
FbbtOp::Cos(a) => vals[a].cos(),
};
vals.push(v);
}
Ok(vals)
}
pub fn forward_result(vals: &[Interval]) -> Interval {
vals.last().copied().unwrap_or(Interval::ENTIRE)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_linear_combination() {
let tape = FbbtTape {
ops: vec![FbbtOp::Const(2.0), FbbtOp::Var(0), FbbtOp::Mul(0, 1)],
};
let vals = forward_pass(&tape, &[-1.0], &[3.0]).unwrap();
let res = forward_result(&vals);
assert!(res.contains(-2.0));
assert!(res.contains(6.0));
}
#[test]
fn quadratic_sum() {
let tape = FbbtTape {
ops: vec![
FbbtOp::Var(0), FbbtOp::PowInt(0, 2), FbbtOp::Var(1), FbbtOp::PowInt(2, 2), FbbtOp::Add(1, 3), ],
};
let vals = forward_pass(&tape, &[-2.0, 0.0], &[1.0, 3.0]).unwrap();
let res = forward_result(&vals);
assert!(res.contains(0.0), "should contain min");
assert!(res.contains(13.0), "should contain max");
assert!(res.lo <= 0.0);
assert!(res.hi >= 13.0);
}
#[test]
fn exp_monotone() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Exp(0)],
};
let vals = forward_pass(&tape, &[0.0], &[1.0]).unwrap();
let res = forward_result(&vals);
assert!(res.contains(1.0));
assert!(res.contains(std::f64::consts::E));
}
#[test]
fn ln_domain_clip() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
};
let vals = forward_pass(&tape, &[-1.0], &[4.0]).unwrap();
let res = forward_result(&vals);
assert_eq!(res.lo, Number::NEG_INFINITY);
assert!(res.hi >= std::f64::consts::LN_2 * 2.0);
}
#[test]
fn ln_fully_outside_domain_is_empty() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
};
let vals = forward_pass(&tape, &[-3.0], &[-1.0]).unwrap();
let res = forward_result(&vals);
assert!(res.is_empty());
}
#[test]
fn cse_via_tape_slot_sharing() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Mul(0, 0)],
};
let vals = forward_pass(&tape, &[-2.0], &[3.0]).unwrap();
let res = forward_result(&vals);
assert!(res.contains(-6.0));
assert!(res.contains(9.0));
}
#[test]
fn opaque_yields_entire() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Opaque, FbbtOp::Add(0, 1)],
};
let vals = forward_pass(&tape, &[1.0], &[2.0]).unwrap();
let res = forward_result(&vals);
assert_eq!(res.lo, Number::NEG_INFINITY);
assert_eq!(res.hi, Number::INFINITY);
}
#[test]
fn empty_tape_yields_entire() {
let tape = FbbtTape::new();
let vals = forward_pass(&tape, &[], &[]).unwrap();
assert!(vals.is_empty());
let res = forward_result(&vals);
assert!(res.is_entire());
}
#[test]
fn malformed_tape_rejected() {
let tape = FbbtTape {
ops: vec![FbbtOp::Add(0, 1), FbbtOp::Const(0.0)],
};
let err = forward_pass(&tape, &[], &[]).unwrap_err();
assert_eq!(err, ForwardError::MalformedTape(0));
}
#[test]
fn out_of_range_var_rejected() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(2)],
};
let err = forward_pass(&tape, &[0.0], &[1.0]).unwrap_err();
assert_eq!(err, ForwardError::VariableIndexOutOfRange(2));
}
#[test]
fn mismatched_bounds_lengths_rejected() {
let tape = FbbtTape {
ops: vec![FbbtOp::Const(0.0)],
};
let err = forward_pass(&tape, &[0.0], &[1.0, 2.0]).unwrap_err();
assert!(matches!(err, ForwardError::BoundsLengthMismatch { .. }));
}
#[test]
fn fuzz_soundness_pointwise() {
let tape = FbbtTape {
ops: vec![
FbbtOp::Var(0), FbbtOp::Const(1.0),
FbbtOp::Sub(0, 1), FbbtOp::Var(1), FbbtOp::Const(2.0),
FbbtOp::Add(3, 4), FbbtOp::Mul(2, 5), FbbtOp::Const(10.0),
FbbtOp::Add(0, 7), FbbtOp::Sqrt(8), FbbtOp::Add(6, 9), ],
};
let x_lo = [-2.0, -1.0];
let x_hi = [3.0, 5.0];
let res = forward_result(&forward_pass(&tape, &x_lo, &x_hi).unwrap());
for ix in 0..5 {
for iy in 0..5 {
let x = x_lo[0] + (x_hi[0] - x_lo[0]) * (ix as f64) / 4.0;
let y = x_lo[1] + (x_hi[1] - x_lo[1]) * (iy as f64) / 4.0;
let f = (x - 1.0) * (y + 2.0) + (x + 10.0).sqrt();
assert!(res.contains(f), "x={x}, y={y}, f={f} not in {:?}", res);
}
}
}
}