use pounce_common::types::Number;
use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
use crate::fbbt::interval::Interval;
#[derive(Debug, Clone, PartialEq)]
pub struct ReverseResult {
pub slots: Vec<Interval>,
pub infeasible: bool,
}
pub fn reverse_pass(tape: &FbbtTape, forward: &[Interval], con_bound: Interval) -> ReverseResult {
assert_eq!(
forward.len(),
tape.ops.len(),
"forward bag length must match tape"
);
if tape.ops.is_empty() {
return ReverseResult {
slots: Vec::new(),
infeasible: con_bound.is_empty(),
};
}
let mut slots = forward.to_vec();
let root_idx = slots.len() - 1;
let new_root = slots[root_idx].intersect(con_bound);
if new_root.is_empty() {
return ReverseResult {
slots,
infeasible: true,
};
}
slots[root_idx] = new_root;
for i in (0..tape.ops.len()).rev() {
let parent = slots[i];
if parent.is_empty() {
return ReverseResult {
slots,
infeasible: true,
};
}
apply_inverse(&tape.ops[i], parent, &mut slots);
}
ReverseResult {
slots,
infeasible: false,
}
}
fn apply_inverse(op: &FbbtOp, parent: Interval, slots: &mut [Interval]) {
match *op {
FbbtOp::Const(_) | FbbtOp::Var(_) | FbbtOp::Opaque => {
}
FbbtOp::Add(a, b) => {
let ai = slots[a];
let bi = slots[b];
slots[a] = ai.intersect(parent.sub(bi));
slots[b] = bi.intersect(parent.sub(slots[a]));
}
FbbtOp::Sub(a, b) => {
let ai = slots[a];
let bi = slots[b];
slots[a] = ai.intersect(parent.add(bi));
slots[b] = bi.intersect(slots[a].sub(parent));
}
FbbtOp::Mul(a, b) => {
let ai = slots[a];
let bi = slots[b];
if !bi.contains_zero() {
slots[a] = ai.intersect(parent.div(bi));
}
let ai2 = slots[a];
if !ai2.contains_zero() {
slots[b] = bi.intersect(parent.div(ai2));
}
}
FbbtOp::Div(a, b) => {
let ai = slots[a];
let bi = slots[b];
slots[a] = ai.intersect(parent.mul(bi));
if !parent.contains_zero() {
slots[b] = bi.intersect(slots[a].div(parent));
}
}
FbbtOp::Neg(a) => {
let ai = slots[a];
slots[a] = ai.intersect(parent.neg());
}
FbbtOp::Sqrt(a) => {
let ai = slots[a];
let z_pos = parent.intersect(Interval::new(0.0, Number::INFINITY));
if z_pos.is_empty() {
slots[a] = Interval::EMPTY;
} else {
slots[a] = ai.intersect(z_pos.pow_uint(2));
}
}
FbbtOp::Exp(a) => {
let ai = slots[a];
let z_pos = parent.intersect(Interval::new(0.0, Number::INFINITY));
if z_pos.is_empty() || z_pos.hi <= 0.0 {
slots[a] = Interval::EMPTY;
} else {
slots[a] = ai.intersect(z_pos.ln());
}
}
FbbtOp::Ln(a) => {
let ai = slots[a];
slots[a] = ai.intersect(parent.exp());
}
FbbtOp::Abs(a) => {
let ai = slots[a];
let z_nonneg = parent.intersect(Interval::new(0.0, Number::INFINITY));
if z_nonneg.is_empty() {
slots[a] = Interval::EMPTY;
} else {
let envelope = Interval::new(-z_nonneg.hi, z_nonneg.hi);
slots[a] = ai.intersect(envelope);
}
}
FbbtOp::PowInt(a, n) => {
let ai = slots[a];
slots[a] = ai.intersect(inverse_powint(parent, n, ai));
}
FbbtOp::Sin(_) | FbbtOp::Cos(_) => {
}
}
}
fn inverse_powint(z: Interval, n: u32, prior_a: Interval) -> Interval {
if z.is_empty() {
return Interval::EMPTY;
}
if n == 0 {
return Interval::ENTIRE;
}
if n == 1 {
return z;
}
if n % 2 == 1 {
let lo = signed_nth_root(z.lo, n);
let hi = signed_nth_root(z.hi, n);
Interval::new(lo, hi)
} else {
let z_pos = z.intersect(Interval::new(0.0, Number::INFINITY));
if z_pos.is_empty() {
return Interval::EMPTY;
}
let abs_lo = z_pos.lo.powf(1.0 / n as f64);
let abs_hi = z_pos.hi.powf(1.0 / n as f64);
let pos_branch = Interval::new(abs_lo, abs_hi);
let neg_branch = Interval::new(-abs_hi, -abs_lo);
let pos_hit = !prior_a.intersect(pos_branch).is_empty();
let neg_hit = !prior_a.intersect(neg_branch).is_empty();
match (pos_hit, neg_hit) {
(true, false) => pos_branch,
(false, true) => neg_branch,
(true, true) => Interval::new(-abs_hi, abs_hi),
(false, false) => Interval::EMPTY,
}
}
}
fn signed_nth_root(x: Number, n: u32) -> Number {
if !x.is_finite() {
return x;
}
let mag = x.abs().powf(1.0 / n as f64);
if x < 0.0 {
-mag
} else {
mag
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run(tape: &FbbtTape, forward: &[Interval], bound: Interval) -> ReverseResult {
reverse_pass(tape, forward, bound)
}
#[test]
fn add_constant_tightens() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Const(1.0), FbbtOp::Add(0, 1)],
};
let forward = vec![
Interval::new(-10.0, 10.0),
Interval::point(1.0),
Interval::new(-9.0, 11.0),
];
let bound = Interval::new(2.0, 4.0);
let r = run(&tape, &forward, bound);
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= 1.0 - 1e-12, "v0.lo = {}", v0.lo);
assert!(v0.hi <= 3.0 + 1e-12, "v0.hi = {}", v0.hi);
}
#[test]
fn mul_constant_tightens() {
let tape = FbbtTape {
ops: vec![FbbtOp::Const(2.0), FbbtOp::Var(0), FbbtOp::Mul(0, 1)],
};
let forward = vec![
Interval::point(2.0),
Interval::new(-100.0, 100.0),
Interval::new(-200.0, 200.0),
];
let bound = Interval::new(4.0, 10.0);
let r = run(&tape, &forward, bound);
assert!(!r.infeasible);
let v1 = r.slots[1];
assert!(v1.lo >= 2.0 - 1e-12);
assert!(v1.hi <= 5.0 + 1e-12);
}
#[test]
fn even_pow_picks_negative_branch() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::PowInt(0, 2)],
};
let forward = vec![Interval::new(-10.0, 0.0), Interval::new(0.0, 100.0)];
let r = run(&tape, &forward, Interval::new(4.0, 9.0));
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= -3.0 - 1e-9, "got {}", v0.lo);
assert!(v0.hi <= -2.0 + 1e-9, "got {}", v0.hi);
}
#[test]
fn odd_pow_inverts_monotonically() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::PowInt(0, 3)],
};
let forward = vec![Interval::new(-100.0, 100.0), Interval::new(-1e6, 1e6)];
let r = run(&tape, &forward, Interval::new(-8.0, 27.0));
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= -2.0 - 1e-9, "got {}", v0.lo);
assert!(v0.hi <= 3.0 + 1e-9, "got {}", v0.hi);
}
#[test]
fn sqrt_inverse() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Sqrt(0)],
};
let forward = vec![Interval::new(-10.0, 100.0), Interval::new(0.0, 10.0)];
let r = run(&tape, &forward, Interval::new(1.0, 2.0));
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= 1.0 - 1e-12);
assert!(v0.hi <= 4.0 + 1e-12);
}
#[test]
fn exp_inverse() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Exp(0)],
};
let forward = vec![Interval::new(-10.0, 10.0), Interval::new(0.0, 1.0e5)];
let r = run(&tape, &forward, Interval::new(1.0, std::f64::consts::E));
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= 0.0 - 1e-12);
assert!(v0.hi <= 1.0 + 1e-12);
}
#[test]
fn ln_inverse() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Ln(0)],
};
let forward = vec![Interval::new(0.5, 100.0), Interval::new(-1.0, 5.0)];
let r = run(&tape, &forward, Interval::new(0.0, 1.0));
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= 1.0 - 1e-12);
assert!(v0.hi <= std::f64::consts::E + 1e-12);
}
#[test]
fn abs_inverse_envelope() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Abs(0)],
};
let forward = vec![Interval::new(-10.0, 10.0), Interval::new(0.0, 10.0)];
let r = run(&tape, &forward, Interval::new(0.0, 2.0));
assert!(!r.infeasible);
let v0 = r.slots[0];
assert!(v0.lo >= -2.0 - 1e-12);
assert!(v0.hi <= 2.0 + 1e-12);
}
#[test]
fn add_already_tight_does_not_widen() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Var(1), FbbtOp::Add(0, 1)],
};
let forward = vec![
Interval::new(0.0, 1.0),
Interval::new(0.0, 1.0),
Interval::new(0.0, 2.0),
];
let r = run(&tape, &forward, Interval::point(1.0));
assert!(!r.infeasible);
assert!(r.slots[0].lo >= 0.0 && r.slots[0].hi <= 1.0);
assert!(r.slots[1].lo >= 0.0 && r.slots[1].hi <= 1.0);
}
#[test]
fn root_disjoint_from_bound_is_infeasible() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0)],
};
let forward = vec![Interval::new(10.0, 20.0)];
let r = run(&tape, &forward, Interval::new(1.0, 5.0));
assert!(r.infeasible);
}
#[test]
fn opaque_does_not_propagate() {
let tape = FbbtTape {
ops: vec![FbbtOp::Var(0), FbbtOp::Opaque, FbbtOp::Add(0, 1)],
};
let forward = vec![Interval::new(0.0, 10.0), Interval::ENTIRE, Interval::ENTIRE];
let r = run(&tape, &forward, Interval::new(5.0, 5.0));
assert!(!r.infeasible);
assert_eq!(r.slots[0], Interval::new(0.0, 10.0));
}
#[test]
fn fuzz_no_overtightening_quadratic_sum() {
let tape = FbbtTape {
ops: vec![
FbbtOp::Var(0),
FbbtOp::PowInt(0, 2),
FbbtOp::Var(1),
FbbtOp::PowInt(2, 2),
FbbtOp::Add(1, 3),
],
};
let forward =
crate::fbbt::forward::forward_pass(&tape, &[-3.0, -3.0], &[3.0, 3.0]).unwrap();
let r = run(&tape, &forward, Interval::point(5.0));
assert!(!r.infeasible);
let var0 = r.slots[0];
let var1 = r.slots[2];
let n_samples = 36;
for k in 0..n_samples {
let theta = (k as Number) * std::f64::consts::TAU / (n_samples as Number);
let x = (5.0_f64).sqrt() * theta.cos();
let y = (5.0_f64).sqrt() * theta.sin();
assert!(
var0.contains(x),
"x={x:.3} not in {:?} (theta={theta})",
var0
);
assert!(
var1.contains(y),
"y={y:.3} not in {:?} (theta={theta})",
var1
);
}
}
}