use crate::kernel::{ExprData, ExprId, ExprPool};
use rug::{ops::Pow, Float};
use std::collections::HashMap;
use std::fmt;
pub const DEFAULT_PREC: u32 = 128;
#[derive(Clone, Debug)]
pub struct ArbBall {
pub mid: Float,
pub rad: Float,
pub prec: u32,
}
impl ArbBall {
pub fn new(prec: u32) -> Self {
ArbBall {
mid: Float::new(prec),
rad: Float::new(prec),
prec,
}
}
pub fn from_f64(v: f64, prec: u32) -> Self {
let mid = Float::with_val(prec, v);
let rad = Float::with_val(prec, 0.0);
ArbBall { mid, rad, prec }
}
pub fn from_midpoint_radius(mid: f64, rad: f64, prec: u32) -> Self {
ArbBall {
mid: Float::with_val(prec, mid),
rad: Float::with_val(prec, rad.abs()),
prec,
}
}
pub fn from_integer(n: &rug::Integer, prec: u32) -> Self {
ArbBall {
mid: Float::with_val(prec, n),
rad: Float::with_val(prec, 0.0),
prec,
}
}
pub fn from_rational(r: &rug::Rational, prec: u32) -> Self {
let mid = Float::with_val(prec, r);
let exact = Float::with_val(prec * 2, r);
let diff = Float::with_val(prec, &exact - &mid).abs();
ArbBall {
mid,
rad: diff,
prec,
}
}
pub fn infinity(prec: u32) -> Self {
let inf = Float::with_val(prec, f64::INFINITY);
ArbBall {
mid: Float::new(prec),
rad: inf,
prec,
}
}
pub fn is_exact(&self) -> bool {
self.rad == 0
}
pub fn contains(&self, v: f64) -> bool {
let v = Float::with_val(self.prec, v);
let lo = Float::with_val(self.prec, &self.mid - &self.rad);
let hi = Float::with_val(self.prec, &self.mid + &self.rad);
v >= lo && v <= hi
}
pub fn lo(&self) -> Float {
Float::with_val(self.prec, &self.mid - &self.rad)
}
pub fn hi(&self) -> Float {
Float::with_val(self.prec, &self.mid + &self.rad)
}
pub fn mid_f64(&self) -> f64 {
self.mid.to_f64()
}
pub fn rad_f64(&self) -> f64 {
self.rad.to_f64()
}
fn add_rounding_error(&mut self) {
if self.mid.is_infinite() || self.mid.is_nan() {
self.rad = Float::with_val(self.prec, f64::INFINITY);
return;
}
let scale = Float::with_val(self.prec, &self.mid).abs()
* Float::with_val(self.prec, 2.0_f64.powi(-(self.prec as i32)));
self.rad += &scale;
}
}
impl fmt::Display for ArbBall {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{:.6} ± {:.2e}]", self.mid.to_f64(), self.rad.to_f64())
}
}
impl PartialEq for ArbBall {
fn eq(&self, other: &Self) -> bool {
self.mid == other.mid && self.rad == other.rad
}
}
impl std::ops::Add for ArbBall {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let prec = self.prec.max(rhs.prec);
let mid = Float::with_val(prec, &self.mid + &rhs.mid);
let mut rad = Float::with_val(prec, &self.rad + &rhs.rad);
let eps = Float::with_val(prec, mid.abs_ref())
* Float::with_val(prec, 2.0_f64.powi(-(prec as i32)));
rad += eps;
ArbBall { mid, rad, prec }
}
}
impl std::ops::Sub for ArbBall {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
let prec = self.prec.max(rhs.prec);
let mid = Float::with_val(prec, &self.mid - &rhs.mid);
let mut rad = Float::with_val(prec, &self.rad + &rhs.rad);
let eps = Float::with_val(prec, mid.abs_ref())
* Float::with_val(prec, 2.0_f64.powi(-(prec as i32)));
rad += eps;
ArbBall { mid, rad, prec }
}
}
impl std::ops::Mul for ArbBall {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let prec = self.prec.max(rhs.prec);
let mid = Float::with_val(prec, &self.mid * &rhs.mid);
let ma = Float::with_val(prec, self.mid.abs_ref());
let mb = Float::with_val(prec, rhs.mid.abs_ref());
let mut rad = Float::with_val(prec, &ma * &rhs.rad)
+ Float::with_val(prec, &mb * &self.rad)
+ Float::with_val(prec, &self.rad * &rhs.rad);
let eps = Float::with_val(prec, mid.abs_ref())
* Float::with_val(prec, 2.0_f64.powi(-(prec as i32)));
rad += eps;
ArbBall { mid, rad, prec }
}
}
impl std::ops::Neg for ArbBall {
type Output = Self;
fn neg(self) -> Self {
ArbBall {
mid: -self.mid,
rad: self.rad,
prec: self.prec,
}
}
}
impl std::ops::Div for ArbBall {
type Output = Option<Self>;
fn div(self, rhs: Self) -> Option<Self> {
if rhs.contains(0.0) {
return None; }
let prec = self.prec.max(rhs.prec);
let lo_rhs = rhs.lo();
let hi_rhs = rhs.hi();
let corners = [
Float::with_val(prec, self.lo() / lo_rhs.clone()),
Float::with_val(prec, self.lo() / hi_rhs.clone()),
Float::with_val(prec, self.hi() / lo_rhs.clone()),
Float::with_val(prec, self.hi() / hi_rhs.clone()),
];
let min = corners
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let max = corners
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let sum = Float::with_val(prec, &min + &max);
let diff = Float::with_val(prec, &max - &min);
let new_mid = sum / 2_f64;
let rad = diff / 2_f64;
Some(ArbBall {
mid: new_mid,
rad,
prec,
})
}
}
impl ArbBall {
pub fn powi(&self, n: i64) -> Self {
if n == 0 {
return ArbBall::from_f64(1.0, self.prec);
}
if n < 0 {
let pos = self.powi(-n);
return (ArbBall::from_f64(1.0, self.prec) / pos)
.unwrap_or_else(|| ArbBall::infinity(self.prec));
}
let mut result = ArbBall::from_f64(1.0, self.prec);
let mut base = self.clone();
let mut exp = n as u64;
while exp > 0 {
if exp & 1 == 1 {
result = result * base.clone();
}
base = base.clone() * base.clone();
exp >>= 1;
}
result
}
pub fn pow_f(&self, exp: &ArbBall) -> Self {
let prec = self.prec;
let lo = self.lo();
let hi = self.hi();
if lo < 0 && !exp.is_exact() {
return ArbBall::infinity(prec); }
let corners = [
Float::with_val(prec, lo.clone().pow(exp.lo())),
Float::with_val(prec, lo.clone().pow(exp.hi())),
Float::with_val(prec, hi.clone().pow(exp.lo())),
Float::with_val(prec, hi.clone().pow(exp.hi())),
];
let min = corners
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let max = corners
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let sum = Float::with_val(prec, &min + &max);
let diff = Float::with_val(prec, &max - &min);
let new_mid = sum / 2_f64;
let rad = diff / 2_f64;
ArbBall {
mid: new_mid,
rad,
prec,
}
}
pub fn sin(&self) -> Self {
let prec = self.prec;
let mid = Float::with_val(prec, self.mid.clone().sin());
let rad = self.rad.clone();
let mut b = ArbBall { mid, rad, prec };
b.add_rounding_error();
b
}
pub fn cos(&self) -> Self {
let prec = self.prec;
let mid = Float::with_val(prec, self.mid.clone().cos());
let rad = self.rad.clone();
let mut b = ArbBall { mid, rad, prec };
b.add_rounding_error();
b
}
pub fn exp(&self) -> Self {
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().exp());
let hi = Float::with_val(prec, self.hi().exp());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
pub fn log(&self) -> Option<Self> {
if self.lo() <= 0 {
return None; }
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().ln());
let hi = Float::with_val(prec, self.hi().ln());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
Some(ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
})
}
pub fn sqrt(&self) -> Option<Self> {
if self.lo() < 0 {
return None;
}
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().sqrt());
let hi = Float::with_val(prec, self.hi().sqrt());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
Some(ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
})
}
pub fn tan(&self) -> Option<Self> {
let prec = self.prec;
let _pi_half = Float::with_val(prec, rug::float::Constant::Pi) / 2_f64;
let lo = self.lo();
let hi = self.hi();
let lo_f = lo.to_f64();
let hi_f = hi.to_f64();
let pi_f: f64 = std::f64::consts::PI;
let near_pole = |v: f64| ((v % pi_f).abs() - pi_f / 2.0).abs() < 1e-9;
if near_pole(lo_f) || near_pole(hi_f) {
return None;
}
let lo_tan = Float::with_val(prec, lo.tan());
let hi_tan = Float::with_val(prec, hi.tan());
if lo_tan > hi_tan {
return None;
}
let sum = Float::with_val(prec, &lo_tan + &hi_tan);
let diff = Float::with_val(prec, &hi_tan - &lo_tan);
Some(ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
})
}
pub fn sinh(&self) -> Self {
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().sinh());
let hi = Float::with_val(prec, self.hi().sinh());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
pub fn cosh(&self) -> Self {
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().cosh());
let hi = Float::with_val(prec, self.hi().cosh());
let (min_val, max_val) = if self.lo() <= 0 && self.hi() >= 0 {
let cosh_lo = lo.clone();
let cosh_hi = hi.clone();
let min = Float::with_val(prec, 1_f64);
let max = if cosh_lo > cosh_hi { cosh_lo } else { cosh_hi };
(min, max)
} else if lo < hi {
(lo, hi)
} else {
(hi, lo)
};
let sum = Float::with_val(prec, &min_val + &max_val);
let diff = Float::with_val(prec, &max_val - &min_val);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
pub fn tanh(&self) -> Self {
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().tanh());
let hi = Float::with_val(prec, self.hi().tanh());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
pub fn asin(&self) -> Option<Self> {
if self.lo() < -1 || self.hi() > 1 {
return None;
}
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().asin());
let hi = Float::with_val(prec, self.hi().asin());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
Some(ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
})
}
pub fn acos(&self) -> Option<Self> {
if self.lo() < -1 || self.hi() > 1 {
return None;
}
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().acos());
let hi = Float::with_val(prec, self.hi().acos());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &lo - &hi);
Some(ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
})
}
pub fn atan(&self) -> Self {
let prec = self.prec;
let lo = Float::with_val(prec, self.lo().atan());
let hi = Float::with_val(prec, self.hi().atan());
let sum = Float::with_val(prec, &lo + &hi);
let diff = Float::with_val(prec, &hi - &lo);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
pub fn erf(&self) -> Self {
let prec = self.prec;
let mid = Float::with_val(prec, self.mid.clone().erf());
let lipschitz = Float::with_val(prec, 2.0_f64 / std::f64::consts::PI.sqrt());
let rad = Float::with_val(prec, &self.rad * &lipschitz);
let mut b = ArbBall { mid, rad, prec };
b.add_rounding_error();
b
}
pub fn erfc(&self) -> Self {
let prec = self.prec;
let mid = Float::with_val(prec, self.mid.clone().erfc());
let lipschitz = Float::with_val(prec, 2.0_f64 / std::f64::consts::PI.sqrt());
let rad = Float::with_val(prec, &self.rad * &lipschitz);
let mut b = ArbBall { mid, rad, prec };
b.add_rounding_error();
b
}
pub fn abs_ball(&self) -> Self {
let prec = self.prec;
if self.lo() <= 0 && self.hi() >= 0 {
let max_abs = self.lo().abs().max(&self.hi().abs()).clone();
ArbBall {
mid: max_abs.clone() / 2_f64,
rad: max_abs / 2_f64,
prec,
}
} else {
let mid = Float::with_val(prec, self.mid.clone().abs());
let rad = self.rad.clone();
let mut b = ArbBall { mid, rad, prec };
b.add_rounding_error();
b
}
}
pub fn floor_ball(&self) -> Self {
let prec = self.prec;
let lo_floor = Float::with_val(prec, self.lo().floor());
let hi_floor = Float::with_val(prec, self.hi().floor());
let diff = Float::with_val(prec, &hi_floor - &lo_floor);
let sum = Float::with_val(prec, &lo_floor + &hi_floor);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
pub fn ceil_ball(&self) -> Self {
let prec = self.prec;
let lo_ceil = Float::with_val(prec, self.lo().ceil());
let hi_ceil = Float::with_val(prec, self.hi().ceil());
let diff = Float::with_val(prec, &hi_ceil - &lo_ceil);
let sum = Float::with_val(prec, &lo_ceil + &hi_ceil);
ArbBall {
mid: sum / 2_f64,
rad: diff / 2_f64,
prec,
}
}
}
#[derive(Clone, Debug)]
pub struct AcbBall {
pub re: ArbBall,
pub im: ArbBall,
}
impl AcbBall {
pub fn from_real(re: ArbBall) -> Self {
let prec = re.prec;
AcbBall {
re,
im: ArbBall::new(prec),
}
}
pub fn from_f64(re: f64, im: f64, prec: u32) -> Self {
AcbBall {
re: ArbBall::from_f64(re, prec),
im: ArbBall::from_f64(im, prec),
}
}
pub fn modulus(&self) -> ArbBall {
let re2 = self.re.clone() * self.re.clone();
let im2 = self.im.clone() * self.im.clone();
let sum = re2 + im2;
sum.sqrt()
.unwrap_or_else(|| ArbBall::infinity(self.re.prec))
}
}
impl fmt::Display for AcbBall {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} + {}·i", self.re, self.im)
}
}
pub struct IntervalEval {
bindings: HashMap<ExprId, ArbBall>,
pub prec: u32,
}
impl IntervalEval {
pub fn new(prec: u32) -> Self {
IntervalEval {
bindings: HashMap::new(),
prec,
}
}
pub fn bind(&mut self, var: ExprId, ball: ArbBall) {
self.bindings.insert(var, ball);
}
pub fn eval(&self, expr: ExprId, pool: &ExprPool) -> Option<ArbBall> {
self.eval_node(expr, pool)
}
fn eval_node(&self, expr: ExprId, pool: &ExprPool) -> Option<ArbBall> {
match pool.get(expr) {
ExprData::Integer(n) => Some(ArbBall::from_integer(&n.0, self.prec)),
ExprData::Rational(r) => Some(ArbBall::from_rational(&r.0, self.prec)),
ExprData::Float(f) => Some(ArbBall::from_f64(f.inner.to_f64(), self.prec)),
ExprData::Symbol { .. } => self.bindings.get(&expr).cloned(),
ExprData::Add(args) => {
let mut acc = ArbBall::from_f64(0.0, self.prec);
for &a in &args {
acc = acc + self.eval_node(a, pool)?;
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc = ArbBall::from_f64(1.0, self.prec);
for &a in &args {
acc = acc * self.eval_node(a, pool)?;
}
Some(acc)
}
ExprData::Pow { base, exp } => {
let b = self.eval_node(base, pool)?;
let e = self.eval_node(exp, pool)?;
if let ExprData::Integer(n) = pool.get(exp) {
let nv = n.0.to_i64()?;
return Some(b.powi(nv));
}
Some(b.pow_f(&e))
}
ExprData::Func { name, args } if args.len() == 1 => {
let x = self.eval_node(args[0], pool)?;
match name.as_str() {
"sin" => Some(x.sin()),
"cos" => Some(x.cos()),
"exp" => Some(x.exp()),
"log" => x.log(),
"sqrt" => x.sqrt(),
_ => None,
}
}
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn ball_contains_midpoint() {
let b = ArbBall::from_midpoint_radius(3.0, 0.5, 64);
assert!(b.contains(3.0));
assert!(b.contains(2.5));
assert!(b.contains(3.5));
assert!(!b.contains(4.0));
}
#[test]
fn ball_add_enclosure() {
let a = ArbBall::from_midpoint_radius(1.0, 0.1, 64);
let b = ArbBall::from_midpoint_radius(2.0, 0.2, 64);
let c = a + b;
assert!(c.contains(2.7));
assert!(c.contains(3.0));
assert!(c.contains(3.3));
}
#[test]
fn ball_mul_enclosure() {
let a = ArbBall::from_midpoint_radius(2.0, 0.5, 64); let b = ArbBall::from_midpoint_radius(3.0, 0.5, 64); let c = a * b;
assert!(c.contains(4.0));
assert!(c.contains(8.0));
}
#[test]
fn ball_powi_exact() {
let b = ArbBall::from_f64(3.0, 128);
let b3 = b.powi(3);
assert!(b3.contains(27.0));
assert!(!b3.contains(26.0));
}
#[test]
fn ball_sin_enclosure() {
let pi_2 = std::f64::consts::PI / 2.0;
let b = ArbBall::from_midpoint_radius(pi_2, 0.01, 128);
let s = b.sin();
assert!(s.contains(1.0));
}
#[test]
fn ball_exp_enclosure() {
let b = ArbBall::from_midpoint_radius(0.0, 0.1, 128); let e = b.exp();
assert!(e.contains(0.905));
assert!(e.contains(1.0));
assert!(e.contains(1.105));
}
#[test]
fn ball_log_enclosure() {
let b = ArbBall::from_midpoint_radius(2.0, 0.5, 128); let l = b.log().unwrap();
assert!(l.contains(0.41));
assert!(l.contains(0.91));
assert!(l.contains(2_f64.ln()));
}
#[test]
fn ball_log_fails_at_nonpositive() {
let b = ArbBall::from_midpoint_radius(0.0, 0.5, 128); assert!(b.log().is_none());
}
#[test]
fn interval_eval_constant() {
let pool = p();
let five = pool.integer(5_i32);
let eval = IntervalEval::new(128);
let r = eval.eval(five, &pool).unwrap();
assert!(r.contains(5.0));
}
#[test]
fn interval_eval_polynomial() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let one = pool.integer(1_i32);
let expr = pool.add(vec![x2, one]);
let x_ball = ArbBall::from_midpoint_radius(3.0, 0.1, 128);
let mut eval = IntervalEval::new(128);
eval.bind(x, x_ball);
let r = eval.eval(expr, &pool).unwrap();
assert!(r.contains(9.5));
assert!(r.contains(10.0));
assert!(r.contains(10.5));
}
#[test]
fn interval_eval_unbound_is_none() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let eval = IntervalEval::new(128);
assert!(eval.eval(x, &pool).is_none());
}
#[test]
fn interval_eval_rational() {
let pool = p();
let third = pool.rational(1, 3);
let eval = IntervalEval::new(128);
let r = eval.eval(third, &pool).unwrap();
let mid = r.mid_f64();
assert!((mid - 1.0 / 3.0).abs() < 1e-15, "mid={mid}");
assert!(r.rad_f64() < 1e-30, "rad={}", r.rad_f64());
}
#[test]
fn acb_modulus() {
let z = AcbBall::from_f64(3.0, 4.0, 128);
let m = z.modulus();
assert!(m.contains(5.0));
}
}