use std::cmp::Ordering;
use crate::{ApproxEq, ApproxEqZero, ApproxOrd};
const SIGN_MASK: u64 = 0x8000_0000_0000_0000;
const EXPONENT_MASK: u64 = 0x7ff0_0000_0000_0000;
const MANTISSA_BITS: u32 = f64::MANTISSA_DIGITS - 1;
pub const MAX_ABSOLUTE: i32 = 1 - f64::MIN_EXP;
pub const MIN_ABSOLUTE: i32 = -f64::MAX_EXP;
pub const MAX_RELATIVE: u32 = MANTISSA_BITS;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct Precision {
min_exponent: u32,
mantissa_bits: u32,
}
impl Default for Precision {
fn default() -> Self {
Self::new_simple(26)
}
}
impl Precision {
pub const DEFAULT: Self = Self::new_simple(26);
pub const fn new_simple(parameter: u32) -> Self {
let absolute = if parameter > i32::MAX as u32 {
i32::MAX
} else {
parameter as i32
};
let relative = parameter;
Self::new(absolute, relative)
}
pub const fn absolute(absolute: i32) -> Self {
Self::new(absolute, MAX_RELATIVE)
}
pub const fn relative(relative: u32) -> Self {
Self::new(MAX_ABSOLUTE, relative)
}
pub const fn new(absolute: i32, relative: u32) -> Self {
let min_exponent = match absolute {
..MIN_ABSOLUTE => (MAX_ABSOLUTE - MIN_ABSOLUTE) as u32 + 1,
MAX_ABSOLUTE.. => 1,
_ => MAX_ABSOLUTE.saturating_sub(absolute) as u32 + 1,
};
Self {
min_exponent,
mantissa_bits: relative,
}
}
pub(crate) fn f32_eq(self, a: f32, b: f32) -> bool {
self.f64_eq(a as f64, b as f64)
}
pub(crate) fn f64_eq(self, a: f64, b: f64) -> bool {
if a == b {
return true;
}
let a_exp = f64_exponent(a);
let b_exp = f64_exponent(b);
if a_exp < self.min_exponent && b_exp < self.min_exponent {
return true; }
if (a_exp > self.min_exponent || b_exp > self.min_exponent) && a_exp.abs_diff(b_exp) > 1 {
return false;
}
let a_bucket = self.bucket(a);
let b_bucket = self.bucket(b);
if a_bucket == b_bucket {
return true;
}
let (a_lo, a_mid, a_hi) = self.nearby_buckets(a);
a_mid == b_bucket || a_lo == Some(b_bucket) || a_hi == Some(b_bucket)
}
pub(crate) fn f64_eq_zero(self, n: f64) -> bool {
let exp = f64_exponent(n);
exp < self.min_exponent
|| (exp == self.min_exponent && self.buckets_near_zero().contains(&self.bucket(n)))
}
fn bucket_mask(self, f: f64) -> u64 {
match f.classify() {
std::num::FpCategory::Nan | std::num::FpCategory::Infinite => u64::MAX,
std::num::FpCategory::Zero | std::num::FpCategory::Subnormal => 0,
std::num::FpCategory::Normal => {
let exponent = f64_exponent(f);
let Some(spare_mantissa_bits) = exponent.checked_sub(self.min_exponent) else {
return 0; };
let mantissa_bits_to_keep = MANTISSA_BITS
.min(spare_mantissa_bits)
.min(self.mantissa_bits);
let mantissa_bits_to_drop = MANTISSA_BITS - mantissa_bits_to_keep;
u64::MAX << mantissa_bits_to_drop
}
}
}
pub(crate) fn bucket(self, f: f64) -> u64 {
f.to_bits() & self.bucket_mask(f)
}
pub(crate) fn nearby_buckets(self, f: f64) -> (Option<u64>, u64, Option<u64>) {
let bucket_mask = self.bucket_mask(f);
let bucket_containing = f.to_bits() & bucket_mask;
if f.is_nan() {
(None, bucket_containing, None)
} else if f.is_infinite() {
if f.is_sign_positive() {
(Some(self.bucket(f64::MAX)), bucket_containing, None)
} else {
(None, bucket_containing, Some(self.bucket(f64::MIN)))
}
} else if bucket_mask == 0 {
let [lo, mid, hi] = self.buckets_near_zero();
(Some(lo), mid, Some(hi))
} else {
let closest_to_zero = f64::from_bits(bucket_containing);
let farthest_from_zero = f64::from_bits(bucket_containing | !bucket_mask);
let [lowest_in_bucket, highest_in_bucket] = if f.is_sign_positive() {
[closest_to_zero, farthest_from_zero]
} else {
[farthest_from_zero, closest_to_zero]
};
(
Some(self.bucket(lowest_in_bucket.next_down())),
bucket_containing,
Some(self.bucket(highest_in_bucket.next_up())),
)
}
}
fn buckets_near_zero(self) -> [u64; 3] {
let bucket_above = (self.min_exponent as u64) << MANTISSA_BITS;
let bucket_below = SIGN_MASK | bucket_above;
[bucket_below, 0, bucket_above]
}
pub fn eq<T: ApproxEq>(self, a: T, b: T) -> bool {
a.approx_eq(&b, self)
}
pub fn cmp<T: ApproxOrd>(self, a: T, b: T) -> Ordering {
a.approx_cmp(&b, self)
}
pub fn lt<T: ApproxOrd>(self, a: T, b: T) -> bool {
self.cmp(a, b) == Ordering::Less
}
pub fn gt<T: ApproxOrd>(self, a: T, b: T) -> bool {
self.cmp(a, b) == Ordering::Less
}
pub fn lt_eq<T: ApproxOrd>(self, a: T, b: T) -> bool {
!self.gt(a, b)
}
pub fn gt_eq<T: ApproxOrd>(self, a: T, b: T) -> bool {
!self.lt(a, b)
}
pub fn eq_zero<T: ApproxEqZero>(self, a: T) -> bool {
a.approx_eq_zero(self)
}
}
fn f64_exponent(f: f64) -> u32 {
((f.to_bits() & EXPONENT_MASK) >> MANTISSA_BITS) as u32
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[cfg(test)]
impl Arbitrary for Precision {
type Parameters = ();
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(MIN_ABSOLUTE..=MAX_ABSOLUTE, 0..=MAX_RELATIVE)
.prop_map(|(abs, rel)| Precision::new(abs, rel))
.boxed()
}
type Strategy = BoxedStrategy<Self>;
}
#[proptest_macro::property_test]
fn proptest_bucketing_idempotent(f: f64, absolute: i32, relative: u32) {
let prec = Precision::new(absolute, relative);
let b = prec.bucket(f);
assert_eq!(b, prec.bucket(f64::from_bits(b)));
}
#[proptest_macro::property_test]
fn proptest_nearby_buckets(f: f64, absolute: i32, relative: u32) {
let prec = Precision::new(absolute, relative);
let (lo, mid, hi) = prec.nearby_buckets(f);
assert_ne!(lo, Some(mid));
assert_ne!(hi, Some(mid));
if f.is_nan() {
assert_eq!(lo, None);
assert_eq!(hi, None);
} else if f.is_infinite() {
assert!(Option::xor(lo, hi).is_some());
} else {
assert!(lo.is_some());
assert!(hi.is_some());
}
if let Some(lo) = lo {
let (_, should_be_lo, should_be_mid) = prec.nearby_buckets(f64::from_bits(lo));
assert_eq!(should_be_lo, lo);
assert_eq!(should_be_mid, Some(mid));
let between_lo_and_mid =
prec.bucket(f64::midpoint(f64::from_bits(lo), f64::from_bits(mid)));
assert!(between_lo_and_mid == lo || between_lo_and_mid == mid);
}
if let Some(hi) = hi {
let (should_be_mid, should_be_hi, _) = prec.nearby_buckets(f64::from_bits(hi));
assert_eq!(should_be_hi, hi);
assert_eq!(should_be_mid, Some(mid));
let between_hi_and_mid =
prec.bucket(f64::midpoint(f64::from_bits(hi), f64::from_bits(mid)));
assert!(between_hi_and_mid == hi || between_hi_and_mid == mid);
}
}
#[proptest_macro::property_test]
fn proptest_buckets(f: f64, absolute: i32, relative: u32) {
let prec = Precision::new(absolute, relative);
if f.is_infinite() {
let inf = match f.is_sign_positive() {
true => f64::INFINITY,
false => f64::NEG_INFINITY,
};
assert_eq!(prec.bucket(f), inf.to_bits());
} else if f.is_nan() {
assert_eq!(prec.bucket(f), f.to_bits());
} else {
let absolute_bucket_log = -(absolute.clamp(MIN_ABSOLUTE, MAX_ABSOLUTE) as f64);
let relative_bucket_log =
f.abs().log2().floor() - relative.clamp(0, MAX_RELATIVE) as f64;
let bucket_size = f64::max(absolute_bucket_log, relative_bucket_log).exp2();
let mut expected_bucket = (f / bucket_size).trunc() * bucket_size;
if bucket_size == 0.0 || bucket_size.is_infinite() || expected_bucket == 0.0 {
expected_bucket = 0.0
}
assert_eq!(prec.bucket(f), expected_bucket.to_bits());
}
}
#[test]
fn test_round_to_integer() {
let prec = Precision::new(0, MAX_RELATIVE); assert_eq!(prec.bucket(5.99), 5.0_f64.to_bits());
assert_eq!(prec.bucket(-5.99), (-5.0_f64).to_bits());
assert_eq!(prec.bucket(999.99), 999.0_f64.to_bits());
assert_eq!(prec.bucket(-999.99), (-999.0_f64).to_bits());
let second_largest_f64_integer = 9007199254740992.0;
assert_eq!(
prec.bucket(second_largest_f64_integer),
second_largest_f64_integer.to_bits(),
);
}
#[proptest_macro::property_test]
fn proptest_every_float_gets_its_own_bucket_proptest(f: f64) {
let prec = Precision::new(MAX_ABSOLUTE, MAX_RELATIVE);
if f.is_subnormal() {
assert_eq!(prec.bucket(f), 0);
} else if f64_exponent(f) > MANTISSA_BITS {
assert_eq!(prec.bucket(f), f.to_bits())
}
}
#[proptest_macro::property_test]
fn proptest_sign_symmetry(f: f64, prec: Precision) {
let bucket = prec.bucket(f);
if bucket == 0 {
assert_eq!(bucket, prec.bucket(-f));
} else {
assert_eq!(bucket ^ SIGN_MASK, prec.bucket(-f));
assert_eq!((bucket & SIGN_MASK != 0), f.is_sign_negative());
}
}
#[test]
fn test_min_absolute_limit() {
let prec = Precision::new(MIN_ABSOLUTE + 1, MAX_RELATIVE);
assert_ne!(prec.bucket(f64::MAX), 0);
assert_ne!(prec.bucket(-f64::MAX), 0);
let prec = Precision::new(MIN_ABSOLUTE, MAX_RELATIVE);
assert_eq!(prec.bucket(f64::MAX), 0);
assert_eq!(prec.bucket(-f64::MAX), 0);
let prec = Precision::new(MIN_ABSOLUTE - 1, MAX_RELATIVE);
assert_eq!(prec.bucket(f64::MAX), 0);
assert_eq!(prec.bucket(-f64::MAX), 0);
}
#[test]
fn test_max_absolute_limit() {
let f = f64::MIN_POSITIVE;
let prec = Precision::new(MAX_ABSOLUTE, MAX_RELATIVE);
assert_ne!(0, prec.bucket(f));
assert_eq!(prec.bucket(f), prec.bucket(f * 1.5));
assert_ne!(prec.bucket(f), prec.bucket(f * 2.0));
assert_eq!(prec.buckets_near_zero(), [(-f).to_bits(), 0, f.to_bits()]);
let prec = Precision::new(MAX_ABSOLUTE - 1, MAX_RELATIVE);
assert_eq!(0, prec.bucket(f));
assert_eq!(prec.bucket(f), prec.bucket(f * 1.5));
assert_ne!(prec.bucket(f), prec.bucket(f * 2.0));
assert_eq!(
prec.buckets_near_zero(),
[(f * -2.0).to_bits(), 0, (f * 2.0).to_bits()],
);
}
#[proptest_macro::property_test]
fn proptest_test_max_absolute_limit(f: f64, relative: u32) {
let prec1 = Precision::new(MAX_ABSOLUTE, relative);
let prec2 = Precision::new(MAX_ABSOLUTE + 1, relative);
assert_eq!(prec1.bucket(f), prec2.bucket(f));
}
#[test]
fn test_max_relative_limit() {
let prec1 = Precision::new(MAX_ABSOLUTE, MAX_RELATIVE - 1);
let prec2 = Precision::new(MAX_ABSOLUTE, MAX_RELATIVE);
let f = 1.0_f64.next_down();
assert_ne!(prec1.bucket(f), prec2.bucket(f));
}
#[proptest_macro::property_test]
fn proptest_max_relative_limit(f: f64, absolute: i32) {
let prec1 = Precision::new(absolute, MAX_RELATIVE);
let prec2 = Precision::new(absolute, MAX_RELATIVE + 1);
assert_eq!(prec1.bucket(f), prec2.bucket(f));
}
#[proptest_macro::property_test]
fn proptest_symmetric_eq(a: f64, b: f64, prec: Precision) {
assert_eq!(prec.f64_eq(a, b), prec.f64_eq(b, a))
}
#[proptest_macro::property_test]
fn proptest_eq_optimizations(a: f64, b: f64, prec: Precision) {
let b_bucket = prec.bucket(b);
let (a_lo, a_mid, a_hi) = prec.nearby_buckets(a);
let expected_eq = a_mid == b_bucket || a_lo == Some(b_bucket) || a_hi == Some(b_bucket);
assert_eq!(prec.f64_eq(a, b), expected_eq);
}
#[proptest_macro::property_test]
fn proptest_eq_zero(f: f64, prec: Precision) {
assert_eq!(prec.f64_eq_zero(f), prec.f64_eq(0.0, f));
}
}