datasynth_core/distributions/
amount.rs1use rand::prelude::*;
7use rand_chacha::ChaCha8Rng;
8use rand_distr::{Distribution, LogNormal};
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11
12use super::benford::{BenfordSampler, FraudAmountGenerator, FraudAmountPattern, ThresholdConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AmountDistributionConfig {
17 pub min_amount: f64,
19 pub max_amount: f64,
21 pub lognormal_mu: f64,
23 pub lognormal_sigma: f64,
25 pub decimal_places: u8,
27 pub round_number_probability: f64,
29 pub nice_number_probability: f64,
31}
32
33impl Default for AmountDistributionConfig {
34 fn default() -> Self {
35 Self {
36 min_amount: 0.01,
37 max_amount: 100_000_000.0, lognormal_mu: 7.0, lognormal_sigma: 2.5, decimal_places: 2,
41 round_number_probability: 0.25, nice_number_probability: 0.15, }
44 }
45}
46
47impl AmountDistributionConfig {
48 pub fn small_transactions() -> Self {
50 Self {
51 min_amount: 0.01,
52 max_amount: 10_000.0,
53 lognormal_mu: 4.0, lognormal_sigma: 1.5,
55 decimal_places: 2,
56 round_number_probability: 0.30,
57 nice_number_probability: 0.20,
58 }
59 }
60
61 pub fn medium_transactions() -> Self {
63 Self {
64 min_amount: 100.0,
65 max_amount: 1_000_000.0,
66 lognormal_mu: 8.5, lognormal_sigma: 2.0,
68 decimal_places: 2,
69 round_number_probability: 0.20,
70 nice_number_probability: 0.15,
71 }
72 }
73
74 pub fn large_transactions() -> Self {
76 Self {
77 min_amount: 1000.0,
78 max_amount: 100_000_000.0,
79 lognormal_mu: 10.0, lognormal_sigma: 2.5,
81 decimal_places: 2,
82 round_number_probability: 0.15,
83 nice_number_probability: 0.10,
84 }
85 }
86}
87
88pub struct AmountSampler {
90 rng: ChaCha8Rng,
92 config: AmountDistributionConfig,
94 lognormal: LogNormal<f64>,
96 decimal_multiplier: f64,
98 benford_sampler: Option<BenfordSampler>,
100 fraud_generator: Option<FraudAmountGenerator>,
102 benford_enabled: bool,
104}
105
106impl AmountSampler {
107 pub fn new(seed: u64) -> Self {
109 Self::with_config(seed, AmountDistributionConfig::default())
110 }
111
112 pub fn with_config(seed: u64, config: AmountDistributionConfig) -> Self {
114 let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
115 .expect("Invalid log-normal parameters");
116 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
117
118 Self {
119 rng: ChaCha8Rng::seed_from_u64(seed),
120 config,
121 lognormal,
122 decimal_multiplier,
123 benford_sampler: None,
124 fraud_generator: None,
125 benford_enabled: false,
126 }
127 }
128
129 pub fn with_benford(seed: u64, config: AmountDistributionConfig) -> Self {
131 let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
132 .expect("Invalid log-normal parameters");
133 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
134
135 Self {
136 rng: ChaCha8Rng::seed_from_u64(seed),
137 benford_sampler: Some(BenfordSampler::new(seed + 100, config.clone())),
138 fraud_generator: Some(FraudAmountGenerator::new(
139 seed + 200,
140 config.clone(),
141 ThresholdConfig::default(),
142 )),
143 config,
144 lognormal,
145 decimal_multiplier,
146 benford_enabled: true,
147 }
148 }
149
150 pub fn with_fraud_config(
152 seed: u64,
153 config: AmountDistributionConfig,
154 threshold_config: ThresholdConfig,
155 benford_enabled: bool,
156 ) -> Self {
157 let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
158 .expect("Invalid log-normal parameters");
159 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
160
161 Self {
162 rng: ChaCha8Rng::seed_from_u64(seed),
163 benford_sampler: if benford_enabled {
164 Some(BenfordSampler::new(seed + 100, config.clone()))
165 } else {
166 None
167 },
168 fraud_generator: Some(FraudAmountGenerator::new(
169 seed + 200,
170 config.clone(),
171 threshold_config,
172 )),
173 config,
174 lognormal,
175 decimal_multiplier,
176 benford_enabled,
177 }
178 }
179
180 pub fn set_benford_enabled(&mut self, enabled: bool) {
182 self.benford_enabled = enabled;
183 if enabled && self.benford_sampler.is_none() {
184 let seed = self.rng.gen();
186 self.benford_sampler = Some(BenfordSampler::new(seed, self.config.clone()));
187 }
188 }
189
190 pub fn is_benford_enabled(&self) -> bool {
192 self.benford_enabled
193 }
194
195 pub fn sample(&mut self) -> Decimal {
200 if self.benford_enabled {
202 if let Some(ref mut benford) = self.benford_sampler {
203 return benford.sample();
204 }
205 }
206
207 self.sample_lognormal()
209 }
210
211 pub fn sample_lognormal(&mut self) -> Decimal {
213 let mut amount = self.lognormal.sample(&mut self.rng);
214
215 amount = amount.clamp(self.config.min_amount, self.config.max_amount);
217
218 let p: f64 = self.rng.gen();
220 if p < self.config.round_number_probability {
221 amount = (amount / 100.0).round() * 100.0;
223 } else if p < self.config.round_number_probability + self.config.nice_number_probability {
224 amount = (amount / 5.0).round() * 5.0;
226 }
227
228 amount = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
230
231 amount = amount.max(self.config.min_amount);
233
234 let amount_str = format!("{:.2}", amount);
236 amount_str.parse::<Decimal>().unwrap_or(Decimal::ONE)
237 }
238
239 pub fn sample_fraud(&mut self, pattern: FraudAmountPattern) -> Decimal {
243 if let Some(ref mut fraud_gen) = self.fraud_generator {
244 fraud_gen.sample(pattern)
245 } else {
246 self.sample()
248 }
249 }
250
251 pub fn sample_summing_to(&mut self, count: usize, total: Decimal) -> Vec<Decimal> {
255 use rust_decimal::prelude::ToPrimitive;
256
257 if count == 0 {
258 return Vec::new();
259 }
260 if count == 1 {
261 return vec![total];
262 }
263
264 let total_f64 = total.to_f64().unwrap_or(0.0);
265
266 let mut weights: Vec<f64> = (0..count)
268 .map(|_| self.rng.gen::<f64>().max(0.01))
269 .collect();
270 let sum: f64 = weights.iter().sum();
271 weights.iter_mut().for_each(|w| *w /= sum);
272
273 let mut amounts: Vec<Decimal> = weights
275 .iter()
276 .map(|w| {
277 let amount = total_f64 * w;
278 let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
279 let amount_str = format!("{:.2}", rounded);
281 amount_str.parse::<Decimal>().unwrap_or(Decimal::ZERO)
282 })
283 .collect();
284
285 let current_sum: Decimal = amounts.iter().copied().sum();
287 let diff = total - current_sum;
288 let last_idx = amounts.len() - 1;
289 amounts[last_idx] += diff;
290
291 if amounts[last_idx] < Decimal::ZERO {
293 let mut remaining = amounts[last_idx].abs();
294 amounts[last_idx] = Decimal::ZERO;
295
296 for amt in amounts.iter_mut().take(last_idx).rev() {
298 if remaining <= Decimal::ZERO {
299 break;
300 }
301 let take = remaining.min(*amt);
302 *amt -= take;
303 remaining -= take;
304 }
305
306 if remaining > Decimal::ZERO {
309 for amt in amounts.iter_mut() {
311 if *amt > Decimal::ZERO {
312 *amt -= remaining;
313 break;
314 }
315 }
316 }
317 }
318
319 amounts
320 }
321
322 pub fn sample_in_range(&mut self, min: Decimal, max: Decimal) -> Decimal {
324 let min_f64 = min.to_string().parse::<f64>().unwrap_or(0.0);
325 let max_f64 = max.to_string().parse::<f64>().unwrap_or(1000000.0);
326
327 let range = max_f64 - min_f64;
328 let amount = min_f64 + self.rng.gen::<f64>() * range;
329
330 let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
331 Decimal::from_f64_retain(rounded).unwrap_or(min)
332 }
333
334 pub fn reset(&mut self, seed: u64) {
336 self.rng = ChaCha8Rng::seed_from_u64(seed);
337 }
338}
339
340pub struct ExchangeRateSampler {
342 rng: ChaCha8Rng,
343 base_rates: std::collections::HashMap<String, f64>,
345 volatility: f64,
347}
348
349impl ExchangeRateSampler {
350 pub fn new(seed: u64) -> Self {
352 let mut base_rates = std::collections::HashMap::new();
353 base_rates.insert("EUR".to_string(), 0.92);
355 base_rates.insert("GBP".to_string(), 0.79);
356 base_rates.insert("CHF".to_string(), 0.88);
357 base_rates.insert("JPY".to_string(), 149.0);
358 base_rates.insert("CNY".to_string(), 7.24);
359 base_rates.insert("CAD".to_string(), 1.36);
360 base_rates.insert("AUD".to_string(), 1.53);
361 base_rates.insert("INR".to_string(), 83.0);
362 base_rates.insert("USD".to_string(), 1.0);
363
364 Self {
365 rng: ChaCha8Rng::seed_from_u64(seed),
366 base_rates,
367 volatility: 0.005, }
369 }
370
371 pub fn get_rate(&mut self, from: &str, to: &str) -> Decimal {
373 let from_usd = self.base_rates.get(from).copied().unwrap_or(1.0);
374 let to_usd = self.base_rates.get(to).copied().unwrap_or(1.0);
375
376 let base_rate = to_usd / from_usd;
378
379 let variation = 1.0 + (self.rng.gen::<f64>() - 0.5) * 2.0 * self.volatility;
381 let rate = base_rate * variation;
382
383 let rounded = (rate * 1_000_000.0).round() / 1_000_000.0;
385 Decimal::from_f64_retain(rounded).unwrap_or(Decimal::ONE)
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_amount_sampler_determinism() {
395 let mut sampler1 = AmountSampler::new(42);
396 let mut sampler2 = AmountSampler::new(42);
397
398 for _ in 0..100 {
399 assert_eq!(sampler1.sample(), sampler2.sample());
400 }
401 }
402
403 #[test]
404 fn test_amount_sampler_range() {
405 let config = AmountDistributionConfig {
406 min_amount: 100.0,
407 max_amount: 1000.0,
408 ..Default::default()
409 };
410 let mut sampler = AmountSampler::with_config(42, config);
411
412 for _ in 0..1000 {
413 let amount = sampler.sample();
414 let amount_f64: f64 = amount.to_string().parse().unwrap();
415 assert!(amount_f64 >= 100.0, "Amount {} below minimum", amount);
416 assert!(amount_f64 <= 1000.0, "Amount {} above maximum", amount);
417 }
418 }
419
420 #[test]
421 fn test_summing_amounts() {
422 let mut sampler = AmountSampler::new(42);
423 let total = Decimal::from(10000);
424 let amounts = sampler.sample_summing_to(5, total);
425
426 assert_eq!(amounts.len(), 5);
427
428 let sum: Decimal = amounts.iter().sum();
429 assert_eq!(sum, total, "Sum {} doesn't match total {}", sum, total);
430 }
431
432 #[test]
433 fn test_exchange_rate() {
434 let mut sampler = ExchangeRateSampler::new(42);
435
436 let eur_usd = sampler.get_rate("EUR", "USD");
437 let eur_f64: f64 = eur_usd.to_string().parse().unwrap();
438 assert!(
439 eur_f64 > 0.8 && eur_f64 < 1.2,
440 "EUR/USD rate {} out of range",
441 eur_f64
442 );
443
444 let usd_usd = sampler.get_rate("USD", "USD");
445 let usd_f64: f64 = usd_usd.to_string().parse().unwrap();
446 assert!(
447 (usd_f64 - 1.0).abs() < 0.01,
448 "USD/USD rate {} should be ~1.0",
449 usd_f64
450 );
451 }
452}