datasynth_core/distributions/
amount.rs1use rand::prelude::*;
7use rand_chacha::ChaCha8Rng;
8use rand_distr::{Distribution, LogNormal};
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11
12use super::benford::{BenfordSampler, FraudAmountGenerator, FraudAmountPattern, ThresholdConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AmountDistributionConfig {
17 pub min_amount: f64,
19 pub max_amount: f64,
21 pub lognormal_mu: f64,
23 pub lognormal_sigma: f64,
25 pub decimal_places: u8,
27 pub round_number_probability: f64,
29 pub nice_number_probability: f64,
31}
32
33impl Default for AmountDistributionConfig {
34 fn default() -> Self {
35 Self {
36 min_amount: 0.01,
37 max_amount: 100_000_000.0, lognormal_mu: 7.0, lognormal_sigma: 2.5, decimal_places: 2,
41 round_number_probability: 0.25, nice_number_probability: 0.15, }
44 }
45}
46
47impl AmountDistributionConfig {
48 pub fn small_transactions() -> Self {
50 Self {
51 min_amount: 0.01,
52 max_amount: 10_000.0,
53 lognormal_mu: 4.0, lognormal_sigma: 1.5,
55 decimal_places: 2,
56 round_number_probability: 0.30,
57 nice_number_probability: 0.20,
58 }
59 }
60
61 pub fn medium_transactions() -> Self {
63 Self {
64 min_amount: 100.0,
65 max_amount: 1_000_000.0,
66 lognormal_mu: 8.5, lognormal_sigma: 2.0,
68 decimal_places: 2,
69 round_number_probability: 0.20,
70 nice_number_probability: 0.15,
71 }
72 }
73
74 pub fn large_transactions() -> Self {
76 Self {
77 min_amount: 1000.0,
78 max_amount: 100_000_000.0,
79 lognormal_mu: 10.0, lognormal_sigma: 2.5,
81 decimal_places: 2,
82 round_number_probability: 0.15,
83 nice_number_probability: 0.10,
84 }
85 }
86}
87
88pub struct AmountSampler {
90 rng: ChaCha8Rng,
92 config: AmountDistributionConfig,
94 lognormal: LogNormal<f64>,
96 decimal_multiplier: f64,
98 benford_sampler: Option<BenfordSampler>,
100 fraud_generator: Option<FraudAmountGenerator>,
102 benford_enabled: bool,
104}
105
106impl AmountSampler {
107 pub fn new(seed: u64) -> Self {
109 Self::with_config(seed, AmountDistributionConfig::default())
110 }
111
112 pub fn with_config(seed: u64, config: AmountDistributionConfig) -> Self {
114 let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
115 .expect("Invalid log-normal parameters");
116 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
117
118 Self {
119 rng: ChaCha8Rng::seed_from_u64(seed),
120 config,
121 lognormal,
122 decimal_multiplier,
123 benford_sampler: None,
124 fraud_generator: None,
125 benford_enabled: false,
126 }
127 }
128
129 pub fn with_benford(seed: u64, config: AmountDistributionConfig) -> Self {
131 let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
132 .expect("Invalid log-normal parameters");
133 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
134
135 Self {
136 rng: ChaCha8Rng::seed_from_u64(seed),
137 benford_sampler: Some(BenfordSampler::new(seed + 100, config.clone())),
138 fraud_generator: Some(FraudAmountGenerator::new(
139 seed + 200,
140 config.clone(),
141 ThresholdConfig::default(),
142 )),
143 config,
144 lognormal,
145 decimal_multiplier,
146 benford_enabled: true,
147 }
148 }
149
150 pub fn with_fraud_config(
152 seed: u64,
153 config: AmountDistributionConfig,
154 threshold_config: ThresholdConfig,
155 benford_enabled: bool,
156 ) -> Self {
157 let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
158 .expect("Invalid log-normal parameters");
159 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
160
161 Self {
162 rng: ChaCha8Rng::seed_from_u64(seed),
163 benford_sampler: if benford_enabled {
164 Some(BenfordSampler::new(seed + 100, config.clone()))
165 } else {
166 None
167 },
168 fraud_generator: Some(FraudAmountGenerator::new(
169 seed + 200,
170 config.clone(),
171 threshold_config,
172 )),
173 config,
174 lognormal,
175 decimal_multiplier,
176 benford_enabled,
177 }
178 }
179
180 pub fn set_benford_enabled(&mut self, enabled: bool) {
182 self.benford_enabled = enabled;
183 if enabled && self.benford_sampler.is_none() {
184 let seed = self.rng.random();
186 self.benford_sampler = Some(BenfordSampler::new(seed, self.config.clone()));
187 }
188 }
189
190 pub fn is_benford_enabled(&self) -> bool {
192 self.benford_enabled
193 }
194
195 #[inline]
200 pub fn sample(&mut self) -> Decimal {
201 if self.benford_enabled {
203 if let Some(ref mut benford) = self.benford_sampler {
204 return benford.sample();
205 }
206 }
207
208 self.sample_lognormal()
210 }
211
212 #[inline]
214 pub fn sample_lognormal(&mut self) -> Decimal {
215 let mut amount = self.lognormal.sample(&mut self.rng);
216
217 amount = amount.clamp(self.config.min_amount, self.config.max_amount);
219
220 let p: f64 = self.rng.random();
222 if p < self.config.round_number_probability {
223 amount = (amount / 100.0).round() * 100.0;
225 } else if p < self.config.round_number_probability + self.config.nice_number_probability {
226 amount = (amount / 5.0).round() * 5.0;
228 }
229
230 amount = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
232
233 amount = amount.max(self.config.min_amount);
235
236 let cents = (amount * 100.0).round() as i64;
240 Decimal::new(cents, 2)
241 }
242
243 pub fn sample_fraud(&mut self, pattern: FraudAmountPattern) -> Decimal {
247 if let Some(ref mut fraud_gen) = self.fraud_generator {
248 fraud_gen.sample(pattern)
249 } else {
250 self.sample()
252 }
253 }
254
255 pub fn sample_summing_to(&mut self, count: usize, total: Decimal) -> Vec<Decimal> {
262 use rust_decimal::prelude::ToPrimitive;
263
264 let min_amount = Decimal::new(1, 2); if count == 0 {
267 return Vec::new();
268 }
269 if count == 1 {
270 return vec![total];
271 }
272
273 let total_f64 = total.to_f64().unwrap_or(0.0);
274
275 let mut weights: Vec<f64> = (0..count)
277 .map(|_| self.rng.random::<f64>().max(0.01))
278 .collect();
279 let sum: f64 = weights.iter().sum();
280 weights.iter_mut().for_each(|w| *w /= sum);
281
282 let mut amounts: Vec<Decimal> = weights
284 .iter()
285 .map(|w| {
286 let amount = total_f64 * w;
287 let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
288 let cents = (rounded * 100.0).round() as i64;
290 Decimal::new(cents, 2)
291 })
292 .collect();
293
294 let current_sum: Decimal = amounts.iter().copied().sum();
296 let diff = total - current_sum;
297 let last_idx = amounts.len() - 1;
298 amounts[last_idx] += diff;
299
300 if amounts[last_idx] < Decimal::ZERO {
302 let mut remaining = amounts[last_idx].abs();
303 amounts[last_idx] = Decimal::ZERO;
304
305 for amt in amounts.iter_mut().take(last_idx).rev() {
307 if remaining <= Decimal::ZERO {
308 break;
309 }
310 let take = remaining.min(*amt);
311 *amt -= take;
312 remaining -= take;
313 }
314
315 if remaining > Decimal::ZERO {
317 for amt in amounts.iter_mut() {
318 if *amt > Decimal::ZERO {
319 *amt -= remaining;
320 break;
321 }
322 }
323 }
324 }
325
326 if total >= min_amount * Decimal::from(count as u32) {
330 loop {
331 let zero_idx = amounts.iter().position(|a| *a == Decimal::ZERO);
333 let Some(zi) = zero_idx else { break };
334
335 let donor = amounts
337 .iter()
338 .enumerate()
339 .filter(|&(j, _)| j != zi)
340 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
341 .map(|(j, _)| j);
342
343 if let Some(di) = donor {
344 if amounts[di] > min_amount {
345 amounts[zi] = min_amount;
346 amounts[di] -= min_amount;
347 } else {
348 break; }
350 } else {
351 break;
352 }
353 }
354 }
355
356 amounts
357 }
358
359 pub fn sample_in_range(&mut self, min: Decimal, max: Decimal) -> Decimal {
361 let min_f64 = min.to_string().parse::<f64>().unwrap_or(0.0);
362 let max_f64 = max.to_string().parse::<f64>().unwrap_or(1000000.0);
363
364 let range = max_f64 - min_f64;
365 let amount = min_f64 + self.rng.random::<f64>() * range;
366
367 let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
368 Decimal::from_f64_retain(rounded).unwrap_or(min)
369 }
370
371 pub fn reset(&mut self, seed: u64) {
373 self.rng = ChaCha8Rng::seed_from_u64(seed);
374 }
375}
376
377pub struct ExchangeRateSampler {
379 rng: ChaCha8Rng,
380 base_rates: std::collections::HashMap<String, f64>,
382 volatility: f64,
384}
385
386impl ExchangeRateSampler {
387 pub fn new(seed: u64) -> Self {
389 let mut base_rates = std::collections::HashMap::new();
390 base_rates.insert("EUR".to_string(), 0.92);
392 base_rates.insert("GBP".to_string(), 0.79);
393 base_rates.insert("CHF".to_string(), 0.88);
394 base_rates.insert("JPY".to_string(), 149.0);
395 base_rates.insert("CNY".to_string(), 7.24);
396 base_rates.insert("CAD".to_string(), 1.36);
397 base_rates.insert("AUD".to_string(), 1.53);
398 base_rates.insert("INR".to_string(), 83.0);
399 base_rates.insert("USD".to_string(), 1.0);
400
401 Self {
402 rng: ChaCha8Rng::seed_from_u64(seed),
403 base_rates,
404 volatility: 0.005, }
406 }
407
408 pub fn get_rate(&mut self, from: &str, to: &str) -> Decimal {
410 let from_usd = self.base_rates.get(from).copied().unwrap_or(1.0);
411 let to_usd = self.base_rates.get(to).copied().unwrap_or(1.0);
412
413 let base_rate = to_usd / from_usd;
415
416 let variation = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * self.volatility;
418 let rate = base_rate * variation;
419
420 let rounded = (rate * 1_000_000.0).round() / 1_000_000.0;
422 Decimal::from_f64_retain(rounded).unwrap_or(Decimal::ONE)
423 }
424}
425
426#[cfg(test)]
427#[allow(clippy::unwrap_used)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_amount_sampler_determinism() {
433 let mut sampler1 = AmountSampler::new(42);
434 let mut sampler2 = AmountSampler::new(42);
435
436 for _ in 0..100 {
437 assert_eq!(sampler1.sample(), sampler2.sample());
438 }
439 }
440
441 #[test]
442 fn test_amount_sampler_range() {
443 let config = AmountDistributionConfig {
444 min_amount: 100.0,
445 max_amount: 1000.0,
446 ..Default::default()
447 };
448 let mut sampler = AmountSampler::with_config(42, config);
449
450 for _ in 0..1000 {
451 let amount = sampler.sample();
452 let amount_f64: f64 = amount.to_string().parse().unwrap();
453 assert!(amount_f64 >= 100.0, "Amount {} below minimum", amount);
454 assert!(amount_f64 <= 1000.0, "Amount {} above maximum", amount);
455 }
456 }
457
458 #[test]
459 fn test_summing_amounts() {
460 let mut sampler = AmountSampler::new(42);
461 let total = Decimal::from(10000);
462 let amounts = sampler.sample_summing_to(5, total);
463
464 assert_eq!(amounts.len(), 5);
465
466 let sum: Decimal = amounts.iter().sum();
467 assert_eq!(sum, total, "Sum {} doesn't match total {}", sum, total);
468 }
469
470 #[test]
471 fn test_exchange_rate() {
472 let mut sampler = ExchangeRateSampler::new(42);
473
474 let eur_usd = sampler.get_rate("EUR", "USD");
475 let eur_f64: f64 = eur_usd.to_string().parse().unwrap();
476 assert!(
477 eur_f64 > 0.8 && eur_f64 < 1.2,
478 "EUR/USD rate {} out of range",
479 eur_f64
480 );
481
482 let usd_usd = sampler.get_rate("USD", "USD");
483 let usd_f64: f64 = usd_usd.to_string().parse().unwrap();
484 assert!(
485 (usd_f64 - 1.0).abs() < 0.01,
486 "USD/USD rate {} should be ~1.0",
487 usd_f64
488 );
489 }
490}