datasynth_eval/statistical/
amount_distribution.rs1use crate::error::{EvalError, EvalResult};
7use rust_decimal::prelude::*;
8use rust_decimal::Decimal;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AmountDistributionAnalysis {
14 pub sample_size: usize,
16 pub mean: Decimal,
18 pub median: Decimal,
20 pub std_dev: Decimal,
22 pub min: Decimal,
24 pub max: Decimal,
26 pub percentile_1: Decimal,
28 pub percentile_99: Decimal,
30 pub skewness: f64,
32 pub kurtosis: f64,
34 pub lognormal_ks_stat: Option<f64>,
36 pub lognormal_ks_pvalue: Option<f64>,
38 pub fitted_mu: Option<f64>,
40 pub fitted_sigma: Option<f64>,
42 pub round_number_ratio: f64,
44 pub nice_number_ratio: f64,
46 pub passes: bool,
48}
49
50pub struct AmountDistributionAnalyzer {
52 expected_mu: Option<f64>,
54 expected_sigma: Option<f64>,
56 significance_level: f64,
58}
59
60impl AmountDistributionAnalyzer {
61 pub fn new() -> Self {
63 Self {
64 expected_mu: None,
65 expected_sigma: None,
66 significance_level: 0.05,
67 }
68 }
69
70 pub fn with_expected_lognormal(mut self, mu: f64, sigma: f64) -> Self {
72 self.expected_mu = Some(mu);
73 self.expected_sigma = Some(sigma);
74 self
75 }
76
77 pub fn with_significance_level(mut self, level: f64) -> Self {
79 self.significance_level = level;
80 self
81 }
82
83 pub fn analyze(&self, amounts: &[Decimal]) -> EvalResult<AmountDistributionAnalysis> {
85 let n = amounts.len();
86 if n < 2 {
87 return Err(EvalError::InsufficientData {
88 required: 2,
89 actual: n,
90 });
91 }
92
93 let positive_amounts: Vec<Decimal> = amounts
95 .iter()
96 .filter(|a| **a > Decimal::ZERO)
97 .copied()
98 .collect();
99
100 let mut sorted = amounts.to_vec();
102 sorted.sort();
103
104 let sum: Decimal = amounts.iter().sum();
106 let mean = sum / Decimal::from(n);
107 let median = sorted[n / 2];
108 let min = sorted[0];
109 let max = sorted[n - 1];
110
111 let percentile_1 = sorted[(n as f64 * 0.01) as usize];
113 let percentile_99 = sorted[((n as f64 * 0.99) as usize).min(n - 1)];
114
115 let variance: Decimal = amounts
117 .iter()
118 .map(|a| (*a - mean) * (*a - mean))
119 .sum::<Decimal>()
120 / Decimal::from(n - 1);
121 let std_dev = decimal_sqrt(variance);
122
123 let amounts_f64: Vec<f64> = amounts.iter().filter_map(|a| a.to_f64()).collect();
125 let mean_f64 = amounts_f64.iter().sum::<f64>() / amounts_f64.len() as f64;
126 let std_f64 = (amounts_f64
127 .iter()
128 .map(|a| (a - mean_f64).powi(2))
129 .sum::<f64>()
130 / (amounts_f64.len() - 1) as f64)
131 .sqrt();
132
133 let skewness = if std_f64 > 0.0 {
135 let n_f64 = amounts_f64.len() as f64;
136 let m3 = amounts_f64
137 .iter()
138 .map(|a| ((a - mean_f64) / std_f64).powi(3))
139 .sum::<f64>()
140 / n_f64;
141 m3 * (n_f64 * (n_f64 - 1.0)).sqrt() / (n_f64 - 2.0)
142 } else {
143 0.0
144 };
145
146 let kurtosis = if std_f64 > 0.0 {
148 let n_f64 = amounts_f64.len() as f64;
149 let m4 = amounts_f64
150 .iter()
151 .map(|a| ((a - mean_f64) / std_f64).powi(4))
152 .sum::<f64>()
153 / n_f64;
154 m4 - 3.0 } else {
156 0.0
157 };
158
159 let (lognormal_ks_stat, lognormal_ks_pvalue, fitted_mu, fitted_sigma) =
161 if positive_amounts.len() >= 10 {
162 self.lognormal_ks_test(&positive_amounts)
163 } else {
164 (None, None, None, None)
165 };
166
167 let round_count = amounts
169 .iter()
170 .filter(|a| {
171 let frac = a.fract();
172 frac.is_zero()
173 })
174 .count();
175 let round_number_ratio = round_count as f64 / n as f64;
176
177 let nice_count = amounts
179 .iter()
180 .filter(|a| {
181 let cents = (a.fract() * Decimal::ONE_HUNDRED).abs();
182 let last_digit = (cents.to_i64().unwrap_or(0) % 10) as u8;
183 last_digit == 0 || last_digit == 5
184 })
185 .count();
186 let nice_number_ratio = nice_count as f64 / n as f64;
187
188 let passes = lognormal_ks_pvalue.is_none_or(|p| p >= self.significance_level);
190
191 Ok(AmountDistributionAnalysis {
192 sample_size: n,
193 mean,
194 median,
195 std_dev,
196 min,
197 max,
198 percentile_1,
199 percentile_99,
200 skewness,
201 kurtosis,
202 lognormal_ks_stat,
203 lognormal_ks_pvalue,
204 fitted_mu,
205 fitted_sigma,
206 round_number_ratio,
207 nice_number_ratio,
208 passes,
209 })
210 }
211
212 fn lognormal_ks_test(
214 &self,
215 amounts: &[Decimal],
216 ) -> (Option<f64>, Option<f64>, Option<f64>, Option<f64>) {
217 let log_amounts: Vec<f64> = amounts
219 .iter()
220 .filter_map(|a| a.to_f64())
221 .filter(|a| *a > 0.0)
222 .map(|a| a.ln())
223 .collect();
224
225 if log_amounts.len() < 10 {
226 return (None, None, None, None);
227 }
228
229 let n = log_amounts.len() as f64;
231 let mu: f64 = log_amounts.iter().sum::<f64>() / n;
232 let sigma: f64 =
233 (log_amounts.iter().map(|x| (x - mu).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
234
235 if sigma <= 0.0 {
236 return (None, None, Some(mu), None);
237 }
238
239 let mut sorted_log = log_amounts.clone();
241 sorted_log.sort_by(|a, b| a.partial_cmp(b).unwrap());
242
243 let n_usize = sorted_log.len();
245 let mut d_max = 0.0f64;
246
247 for (i, &x) in sorted_log.iter().enumerate() {
248 let f_n = (i + 1) as f64 / n_usize as f64;
249 let f_x = normal_cdf((x - mu) / sigma);
250 let d_plus = (f_n - f_x).abs();
251 let d_minus = (f_x - i as f64 / n_usize as f64).abs();
252 d_max = d_max.max(d_plus).max(d_minus);
253 }
254
255 let sqrt_n = (n_usize as f64).sqrt();
258 let lambda = (sqrt_n + 0.12 + 0.11 / sqrt_n) * d_max;
259 let p_value = kolmogorov_pvalue(lambda);
260
261 (Some(d_max), Some(p_value), Some(mu), Some(sigma))
262 }
263}
264
265impl Default for AmountDistributionAnalyzer {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271fn normal_cdf(x: f64) -> f64 {
273 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
274}
275
276fn erf(x: f64) -> f64 {
278 let a1 = 0.254829592;
280 let a2 = -0.284496736;
281 let a3 = 1.421413741;
282 let a4 = -1.453152027;
283 let a5 = 1.061405429;
284 let p = 0.3275911;
285
286 let sign = if x < 0.0 { -1.0 } else { 1.0 };
287 let x = x.abs();
288
289 let t = 1.0 / (1.0 + p * x);
290 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
291
292 sign * y
293}
294
295fn kolmogorov_pvalue(lambda: f64) -> f64 {
297 if lambda <= 0.0 {
298 return 1.0;
299 }
300
301 let mut sum = 0.0;
304 let lambda_sq = lambda * lambda;
305
306 for k in 1..=100 {
307 let k_f64 = k as f64;
308 let term = (-1.0f64).powi(k - 1) * (-2.0 * k_f64 * k_f64 * lambda_sq).exp();
309 sum += term;
310 if term.abs() < 1e-10 {
311 break;
312 }
313 }
314
315 (2.0 * sum).clamp(0.0, 1.0)
316}
317
318fn decimal_sqrt(value: Decimal) -> Decimal {
320 if value <= Decimal::ZERO {
321 return Decimal::ZERO;
322 }
323
324 let mut guess = value / Decimal::TWO;
326 for _ in 0..20 {
327 let new_guess = (guess + value / guess) / Decimal::TWO;
328 if (new_guess - guess).abs() < Decimal::new(1, 10) {
329 return new_guess;
330 }
331 guess = new_guess;
332 }
333 guess
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use rust_decimal_macros::dec;
340
341 #[test]
342 fn test_basic_statistics() {
343 let amounts = vec![
344 dec!(100.00),
345 dec!(200.00),
346 dec!(300.00),
347 dec!(400.00),
348 dec!(500.00),
349 ];
350
351 let analyzer = AmountDistributionAnalyzer::new();
352 let result = analyzer.analyze(&amounts).unwrap();
353
354 assert_eq!(result.sample_size, 5);
355 assert_eq!(result.mean, dec!(300.00));
356 assert_eq!(result.min, dec!(100.00));
357 assert_eq!(result.max, dec!(500.00));
358 }
359
360 #[test]
361 fn test_round_number_detection() {
362 let amounts = vec![
363 dec!(100.00), dec!(200.50), dec!(300.00), dec!(400.25), dec!(500.00), ];
369
370 let analyzer = AmountDistributionAnalyzer::new();
371 let result = analyzer.analyze(&amounts).unwrap();
372
373 assert!((result.round_number_ratio - 0.6).abs() < 0.01);
374 }
375
376 #[test]
377 fn test_insufficient_data() {
378 let amounts = vec![dec!(100.00)];
379 let analyzer = AmountDistributionAnalyzer::new();
380 let result = analyzer.analyze(&amounts);
381 assert!(matches!(result, Err(EvalError::InsufficientData { .. })));
382 }
383
384 #[test]
385 fn test_normal_cdf() {
386 assert!((normal_cdf(0.0) - 0.5).abs() < 0.001);
387 assert!((normal_cdf(1.96) - 0.975).abs() < 0.01);
388 assert!((normal_cdf(-1.96) - 0.025).abs() < 0.01);
389 }
390}