Skip to main content

datasynth_core/distributions/
zero_inflated.rs

1//! Zero-inflated distributions for data with excess zeros.
2//!
3//! Zero-inflated distributions model scenarios where zeros occur more
4//! frequently than a standard distribution would predict, such as:
5//! - Credit memos and returns (most transactions have no credits)
6//! - Warranty claims (most products have no claims)
7//! - Late payment penalties (most payments have no penalties)
8//! - Adjustment entries (most periods have no adjustments)
9
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rand_distr::{Distribution, Exp, LogNormal, Poisson};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15
16/// Type of base distribution for the non-zero values.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19#[derive(Default)]
20pub enum ZeroInflatedBaseDistribution {
21    /// Log-normal distribution (positive amounts)
22    #[default]
23    LogNormal,
24    /// Exponential distribution (time-based or decay patterns)
25    Exponential,
26    /// Poisson distribution (count data)
27    Poisson,
28}
29
30/// Configuration for zero-inflated distribution.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ZeroInflatedConfig {
33    /// Probability of a structural zero (0.0-1.0).
34    /// Higher values = more zeros.
35    pub zero_probability: f64,
36    /// Type of base distribution for non-zero values.
37    pub base_distribution: ZeroInflatedBaseDistribution,
38    /// Mu parameter for log-normal base distribution.
39    #[serde(default = "default_mu")]
40    pub lognormal_mu: f64,
41    /// Sigma parameter for log-normal base distribution.
42    #[serde(default = "default_sigma")]
43    pub lognormal_sigma: f64,
44    /// Lambda parameter for exponential base distribution.
45    #[serde(default = "default_lambda")]
46    pub exponential_lambda: f64,
47    /// Lambda parameter for Poisson base distribution.
48    #[serde(default = "default_poisson_lambda")]
49    pub poisson_lambda: f64,
50    /// Minimum non-zero value.
51    #[serde(default = "default_min_value")]
52    pub min_value: f64,
53    /// Maximum value (clamps output).
54    #[serde(default)]
55    pub max_value: Option<f64>,
56    /// Number of decimal places for rounding.
57    #[serde(default = "default_decimal_places")]
58    pub decimal_places: u8,
59}
60
61fn default_mu() -> f64 {
62    6.0
63}
64
65fn default_sigma() -> f64 {
66    1.5
67}
68
69fn default_lambda() -> f64 {
70    0.01
71}
72
73fn default_poisson_lambda() -> f64 {
74    3.0
75}
76
77fn default_min_value() -> f64 {
78    0.01
79}
80
81fn default_decimal_places() -> u8 {
82    2
83}
84
85impl Default for ZeroInflatedConfig {
86    fn default() -> Self {
87        Self {
88            zero_probability: 0.7, // 70% zeros
89            base_distribution: ZeroInflatedBaseDistribution::LogNormal,
90            lognormal_mu: 6.0,
91            lognormal_sigma: 1.5,
92            exponential_lambda: 0.01,
93            poisson_lambda: 3.0,
94            min_value: 0.01,
95            max_value: None,
96            decimal_places: 2,
97        }
98    }
99}
100
101impl ZeroInflatedConfig {
102    /// Create a new zero-inflated configuration with log-normal base.
103    pub fn lognormal(zero_probability: f64, mu: f64, sigma: f64) -> Self {
104        Self {
105            zero_probability,
106            base_distribution: ZeroInflatedBaseDistribution::LogNormal,
107            lognormal_mu: mu,
108            lognormal_sigma: sigma,
109            ..Default::default()
110        }
111    }
112
113    /// Create a new zero-inflated configuration with exponential base.
114    pub fn exponential(zero_probability: f64, lambda: f64) -> Self {
115        Self {
116            zero_probability,
117            base_distribution: ZeroInflatedBaseDistribution::Exponential,
118            exponential_lambda: lambda,
119            ..Default::default()
120        }
121    }
122
123    /// Create a new zero-inflated configuration with Poisson base.
124    pub fn poisson(zero_probability: f64, lambda: f64) -> Self {
125        Self {
126            zero_probability,
127            base_distribution: ZeroInflatedBaseDistribution::Poisson,
128            poisson_lambda: lambda,
129            decimal_places: 0, // Poisson is discrete
130            min_value: 0.0,
131            ..Default::default()
132        }
133    }
134
135    /// Create a configuration for credit memos/returns.
136    pub fn credit_memos() -> Self {
137        Self {
138            zero_probability: 0.85, // 85% have no credits
139            base_distribution: ZeroInflatedBaseDistribution::LogNormal,
140            lognormal_mu: 5.5, // ~$245 median credit
141            lognormal_sigma: 1.2,
142            min_value: 10.0, // Minimum $10 credit
143            max_value: Some(50_000.0),
144            decimal_places: 2,
145            ..Default::default()
146        }
147    }
148
149    /// Create a configuration for warranty claims.
150    pub fn warranty_claims() -> Self {
151        Self {
152            zero_probability: 0.95, // 95% have no claims
153            base_distribution: ZeroInflatedBaseDistribution::LogNormal,
154            lognormal_mu: 6.0, // ~$403 median claim
155            lognormal_sigma: 1.5,
156            min_value: 25.0,
157            max_value: Some(10_000.0),
158            decimal_places: 2,
159            ..Default::default()
160        }
161    }
162
163    /// Create a configuration for late payment penalties.
164    pub fn late_penalties() -> Self {
165        Self {
166            zero_probability: 0.80, // 80% pay on time
167            base_distribution: ZeroInflatedBaseDistribution::LogNormal,
168            lognormal_mu: 4.0, // ~$55 median penalty
169            lognormal_sigma: 1.0,
170            min_value: 5.0,
171            max_value: Some(5_000.0),
172            decimal_places: 2,
173            ..Default::default()
174        }
175    }
176
177    /// Create a configuration for adjustment entries (count-based).
178    pub fn adjustment_count() -> Self {
179        Self {
180            zero_probability: 0.70, // 70% have no adjustments
181            base_distribution: ZeroInflatedBaseDistribution::Poisson,
182            poisson_lambda: 2.0, // Average 2 adjustments when they occur
183            min_value: 0.0,
184            max_value: Some(10.0),
185            decimal_places: 0,
186            ..Default::default()
187        }
188    }
189
190    /// Create a configuration for returns processing time.
191    pub fn return_processing_time() -> Self {
192        Self {
193            zero_probability: 0.90, // 90% have no returns
194            base_distribution: ZeroInflatedBaseDistribution::Exponential,
195            exponential_lambda: 0.1, // Average 10 days processing
196            min_value: 1.0,
197            max_value: Some(60.0),
198            decimal_places: 0,
199            ..Default::default()
200        }
201    }
202
203    /// Validate the configuration.
204    pub fn validate(&self) -> Result<(), String> {
205        if !(0.0..=1.0).contains(&self.zero_probability) {
206            return Err("zero_probability must be between 0.0 and 1.0".to_string());
207        }
208
209        match self.base_distribution {
210            ZeroInflatedBaseDistribution::LogNormal => {
211                if self.lognormal_sigma <= 0.0 {
212                    return Err("lognormal_sigma must be positive".to_string());
213                }
214            }
215            ZeroInflatedBaseDistribution::Exponential => {
216                if self.exponential_lambda <= 0.0 {
217                    return Err("exponential_lambda must be positive".to_string());
218                }
219            }
220            ZeroInflatedBaseDistribution::Poisson => {
221                if self.poisson_lambda <= 0.0 {
222                    return Err("poisson_lambda must be positive".to_string());
223                }
224            }
225        }
226
227        if let Some(max) = self.max_value {
228            if max <= self.min_value {
229                return Err("max_value must be greater than min_value".to_string());
230            }
231        }
232
233        Ok(())
234    }
235
236    /// Get the expected value (mean) including zeros.
237    pub fn expected_value(&self) -> f64 {
238        let non_zero_prob = 1.0 - self.zero_probability;
239
240        let non_zero_mean = match self.base_distribution {
241            ZeroInflatedBaseDistribution::LogNormal => {
242                (self.lognormal_mu + self.lognormal_sigma.powi(2) / 2.0).exp()
243            }
244            ZeroInflatedBaseDistribution::Exponential => 1.0 / self.exponential_lambda,
245            ZeroInflatedBaseDistribution::Poisson => self.poisson_lambda,
246        };
247
248        non_zero_prob * non_zero_mean.max(self.min_value)
249    }
250
251    /// Get the probability of non-zero value.
252    pub fn non_zero_probability(&self) -> f64 {
253        1.0 - self.zero_probability
254    }
255}
256
257/// Internal enum for holding the base distribution sampler.
258enum BaseDistributionSampler {
259    LogNormal(LogNormal<f64>),
260    Exponential(Exp<f64>),
261    Poisson(Poisson<f64>),
262}
263
264/// Zero-inflated distribution sampler.
265pub struct ZeroInflatedSampler {
266    rng: ChaCha8Rng,
267    config: ZeroInflatedConfig,
268    base_sampler: BaseDistributionSampler,
269    decimal_multiplier: f64,
270}
271
272impl ZeroInflatedSampler {
273    /// Create a new zero-inflated sampler.
274    pub fn new(seed: u64, config: ZeroInflatedConfig) -> Result<Self, String> {
275        config.validate()?;
276
277        let base_sampler = match config.base_distribution {
278            ZeroInflatedBaseDistribution::LogNormal => {
279                let dist = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
280                    .map_err(|e| format!("Invalid LogNormal distribution: {}", e))?;
281                BaseDistributionSampler::LogNormal(dist)
282            }
283            ZeroInflatedBaseDistribution::Exponential => {
284                let dist = Exp::new(config.exponential_lambda)
285                    .map_err(|e| format!("Invalid Exponential distribution: {}", e))?;
286                BaseDistributionSampler::Exponential(dist)
287            }
288            ZeroInflatedBaseDistribution::Poisson => {
289                let dist = Poisson::new(config.poisson_lambda)
290                    .map_err(|e| format!("Invalid Poisson distribution: {}", e))?;
291                BaseDistributionSampler::Poisson(dist)
292            }
293        };
294
295        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
296
297        Ok(Self {
298            rng: ChaCha8Rng::seed_from_u64(seed),
299            config,
300            base_sampler,
301            decimal_multiplier,
302        })
303    }
304
305    /// Sample a value from the distribution.
306    pub fn sample(&mut self) -> f64 {
307        // First, determine if this is a structural zero
308        let p: f64 = self.rng.gen();
309        if p < self.config.zero_probability {
310            return 0.0;
311        }
312
313        // Sample from base distribution
314        let mut value = match &self.base_sampler {
315            BaseDistributionSampler::LogNormal(dist) => dist.sample(&mut self.rng),
316            BaseDistributionSampler::Exponential(dist) => dist.sample(&mut self.rng),
317            BaseDistributionSampler::Poisson(dist) => dist.sample(&mut self.rng),
318        };
319
320        // Apply constraints
321        value = value.max(self.config.min_value);
322        if let Some(max) = self.config.max_value {
323            value = value.min(max);
324        }
325
326        // Round to decimal places
327        (value * self.decimal_multiplier).round() / self.decimal_multiplier
328    }
329
330    /// Sample a value as Decimal.
331    pub fn sample_decimal(&mut self) -> Decimal {
332        let value = self.sample();
333        Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
334    }
335
336    /// Sample with information about whether it's a structural zero.
337    pub fn sample_with_info(&mut self) -> ZeroInflatedSample {
338        let p: f64 = self.rng.gen();
339        if p < self.config.zero_probability {
340            return ZeroInflatedSample {
341                value: 0.0,
342                is_structural_zero: true,
343            };
344        }
345
346        let mut value = match &self.base_sampler {
347            BaseDistributionSampler::LogNormal(dist) => dist.sample(&mut self.rng),
348            BaseDistributionSampler::Exponential(dist) => dist.sample(&mut self.rng),
349            BaseDistributionSampler::Poisson(dist) => dist.sample(&mut self.rng),
350        };
351
352        value = value.max(self.config.min_value);
353        if let Some(max) = self.config.max_value {
354            value = value.min(max);
355        }
356        value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
357
358        ZeroInflatedSample {
359            value,
360            is_structural_zero: false,
361        }
362    }
363
364    /// Sample multiple values.
365    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
366        (0..n).map(|_| self.sample()).collect()
367    }
368
369    /// Reset the sampler with a new seed.
370    pub fn reset(&mut self, seed: u64) {
371        self.rng = ChaCha8Rng::seed_from_u64(seed);
372    }
373
374    /// Get the configuration.
375    pub fn config(&self) -> &ZeroInflatedConfig {
376        &self.config
377    }
378}
379
380/// Result of sampling with structural zero information.
381#[derive(Debug, Clone)]
382pub struct ZeroInflatedSample {
383    /// The sampled value
384    pub value: f64,
385    /// Whether this is a structural zero (vs. a sampling zero)
386    pub is_structural_zero: bool,
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_zero_inflated_validation() {
395        let config = ZeroInflatedConfig::lognormal(0.7, 6.0, 1.5);
396        assert!(config.validate().is_ok());
397
398        let invalid_prob = ZeroInflatedConfig::lognormal(1.5, 6.0, 1.5);
399        assert!(invalid_prob.validate().is_err());
400
401        let invalid_sigma = ZeroInflatedConfig::lognormal(0.7, 6.0, -1.0);
402        assert!(invalid_sigma.validate().is_err());
403    }
404
405    #[test]
406    fn test_zero_inflated_sampling() {
407        let config = ZeroInflatedConfig::lognormal(0.7, 6.0, 1.5);
408        let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
409
410        let samples = sampler.sample_n(1000);
411        assert_eq!(samples.len(), 1000);
412
413        // All samples should be non-negative
414        assert!(samples.iter().all(|&x| x >= 0.0));
415
416        // Count zeros - should be approximately 70%
417        let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
418        assert!(zero_count > 600 && zero_count < 800);
419    }
420
421    #[test]
422    fn test_zero_inflated_determinism() {
423        let config = ZeroInflatedConfig::lognormal(0.7, 6.0, 1.5);
424
425        let mut sampler1 = ZeroInflatedSampler::new(42, config.clone()).unwrap();
426        let mut sampler2 = ZeroInflatedSampler::new(42, config).unwrap();
427
428        for _ in 0..100 {
429            assert_eq!(sampler1.sample(), sampler2.sample());
430        }
431    }
432
433    #[test]
434    fn test_zero_inflated_exponential() {
435        let config = ZeroInflatedConfig::exponential(0.5, 0.1);
436        let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
437
438        let samples = sampler.sample_n(1000);
439
440        // Count zeros - should be approximately 50%
441        let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
442        assert!(zero_count > 400 && zero_count < 600);
443
444        // Non-zero values should be positive
445        assert!(samples.iter().filter(|&&x| x > 0.0).all(|&x| x >= 0.01));
446    }
447
448    #[test]
449    fn test_zero_inflated_poisson() {
450        let config = ZeroInflatedConfig::poisson(0.6, 3.0);
451        let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
452
453        let samples = sampler.sample_n(1000);
454
455        // Count zeros - should be approximately 60%
456        let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
457        assert!(zero_count > 500 && zero_count < 700);
458
459        // Non-zero values should be integers (rounded)
460        for s in samples.iter().filter(|&&x| x > 0.0) {
461            assert!((s - s.round()).abs() < 0.001);
462        }
463    }
464
465    #[test]
466    fn test_sample_with_info() {
467        let config = ZeroInflatedConfig::lognormal(0.5, 6.0, 1.5);
468        let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
469
470        let mut structural_zeros = 0;
471        let mut non_zeros = 0;
472
473        for _ in 0..1000 {
474            let result = sampler.sample_with_info();
475            if result.is_structural_zero {
476                assert_eq!(result.value, 0.0);
477                structural_zeros += 1;
478            } else {
479                non_zeros += 1;
480            }
481        }
482
483        // Should be approximately 50/50
484        assert!(structural_zeros > 400 && structural_zeros < 600);
485        assert!(non_zeros > 400 && non_zeros < 600);
486    }
487
488    #[test]
489    fn test_credit_memos_preset() {
490        let config = ZeroInflatedConfig::credit_memos();
491        assert!(config.validate().is_ok());
492
493        let mut sampler = ZeroInflatedSampler::new(42, config.clone()).unwrap();
494        let samples = sampler.sample_n(1000);
495
496        // High zero rate (~85%)
497        let zero_count = samples.iter().filter(|&&x| x == 0.0).count();
498        assert!(zero_count > 750);
499
500        // Non-zero values should be >= min_value
501        assert!(samples
502            .iter()
503            .filter(|&&x| x > 0.0)
504            .all(|&x| x >= config.min_value));
505    }
506
507    #[test]
508    fn test_expected_value() {
509        let config = ZeroInflatedConfig::lognormal(0.5, 6.0, 1.5);
510        let expected = config.expected_value();
511
512        // E[X] = (1 - p) * exp(mu + sigma^2/2)
513        // = 0.5 * exp(6 + 1.125) = 0.5 * exp(7.125) ≈ 620
514        assert!(expected > 500.0 && expected < 800.0);
515    }
516
517    #[test]
518    fn test_max_value_constraint() {
519        let mut config = ZeroInflatedConfig::lognormal(0.3, 8.0, 2.0);
520        config.max_value = Some(1000.0);
521
522        let mut sampler = ZeroInflatedSampler::new(42, config).unwrap();
523        let samples = sampler.sample_n(1000);
524
525        // All samples should be <= max_value
526        assert!(samples.iter().all(|&x| x <= 1000.0));
527    }
528}