use rust_decimal::Decimal;
use rust_decimal::prelude::*;
#[cfg(feature = "with_serde")]
use serde_derive::{Serialize, Deserialize};
use std::{
convert::From,
ops::{Add, Sub}
};
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "with_serde", derive(Serialize, Deserialize))]
pub enum NumericUnion {
#[cfg_attr(feature = "with_serde", serde(rename = "decimal"))]
Decimal(Decimal),
#[cfg_attr(feature = "with_serde", serde(rename = "double"))]
Double(f64),
#[cfg_attr(feature = "with_serde", serde(rename = "float"))]
Float(f32),
#[cfg_attr(feature = "with_serde", serde(rename = "integer"))]
Integer(i64),
}
macro_rules! math_op {
($op:ident, $val1:ident, $val2:ident) => {
Ok(match $val1 {
NumericUnion::Decimal(val1) => {
match $val2 {
NumericUnion::Decimal(val2) => {
NumericUnion::Decimal(val1.$op(val2))
}
NumericUnion::Double(val2) => {
NumericUnion::Decimal(val1.$op(Decimal::from_f64(val2).ok_or(format!("error converting f64 to Decimal: {}", val2))?))
}
NumericUnion::Float(val2) => {
NumericUnion::Decimal(val1.$op(Decimal::from_f32(val2).ok_or(format!("error converting f32 to Decimal: {}", val2))?))
}
NumericUnion::Integer(val2) => {
NumericUnion::Decimal(val1.$op(Decimal::from_i64(val2).ok_or(format!("error converting i64 to Decimal: {}", val2))?))
}
}
}
NumericUnion::Double(val1) => {
match $val2 {
NumericUnion::Decimal(val2) => {
NumericUnion::Double(val1.$op(val2.to_f64().ok_or(format!("error converting Decimal to f64"))?))
}
NumericUnion::Double(val2) => {
NumericUnion::Double(val1.$op(val2))
}
NumericUnion::Float(val2) => {
NumericUnion::Double(val1.$op((val2 as f64)))
}
NumericUnion::Integer(val2) => {
NumericUnion::Double(val1.$op((val2 as f64)))
}
}
}
NumericUnion::Float(val1) => {
match $val2 {
NumericUnion::Decimal(val2) => {
NumericUnion::Float(val1.$op(val2.to_f32().ok_or(format!("error converting Decimal to f32"))?))
}
NumericUnion::Double(val2) => {
NumericUnion::Float(val1.$op((val2 as f32)))
}
NumericUnion::Float(val2) => {
NumericUnion::Float(val1.$op(val2))
}
NumericUnion::Integer(val2) => {
NumericUnion::Float(val1.$op((val2 as f32)))
}
}
}
NumericUnion::Integer(val1) => {
match $val2 {
NumericUnion::Decimal(val2) => {
NumericUnion::Integer(val1.$op(val2.to_i64().ok_or(format!("error converting Decimal to i64"))?))
}
NumericUnion::Double(val2) => {
NumericUnion::Integer(val1.$op((val2 as i64)))
}
NumericUnion::Float(val2) => {
NumericUnion::Integer(val1.$op((val2 as i64)))
}
NumericUnion::Integer(val2) => {
NumericUnion::Integer(val1.$op(val2))
}
}
}
})
}
}
impl NumericUnion {
pub fn add(self, val2: Self) -> Result<Self, String> {
math_op!{ add, self, val2 }
}
pub fn sub(self, val2: Self) -> Result<Self, String> {
math_op!{ sub, self, val2 }
}
pub fn is_zero(&self) -> bool {
match self {
NumericUnion::Decimal(x) => x == &Decimal::zero(),
NumericUnion::Double(x) => x == &f64::zero(),
NumericUnion::Float(x) => x == &f32::zero(),
NumericUnion::Integer(x) => x == &i64::zero(),
}
}
pub fn is_negative(&self) -> bool {
match self {
NumericUnion::Decimal(x) => x < &Decimal::zero(),
NumericUnion::Double(x) => x < &f64::zero(),
NumericUnion::Float(x) => x < &f32::zero(),
NumericUnion::Integer(x) => x < &i64::zero(),
}
}
}
impl From<Decimal> for NumericUnion {
fn from(val: Decimal) -> Self {
NumericUnion::Decimal(val)
}
}
impl From<f64> for NumericUnion {
fn from(val: f64) -> Self {
NumericUnion::Double(val)
}
}
impl From<f32> for NumericUnion {
fn from(val: f32) -> Self {
NumericUnion::Float(val)
}
}
impl From<i64> for NumericUnion {
fn from(val: i64) -> Self {
NumericUnion::Integer(val)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn numeric_union_add() {
let num1 = NumericUnion::Integer(4);
let num2 = NumericUnion::Decimal(Decimal::new(32, 0));
assert_eq!(num1.clone().add(num2.clone()).unwrap(), NumericUnion::Integer(36));
assert_eq!(num2.clone().add(num1.clone()).unwrap(), NumericUnion::Decimal(Decimal::new(360, 1)));
let num1 = NumericUnion::Double(56.2213);
let num2 = NumericUnion::Integer(42);
assert_eq!(num1.clone().add(num2.clone()).unwrap(), NumericUnion::Double(98.2213));
assert_eq!(num2.clone().add(num1.clone()).unwrap(), NumericUnion::Integer(98));
let num1 = NumericUnion::Double(56.2213);
let num2 = NumericUnion::Decimal(Decimal::new(1245, 2));
assert_eq!(num1.clone().add(num2.clone()).unwrap(), NumericUnion::Double(68.6713));
assert_eq!(num2.clone().add(num1.clone()).unwrap(), NumericUnion::Decimal(Decimal::new(686713, 4)));
}
#[test]
fn numeric_union_sub() {
let num1 = NumericUnion::Integer(4);
let num2 = NumericUnion::Decimal(Decimal::new(32, 0));
assert_eq!(num1.clone().sub(num2.clone()).unwrap(), NumericUnion::Integer(-28));
assert_eq!(num2.clone().sub(num1.clone()).unwrap(), NumericUnion::Decimal(Decimal::new(280, 1)));
let num1 = NumericUnion::Double(56.2213);
let num2 = NumericUnion::Integer(42);
assert_eq!(num1.clone().sub(num2.clone()).unwrap(), NumericUnion::Double(14.2213));
assert_eq!(num2.clone().sub(num1.clone()).unwrap(), NumericUnion::Integer(-14));
let num1 = NumericUnion::Double(56.2213);
let num2 = NumericUnion::Decimal(Decimal::new(1245, 2));
assert_eq!(num1.clone().sub(num2.clone()).unwrap(), NumericUnion::Double(43.7713));
assert_eq!(num2.clone().sub(num1.clone()).unwrap(), NumericUnion::Decimal(Decimal::new(-437713, 4)));
}
#[test]
fn numeric_union_is_zero() {
let num1 = NumericUnion::Integer(4);
let num2 = NumericUnion::Decimal(Decimal::new(32, 0));
let num3 = NumericUnion::Float(0.0);
let num4 = NumericUnion::Integer(0);
let num5 = NumericUnion::Decimal(Decimal::new(555, 0)).sub(NumericUnion::Integer(555)).unwrap();
assert!(!num1.is_zero());
assert!(!num2.is_zero());
assert!(num3.is_zero());
assert!(num4.is_zero());
assert!(num5.is_zero());
}
#[test]
fn numeric_union_is_negative() {
let num1 = NumericUnion::Integer(4);
let num2 = NumericUnion::Decimal(Decimal::new(32, 0));
let num3 = num1.clone().sub(num2.clone()).unwrap();
let num4 = num2.sub(num1).unwrap();
assert!(num3.is_negative());
assert!(!num4.is_negative());
}
}