datasynth_core/distributions/
zero_inflated.rs1use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rand_distr::{Distribution, Exp, LogNormal, Poisson};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19#[derive(Default)]
20pub enum ZeroInflatedBaseDistribution {
21 #[default]
23 LogNormal,
24 Exponential,
26 Poisson,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ZeroInflatedConfig {
33 pub zero_probability: f64,
36 pub base_distribution: ZeroInflatedBaseDistribution,
38 #[serde(default = "default_mu")]
40 pub lognormal_mu: f64,
41 #[serde(default = "default_sigma")]
43 pub lognormal_sigma: f64,
44 #[serde(default = "default_lambda")]
46 pub exponential_lambda: f64,
47 #[serde(default = "default_poisson_lambda")]
49 pub poisson_lambda: f64,
50 #[serde(default = "default_min_value")]
52 pub min_value: f64,
53 #[serde(default)]
55 pub max_value: Option<f64>,
56 #[serde(default = "default_decimal_places")]
58 pub decimal_places: u8,
59}
60
61fn default_mu() -> f64 {
62 6.0
63}
64
65fn default_sigma() -> f64 {
66 1.5
67}
68
69fn default_lambda() -> f64 {
70 0.01
71}
72
73fn default_poisson_lambda() -> f64 {
74 3.0
75}
76
77fn default_min_value() -> f64 {
78 0.01
79}
80
81fn default_decimal_places() -> u8 {
82 2
83}
84
85impl Default for ZeroInflatedConfig {
86 fn default() -> Self {
87 Self {
88 zero_probability: 0.7, base_distribution: ZeroInflatedBaseDistribution::LogNormal,
90 lognormal_mu: 6.0,
91 lognormal_sigma: 1.5,
92 exponential_lambda: 0.01,
93 poisson_lambda: 3.0,
94 min_value: 0.01,
95 max_value: None,
96 decimal_places: 2,
97 }
98 }
99}
100
101impl ZeroInflatedConfig {
102 pub fn lognormal(zero_probability: f64, mu: f64, sigma: f64) -> Self {
104 Self {
105 zero_probability,
106 base_distribution: ZeroInflatedBaseDistribution::LogNormal,
107 lognormal_mu: mu,
108 lognormal_sigma: sigma,
109 ..Default::default()
110 }
111 }
112
113 pub fn exponential(zero_probability: f64, lambda: f64) -> Self {
115 Self {
116 zero_probability,
117 base_distribution: ZeroInflatedBaseDistribution::Exponential,
118 exponential_lambda: lambda,
119 ..Default::default()
120 }
121 }
122
123 pub fn poisson(zero_probability: f64, lambda: f64) -> Self {
125 Self {
126 zero_probability,
127 base_distribution: ZeroInflatedBaseDistribution::Poisson,
128 poisson_lambda: lambda,
129 decimal_places: 0, min_value: 0.0,
131 ..Default::default()
132 }
133 }
134
135 pub fn credit_memos() -> Self {
137 Self {
138 zero_probability: 0.85, base_distribution: ZeroInflatedBaseDistribution::LogNormal,
140 lognormal_mu: 5.5, lognormal_sigma: 1.2,
142 min_value: 10.0, max_value: Some(50_000.0),
144 decimal_places: 2,
145 ..Default::default()
146 }
147 }
148
149 pub fn warranty_claims() -> Self {
151 Self {
152 zero_probability: 0.95, base_distribution: ZeroInflatedBaseDistribution::LogNormal,
154 lognormal_mu: 6.0, lognormal_sigma: 1.5,
156 min_value: 25.0,
157 max_value: Some(10_000.0),
158 decimal_places: 2,
159 ..Default::default()
160 }
161 }
162
163 pub fn late_penalties() -> Self {
165 Self {
166 zero_probability: 0.80, base_distribution: ZeroInflatedBaseDistribution::LogNormal,
168 lognormal_mu: 4.0, lognormal_sigma: 1.0,
170 min_value: 5.0,
171 max_value: Some(5_000.0),
172 decimal_places: 2,
173 ..Default::default()
174 }
175 }
176
177 pub fn adjustment_count() -> Self {
179 Self {
180 zero_probability: 0.70, base_distribution: ZeroInflatedBaseDistribution::Poisson,
182 poisson_lambda: 2.0, min_value: 0.0,
184 max_value: Some(10.0),
185 decimal_places: 0,
186 ..Default::default()
187 }
188 }
189
190 pub fn return_processing_time() -> Self {
192 Self {
193 zero_probability: 0.90, base_distribution: ZeroInflatedBaseDistribution::Exponential,
195 exponential_lambda: 0.1, min_value: 1.0,
197 max_value: Some(60.0),
198 decimal_places: 0,
199 ..Default::default()
200 }
201 }
202
203 pub fn validate(&self) -> Result<(), String> {
205 if !(0.0..=1.0).contains(&self.zero_probability) {
206 return Err("zero_probability must be between 0.0 and 1.0".to_string());
207 }
208
209 match self.base_distribution {
210 ZeroInflatedBaseDistribution::LogNormal => {
211 if self.lognormal_sigma <= 0.0 {
212 return Err("lognormal_sigma must be positive".to_string());
213 }
214 }
215 ZeroInflatedBaseDistribution::Exponential => {
216 if self.exponential_lambda <= 0.0 {
217 return Err("exponential_lambda must be positive".to_string());
218 }
219 }
220 ZeroInflatedBaseDistribution::Poisson => {
221 if self.poisson_lambda <= 0.0 {
222 return Err("poisson_lambda must be positive".to_string());
223 }
224 }
225 }
226
227 if let Some(max) = self.max_value {
228 if max <= self.min_value {
229 return Err("max_value must be greater than min_value".to_string());
230 }
231 }
232
233 Ok(())
234 }
235
236 pub fn expected_value(&self) -> f64 {
238 let non_zero_prob = 1.0 - self.zero_probability;
239
240 let non_zero_mean = match self.base_distribution {
241 ZeroInflatedBaseDistribution::LogNormal => {
242 (self.lognormal_mu + self.lognormal_sigma.powi(2) / 2.0).exp()
243 }
244 ZeroInflatedBaseDistribution::Exponential => 1.0 / self.exponential_lambda,
245 ZeroInflatedBaseDistribution::Poisson => self.poisson_lambda,
246 };
247
248 non_zero_prob * non_zero_mean.max(self.min_value)
249 }
250
251 pub fn non_zero_probability(&self) -> f64 {
253 1.0 - self.zero_probability
254 }
255}
256
257enum BaseDistributionSampler {
259 LogNormal(LogNormal<f64>),
260 Exponential(Exp<f64>),
261 Poisson(Poisson<f64>),
262}
263
264pub struct ZeroInflatedSampler {
266 rng: ChaCha8Rng,
267 config: ZeroInflatedConfig,
268 base_sampler: BaseDistributionSampler,
269 decimal_multiplier: f64,
270}
271
272impl ZeroInflatedSampler {
273 pub fn new(seed: u64, config: ZeroInflatedConfig) -> Result<Self, String> {
275 config.validate()?;
276
277 let base_sampler = match config.base_distribution {
278 ZeroInflatedBaseDistribution::LogNormal => {
279 let dist = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
280 .map_err(|e| format!("Invalid LogNormal distribution: {}", e))?;
281 BaseDistributionSampler::LogNormal(dist)
282 }
283 ZeroInflatedBaseDistribution::Exponential => {
284 let dist = Exp::new(config.exponential_lambda)
285 .map_err(|e| format!("Invalid Exponential distribution: {}", e))?;
286 BaseDistributionSampler::Exponential(dist)
287 }
288 ZeroInflatedBaseDistribution::Poisson => {
289 let dist = Poisson::new(config.poisson_lambda)
290 .map_err(|e| format!("Invalid Poisson distribution: {}", e))?;
291 BaseDistributionSampler::Poisson(dist)
292 }
293 };
294
295 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
296
297 Ok(Self {
298 rng: ChaCha8Rng::seed_from_u64(seed),
299 config,
300 base_sampler,
301 decimal_multiplier,
302 })
303 }
304
305 pub fn sample(&mut self) -> f64 {
307 let p: f64 = self.rng.gen();
309 if p < self.config.zero_probability {
310 return 0.0;
311 }
312
313 let mut value = match &self.base_sampler {
315 BaseDistributionSampler::LogNormal(dist) => dist.sample(&mut self.rng),
316 BaseDistributionSampler::Exponential(dist) => dist.sample(&mut self.rng),
317 BaseDistributionSampler::Poisson(dist) => dist.sample(&mut self.rng),
318 };
319
320 value = value.max(self.config.min_value);
322 if let Some(max) = self.config.max_value {
323 value = value.min(max);
324 }
325
326 (value * self.decimal_multiplier).round() / self.decimal_multiplier
328 }
329
330 pub fn sample_decimal(&mut self) -> Decimal {
332 let value = self.sample();
333 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
334 }
335
336 pub fn sample_with_info(&mut self) -> ZeroInflatedSample {
338 let p: f64 = self.rng.gen();
339 if p < self.config.zero_probability {
340 return ZeroInflatedSample {
341 value: 0.0,
342 is_structural_zero: true,
343 };
344 }
345
346 let mut value = match &self.base_sampler {
347 BaseDistributionSampler::LogNormal(dist) => dist.sample(&mut self.rng),
348 BaseDistributionSampler::Exponential(dist) => dist.sample(&mut self.rng),
349 BaseDistributionSampler::Poisson(dist) => dist.sample(&mut self.rng),
350 };
351
352 value = value.max(self.config.min_value);
353 if let Some(max) = self.config.max_value {
354 value = value.min(max);
355 }
356 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
357
358 ZeroInflatedSample {
359 value,
360 is_structural_zero: false,
361 }
362 }
363
364 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
366 (0..n).map(|_| self.sample()).collect()
367 }
368
369 pub fn reset(&mut self, seed: u64) {
371 self.rng = ChaCha8Rng::seed_from_u64(seed);
372 }
373
374 pub fn config(&self) -> &ZeroInflatedConfig {
376 &self.config
377 }
378}
379
380#[derive(Debug, Clone)]
382pub struct ZeroInflatedSample {
383 pub value: f64,
385 pub is_structural_zero: bool,
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_zero_inflated_validation() {
395 let config = ZeroInflatedConfig::lognormal(0.7, 6.0, 1.5);
396 assert!(config.validate().is_ok());
397
398 let invalid_prob = ZeroInflatedConfig::lognormal(1.5, 6.0, 1.5);
399 assert!(invalid_prob.validate().is_err());
400
401 let invalid_sigma = ZeroInflatedConfig::lognormal(0.7, 6.0, -1.0);
402 assert!(invalid_sigma.validate().is_err());
403 }
404
405 #[test]
406 fn test_zero_inflated_sampling() {
407 let config = ZeroInflatedConfig::lognormal(0.7, 6.0, 1.5);
408 let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
409
410 let samples = sampler.sample_n(1000);
411 assert_eq!(samples.len(), 1000);
412
413 assert!(samples.iter().all(|&x| x >= 0.0));
415
416 let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
418 assert!(zero_count > 600 && zero_count < 800);
419 }
420
421 #[test]
422 fn test_zero_inflated_determinism() {
423 let config = ZeroInflatedConfig::lognormal(0.7, 6.0, 1.5);
424
425 let mut sampler1 = ZeroInflatedSampler::new(42, config.clone()).unwrap();
426 let mut sampler2 = ZeroInflatedSampler::new(42, config).unwrap();
427
428 for _ in 0..100 {
429 assert_eq!(sampler1.sample(), sampler2.sample());
430 }
431 }
432
433 #[test]
434 fn test_zero_inflated_exponential() {
435 let config = ZeroInflatedConfig::exponential(0.5, 0.1);
436 let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
437
438 let samples = sampler.sample_n(1000);
439
440 let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
442 assert!(zero_count > 400 && zero_count < 600);
443
444 assert!(samples.iter().filter(|&&x| x > 0.0).all(|&x| x >= 0.01));
446 }
447
448 #[test]
449 fn test_zero_inflated_poisson() {
450 let config = ZeroInflatedConfig::poisson(0.6, 3.0);
451 let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
452
453 let samples = sampler.sample_n(1000);
454
455 let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
457 assert!(zero_count > 500 && zero_count < 700);
458
459 for s in samples.iter().filter(|&&x| x > 0.0) {
461 assert!((s - s.round()).abs() < 0.001);
462 }
463 }
464
465 #[test]
466 fn test_sample_with_info() {
467 let config = ZeroInflatedConfig::lognormal(0.5, 6.0, 1.5);
468 let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
469
470 let mut structural_zeros = 0;
471 let mut non_zeros = 0;
472
473 for _ in 0..1000 {
474 let result = sampler.sample_with_info();
475 if result.is_structural_zero {
476 assert_eq!(result.value, 0.0);
477 structural_zeros += 1;
478 } else {
479 non_zeros += 1;
480 }
481 }
482
483 assert!(structural_zeros > 400 && structural_zeros < 600);
485 assert!(non_zeros > 400 && non_zeros < 600);
486 }
487
488 #[test]
489 fn test_credit_memos_preset() {
490 let config = ZeroInflatedConfig::credit_memos();
491 assert!(config.validate().is_ok());
492
493 let mut sampler = ZeroInflatedSampler::new(42, config.clone()).unwrap();
494 let samples = sampler.sample_n(1000);
495
496 let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
498 assert!(zero_count > 750);
499
500 assert!(samples
502 .iter()
503 .filter(|&&x| x > 0.0)
504 .all(|&x| x >= config.min_value));
505 }
506
507 #[test]
508 fn test_expected_value() {
509 let config = ZeroInflatedConfig::lognormal(0.5, 6.0, 1.5);
510 let expected = config.expected_value();
511
512 assert!(expected > 500.0 && expected < 800.0);
515 }
516
517 #[test]
518 fn test_max_value_constraint() {
519 let mut config = ZeroInflatedConfig::lognormal(0.3, 8.0, 2.0);
520 config.max_value = Some(1000.0);
521
522 let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
523 let samples = sampler.sample_n(1000);
524
525 assert!(samples.iter().all(|&x| x <= 1000.0));
527 }
528}