use oxieml::{Canonical, EmlConstraint, EmlSmtSolver, EmlTree, SmtResult};
use scirs2_core::ndarray::Array2;
const EQ_TOL: f64 = 1e-9;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Verdict {
Proven,
Counterexample,
Unknown,
}
fn eval_point(tree: &EmlTree, xs: &[f64]) -> Option<f64> {
let row = Array2::from_shape_vec((1, xs.len()), xs.to_vec()).ok()?;
crate::forest::eval_tree(tree, &row).ok().map(|y| y[0])
}
fn verified_witness(
constraint: &EmlConstraint,
tree: &EmlTree,
bounds: &[(f64, f64)],
violates: impl Fn(f64) -> bool,
) -> bool {
let solver = EmlSmtSolver::new(bounds.to_vec());
let Ok(SmtResult::Sat(sol)) = solver.check_sat(constraint) else {
return false;
};
let xs: Vec<f64> = bounds
.iter()
.enumerate()
.map(|(i, &(lo, hi))| {
sol.assignments
.get(i)
.copied()
.unwrap_or((lo + hi) * 0.5)
.clamp(lo, hi)
})
.collect();
eval_point(tree, &xs).is_some_and(violates)
}
#[must_use]
pub fn prove_lower_bound(tree: &EmlTree, c: f64, bounds: &[(f64, f64)]) -> Verdict {
let (lo, _hi) = crate::analyze::certified_range(tree, bounds);
if lo >= c {
return Verdict::Proven;
}
if verified_witness(
&EmlConstraint::lt(tree.clone(), EmlTree::const_val(c)),
tree,
bounds,
|v| v < c,
) {
return Verdict::Counterexample;
}
Verdict::Unknown
}
#[must_use]
pub fn prove_upper_bound(tree: &EmlTree, c: f64, bounds: &[(f64, f64)]) -> Verdict {
let (_lo, hi) = crate::analyze::certified_range(tree, bounds);
if hi <= c {
return Verdict::Proven;
}
if verified_witness(
&EmlConstraint::gt(tree.clone(), EmlTree::const_val(c)),
tree,
bounds,
|v| v > c,
) {
return Verdict::Counterexample;
}
Verdict::Unknown
}
#[must_use]
pub fn prove_positive(tree: &EmlTree, bounds: &[(f64, f64)]) -> Verdict {
let (lo, _hi) = crate::analyze::certified_range(tree, bounds);
if lo > 0.0 {
return Verdict::Proven;
}
if verified_witness(
&EmlConstraint::le(tree.clone(), EmlTree::const_val(0.0)),
tree,
bounds,
|v| v <= 0.0,
) {
return Verdict::Counterexample;
}
Verdict::Unknown
}
#[must_use]
pub fn prove_negative(tree: &EmlTree, bounds: &[(f64, f64)]) -> Verdict {
let (_lo, hi) = crate::analyze::certified_range(tree, bounds);
if hi < 0.0 {
return Verdict::Proven;
}
if verified_witness(
&EmlConstraint::ge(tree.clone(), EmlTree::const_val(0.0)),
tree,
bounds,
|v| v >= 0.0,
) {
return Verdict::Counterexample;
}
Verdict::Unknown
}
#[must_use]
pub fn prove_no_root(tree: &EmlTree, bounds: &[(f64, f64)]) -> Verdict {
let (lo, hi) = crate::analyze::certified_range(tree, bounds);
if lo > 0.0 || hi < 0.0 {
Verdict::Proven
} else {
Verdict::Unknown
}
}
#[must_use]
pub fn prove_equivalent(a: &EmlTree, b: &EmlTree, bounds: &[(f64, f64)]) -> Verdict {
if a.root == b.root {
return Verdict::Proven;
}
let diff = Canonical::sub(a, b);
let (lo, hi) = crate::analyze::certified_range(&diff, bounds);
if lo >= -EQ_TOL && hi <= EQ_TOL {
return Verdict::Proven;
}
if verified_witness(
&EmlConstraint::gt(a.clone(), b.clone()),
&diff,
bounds,
|d| d.abs() > EQ_TOL,
) || verified_witness(
&EmlConstraint::lt(a.clone(), b.clone()),
&diff,
bounds,
|d| d.abs() > EQ_TOL,
) {
return Verdict::Counterexample;
}
Verdict::Unknown
}
#[cfg(test)]
mod tests {
use super::*;
fn exp_x0() -> EmlTree {
EmlTree::eml(&EmlTree::var(0), &EmlTree::one()) }
fn exp_x0_minus_1() -> EmlTree {
EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(std::f64::consts::E))
}
#[test]
fn proves_bounds_via_interval() {
assert_eq!(
prove_lower_bound(&exp_x0(), 1.0, &[(0.0, 1.0)]),
Verdict::Proven
);
assert_eq!(
prove_upper_bound(&exp_x0(), 3.0, &[(0.0, 1.0)]),
Verdict::Proven
);
assert_eq!(prove_positive(&exp_x0(), &[(-2.0, 2.0)]), Verdict::Proven);
}
#[test]
fn finds_verified_counterexamples() {
assert_eq!(
prove_lower_bound(&exp_x0(), 2.0, &[(0.0, 1.0)]),
Verdict::Counterexample
);
assert_eq!(
prove_upper_bound(&exp_x0(), 2.0, &[(0.0, 1.0)]),
Verdict::Counterexample
);
}
#[test]
fn no_root_is_sound_despite_const_operands() {
assert_eq!(prove_no_root(&exp_x0(), &[(-3.0, 3.0)]), Verdict::Proven);
assert_eq!(
prove_no_root(&exp_x0_minus_1(), &[(0.5, 1.0)]),
Verdict::Proven
);
assert_ne!(
prove_no_root(&exp_x0_minus_1(), &[(-1.0, 1.0)]),
Verdict::Proven
);
}
#[test]
fn proves_structural_equivalence() {
let f = exp_x0();
assert_eq!(
prove_equivalent(&f, &f, &[(0.0, 1.0)]),
Verdict::Proven,
"a law is equal to itself"
);
}
}