use crate::error::EmlError;
use crate::eval::EvalCtx;
use crate::tree::{EmlNode, EmlTree};
#[derive(Clone, Debug)]
pub enum EmlConstraint {
EqZero(EmlTree),
GtZero(EmlTree),
GeZero(EmlTree),
And(Vec<EmlConstraint>),
Or(Vec<EmlConstraint>),
}
#[derive(Clone, Debug)]
pub struct EmlSolution {
pub assignments: Vec<f64>,
pub is_exact: bool,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Interval {
pub lo: f64,
pub hi: f64,
}
impl Interval {
pub fn new(lo: f64, hi: f64) -> Self {
Self { lo, hi }
}
pub fn is_empty(&self) -> bool {
self.lo > self.hi || self.lo.is_nan() || self.hi.is_nan()
}
pub fn width(&self) -> f64 {
self.hi - self.lo
}
pub fn midpoint(&self) -> f64 {
(self.lo + self.hi) / 2.0
}
pub fn contains(&self, x: f64) -> bool {
x >= self.lo && x <= self.hi
}
pub fn split(&self) -> (Self, Self) {
let mid = self.midpoint();
(Self::new(self.lo, mid), Self::new(mid, self.hi))
}
pub fn intersect(&self, other: &Self) -> Self {
Self::new(self.lo.max(other.lo), self.hi.min(other.hi))
}
pub fn hull(&self, other: &Self) -> Self {
Self::new(self.lo.min(other.lo), self.hi.max(other.hi))
}
pub fn exp(&self) -> Self {
Self::new(self.lo.exp(), self.hi.exp())
}
pub fn ln(&self) -> Self {
if self.lo <= 0.0 || !self.lo.is_finite() || !self.hi.is_finite() {
Self::new(f64::INFINITY, f64::NEG_INFINITY)
} else {
Self::new(self.lo.ln(), self.hi.ln())
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum PropResult {
Changed,
Stable,
Conflict,
}
#[derive(Clone, Debug)]
pub struct IntervalDomain {
pub vars: Vec<Interval>,
}
impl IntervalDomain {
pub fn new(bounds: &[(f64, f64)], num_vars: usize) -> Self {
let vars = (0..num_vars)
.map(|i| {
if i < bounds.len() {
Interval::new(bounds[i].0, bounds[i].1)
} else {
Interval::new(-10.0, 10.0)
}
})
.collect();
Self { vars }
}
pub fn is_empty(&self) -> bool {
self.vars.iter().any(Interval::is_empty)
}
pub fn propagate(&mut self, c: &EmlConstraint) -> PropResult {
const MAX_ITERATIONS: usize = 20;
let mut changed_any = false;
for _ in 0..MAX_ITERATIONS {
let result = propagate_once(&mut self.vars, c);
match result {
PropResult::Conflict => return PropResult::Conflict,
PropResult::Changed => {
changed_any = true;
continue;
}
PropResult::Stable => break,
}
}
if changed_any {
PropResult::Changed
} else {
PropResult::Stable
}
}
}
fn eval_interval(node: &EmlNode, vars: &[Interval]) -> Interval {
match node {
EmlNode::One => Interval::new(1.0, 1.0),
EmlNode::Var(i) => vars
.get(*i)
.copied()
.unwrap_or_else(|| Interval::new(-10.0, 10.0)),
EmlNode::Eml { left, right } => {
let l = eval_interval(left, vars);
let r = eval_interval(right, vars);
if l.is_empty() || r.is_empty() {
return Interval::new(f64::INFINITY, f64::NEG_INFINITY);
}
let exp_l = l.exp();
let ln_r = r.ln();
if ln_r.is_empty() {
return Interval::new(f64::INFINITY, f64::NEG_INFINITY);
}
Interval::new(exp_l.lo - ln_r.hi, exp_l.hi - ln_r.lo)
}
}
}
fn propagate_once(vars: &mut [Interval], c: &EmlConstraint) -> PropResult {
match c {
EmlConstraint::EqZero(tree) => {
let v = eval_interval(&tree.root, vars);
if v.is_empty() {
return PropResult::Conflict;
}
if v.lo > 0.0 || v.hi < 0.0 {
return PropResult::Conflict;
}
PropResult::Stable
}
EmlConstraint::GtZero(tree) => {
let v = eval_interval(&tree.root, vars);
if v.is_empty() {
return PropResult::Conflict;
}
if v.hi <= 0.0 {
return PropResult::Conflict;
}
PropResult::Stable
}
EmlConstraint::GeZero(tree) => {
let v = eval_interval(&tree.root, vars);
if v.is_empty() {
return PropResult::Conflict;
}
if v.hi < 0.0 {
return PropResult::Conflict;
}
PropResult::Stable
}
EmlConstraint::And(constraints) => {
let mut any_changed = false;
for inner in constraints {
match propagate_once(vars, inner) {
PropResult::Conflict => return PropResult::Conflict,
PropResult::Changed => any_changed = true,
PropResult::Stable => {}
}
}
if any_changed {
PropResult::Changed
} else {
PropResult::Stable
}
}
EmlConstraint::Or(constraints) => {
if constraints.is_empty() {
return PropResult::Conflict;
}
let mut all_conflict = true;
for inner in constraints {
if !matches!(propagate_once(vars, inner), PropResult::Conflict) {
all_conflict = false;
break;
}
}
if all_conflict {
PropResult::Conflict
} else {
PropResult::Stable
}
}
}
}
pub struct EmlNraSolver {
pub max_iterations: usize,
pub tolerance: f64,
pub initial_bounds: Vec<(f64, f64)>,
}
impl Default for EmlNraSolver {
fn default() -> Self {
Self {
max_iterations: 10_000,
tolerance: 1e-8,
initial_bounds: vec![(-10.0, 10.0)],
}
}
}
impl EmlNraSolver {
pub fn new(initial_bounds: Vec<(f64, f64)>) -> Self {
Self {
initial_bounds,
..Default::default()
}
}
pub fn solve(&self, constraint: &EmlConstraint) -> Result<EmlSolution, EmlError> {
let num_vars = count_constraint_vars(constraint);
if num_vars == 0 {
let ctx = EvalCtx::new(&[]);
if check_constraint(constraint, &ctx) {
return Ok(EmlSolution {
assignments: vec![],
is_exact: true,
});
}
return Err(EmlError::ConvergenceFailed {
best_mse: f64::INFINITY,
iterations: 0,
});
}
let mut domain = IntervalDomain::new(&self.initial_bounds, num_vars);
if matches!(domain.propagate(constraint), PropResult::Conflict) {
return Err(EmlError::ConvergenceFailed {
best_mse: f64::INFINITY,
iterations: 0,
});
}
for _iter in 0..self.max_iterations {
let midpoints: Vec<f64> = domain.vars.iter().map(Interval::midpoint).collect();
let ctx = EvalCtx::new(&midpoints);
if check_constraint(constraint, &ctx) {
return Ok(EmlSolution {
assignments: midpoints,
is_exact: domain.vars.iter().all(|iv| iv.width() < self.tolerance),
});
}
let widest_idx = domain
.vars
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.width()
.partial_cmp(&b.width())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i);
if let Some(widest) = widest_idx {
let (lo_half, hi_half) = domain.vars[widest].split();
let mut lo_point = midpoints.clone();
lo_point[widest] = lo_half.midpoint();
let lo_resid = evaluate_constraint_residual(constraint, &lo_point);
let mut hi_point = midpoints;
hi_point[widest] = hi_half.midpoint();
let hi_resid = evaluate_constraint_residual(constraint, &hi_point);
domain.vars[widest] = if lo_resid.abs() < hi_resid.abs() {
lo_half
} else {
hi_half
};
if matches!(domain.propagate(constraint), PropResult::Conflict) {
return Err(EmlError::ConvergenceFailed {
best_mse: f64::INFINITY,
iterations: 0,
});
}
if domain.vars.iter().all(|iv| iv.width() < self.tolerance) {
break;
}
}
}
let midpoints: Vec<f64> = domain.vars.iter().map(Interval::midpoint).collect();
let ctx = EvalCtx::new(&midpoints);
if check_constraint(constraint, &ctx) {
Ok(EmlSolution {
assignments: midpoints,
is_exact: false,
})
} else {
Err(EmlError::ConvergenceFailed {
best_mse: evaluate_constraint_residual(constraint, &midpoints).abs(),
iterations: self.max_iterations,
})
}
}
}
#[cfg(feature = "smt")]
#[derive(Clone, Debug)]
pub enum SmtResult {
Sat(EmlSolution),
Unsat,
Unknown,
}
#[cfg(feature = "smt")]
pub struct EmlSmtSolver {
pub bounds: Vec<(f64, f64)>,
pub relaxation_samples: usize,
}
#[cfg(feature = "smt")]
impl Default for EmlSmtSolver {
fn default() -> Self {
Self {
bounds: vec![(-10.0, 10.0)],
relaxation_samples: 3,
}
}
}
#[cfg(feature = "smt")]
impl EmlSmtSolver {
pub fn new(bounds: Vec<(f64, f64)>) -> Self {
Self {
bounds,
relaxation_samples: 3,
}
}
pub fn check_sat(&self, c: &EmlConstraint) -> Result<SmtResult, EmlError> {
let num_vars = count_constraint_vars(c);
if num_vars == 0 {
let ctx = EvalCtx::new(&[]);
return Ok(if check_constraint(c, &ctx) {
SmtResult::Sat(EmlSolution {
assignments: vec![],
is_exact: true,
})
} else {
SmtResult::Unsat
});
}
let mut domain = IntervalDomain::new(&self.bounds, num_vars);
if matches!(domain.propagate(c), PropResult::Conflict) {
return Ok(SmtResult::Unsat);
}
match oxiz_check(c, &domain, self.relaxation_samples) {
OxizVerdict::Unsat => Ok(SmtResult::Unsat),
OxizVerdict::Sat | OxizVerdict::Unknown => {
let tight_bounds: Vec<(f64, f64)> =
domain.vars.iter().map(|iv| (iv.lo, iv.hi)).collect();
let bisect = EmlNraSolver::new(tight_bounds);
match bisect.solve(c) {
Ok(sol) => Ok(SmtResult::Sat(sol)),
Err(_) => Ok(SmtResult::Unknown),
}
}
}
}
}
#[cfg(feature = "smt")]
#[derive(Clone, Copy)]
enum OxizVerdict {
Sat,
Unsat,
Unknown,
}
#[cfg(feature = "smt")]
fn oxiz_check(c: &EmlConstraint, domain: &IntervalDomain, samples: usize) -> OxizVerdict {
use oxiz::{Solver, SolverResult, TermId, TermManager};
let mut tm = TermManager::new();
let mut solver = Solver::new();
let real_sort = tm.sorts.real_sort;
let mut var_terms: Vec<TermId> = Vec::with_capacity(domain.vars.len());
for i in 0..domain.vars.len() {
let term = tm.mk_var(&format!("x{i}"), real_sort);
var_terms.push(term);
}
for (i, iv) in domain.vars.iter().enumerate() {
let Some(lo) = float_to_term(&mut tm, iv.lo) else {
return OxizVerdict::Unknown;
};
let Some(hi) = float_to_term(&mut tm, iv.hi) else {
return OxizVerdict::Unknown;
};
let ge = tm.mk_ge(var_terms[i], lo);
let le = tm.mk_le(var_terms[i], hi);
solver.assert(ge, &mut tm);
solver.assert(le, &mut tm);
}
let mut counter: usize = 0;
let encoded = match encode_constraint(
c,
&var_terms,
domain,
samples,
&mut counter,
&mut tm,
&mut solver,
) {
Some(t) => t,
None => return OxizVerdict::Unknown,
};
solver.assert(encoded, &mut tm);
match solver.check(&mut tm) {
SolverResult::Sat => OxizVerdict::Sat,
SolverResult::Unsat => OxizVerdict::Unsat,
SolverResult::Unknown => OxizVerdict::Unknown,
}
}
#[cfg(feature = "smt")]
fn float_to_term(tm: &mut oxiz::TermManager, v: f64) -> Option<oxiz::TermId> {
use num_rational::Rational64;
if !v.is_finite() {
return None;
}
const DENOM_CAP: i64 = 1_000_000;
const VALUE_CAP: f64 = 1.0e12;
if v.abs() > VALUE_CAP {
return None;
}
let scaled = (v * DENOM_CAP as f64).round();
if !scaled.is_finite() || scaled.abs() > (i64::MAX as f64) / 4.0 {
return None;
}
let num = scaled as i64;
let r = Rational64::new(num, DENOM_CAP);
Some(tm.mk_real(r))
}
#[cfg(feature = "smt")]
fn encode_constraint(
c: &EmlConstraint,
var_terms: &[oxiz::TermId],
domain: &IntervalDomain,
samples: usize,
counter: &mut usize,
tm: &mut oxiz::TermManager,
solver: &mut oxiz::Solver,
) -> Option<oxiz::TermId> {
match c {
EmlConstraint::EqZero(tree) => {
let (term, _) =
encode_tree(&tree.root, var_terms, domain, samples, counter, tm, solver)?;
let zero = float_to_term(tm, 0.0)?;
Some(tm.mk_eq(term, zero))
}
EmlConstraint::GtZero(tree) => {
let (term, _) =
encode_tree(&tree.root, var_terms, domain, samples, counter, tm, solver)?;
let zero = float_to_term(tm, 0.0)?;
Some(tm.mk_gt(term, zero))
}
EmlConstraint::GeZero(tree) => {
let (term, _) =
encode_tree(&tree.root, var_terms, domain, samples, counter, tm, solver)?;
let zero = float_to_term(tm, 0.0)?;
Some(tm.mk_ge(term, zero))
}
EmlConstraint::And(cs) => {
if cs.is_empty() {
return Some(tm.mk_true());
}
let mut encoded: Vec<oxiz::TermId> = Vec::with_capacity(cs.len());
for inner in cs {
let t = encode_constraint(inner, var_terms, domain, samples, counter, tm, solver)?;
encoded.push(t);
}
Some(tm.mk_and(encoded))
}
EmlConstraint::Or(cs) => {
if cs.is_empty() {
return Some(tm.mk_false());
}
let mut encoded: Vec<oxiz::TermId> = Vec::with_capacity(cs.len());
for inner in cs {
let t = encode_constraint(inner, var_terms, domain, samples, counter, tm, solver)?;
encoded.push(t);
}
Some(tm.mk_or(encoded))
}
}
}
#[cfg(feature = "smt")]
fn encode_tree(
node: &EmlNode,
var_terms: &[oxiz::TermId],
domain: &IntervalDomain,
samples: usize,
counter: &mut usize,
tm: &mut oxiz::TermManager,
solver: &mut oxiz::Solver,
) -> Option<(oxiz::TermId, Interval)> {
match node {
EmlNode::One => {
let one = float_to_term(tm, 1.0)?;
Some((one, Interval::new(1.0, 1.0)))
}
EmlNode::Var(i) => {
let iv = *domain.vars.get(*i)?;
let term = *var_terms.get(*i)?;
Some((term, iv))
}
EmlNode::Eml { left, right } => {
let (l_term, l_iv) =
encode_tree(left, var_terms, domain, samples, counter, tm, solver)?;
let (r_term, r_iv) =
encode_tree(right, var_terms, domain, samples, counter, tm, solver)?;
if !l_iv.lo.is_finite() || !l_iv.hi.is_finite() {
return None;
}
if !r_iv.lo.is_finite() || !r_iv.hi.is_finite() || r_iv.lo <= 0.0 {
return None;
}
if l_iv.is_empty() || r_iv.is_empty() {
return None;
}
let real_sort = tm.sorts.real_sort;
*counter += 1;
let ex = tm.mk_var(&format!("__ex_{counter}"), real_sort);
*counter += 1;
let ln_term = tm.mk_var(&format!("__ln_{counter}"), real_sort);
let exp_lo = float_to_term(tm, l_iv.lo.exp())?;
let exp_hi = float_to_term(tm, l_iv.hi.exp())?;
let t1 = tm.mk_ge(ex, exp_lo);
let t2 = tm.mk_le(ex, exp_hi);
solver.assert(t1, tm);
solver.assert(t2, tm);
if l_iv.width() > 0.0 {
let slope = (l_iv.hi.exp() - l_iv.lo.exp()) / l_iv.width();
let slope_term = float_to_term(tm, slope)?;
let lo_const = float_to_term(tm, l_iv.lo)?;
let diff = tm.mk_sub(l_term, lo_const);
let prod = tm.mk_mul([slope_term, diff]);
let rhs = tm.mk_add([exp_lo, prod]);
let sec_ub = tm.mk_le(ex, rhs);
solver.assert(sec_ub, tm);
let n = samples.max(1);
for k in 0..n {
let t = if n == 1 {
l_iv.midpoint()
} else {
l_iv.lo + (l_iv.width() * k as f64) / ((n - 1) as f64)
};
let exp_t = t.exp();
let exp_t_term = float_to_term(tm, exp_t)?;
let t_const = float_to_term(tm, t)?;
let l_minus_t = tm.mk_sub(l_term, t_const);
let slope_tangent = tm.mk_mul([exp_t_term, l_minus_t]);
let rhs = tm.mk_add([exp_t_term, slope_tangent]);
let tan_lb = tm.mk_ge(ex, rhs);
solver.assert(tan_lb, tm);
}
}
let ln_lo = float_to_term(tm, r_iv.lo.ln())?;
let ln_hi = float_to_term(tm, r_iv.hi.ln())?;
let t3 = tm.mk_ge(ln_term, ln_lo);
let t4 = tm.mk_le(ln_term, ln_hi);
solver.assert(t3, tm);
solver.assert(t4, tm);
if r_iv.width() > 0.0 {
let slope = (r_iv.hi.ln() - r_iv.lo.ln()) / r_iv.width();
let slope_term = float_to_term(tm, slope)?;
let lo_const = float_to_term(tm, r_iv.lo)?;
let diff = tm.mk_sub(r_term, lo_const);
let prod = tm.mk_mul([slope_term, diff]);
let rhs = tm.mk_add([ln_lo, prod]);
let sec_lb = tm.mk_ge(ln_term, rhs);
solver.assert(sec_lb, tm);
let n = samples.max(1);
for k in 0..n {
let t = if n == 1 {
r_iv.midpoint()
} else {
r_iv.lo + (r_iv.width() * k as f64) / ((n - 1) as f64)
};
if t <= 0.0 {
continue;
}
let inv_t = 1.0 / t;
let ln_t = t.ln();
let ln_t_term = float_to_term(tm, ln_t)?;
let inv_t_term = float_to_term(tm, inv_t)?;
let t_const = float_to_term(tm, t)?;
let r_minus_t = tm.mk_sub(r_term, t_const);
let slope_tangent = tm.mk_mul([inv_t_term, r_minus_t]);
let rhs = tm.mk_add([ln_t_term, slope_tangent]);
let tan_ub = tm.mk_le(ln_term, rhs);
solver.assert(tan_ub, tm);
}
}
let result = tm.mk_sub(ex, ln_term);
let result_iv =
Interval::new(l_iv.lo.exp() - r_iv.hi.ln(), l_iv.hi.exp() - r_iv.lo.ln());
Some((result, result_iv))
}
}
}
fn check_constraint(constraint: &EmlConstraint, ctx: &EvalCtx) -> bool {
match constraint {
EmlConstraint::EqZero(tree) => tree.eval_real(ctx).is_ok_and(|v| v.abs() < 1e-8),
EmlConstraint::GtZero(tree) => tree.eval_real(ctx).is_ok_and(|v| v > 0.0),
EmlConstraint::GeZero(tree) => tree.eval_real(ctx).is_ok_and(|v| v >= -1e-12),
EmlConstraint::And(constraints) => constraints.iter().all(|c| check_constraint(c, ctx)),
EmlConstraint::Or(constraints) => constraints.iter().any(|c| check_constraint(c, ctx)),
}
}
fn evaluate_constraint_residual(constraint: &EmlConstraint, vars: &[f64]) -> f64 {
let ctx = EvalCtx::new(vars);
match constraint {
EmlConstraint::EqZero(tree) => tree.eval_real(&ctx).unwrap_or(f64::INFINITY),
EmlConstraint::GtZero(tree) => {
let v = tree.eval_real(&ctx).unwrap_or(f64::NEG_INFINITY);
if v > 0.0 { 0.0 } else { -v }
}
EmlConstraint::GeZero(tree) => {
let v = tree.eval_real(&ctx).unwrap_or(f64::NEG_INFINITY);
if v >= 0.0 { 0.0 } else { -v }
}
EmlConstraint::And(constraints) => constraints
.iter()
.map(|c| evaluate_constraint_residual(c, vars))
.map(f64::abs)
.sum(),
EmlConstraint::Or(constraints) => constraints
.iter()
.map(|c| evaluate_constraint_residual(c, vars))
.map(f64::abs)
.fold(f64::INFINITY, f64::min),
}
}
fn count_constraint_vars(constraint: &EmlConstraint) -> usize {
match constraint {
EmlConstraint::EqZero(tree) | EmlConstraint::GtZero(tree) | EmlConstraint::GeZero(tree) => {
tree.num_vars()
}
EmlConstraint::And(cs) | EmlConstraint::Or(cs) => {
cs.iter().map(count_constraint_vars).max().unwrap_or(0)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonical::Canonical;
#[test]
fn test_eq_zero_trivial() {
let one = EmlTree::one();
let ln_one = Canonical::ln(&one);
let constraint = EmlConstraint::EqZero(ln_one);
let solver = EmlNraSolver::default();
assert!(solver.solve(&constraint).is_ok());
}
#[test]
fn test_gt_zero() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let constraint = EmlConstraint::GtZero(exp_x);
let solver = EmlNraSolver::new(vec![(-5.0, 5.0)]);
assert!(solver.solve(&constraint).is_ok());
}
#[test]
fn test_and_constraint() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let x_tree = EmlTree::var(0);
let exp_x = EmlTree::eml(&x, &one);
let e_minus_x = EmlTree::eml(&one, &exp_x); let constraint = EmlConstraint::And(vec![
EmlConstraint::GtZero(x_tree),
EmlConstraint::GtZero(e_minus_x),
]);
let solver = EmlNraSolver::new(vec![(0.1, 2.5)]);
let result = solver.solve(&constraint).expect("expected Sat solution");
let x_val = result.assignments[0];
assert!(x_val > 0.0 && x_val < std::f64::consts::E);
}
#[test]
fn test_interval_exp_forward() {
let iv = Interval::new(0.0, 2.0);
let exp_iv = iv.exp();
assert!((exp_iv.lo - 1.0).abs() < 1e-12);
assert!((exp_iv.hi - 2.0_f64.exp()).abs() < 1e-12);
}
#[test]
fn test_interval_ln_forward() {
let iv = Interval::new(1.0, std::f64::consts::E);
let ln_iv = iv.ln();
assert!(ln_iv.lo.abs() < 1e-12);
assert!((ln_iv.hi - 1.0).abs() < 1e-12);
}
#[test]
fn test_interval_ln_negative_empty() {
let iv = Interval::new(-1.0, 1.0);
let ln_iv = iv.ln();
assert!(ln_iv.is_empty());
}
#[test]
fn test_interval_intersect_and_hull() {
let a = Interval::new(0.0, 2.0);
let b = Interval::new(1.0, 3.0);
let inter = a.intersect(&b);
assert!((inter.lo - 1.0).abs() < 1e-12);
assert!((inter.hi - 2.0).abs() < 1e-12);
let hull = a.hull(&b);
assert!((hull.lo - 0.0).abs() < 1e-12);
assert!((hull.hi - 3.0).abs() < 1e-12);
}
#[test]
fn test_interval_disjoint_intersect_empty() {
let a = Interval::new(0.0, 1.0);
let b = Interval::new(2.0, 3.0);
assert!(a.intersect(&b).is_empty());
}
#[test]
fn test_propagate_exp_positivity_conflict() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let c = EmlConstraint::EqZero(exp_x);
let mut domain = IntervalDomain::new(&[(-5.0, 5.0)], 1);
assert_eq!(domain.propagate(&c), PropResult::Conflict);
}
}
#[cfg(all(test, feature = "smt"))]
mod smt_tests {
use super::*;
use crate::canonical::Canonical;
#[test]
fn test_smt_sat_exp_positive() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let c = EmlConstraint::GtZero(exp_x);
let solver = EmlSmtSolver::new(vec![(-3.0, 3.0)]);
match solver.check_sat(&c).expect("check_sat error") {
SmtResult::Sat(_) => {}
other => panic!("expected Sat, got {other:?}"),
}
}
#[test]
fn test_smt_ln_bracket() {
let x = EmlTree::var(0);
let ln_x = Canonical::ln(&x);
let gt = EmlConstraint::GtZero(ln_x);
let solver = EmlSmtSolver::new(vec![(1.1, 5.0)]);
assert!(matches!(
solver.check_sat(>).expect("check_sat error"),
SmtResult::Sat(_)
));
}
#[test]
fn test_smt_unsat_ln_of_negative() {
let x = EmlTree::var(0);
let ln_x = Canonical::ln(&x);
let c = EmlConstraint::GtZero(ln_x);
let solver = EmlSmtSolver::new(vec![(-2.0, -1.0)]);
assert!(matches!(
solver.check_sat(&c).expect("check_sat error"),
SmtResult::Unsat
));
}
#[test]
fn test_smt_witness_verifies() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let c = EmlConstraint::GtZero(exp_x);
let solver = EmlSmtSolver::new(vec![(-1.0, 1.0)]);
match solver.check_sat(&c).expect("check_sat error") {
SmtResult::Sat(sol) => {
let ctx = crate::eval::EvalCtx::new(&sol.assignments);
assert!(check_constraint(&c, &ctx));
}
other => panic!("expected Sat, got {other:?}"),
}
}
#[test]
fn test_smt_constant_true() {
let one = EmlTree::one();
let ln_one = Canonical::ln(&one);
let c = EmlConstraint::EqZero(ln_one);
let solver = EmlSmtSolver::default();
assert!(matches!(
solver.check_sat(&c).expect("check_sat error"),
SmtResult::Sat(_)
));
}
}