use crate::error::decimal::DecimalError;
use crate::geometrics::HasX;
use num_traits::{FromPrimitive, ToPrimitive};
use rand::distr::Distribution;
use rand_distr::Normal;
use rust_decimal::{Decimal, MathematicalOps, RoundingStrategy};
use rust_decimal_macros::dec;
pub const ONE_DAY: Decimal = dec!(0.00396825397);
#[macro_export]
macro_rules! assert_decimal_eq {
($left:expr, $right:expr, $epsilon:expr) => {
let diff = ($left - $right).abs();
assert!(
diff <= $epsilon,
"assertion failed: `(left == right)`\n left: `{}`\n right: `{}`\n diff: `{}`\n epsilon: `{}`",
$left,
$right,
diff,
$epsilon
);
};
}
pub trait DecimalStats {
fn mean(&self) -> Decimal;
fn std_dev(&self) -> Decimal;
}
impl DecimalStats for Vec<Decimal> {
fn mean(&self) -> Decimal {
if self.is_empty() {
return Decimal::ZERO;
}
let sum: Decimal = self.iter().sum();
sum / Decimal::from(self.len())
}
fn std_dev(&self) -> Decimal {
if self.len() < 2usize {
return Decimal::ZERO;
}
let mean = self.mean();
let variance: Decimal = self.iter().map(|x| (x - mean).powd(Decimal::TWO)).sum();
(variance / Decimal::from(self.len() - 1))
.sqrt()
.unwrap_or(Decimal::ZERO)
}
}
pub fn decimal_to_f64(value: Decimal) -> Result<f64, DecimalError> {
value.to_f64().ok_or(DecimalError::ConversionError {
from_type: format!("Decimal: {value}"),
to_type: "f64".to_string(),
reason: "Failed to convert Decimal to f64".to_string(),
})
}
pub fn f64_to_decimal(value: f64) -> Result<Decimal, DecimalError> {
Decimal::from_f64(value).ok_or(DecimalError::ConversionError {
from_type: format!("f64: {value}"),
to_type: "Decimal".to_string(),
reason: "Failed to convert f64 to Decimal".to_string(),
})
}
#[must_use]
#[inline]
pub(crate) fn finite_decimal(value: f64) -> Option<Decimal> {
if value.is_finite() {
Decimal::from_f64(value)
} else {
None
}
}
#[must_use]
pub fn decimal_normal_sample() -> Decimal {
let mut t_rng = rand::rng();
let normal = match Normal::new(0.0, 1.0) {
Ok(n) => n,
Err(_) => unreachable!("standard normal parameters are always valid"),
};
Decimal::from_f64(normal.sample(&mut t_rng)).unwrap_or(Decimal::ZERO)
}
impl HasX for Decimal {
fn get_x(&self) -> Decimal {
*self
}
}
pub(crate) const DIV_DEFAULT_SCALE: u32 = 28;
#[inline]
pub(crate) fn d_add(lhs: Decimal, rhs: Decimal, op: &'static str) -> Result<Decimal, DecimalError> {
lhs.checked_add(rhs)
.ok_or_else(|| DecimalError::overflow(op, lhs, rhs))
}
#[inline]
pub(crate) fn d_sum(values: &[Decimal], op: &'static str) -> Result<Decimal, DecimalError> {
d_sum_iter(values.iter().copied(), op)
}
#[inline]
pub(crate) fn d_sum_iter<I>(iter: I, op: &'static str) -> Result<Decimal, DecimalError>
where
I: IntoIterator<Item = Decimal>,
{
let mut acc = Decimal::ZERO;
for v in iter {
acc = acc
.checked_add(v)
.ok_or_else(|| DecimalError::overflow(op, acc, v))?;
}
Ok(acc)
}
#[inline]
pub(crate) fn d_sub(lhs: Decimal, rhs: Decimal, op: &'static str) -> Result<Decimal, DecimalError> {
lhs.checked_sub(rhs)
.ok_or_else(|| DecimalError::overflow(op, lhs, rhs))
}
#[inline]
pub(crate) fn d_mul(lhs: Decimal, rhs: Decimal, op: &'static str) -> Result<Decimal, DecimalError> {
lhs.checked_mul(rhs)
.ok_or_else(|| DecimalError::overflow(op, lhs, rhs))
}
#[inline]
pub(crate) fn d_div(lhs: Decimal, rhs: Decimal, op: &'static str) -> Result<Decimal, DecimalError> {
if rhs.is_zero() {
return Err(DecimalError::arithmetic_error(op, "division by zero"));
}
let raw = lhs
.checked_div(rhs)
.ok_or_else(|| DecimalError::overflow(op, lhs, rhs))?;
Ok(raw.round_dp_with_strategy(DIV_DEFAULT_SCALE, RoundingStrategy::MidpointNearestEven))
}
#[macro_export]
macro_rules! d2fu {
($val:expr) => {
$crate::model::decimal::decimal_to_f64($val)
};
}
#[macro_export]
macro_rules! d2f {
($val:expr) => {
$crate::model::decimal::decimal_to_f64($val)?
};
}
#[macro_export]
macro_rules! nz {
($val:expr) => {{
::std::num::NonZeroUsize::new($val)
.unwrap_or_else(|| panic!("nz!({}) must be non-zero", stringify!($val)))
}};
}
#[macro_export]
macro_rules! f2du {
($val:expr) => {
$crate::model::decimal::f64_to_decimal($val)
};
}
#[macro_export]
macro_rules! f2d {
($val:expr) => {
$crate::model::decimal::f64_to_decimal($val)?
};
}
#[cfg(test)]
pub mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_f64_to_decimal_valid() {
let value = 42.42;
let result = f64_to_decimal(value);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Decimal::from_str("42.42").unwrap());
}
#[test]
fn test_f64_to_decimal_zero() {
let value = 0.0;
let result = f64_to_decimal(value);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Decimal::from_str("0").unwrap());
}
#[test]
fn test_decimal_to_f64_valid() {
let decimal = Decimal::from_str("42.42").unwrap();
let result = decimal_to_f64(decimal);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42.42);
}
#[test]
fn test_decimal_to_f64_zero() {
let decimal = Decimal::from_str("0").unwrap();
let result = decimal_to_f64(decimal);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.0);
}
}
#[cfg(test)]
mod tests_random_generation {
use super::*;
use approx::assert_relative_eq;
use rand::distr::Distribution;
use std::collections::HashMap;
#[test]
fn test_normal_sample_returns() {
for _ in 0..1000 {
let sample = decimal_normal_sample();
assert!(sample <= Decimal::TEN);
assert!(sample >= -Decimal::TEN);
}
}
#[test]
fn test_normal_sample_distribution() {
const NUM_SAMPLES: usize = 10000;
let mut samples = Vec::with_capacity(NUM_SAMPLES);
for _ in 0..NUM_SAMPLES {
samples.push(decimal_normal_sample().to_f64().unwrap());
}
let sum: f64 = samples.iter().sum();
let mean = sum / NUM_SAMPLES as f64;
let variance_sum: f64 = samples.iter().map(|&x| (x - mean).powi(2)).sum();
let std_dev = (variance_sum / NUM_SAMPLES as f64).sqrt();
assert_relative_eq!(mean, 0.0, epsilon = 0.04);
assert_relative_eq!(std_dev, 1.0, epsilon = 0.03);
}
#[test]
fn test_normal_distribution_transformation() {
let mut t_rng = rand::rng();
let normal = Normal::new(-1.0, 0.5).unwrap();
let mut value_counts: HashMap<i32, usize> = HashMap::new();
const SAMPLES: usize = 5000;
for _ in 0..SAMPLES {
let raw_sample = normal.sample(&mut t_rng);
let positive_sample = raw_sample.to_f64().unwrap();
let bucket = (positive_sample.round() as i32).max(0);
*value_counts.entry(bucket).or_insert(0) += 1;
}
assert!(value_counts.get(&0).unwrap_or(&0) > &(SAMPLES / 10));
let max_bucket = value_counts.keys().max().unwrap_or(&0);
assert!(*max_bucket > 0);
}
#[test]
fn test_normal_sample_consistency() {
let sample1 = decimal_normal_sample();
let sample2 = decimal_normal_sample();
let sample3 = decimal_normal_sample();
assert!(sample1 != sample2 || sample2 != sample3);
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod checked_helpers_tests {
use super::*;
#[test]
fn d_add_happy_path() {
let result = d_add(dec!(1.25), dec!(2.50), "test::add");
assert_eq!(result.unwrap(), dec!(3.75));
}
#[test]
fn d_add_overflow_on_max_plus_max() {
let err = d_add(Decimal::MAX, Decimal::MAX, "test::add").unwrap_err();
match err {
DecimalError::Overflow { operation, .. } => assert_eq!(operation, "test::add"),
other => panic!("expected Overflow, got {other:?}"),
}
}
#[test]
fn d_sub_happy_path() {
let result = d_sub(dec!(10), dec!(3.5), "test::sub");
assert_eq!(result.unwrap(), dec!(6.5));
}
#[test]
fn d_sub_overflow_on_min_minus_max() {
let err = d_sub(Decimal::MIN, Decimal::MAX, "test::sub").unwrap_err();
assert!(
matches!(err, DecimalError::Overflow { operation, .. } if operation == "test::sub")
);
}
#[test]
fn d_mul_happy_path() {
let result = d_mul(dec!(2.5), dec!(4), "test::mul");
assert_eq!(result.unwrap(), dec!(10.0));
}
#[test]
fn d_mul_overflow_on_max_times_two() {
let err = d_mul(Decimal::MAX, dec!(2), "test::mul").unwrap_err();
assert!(
matches!(err, DecimalError::Overflow { operation, .. } if operation == "test::mul")
);
}
#[test]
fn d_div_happy_path_exact() {
let result = d_div(dec!(10), dec!(4), "test::div");
assert_eq!(result.unwrap(), dec!(2.5));
}
#[test]
fn d_div_applies_banker_rounding() {
let result = d_div(dec!(1), dec!(3), "test::div").unwrap();
assert_eq!(result, dec!(0.3333333333333333333333333333));
}
#[test]
fn d_div_zero_denominator_returns_arithmetic_error() {
let err = d_div(dec!(1), Decimal::ZERO, "test::div").unwrap_err();
assert!(matches!(err, DecimalError::ArithmeticError { .. }));
}
#[test]
fn d_div_tag_is_preserved_on_overflow() {
let err = d_div(Decimal::MIN, dec!(0.5), "test::div").unwrap_err();
match err {
DecimalError::Overflow { operation, .. } => assert_eq!(operation, "test::div"),
other => panic!("expected Overflow, got {other:?}"),
}
}
#[test]
fn d_sum_empty_returns_zero() {
assert_eq!(d_sum(&[], "test::sum").unwrap(), Decimal::ZERO);
}
#[test]
fn d_sum_happy_path() {
let result = d_sum(&[dec!(1.5), dec!(2.25), dec!(-0.75), dec!(10)], "test::sum");
assert_eq!(result.unwrap(), dec!(13));
}
#[test]
fn d_sum_overflow_returns_tagged_error() {
let err = d_sum(&[Decimal::MAX, Decimal::MAX], "test::sum").unwrap_err();
assert!(
matches!(err, DecimalError::Overflow { operation, .. } if operation == "test::sum")
);
}
#[test]
fn d_sum_iter_empty_returns_zero() {
let empty: std::iter::Empty<Decimal> = std::iter::empty();
assert_eq!(d_sum_iter(empty, "test::sum_iter").unwrap(), Decimal::ZERO);
}
#[test]
fn d_sum_iter_happy_path_matches_d_sum() {
let values = [dec!(1.5), dec!(2.25), dec!(-0.75), dec!(10)];
let via_iter = d_sum_iter(values.iter().copied(), "test::sum_iter").unwrap();
let via_slice = d_sum(&values, "test::sum_iter").unwrap();
assert_eq!(via_iter, dec!(13));
assert_eq!(via_iter, via_slice);
}
#[test]
fn d_sum_iter_accepts_lazy_map() {
let sum = d_sum_iter((1i64..=4).map(Decimal::from), "test::sum_iter_lazy").unwrap();
assert_eq!(sum, dec!(10));
}
#[test]
fn d_sum_iter_overflow_returns_tagged_error() {
let err = d_sum_iter([Decimal::MAX, Decimal::MAX], "test::sum_iter").unwrap_err();
assert!(
matches!(err, DecimalError::Overflow { operation, .. } if operation == "test::sum_iter")
);
}
}