use crate::error::decimal::DecimalError;
use crate::geometrics::HasX;
use num_traits::{FromPrimitive, ToPrimitive};
use rand::distr::Distribution;
use rand_distr::Normal;
use rust_decimal::{Decimal, MathematicalOps};
use rust_decimal_macros::dec;
pub const ONE_DAY: Decimal = dec!(0.00396825397);
#[macro_export]
macro_rules! assert_decimal_eq {
($left:expr, $right:expr, $epsilon:expr) => {
let diff = ($left - $right).abs();
assert!(
diff <= $epsilon,
"assertion failed: `(left == right)`\n left: `{}`\n right: `{}`\n diff: `{}`\n epsilon: `{}`",
$left,
$right,
diff,
$epsilon
);
};
}
pub trait DecimalStats {
fn mean(&self) -> Decimal;
fn std_dev(&self) -> Decimal;
}
impl DecimalStats for Vec<Decimal> {
fn mean(&self) -> Decimal {
if self.is_empty() {
return Decimal::ZERO;
}
let sum: Decimal = self.iter().sum();
sum / Decimal::from(self.len())
}
fn std_dev(&self) -> Decimal {
if self.len() < 2usize {
return Decimal::ZERO;
}
let mean = self.mean();
let variance: Decimal = self.iter().map(|x| (x - mean).powd(Decimal::TWO)).sum();
(variance / Decimal::from(self.len() - 1))
.sqrt()
.unwrap_or(Decimal::ZERO)
}
}
pub fn decimal_to_f64(value: Decimal) -> Result<f64, DecimalError> {
value.to_f64().ok_or(DecimalError::ConversionError {
from_type: format!("Decimal: {value}"),
to_type: "f64".to_string(),
reason: "Failed to convert Decimal to f64".to_string(),
})
}
pub fn f64_to_decimal(value: f64) -> Result<Decimal, DecimalError> {
Decimal::from_f64(value).ok_or(DecimalError::ConversionError {
from_type: format!("f64: {value}"),
to_type: "Decimal".to_string(),
reason: "Failed to convert f64 to Decimal".to_string(),
})
}
pub fn decimal_normal_sample() -> Decimal {
let mut t_rng = rand::rng();
let normal =
Normal::new(0.0, 1.0).expect("standard normal distribution parameters are always valid");
Decimal::from_f64(normal.sample(&mut t_rng)).unwrap_or(Decimal::ZERO)
}
impl HasX for Decimal {
fn get_x(&self) -> Decimal {
*self
}
}
#[macro_export]
macro_rules! d2fu {
($val:expr) => {
$crate::model::decimal::decimal_to_f64($val)
};
}
#[macro_export]
macro_rules! d2f {
($val:expr) => {
$crate::model::decimal::decimal_to_f64($val)?
};
}
#[macro_export]
macro_rules! f2du {
($val:expr) => {
$crate::model::decimal::f64_to_decimal($val)
};
}
#[macro_export]
macro_rules! f2d {
($val:expr) => {
$crate::model::decimal::f64_to_decimal($val)?
};
}
#[cfg(test)]
pub mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_f64_to_decimal_valid() {
let value = 42.42;
let result = f64_to_decimal(value);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Decimal::from_str("42.42").unwrap());
}
#[test]
fn test_f64_to_decimal_zero() {
let value = 0.0;
let result = f64_to_decimal(value);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Decimal::from_str("0").unwrap());
}
#[test]
fn test_decimal_to_f64_valid() {
let decimal = Decimal::from_str("42.42").unwrap();
let result = decimal_to_f64(decimal);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42.42);
}
#[test]
fn test_decimal_to_f64_zero() {
let decimal = Decimal::from_str("0").unwrap();
let result = decimal_to_f64(decimal);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.0);
}
}
#[cfg(test)]
mod tests_random_generation {
use super::*;
use approx::assert_relative_eq;
use rand::distr::Distribution;
use std::collections::HashMap;
#[test]
fn test_normal_sample_returns() {
for _ in 0..1000 {
let sample = decimal_normal_sample();
assert!(sample <= Decimal::TEN);
assert!(sample >= -Decimal::TEN);
}
}
#[test]
fn test_normal_sample_distribution() {
const NUM_SAMPLES: usize = 10000;
let mut samples = Vec::with_capacity(NUM_SAMPLES);
for _ in 0..NUM_SAMPLES {
samples.push(decimal_normal_sample().to_f64().unwrap());
}
let sum: f64 = samples.iter().sum();
let mean = sum / NUM_SAMPLES as f64;
let variance_sum: f64 = samples.iter().map(|&x| (x - mean).powi(2)).sum();
let std_dev = (variance_sum / NUM_SAMPLES as f64).sqrt();
assert_relative_eq!(mean, 0.0, epsilon = 0.04);
assert_relative_eq!(std_dev, 1.0, epsilon = 0.03);
}
#[test]
fn test_normal_distribution_transformation() {
let mut t_rng = rand::rng();
let normal = Normal::new(-1.0, 0.5).unwrap();
let mut value_counts: HashMap<i32, usize> = HashMap::new();
const SAMPLES: usize = 5000;
for _ in 0..SAMPLES {
let raw_sample = normal.sample(&mut t_rng);
let positive_sample = raw_sample.to_f64().unwrap();
let bucket = (positive_sample.round() as i32).max(0);
*value_counts.entry(bucket).or_insert(0) += 1;
}
assert!(value_counts.get(&0).unwrap_or(&0) > &(SAMPLES / 10));
let max_bucket = value_counts.keys().max().unwrap_or(&0);
assert!(*max_bucket > 0);
}
#[test]
fn test_normal_sample_consistency() {
let sample1 = decimal_normal_sample();
let sample2 = decimal_normal_sample();
let sample3 = decimal_normal_sample();
assert!(sample1 != sample2 || sample2 != sample3);
}
}