use crate::statistic::{algorithm::welford_online, summary::dataset::dispersion::Dispersion};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
pub mod dispersion;
#[derive(Debug, Clone, PartialEq, PartialOrd, Default, Deserialize, Serialize)]
pub struct DataSetSummary {
pub count: Decimal,
pub sum: Decimal,
pub mean: Decimal,
pub dispersion: Dispersion,
}
impl DataSetSummary {
pub fn update(&mut self, next_value: Decimal) {
self.count += Decimal::ONE;
self.sum += next_value;
let prev_mean = self.mean;
self.mean = welford_online::calculate_mean(self.mean, next_value, self.count);
self.dispersion
.update(prev_mean, self.mean, next_value, self.count);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::statistic::summary::dataset::dispersion::Range;
use rust_decimal_macros::dec;
use std::str::FromStr;
#[test]
fn update_data_summary_with_position() {
struct TestCase {
input_next_value: Decimal,
expected_summary: DataSetSummary,
}
let mut data_summary = DataSetSummary::default();
let test_cases = vec![
TestCase {
input_next_value: dec!(1.1),
expected_summary: DataSetSummary {
count: dec!(1),
sum: dec!(1.1),
mean: dec!(1.1),
dispersion: Dispersion {
range: Range {
activated: true,
high: dec!(1.1),
low: dec!(1.1),
},
recurrence_relation_m: dec!(0.0),
variance: dec!(0.0),
std_dev: dec!(0.0),
},
},
},
TestCase {
input_next_value: dec!(1.2),
expected_summary: DataSetSummary {
count: dec!(2),
sum: dec!(2.3),
mean: Decimal::from_str("1.15").unwrap(), dispersion: Dispersion {
range: Range {
activated: true,
high: dec!(1.2),
low: dec!(1.1),
},
recurrence_relation_m: dec!(0.005),
variance: dec!(0.0025),
std_dev: dec!(0.05),
},
},
},
TestCase {
input_next_value: dec!(1.3),
expected_summary: DataSetSummary {
count: dec!(3),
sum: dec!(3.6),
mean: dec!(1.2), dispersion: Dispersion {
range: Range {
activated: true,
high: dec!(1.3),
low: dec!(1.1),
},
recurrence_relation_m: dec!(0.02),
variance: Decimal::from_str("0.006666666667").unwrap(), std_dev: Decimal::from_str("0.081649658092").unwrap(), },
},
},
];
for (index, test) in test_cases.into_iter().enumerate() {
data_summary.update(test.input_next_value);
assert_eq!(
data_summary.count, test.expected_summary.count,
"Count Input: {:?}",
index
);
assert_eq!(
data_summary.sum, test.expected_summary.sum,
"Sum Input: {:?}",
index
);
assert_eq!(
data_summary.mean, test.expected_summary.mean,
"Mean Input: {:?}",
index
);
assert_eq!(
data_summary.dispersion.range, test.expected_summary.dispersion.range,
"Range Input: {:?}",
index
);
let tolerance = Decimal::from_str("0.000000000001").unwrap();
let recurrence_diff = (data_summary.dispersion.recurrence_relation_m
- test.expected_summary.dispersion.recurrence_relation_m)
.abs();
assert!(
recurrence_diff <= tolerance,
"Recurrence difference {} exceeds tolerance, Input: {:?}",
recurrence_diff,
index
);
let variance_diff = (data_summary.dispersion.variance
- test.expected_summary.dispersion.variance)
.abs();
assert!(
variance_diff <= tolerance,
"Variance difference {} exceeds tolerance, Input: {:?}",
variance_diff,
index
);
let std_dev_diff =
(data_summary.dispersion.std_dev - test.expected_summary.dispersion.std_dev).abs();
assert!(
std_dev_diff <= tolerance,
"Std Dev difference {} exceeds tolerance, Input: {:?}",
std_dev_diff,
index
);
}
}
}