use std::cmp::Ordering;
use std::fmt;
#[derive(Debug, Clone, Copy)]
pub enum NumberValue {
Integer(i64),
Float(f64),
}
impl NumberValue {
#[inline]
pub fn from_i64(value: i64) -> Self {
NumberValue::Integer(value)
}
#[inline]
pub fn from_f64(value: f64) -> Self {
if value.fract() == 0.0
&& !value.is_nan()
&& !value.is_infinite()
&& value >= i64::MIN as f64
&& value <= i64::MAX as f64
{
NumberValue::Integer(value as i64)
} else {
NumberValue::Float(value)
}
}
#[inline]
pub fn is_integer(&self) -> bool {
matches!(self, NumberValue::Integer(_))
}
#[inline]
pub fn as_i64(&self) -> Option<i64> {
match *self {
NumberValue::Integer(i) => Some(i),
NumberValue::Float(f) => {
if f.fract() == 0.0
&& !f.is_nan()
&& !f.is_infinite()
&& f >= i64::MIN as f64
&& f <= i64::MAX as f64
{
Some(f as i64)
} else {
None
}
}
}
}
#[inline]
pub fn as_f64(&self) -> f64 {
match *self {
NumberValue::Integer(i) => i as f64,
NumberValue::Float(f) => f,
}
}
#[inline]
pub fn is_zero(&self) -> bool {
match *self {
NumberValue::Integer(i) => i == 0,
NumberValue::Float(f) => f == 0.0,
}
}
#[inline]
pub fn is_nan(&self) -> bool {
matches!(*self, NumberValue::Float(f) if f.is_nan())
}
pub fn add(&self, other: &NumberValue) -> NumberValue {
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => match a.checked_add(b) {
Some(r) => NumberValue::Integer(r),
None => NumberValue::Float(a as f64 + b as f64),
},
_ => NumberValue::from_f64(self.as_f64() + other.as_f64()),
}
}
pub fn sub(&self, other: &NumberValue) -> NumberValue {
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => match a.checked_sub(b) {
Some(r) => NumberValue::Integer(r),
None => NumberValue::Float(a as f64 - b as f64),
},
_ => NumberValue::from_f64(self.as_f64() - other.as_f64()),
}
}
pub fn mul(&self, other: &NumberValue) -> NumberValue {
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => match a.checked_mul(b) {
Some(r) => NumberValue::Integer(r),
None => NumberValue::Float(a as f64 * b as f64),
},
_ => NumberValue::from_f64(self.as_f64() * other.as_f64()),
}
}
pub fn div(&self, other: &NumberValue) -> Option<NumberValue> {
if other.is_zero() {
return None;
}
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => {
if a == i64::MIN && b == -1 {
return Some(NumberValue::Float(-(i64::MIN as f64)));
}
if a % b == 0 {
Some(NumberValue::Integer(a / b))
} else {
Some(NumberValue::Float(a as f64 / b as f64))
}
}
_ => Some(NumberValue::from_f64(self.as_f64() / other.as_f64())),
}
}
pub fn rem(&self, other: &NumberValue) -> Option<NumberValue> {
if other.is_zero() {
return None;
}
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => {
if a == i64::MIN && b == -1 {
return Some(NumberValue::Integer(0));
}
Some(NumberValue::Integer(a % b))
}
_ => Some(NumberValue::from_f64(self.as_f64() % other.as_f64())),
}
}
pub fn neg(&self) -> NumberValue {
match *self {
NumberValue::Integer(i) => match i.checked_neg() {
Some(r) => NumberValue::Integer(r),
None => NumberValue::Float(-(i as f64)),
},
NumberValue::Float(f) => NumberValue::Float(-f),
}
}
pub fn abs(&self) -> NumberValue {
match *self {
NumberValue::Integer(i) => match i.checked_abs() {
Some(r) => NumberValue::Integer(r),
None => NumberValue::Float((i as f64).abs()),
},
NumberValue::Float(f) => NumberValue::Float(f.abs()),
}
}
}
impl PartialEq for NumberValue {
#[inline]
fn eq(&self, other: &Self) -> bool {
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => a == b,
(NumberValue::Float(a), NumberValue::Float(b)) => a == b,
(NumberValue::Integer(a), NumberValue::Float(b)) => (a as f64) == b,
(NumberValue::Float(a), NumberValue::Integer(b)) => a == (b as f64),
}
}
}
impl PartialOrd for NumberValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (*self, *other) {
(NumberValue::Integer(a), NumberValue::Integer(b)) => Some(a.cmp(&b)),
(NumberValue::Float(a), NumberValue::Float(b)) => a.partial_cmp(&b),
(NumberValue::Integer(a), NumberValue::Float(b)) => (a as f64).partial_cmp(&b),
(NumberValue::Float(a), NumberValue::Integer(b)) => a.partial_cmp(&(b as f64)),
}
}
}
impl fmt::Display for NumberValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
NumberValue::Integer(i) => write!(f, "{}", i),
NumberValue::Float(fl) => {
if fl.is_nan() || fl.is_infinite() {
write!(f, "null")
} else if fl.fract() == 0.0 {
write!(f, "{}.0", fl as i64)
} else {
write!(f, "{}", fl)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_f64_collapses_whole() {
assert!(matches!(
NumberValue::from_f64(42.0),
NumberValue::Integer(42)
));
assert!(matches!(
NumberValue::from_f64(-3.0),
NumberValue::Integer(-3)
));
assert!(matches!(NumberValue::from_f64(1.5), NumberValue::Float(_)));
}
#[test]
fn from_f64_rejects_nan_inf_for_int_path() {
assert!(matches!(
NumberValue::from_f64(f64::NAN),
NumberValue::Float(_)
));
assert!(matches!(
NumberValue::from_f64(f64::INFINITY),
NumberValue::Float(_)
));
}
#[test]
fn add_overflow_falls_to_float() {
let a = NumberValue::Integer(i64::MAX);
let b = NumberValue::Integer(1);
assert!(matches!(a.add(&b), NumberValue::Float(_)));
}
#[test]
fn add_no_overflow_stays_int() {
let a = NumberValue::Integer(2);
let b = NumberValue::Integer(3);
assert!(matches!(a.add(&b), NumberValue::Integer(5)));
}
#[test]
fn div_zero_returns_none() {
let a = NumberValue::Integer(1);
let z = NumberValue::Integer(0);
assert!(a.div(&z).is_none());
let zf = NumberValue::Float(0.0);
assert!(a.div(&zf).is_none());
}
#[test]
fn div_int_int_exact_stays_int() {
let a = NumberValue::Integer(10);
let b = NumberValue::Integer(2);
assert!(matches!(a.div(&b).unwrap(), NumberValue::Integer(5)));
}
#[test]
fn div_int_int_inexact_promotes_float() {
let a = NumberValue::Integer(7);
let b = NumberValue::Integer(2);
assert!(matches!(a.div(&b).unwrap(), NumberValue::Float(_)));
}
#[test]
fn cross_type_eq_and_ord() {
let i = NumberValue::Integer(5);
let f = NumberValue::Float(5.0);
assert_eq!(i, f);
assert_eq!(i.partial_cmp(&f), Some(Ordering::Equal));
let f2 = NumberValue::Float(5.5);
assert_eq!(i.partial_cmp(&f2), Some(Ordering::Less));
}
#[test]
fn neg_overflow_falls_to_float() {
let a = NumberValue::Integer(i64::MIN);
assert!(matches!(a.neg(), NumberValue::Float(_)));
}
}