extern crate num_traits;
use std::fmt::Debug;
use num_traits::{cast::cast, Float};
#[derive(Clone, Debug)]
pub enum Method {
Weak,
Strong,
Average,
Asymmetric,
}
impl Method {
#[inline]
fn common_checks<T: Float>(a: T, b: T) -> Option<bool> {
if a == b {
Some(true)
} else if !a.is_finite() || !b.is_finite() {
Some(false)
} else {
None
}
}
fn method<'a, T: Float + 'a>(&self, rel_tol: T, abs_tol: T) -> Box<dyn Fn(T, T) -> bool + 'a> {
match self {
Method::Asymmetric => Box::new(move |a, b| match Self::common_checks(a, b) {
Some(result) => result,
None => {
let diff = Float::abs(a - b);
(diff <= Float::abs(rel_tol * b)) || (diff <= abs_tol)
}
}),
Method::Average => Box::new(move |a, b| match Self::common_checks(a, b) {
Some(result) => result,
None => {
let diff = Float::abs(a - b);
diff <= (rel_tol * (a + b) / cast(2.0).unwrap()).abs()
|| (diff <= abs_tol)
}
}),
Method::Strong => Box::new(move |a, b| match Self::common_checks(a, b) {
Some(result) => result,
None => {
let diff = Float::abs(a - b);
((diff <= Float::abs(rel_tol * b)) && (diff <= Float::abs(rel_tol * a)))
|| (diff <= abs_tol)
}
}),
Method::Weak => Box::new(move |a, b| match Self::common_checks(a, b) {
Some(result) => result,
None => {
let diff = Float::abs(a - b);
((diff <= Float::abs(rel_tol * b)) || (diff <= Float::abs(rel_tol * a)))
|| (diff <= abs_tol)
}
}),
}
}
}
pub const ASYMMETRIC: Method = Method::Asymmetric;
pub const AVERAGE: Method = Method::Average;
pub const STRONG: Method = Method::Strong;
pub const WEAK: Method = Method::Weak;
impl From<&str> for Method {
fn from(s: &str) -> Self {
match s.to_lowercase().as_ref() {
"asymmetric" => Self::Asymmetric,
"average" => Self::Average,
"strong" => Self::Strong,
"weak" => Self::Weak,
_ => panic!("unknown method {:?}", s),
}
}
}
pub const DEFAULT_REL_TOL: f64 = 1e-8;
pub const DEFAULT_ABS_TOL: f64 = 0.0;
pub struct Comparator<'a, T: Float> {
is_close: Box<dyn Fn(T, T) -> bool + 'a>,
}
impl<T: Float> Comparator<'_, T> {
pub fn is_close(&self, a: T, b: T) -> bool {
(self.is_close)(a, b)
}
pub fn all_close<I, J>(&self, a: I, b: J) -> bool
where
I: IntoIterator<Item=T>,
J: IntoIterator<Item=T>,
{
a.into_iter()
.zip(b.into_iter())
.all(|(x, y)| self.is_close(x, y))
}
pub fn any_close<I, J>(&self, a: I, b: J) -> bool
where
I: IntoIterator<Item=T>,
J: IntoIterator<Item=T>,
{
a.into_iter()
.zip(b.into_iter())
.any(|(x, y)| self.is_close(x, y))
}
}
#[derive(Clone, Debug)]
pub struct ComparatorBuilder<T: Float> {
rel_tol: T,
abs_tol: T,
method: Method,
}
impl<T: Float> Default for ComparatorBuilder<T> {
fn default() -> Self {
ComparatorBuilder {
rel_tol: cast(DEFAULT_REL_TOL).unwrap(),
abs_tol: cast(DEFAULT_ABS_TOL).unwrap(),
method: WEAK,
}
}
}
impl<T: Float> ComparatorBuilder<T> {
pub fn rel_tol(&mut self, value: T) -> &mut Self {
self.rel_tol = value.abs();
self
}
pub fn abs_tol(&mut self, value: T) -> &mut Self {
self.abs_tol = value.abs();
self
}
pub fn method<M: Into<Method>>(&mut self, method: M) -> &mut Self {
self.method = method.into();
self
}
}
impl<'a, T: Float + 'a> ComparatorBuilder<T> {
pub fn compile(&self) -> Comparator<'a, T> {
Comparator {
is_close: self.method.method(self.rel_tol, self.abs_tol),
}
}
pub fn is_close(&self, a: T, b: T) -> bool {
self.compile().is_close(a, b)
}
pub fn all_close<I, J>(&self, a: I, b: J) -> bool
where
I: IntoIterator<Item=T>,
J: IntoIterator<Item=T>,
{
self.compile().all_close(a, b)
}
pub fn any_close<I, J>(&self, a: I, b: J) -> bool
where
I: IntoIterator<Item=T>,
J: IntoIterator<Item=T>,
{
self.compile().any_close(a, b)
}
}
pub fn default<T: Float>() -> ComparatorBuilder<T> {
ComparatorBuilder::default()
}
#[macro_export]
macro_rules! is_close {
($a:expr, $b:expr $(, $set:ident = $val:expr)*) => {
{
$crate::default()$(.$set($val))*.is_close($a, $b)
}
};
}
#[macro_export]
macro_rules! all_close {
($a:expr, $b:expr $(, $set:ident = $val:expr)*) => {
{
$crate::default()$(.$set($val))*.all_close($a, $b)
}
};
}
#[macro_export]
macro_rules! any_close {
($a:expr, $b:expr $(, $set:ident = $val:expr)*) => {
{
$crate::default()$(.$set($val))*.any_close($a, $b)
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_debug() {
assert_eq!(
"ComparatorBuilder { rel_tol: 0.00000001, abs_tol: 0.0, method: Weak }",
format!("{:?}", default::<f64>())
)
}
#[test]
fn test_exact() {
for (a, b) in &[
(2.0, 2.0),
(0.1e200, 0.1e200),
(1.123e-300, 1.123e-300),
(0.0, -0.0),
] {
assert!(default().rel_tol(0.0).abs_tol(0.0).is_close(*a, *b));
assert!(is_close!(*a, *b, abs_tol = 0.0));
}
}
#[test]
fn test_relative() {
for (a, b) in &[
(1e8, 1e8 + 1.),
(-1e-8, -1.000000009e-8),
(1.12345678, 1.12345679),
] {
assert!(default().rel_tol(1e-8).is_close(*a, *b));
assert!(is_close!(*a, *b, rel_tol = 1e-8));
assert!(!default().rel_tol(1e-9).is_close(*a, *b));
assert!(!is_close!(*a, *b, rel_tol = 1e-9));
}
}
#[test]
fn test_zero() {
for (a, b) in &[(1e-9, 0.0), (-1e-9, 0.0), (-1e-150, 0.0)] {
assert!(default().abs_tol(1e-8).is_close(*a, *b));
assert!(is_close!(*a, *b, abs_tol = 1e-8));
assert!(!default().rel_tol(0.9).is_close(*a, *b));
assert!(!is_close!(*a, *b, rel_tol = 0.9));
}
}
#[test]
fn test_non_finite() {
for (a, b) in &[
(f64::INFINITY, f64::INFINITY),
(f64::NEG_INFINITY, f64::NEG_INFINITY),
] {
assert!(default().abs_tol(0.999999999999999).is_close(*a, *b));
assert!(is_close!(*a, *b, abs_tol = 0.999999999999999));
}
for (a, b) in &[
(f64::NAN, f64::NAN),
(f64::NAN, 1e-100),
(1e-100, f64::NAN),
(f64::INFINITY, f64::NAN),
(f64::NAN, f64::INFINITY),
(f64::INFINITY, f64::NEG_INFINITY),
(f64::INFINITY, 1.0),
(1.0, f64::INFINITY),
] {
assert!(!default().abs_tol(0.999999999999999).is_close(*a, *b));
assert!(!is_close!(*a, *b, abs_tol = 0.999999999999999));
}
}
#[test]
fn test_other_methods() {
assert!(default().method("weak").rel_tol(1e-1).is_close(9.0, 10.0));
assert!(default().method("weak").rel_tol(1e-1).is_close(10.0, 9.0));
assert!(!default().method(WEAK).rel_tol(1e-2).is_close(9.0, 10.0));
assert!(!default().method(WEAK).rel_tol(1e-2).is_close(10.0, 9.0));
assert!(all_close!(
vec![9.0, 10.0],
vec![10.0, 9.0],
rel_tol = 2e-1,
method = "STRONG"
));
assert!(!any_close!(
vec![9.0, 10.0],
vec![10.0, 9.0],
rel_tol = 1e-1,
method = STRONG
));
assert!(is_close!(9.0, 10.0, rel_tol = 2e-1, method = "average"));
assert!(is_close!(10.0, 9.0, rel_tol = 2e-1, method = "average"));
assert!(!is_close!(9.0, 10.0, rel_tol = 1e-1, method = AVERAGE));
assert!(!is_close!(10.0, 9.0, rel_tol = 1e-1, method = AVERAGE));
let ic = default().method(ASYMMETRIC).rel_tol(1e-1).compile();
assert!(ic.is_close(9.0, 10.0));
assert!(!ic.is_close(10.0, 9.0));
}
#[test]
#[should_panic(expected = "unknown method \"fnord\"")]
fn test_unknown_method() {
default::<f64>().method("fnord");
}
#[test]
fn test_all_close() {
assert!(default().all_close(vec![0.0, 1.0, 2.0], (0..3).into_iter().map(|i| i as f64)));
assert!(all_close!(
vec![0.0, 1.0, 2.0],
(0..3).into_iter().map(|i| i as f64)
));
assert!(!default().all_close(vec![0.0, 1.0, 3.0], (0..3).into_iter().map(|i| i as f64)));
assert!(!all_close!(
vec![0.0, 1.0, 3.0],
(0..3).into_iter().map(|i| i as f64)
));
}
}