datasynth_eval/statistical/
amount_distribution.rs1use crate::error::{EvalError, EvalResult};
7use rust_decimal::prelude::*;
8use rust_decimal::Decimal;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AmountDistributionAnalysis {
14 pub sample_size: usize,
16 pub mean: Decimal,
18 pub median: Decimal,
20 pub std_dev: Decimal,
22 pub min: Decimal,
24 pub max: Decimal,
26 pub percentile_1: Decimal,
28 pub percentile_99: Decimal,
30 pub skewness: f64,
32 pub kurtosis: f64,
34 pub lognormal_ks_stat: Option<f64>,
36 pub lognormal_ks_pvalue: Option<f64>,
38 pub fitted_mu: Option<f64>,
40 pub fitted_sigma: Option<f64>,
42 pub round_number_ratio: f64,
44 pub nice_number_ratio: f64,
46 pub passes: bool,
48}
49
50pub struct AmountDistributionAnalyzer {
52 expected_mu: Option<f64>,
54 expected_sigma: Option<f64>,
56 significance_level: f64,
58}
59
60impl AmountDistributionAnalyzer {
61 pub fn new() -> Self {
63 Self {
64 expected_mu: None,
65 expected_sigma: None,
66 significance_level: 0.05,
67 }
68 }
69
70 pub fn with_expected_lognormal(mut self, mu: f64, sigma: f64) -> Self {
72 self.expected_mu = Some(mu);
73 self.expected_sigma = Some(sigma);
74 self
75 }
76
77 pub fn with_significance_level(mut self, level: f64) -> Self {
79 self.significance_level = level;
80 self
81 }
82
83 pub fn analyze(&self, amounts: &[Decimal]) -> EvalResult<AmountDistributionAnalysis> {
85 let n = amounts.len();
86 if n < 2 {
87 return Err(EvalError::InsufficientData {
88 required: 2,
89 actual: n,
90 });
91 }
92
93 let positive_amounts: Vec<Decimal> = amounts
95 .iter()
96 .filter(|a| **a > Decimal::ZERO)
97 .copied()
98 .collect();
99
100 let mut sorted = amounts.to_vec();
102 sorted.sort();
103
104 let sum: Decimal = amounts.iter().sum();
106 let mean = sum / Decimal::from(n);
107 let median = sorted[n / 2];
108 let min = sorted[0];
109 let max = sorted[n - 1];
110
111 let percentile_1 = sorted[(n as f64 * 0.01) as usize];
113 let percentile_99 = sorted[((n as f64 * 0.99) as usize).min(n - 1)];
114
115 let variance: Decimal = amounts
117 .iter()
118 .map(|a| (*a - mean) * (*a - mean))
119 .sum::<Decimal>()
120 / Decimal::from(n - 1);
121 let std_dev = decimal_sqrt(variance);
122
123 let amounts_f64: Vec<f64> = amounts
125 .iter()
126 .filter_map(rust_decimal::prelude::ToPrimitive::to_f64)
127 .collect();
128 let mean_f64 = amounts_f64.iter().sum::<f64>() / amounts_f64.len() as f64;
129 let std_f64 = (amounts_f64
130 .iter()
131 .map(|a| (a - mean_f64).powi(2))
132 .sum::<f64>()
133 / (amounts_f64.len() - 1) as f64)
134 .sqrt();
135
136 let skewness = if std_f64 > 0.0 {
138 let n_f64 = amounts_f64.len() as f64;
139 let m3 = amounts_f64
140 .iter()
141 .map(|a| ((a - mean_f64) / std_f64).powi(3))
142 .sum::<f64>()
143 / n_f64;
144 m3 * (n_f64 * (n_f64 - 1.0)).sqrt() / (n_f64 - 2.0)
145 } else {
146 0.0
147 };
148
149 let kurtosis = if std_f64 > 0.0 {
151 let n_f64 = amounts_f64.len() as f64;
152 let m4 = amounts_f64
153 .iter()
154 .map(|a| ((a - mean_f64) / std_f64).powi(4))
155 .sum::<f64>()
156 / n_f64;
157 m4 - 3.0 } else {
159 0.0
160 };
161
162 let (lognormal_ks_stat, lognormal_ks_pvalue, fitted_mu, fitted_sigma) =
164 if positive_amounts.len() >= 10 {
165 self.lognormal_ks_test(&positive_amounts)
166 } else {
167 (None, None, None, None)
168 };
169
170 let round_count = amounts
172 .iter()
173 .filter(|a| {
174 let frac = a.fract();
175 frac.is_zero()
176 })
177 .count();
178 let round_number_ratio = round_count as f64 / n as f64;
179
180 let nice_count = amounts
182 .iter()
183 .filter(|a| {
184 let cents = (a.fract() * Decimal::ONE_HUNDRED).abs();
185 let last_digit = (cents.to_i64().unwrap_or(0) % 10) as u8;
186 last_digit == 0 || last_digit == 5
187 })
188 .count();
189 let nice_number_ratio = nice_count as f64 / n as f64;
190
191 let passes = lognormal_ks_pvalue.is_none_or(|p| p >= self.significance_level);
193
194 Ok(AmountDistributionAnalysis {
195 sample_size: n,
196 mean,
197 median,
198 std_dev,
199 min,
200 max,
201 percentile_1,
202 percentile_99,
203 skewness,
204 kurtosis,
205 lognormal_ks_stat,
206 lognormal_ks_pvalue,
207 fitted_mu,
208 fitted_sigma,
209 round_number_ratio,
210 nice_number_ratio,
211 passes,
212 })
213 }
214
215 fn lognormal_ks_test(
217 &self,
218 amounts: &[Decimal],
219 ) -> (Option<f64>, Option<f64>, Option<f64>, Option<f64>) {
220 let log_amounts: Vec<f64> = amounts
222 .iter()
223 .filter_map(rust_decimal::prelude::ToPrimitive::to_f64)
224 .filter(|a| *a > 0.0)
225 .map(f64::ln)
226 .collect();
227
228 if log_amounts.len() < 10 {
229 return (None, None, None, None);
230 }
231
232 let n = log_amounts.len() as f64;
234 let mu: f64 = log_amounts.iter().sum::<f64>() / n;
235 let sigma: f64 =
236 (log_amounts.iter().map(|x| (x - mu).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
237
238 if sigma <= 0.0 {
239 return (None, None, Some(mu), None);
240 }
241
242 let mut sorted_log = log_amounts.clone();
244 sorted_log.sort_by(f64::total_cmp);
245
246 let n_usize = sorted_log.len();
248 let mut d_max = 0.0f64;
249
250 for (i, &x) in sorted_log.iter().enumerate() {
251 let f_n = (i + 1) as f64 / n_usize as f64;
252 let f_x = normal_cdf((x - mu) / sigma);
253 let d_plus = (f_n - f_x).abs();
254 let d_minus = (f_x - i as f64 / n_usize as f64).abs();
255 d_max = d_max.max(d_plus).max(d_minus);
256 }
257
258 let sqrt_n = (n_usize as f64).sqrt();
261 let lambda = (sqrt_n + 0.12 + 0.11 / sqrt_n) * d_max;
262 let p_value = kolmogorov_pvalue(lambda);
263
264 (Some(d_max), Some(p_value), Some(mu), Some(sigma))
265 }
266}
267
268impl Default for AmountDistributionAnalyzer {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274fn normal_cdf(x: f64) -> f64 {
276 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
277}
278
279fn erf(x: f64) -> f64 {
281 let a1 = 0.254829592;
283 let a2 = -0.284496736;
284 let a3 = 1.421413741;
285 let a4 = -1.453152027;
286 let a5 = 1.061405429;
287 let p = 0.3275911;
288
289 let sign = if x < 0.0 { -1.0 } else { 1.0 };
290 let x = x.abs();
291
292 let t = 1.0 / (1.0 + p * x);
293 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
294
295 sign * y
296}
297
298fn kolmogorov_pvalue(lambda: f64) -> f64 {
300 if lambda <= 0.0 {
301 return 1.0;
302 }
303
304 let mut sum = 0.0;
307 let lambda_sq = lambda * lambda;
308
309 for k in 1..=100 {
310 let k_f64 = k as f64;
311 let term = (-1.0f64).powi(k - 1) * (-2.0 * k_f64 * k_f64 * lambda_sq).exp();
312 sum += term;
313 if term.abs() < 1e-10 {
314 break;
315 }
316 }
317
318 (2.0 * sum).clamp(0.0, 1.0)
319}
320
321fn decimal_sqrt(value: Decimal) -> Decimal {
323 if value <= Decimal::ZERO {
324 return Decimal::ZERO;
325 }
326
327 let mut guess = value / Decimal::TWO;
329 for _ in 0..20 {
330 let new_guess = (guess + value / guess) / Decimal::TWO;
331 if (new_guess - guess).abs() < Decimal::new(1, 10) {
332 return new_guess;
333 }
334 guess = new_guess;
335 }
336 guess
337}
338
339#[cfg(test)]
340#[allow(clippy::unwrap_used)]
341mod tests {
342 use super::*;
343 use rust_decimal_macros::dec;
344
345 #[test]
346 fn test_basic_statistics() {
347 let amounts = vec![
348 dec!(100.00),
349 dec!(200.00),
350 dec!(300.00),
351 dec!(400.00),
352 dec!(500.00),
353 ];
354
355 let analyzer = AmountDistributionAnalyzer::new();
356 let result = analyzer.analyze(&amounts).unwrap();
357
358 assert_eq!(result.sample_size, 5);
359 assert_eq!(result.mean, dec!(300.00));
360 assert_eq!(result.min, dec!(100.00));
361 assert_eq!(result.max, dec!(500.00));
362 }
363
364 #[test]
365 fn test_round_number_detection() {
366 let amounts = vec![
367 dec!(100.00), dec!(200.50), dec!(300.00), dec!(400.25), dec!(500.00), ];
373
374 let analyzer = AmountDistributionAnalyzer::new();
375 let result = analyzer.analyze(&amounts).unwrap();
376
377 assert!((result.round_number_ratio - 0.6).abs() < 0.01);
378 }
379
380 #[test]
381 fn test_insufficient_data() {
382 let amounts = vec![dec!(100.00)];
383 let analyzer = AmountDistributionAnalyzer::new();
384 let result = analyzer.analyze(&amounts);
385 assert!(matches!(result, Err(EvalError::InsufficientData { .. })));
386 }
387
388 #[test]
389 fn test_normal_cdf() {
390 assert!((normal_cdf(0.0) - 0.5).abs() < 0.001);
391 assert!((normal_cdf(1.96) - 0.975).abs() < 0.01);
392 assert!((normal_cdf(-1.96) - 0.025).abs() < 0.01);
393 }
394}