datasynth_eval/statistical/
amount_distribution.rs1use crate::error::{EvalError, EvalResult};
7use rust_decimal::prelude::*;
8use rust_decimal::Decimal;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AmountDistributionAnalysis {
14 pub sample_size: usize,
16 pub mean: Decimal,
18 pub median: Decimal,
20 pub std_dev: Decimal,
22 pub min: Decimal,
24 pub max: Decimal,
26 pub percentile_1: Decimal,
28 pub percentile_99: Decimal,
30 pub skewness: f64,
32 pub kurtosis: f64,
34 pub lognormal_ks_stat: Option<f64>,
36 pub lognormal_ks_pvalue: Option<f64>,
38 pub fitted_mu: Option<f64>,
40 pub fitted_sigma: Option<f64>,
42 pub round_number_ratio: f64,
44 pub nice_number_ratio: f64,
46 pub passes: bool,
48}
49
50pub struct AmountDistributionAnalyzer {
52 expected_mu: Option<f64>,
54 expected_sigma: Option<f64>,
56 significance_level: f64,
58}
59
60impl AmountDistributionAnalyzer {
61 pub fn new() -> Self {
63 Self {
64 expected_mu: None,
65 expected_sigma: None,
66 significance_level: 0.05,
67 }
68 }
69
70 pub fn with_expected_lognormal(mut self, mu: f64, sigma: f64) -> Self {
72 self.expected_mu = Some(mu);
73 self.expected_sigma = Some(sigma);
74 self
75 }
76
77 pub fn with_significance_level(mut self, level: f64) -> Self {
79 self.significance_level = level;
80 self
81 }
82
83 pub fn analyze(&self, amounts: &[Decimal]) -> EvalResult<AmountDistributionAnalysis> {
85 let n = amounts.len();
86 if n < 2 {
87 return Err(EvalError::InsufficientData {
88 required: 2,
89 actual: n,
90 });
91 }
92
93 let positive_amounts: Vec<Decimal> = amounts
95 .iter()
96 .filter(|a| **a > Decimal::ZERO)
97 .copied()
98 .collect();
99
100 let mut sorted = amounts.to_vec();
102 sorted.sort();
103
104 let sum: Decimal = amounts.iter().sum();
106 let mean = sum / Decimal::from(n);
107 let median = sorted[n / 2];
108 let min = sorted[0];
109 let max = sorted[n - 1];
110
111 let percentile_1 = sorted[(n as f64 * 0.01) as usize];
113 let percentile_99 = sorted[((n as f64 * 0.99) as usize).min(n - 1)];
114
115 let amounts_f64: Vec<f64> = amounts
126 .iter()
127 .filter_map(rust_decimal::prelude::ToPrimitive::to_f64)
128 .collect();
129 let mean_f64 = amounts_f64.iter().sum::<f64>() / amounts_f64.len() as f64;
130 let std_f64 = (amounts_f64
131 .iter()
132 .map(|a| (a - mean_f64).powi(2))
133 .sum::<f64>()
134 / (amounts_f64.len() - 1) as f64)
135 .sqrt();
136 let std_dev = rust_decimal::Decimal::from_f64_retain(std_f64).unwrap_or(Decimal::ZERO);
137
138 let skewness = if std_f64 > 0.0 {
140 let n_f64 = amounts_f64.len() as f64;
141 let m3 = amounts_f64
142 .iter()
143 .map(|a| ((a - mean_f64) / std_f64).powi(3))
144 .sum::<f64>()
145 / n_f64;
146 m3 * (n_f64 * (n_f64 - 1.0)).sqrt() / (n_f64 - 2.0)
147 } else {
148 0.0
149 };
150
151 let kurtosis = if std_f64 > 0.0 {
153 let n_f64 = amounts_f64.len() as f64;
154 let m4 = amounts_f64
155 .iter()
156 .map(|a| ((a - mean_f64) / std_f64).powi(4))
157 .sum::<f64>()
158 / n_f64;
159 m4 - 3.0 } else {
161 0.0
162 };
163
164 let (lognormal_ks_stat, lognormal_ks_pvalue, fitted_mu, fitted_sigma) =
166 if positive_amounts.len() >= 10 {
167 self.lognormal_ks_test(&positive_amounts)
168 } else {
169 (None, None, None, None)
170 };
171
172 let round_count = amounts
174 .iter()
175 .filter(|a| {
176 let frac = a.fract();
177 frac.is_zero()
178 })
179 .count();
180 let round_number_ratio = round_count as f64 / n as f64;
181
182 let nice_count = amounts
184 .iter()
185 .filter(|a| {
186 let cents = (a.fract() * Decimal::ONE_HUNDRED).abs();
187 let last_digit = (cents.to_i64().unwrap_or(0) % 10) as u8;
188 last_digit == 0 || last_digit == 5
189 })
190 .count();
191 let nice_number_ratio = nice_count as f64 / n as f64;
192
193 let passes = lognormal_ks_pvalue.is_none_or(|p| p >= self.significance_level);
195
196 Ok(AmountDistributionAnalysis {
197 sample_size: n,
198 mean,
199 median,
200 std_dev,
201 min,
202 max,
203 percentile_1,
204 percentile_99,
205 skewness,
206 kurtosis,
207 lognormal_ks_stat,
208 lognormal_ks_pvalue,
209 fitted_mu,
210 fitted_sigma,
211 round_number_ratio,
212 nice_number_ratio,
213 passes,
214 })
215 }
216
217 fn lognormal_ks_test(
219 &self,
220 amounts: &[Decimal],
221 ) -> (Option<f64>, Option<f64>, Option<f64>, Option<f64>) {
222 let log_amounts: Vec<f64> = amounts
224 .iter()
225 .filter_map(rust_decimal::prelude::ToPrimitive::to_f64)
226 .filter(|a| *a > 0.0)
227 .map(f64::ln)
228 .collect();
229
230 if log_amounts.len() < 10 {
231 return (None, None, None, None);
232 }
233
234 let n = log_amounts.len() as f64;
236 let mu: f64 = log_amounts.iter().sum::<f64>() / n;
237 let sigma: f64 =
238 (log_amounts.iter().map(|x| (x - mu).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
239
240 if sigma <= 0.0 {
241 return (None, None, Some(mu), None);
242 }
243
244 let mut sorted_log = log_amounts.clone();
246 sorted_log.sort_by(f64::total_cmp);
247
248 let n_usize = sorted_log.len();
250 let mut d_max = 0.0f64;
251
252 for (i, &x) in sorted_log.iter().enumerate() {
253 let f_n = (i + 1) as f64 / n_usize as f64;
254 let f_x = normal_cdf((x - mu) / sigma);
255 let d_plus = (f_n - f_x).abs();
256 let d_minus = (f_x - i as f64 / n_usize as f64).abs();
257 d_max = d_max.max(d_plus).max(d_minus);
258 }
259
260 let sqrt_n = (n_usize as f64).sqrt();
263 let lambda = (sqrt_n + 0.12 + 0.11 / sqrt_n) * d_max;
264 let p_value = kolmogorov_pvalue(lambda);
265
266 (Some(d_max), Some(p_value), Some(mu), Some(sigma))
267 }
268}
269
270impl Default for AmountDistributionAnalyzer {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276fn normal_cdf(x: f64) -> f64 {
278 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
279}
280
281fn erf(x: f64) -> f64 {
283 let a1 = 0.254829592;
285 let a2 = -0.284496736;
286 let a3 = 1.421413741;
287 let a4 = -1.453152027;
288 let a5 = 1.061405429;
289 let p = 0.3275911;
290
291 let sign = if x < 0.0 { -1.0 } else { 1.0 };
292 let x = x.abs();
293
294 let t = 1.0 / (1.0 + p * x);
295 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
296
297 sign * y
298}
299
300fn kolmogorov_pvalue(lambda: f64) -> f64 {
302 if lambda <= 0.0 {
303 return 1.0;
304 }
305
306 let mut sum = 0.0;
309 let lambda_sq = lambda * lambda;
310
311 for k in 1..=100 {
312 let k_f64 = k as f64;
313 let term = (-1.0f64).powi(k - 1) * (-2.0 * k_f64 * k_f64 * lambda_sq).exp();
314 sum += term;
315 if term.abs() < 1e-10 {
316 break;
317 }
318 }
319
320 (2.0 * sum).clamp(0.0, 1.0)
321}
322
323#[allow(dead_code)]
329fn decimal_sqrt(value: Decimal) -> Decimal {
330 if value <= Decimal::ZERO {
331 return Decimal::ZERO;
332 }
333
334 let mut guess = value / Decimal::TWO;
336 for _ in 0..20 {
337 let new_guess = (guess + value / guess) / Decimal::TWO;
338 if (new_guess - guess).abs() < Decimal::new(1, 10) {
339 return new_guess;
340 }
341 guess = new_guess;
342 }
343 guess
344}
345
346#[cfg(test)]
347#[allow(clippy::unwrap_used)]
348mod tests {
349 use super::*;
350 use rust_decimal_macros::dec;
351
352 #[test]
353 fn test_basic_statistics() {
354 let amounts = vec![
355 dec!(100.00),
356 dec!(200.00),
357 dec!(300.00),
358 dec!(400.00),
359 dec!(500.00),
360 ];
361
362 let analyzer = AmountDistributionAnalyzer::new();
363 let result = analyzer.analyze(&amounts).unwrap();
364
365 assert_eq!(result.sample_size, 5);
366 assert_eq!(result.mean, dec!(300.00));
367 assert_eq!(result.min, dec!(100.00));
368 assert_eq!(result.max, dec!(500.00));
369 }
370
371 #[test]
372 fn test_round_number_detection() {
373 let amounts = vec![
374 dec!(100.00), dec!(200.50), dec!(300.00), dec!(400.25), dec!(500.00), ];
380
381 let analyzer = AmountDistributionAnalyzer::new();
382 let result = analyzer.analyze(&amounts).unwrap();
383
384 assert!((result.round_number_ratio - 0.6).abs() < 0.01);
385 }
386
387 #[test]
388 fn test_insufficient_data() {
389 let amounts = vec![dec!(100.00)];
390 let analyzer = AmountDistributionAnalyzer::new();
391 let result = analyzer.analyze(&amounts);
392 assert!(matches!(result, Err(EvalError::InsufficientData { .. })));
393 }
394
395 #[test]
396 fn test_normal_cdf() {
397 assert!((normal_cdf(0.0) - 0.5).abs() < 0.001);
398 assert!((normal_cdf(1.96) - 0.975).abs() < 0.01);
399 assert!((normal_cdf(-1.96) - 0.025).abs() < 0.01);
400 }
401}