Skip to main content

datasynth_core/distributions/
mixture.rs

1//! Mixture model distributions for multi-modal data generation.
2//!
3//! Provides Gaussian and Log-Normal mixture models that can generate
4//! realistic multi-modal distributions commonly observed in accounting data
5//! (e.g., routine vs. significant vs. major transactions).
6
7use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13/// Configuration for a single Gaussian component in a mixture.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
17    pub weight: f64,
18    /// Mean (mu) of the Gaussian distribution
19    pub mu: f64,
20    /// Standard deviation (sigma) of the Gaussian distribution
21    pub sigma: f64,
22    /// Optional label for this component (e.g., "routine", "significant")
23    #[serde(default)]
24    pub label: Option<String>,
25}
26
27impl GaussianComponent {
28    /// Create a new Gaussian component.
29    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30        Self {
31            weight,
32            mu,
33            sigma,
34            label: None,
35        }
36    }
37
38    /// Create a labeled Gaussian component.
39    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40        Self {
41            weight,
42            mu,
43            sigma,
44            label: Some(label.into()),
45        }
46    }
47}
48
49/// Configuration for a Gaussian Mixture Model.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52    /// Components of the mixture
53    pub components: Vec<GaussianComponent>,
54    /// Whether to allow negative values (default: true)
55    #[serde(default = "default_true")]
56    pub allow_negative: bool,
57    /// Minimum value (if not allowing negative)
58    #[serde(default)]
59    pub min_value: Option<f64>,
60    /// Maximum value (clamps output)
61    #[serde(default)]
62    pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66    true
67}
68
69impl Default for GaussianMixtureConfig {
70    fn default() -> Self {
71        Self {
72            components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73            allow_negative: true,
74            min_value: None,
75            max_value: None,
76        }
77    }
78}
79
80impl GaussianMixtureConfig {
81    /// Create a new Gaussian mixture configuration.
82    pub fn new(components: Vec<GaussianComponent>) -> Self {
83        Self {
84            components,
85            ..Default::default()
86        }
87    }
88
89    /// Validate the configuration.
90    pub fn validate(&self) -> Result<(), String> {
91        if self.components.is_empty() {
92            return Err("At least one component is required".to_string());
93        }
94
95        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96        if (weight_sum - 1.0).abs() > 0.01 {
97            return Err(format!(
98                "Component weights must sum to 1.0, got {}",
99                weight_sum
100            ));
101        }
102
103        for (i, component) in self.components.iter().enumerate() {
104            if component.weight < 0.0 || component.weight > 1.0 {
105                return Err(format!(
106                    "Component {} weight must be between 0.0 and 1.0, got {}",
107                    i, component.weight
108                ));
109            }
110            if component.sigma <= 0.0 {
111                return Err(format!(
112                    "Component {} sigma must be positive, got {}",
113                    i, component.sigma
114                ));
115            }
116        }
117
118        Ok(())
119    }
120}
121
122/// Configuration for a single Log-Normal component in a mixture.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct LogNormalComponent {
125    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
126    pub weight: f64,
127    /// Mu parameter (location) of the log-normal distribution
128    pub mu: f64,
129    /// Sigma parameter (scale) of the log-normal distribution
130    pub sigma: f64,
131    /// Optional label for this component
132    #[serde(default)]
133    pub label: Option<String>,
134}
135
136impl LogNormalComponent {
137    /// Create a new Log-Normal component.
138    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
139        Self {
140            weight,
141            mu,
142            sigma,
143            label: None,
144        }
145    }
146
147    /// Create a labeled Log-Normal component.
148    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
149        Self {
150            weight,
151            mu,
152            sigma,
153            label: Some(label.into()),
154        }
155    }
156
157    /// Get the expected value (mean) of this component.
158    pub fn expected_value(&self) -> f64 {
159        (self.mu + self.sigma.powi(2) / 2.0).exp()
160    }
161
162    /// Get the median of this component.
163    pub fn median(&self) -> f64 {
164        self.mu.exp()
165    }
166}
167
168/// Configuration for a Log-Normal Mixture Model.
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct LogNormalMixtureConfig {
171    /// Components of the mixture
172    pub components: Vec<LogNormalComponent>,
173    /// Minimum value (default: 0.01)
174    #[serde(default = "default_min_value")]
175    pub min_value: f64,
176    /// Maximum value (clamps output)
177    #[serde(default)]
178    pub max_value: Option<f64>,
179    /// Number of decimal places for rounding
180    #[serde(default = "default_decimal_places")]
181    pub decimal_places: u8,
182}
183
184fn default_min_value() -> f64 {
185    0.01
186}
187
188fn default_decimal_places() -> u8 {
189    2
190}
191
192impl Default for LogNormalMixtureConfig {
193    fn default() -> Self {
194        Self {
195            components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
196            min_value: 0.01,
197            max_value: None,
198            decimal_places: 2,
199        }
200    }
201}
202
203impl LogNormalMixtureConfig {
204    /// Create a new Log-Normal mixture configuration.
205    pub fn new(components: Vec<LogNormalComponent>) -> Self {
206        Self {
207            components,
208            ..Default::default()
209        }
210    }
211
212    /// Create a typical transaction amount mixture (routine/significant/major).
213    pub fn typical_transactions() -> Self {
214        Self {
215            components: vec![
216                LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
217                LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
218                LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
219            ],
220            min_value: 0.01,
221            max_value: Some(100_000_000.0),
222            decimal_places: 2,
223        }
224    }
225
226    /// Validate the configuration.
227    pub fn validate(&self) -> Result<(), String> {
228        if self.components.is_empty() {
229            return Err("At least one component is required".to_string());
230        }
231
232        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
233        if (weight_sum - 1.0).abs() > 0.01 {
234            return Err(format!(
235                "Component weights must sum to 1.0, got {}",
236                weight_sum
237            ));
238        }
239
240        for (i, component) in self.components.iter().enumerate() {
241            if component.weight < 0.0 || component.weight > 1.0 {
242                return Err(format!(
243                    "Component {} weight must be between 0.0 and 1.0, got {}",
244                    i, component.weight
245                ));
246            }
247            if component.sigma <= 0.0 {
248                return Err(format!(
249                    "Component {} sigma must be positive, got {}",
250                    i, component.sigma
251                ));
252            }
253        }
254
255        if self.min_value < 0.0 {
256            return Err("min_value must be non-negative".to_string());
257        }
258
259        Ok(())
260    }
261}
262
263/// Result of sampling with component information.
264#[derive(Debug, Clone)]
265pub struct SampleWithComponent {
266    /// The sampled value
267    pub value: f64,
268    /// Index of the component that generated this sample
269    pub component_index: usize,
270    /// Label of the component (if available)
271    pub component_label: Option<String>,
272}
273
274/// Gaussian Mixture Model sampler.
275pub struct GaussianMixtureSampler {
276    rng: ChaCha8Rng,
277    config: GaussianMixtureConfig,
278    /// Pre-computed cumulative weights for O(log n) component selection
279    cumulative_weights: Vec<f64>,
280    /// Normal distributions for each component
281    distributions: Vec<Normal<f64>>,
282}
283
284impl GaussianMixtureSampler {
285    /// Create a new Gaussian mixture sampler.
286    pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
287        config.validate()?;
288
289        // Pre-compute cumulative weights
290        let mut cumulative_weights = Vec::with_capacity(config.components.len());
291        let mut cumulative = 0.0;
292        for component in &config.components {
293            cumulative += component.weight;
294            cumulative_weights.push(cumulative);
295        }
296
297        // Create distributions
298        let distributions: Result<Vec<_>, _> = config
299            .components
300            .iter()
301            .map(|c| {
302                Normal::new(c.mu, c.sigma)
303                    .map_err(|e| format!("Invalid normal distribution: {}", e))
304            })
305            .collect();
306
307        Ok(Self {
308            rng: ChaCha8Rng::seed_from_u64(seed),
309            config,
310            cumulative_weights,
311            distributions: distributions?,
312        })
313    }
314
315    /// Select a component using binary search on cumulative weights.
316    fn select_component(&mut self) -> usize {
317        let p: f64 = self.rng.random();
318        match self.cumulative_weights.binary_search_by(|w| {
319            w.partial_cmp(&p).unwrap_or_else(|| {
320                tracing::debug!("NaN detected in mixture weight comparison");
321                std::cmp::Ordering::Less
322            })
323        }) {
324            Ok(i) => i,
325            Err(i) => i.min(self.distributions.len() - 1),
326        }
327    }
328
329    /// Sample a value from the mixture.
330    pub fn sample(&mut self) -> f64 {
331        let component_idx = self.select_component();
332        let mut value = self.distributions[component_idx].sample(&mut self.rng);
333
334        // Apply constraints
335        if !self.config.allow_negative {
336            value = value.abs();
337        }
338        if let Some(min) = self.config.min_value {
339            value = value.max(min);
340        }
341        if let Some(max) = self.config.max_value {
342            value = value.min(max);
343        }
344
345        value
346    }
347
348    /// Sample a value with component information.
349    pub fn sample_with_component(&mut self) -> SampleWithComponent {
350        let component_idx = self.select_component();
351        let mut value = self.distributions[component_idx].sample(&mut self.rng);
352
353        // Apply constraints
354        if !self.config.allow_negative {
355            value = value.abs();
356        }
357        if let Some(min) = self.config.min_value {
358            value = value.max(min);
359        }
360        if let Some(max) = self.config.max_value {
361            value = value.min(max);
362        }
363
364        SampleWithComponent {
365            value,
366            component_index: component_idx,
367            component_label: self.config.components[component_idx].label.clone(),
368        }
369    }
370
371    /// Sample multiple values.
372    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
373        (0..n).map(|_| self.sample()).collect()
374    }
375
376    /// Reset the sampler with a new seed.
377    pub fn reset(&mut self, seed: u64) {
378        self.rng = ChaCha8Rng::seed_from_u64(seed);
379    }
380
381    /// Get the configuration.
382    pub fn config(&self) -> &GaussianMixtureConfig {
383        &self.config
384    }
385}
386
387/// Log-Normal Mixture Model sampler for positive-only distributions.
388pub struct LogNormalMixtureSampler {
389    rng: ChaCha8Rng,
390    config: LogNormalMixtureConfig,
391    /// Pre-computed cumulative weights for O(log n) component selection
392    cumulative_weights: Vec<f64>,
393    /// Log-normal distributions for each component
394    distributions: Vec<LogNormal<f64>>,
395    /// Decimal multiplier for rounding
396    decimal_multiplier: f64,
397}
398
399impl LogNormalMixtureSampler {
400    /// Create a new Log-Normal mixture sampler.
401    pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
402        config.validate()?;
403
404        // Pre-compute cumulative weights
405        let mut cumulative_weights = Vec::with_capacity(config.components.len());
406        let mut cumulative = 0.0;
407        for component in &config.components {
408            cumulative += component.weight;
409            cumulative_weights.push(cumulative);
410        }
411
412        // Create distributions
413        let distributions: Result<Vec<_>, _> = config
414            .components
415            .iter()
416            .map(|c| {
417                LogNormal::new(c.mu, c.sigma)
418                    .map_err(|e| format!("Invalid log-normal distribution: {}", e))
419            })
420            .collect();
421
422        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
423
424        Ok(Self {
425            rng: ChaCha8Rng::seed_from_u64(seed),
426            config,
427            cumulative_weights,
428            distributions: distributions?,
429            decimal_multiplier,
430        })
431    }
432
433    /// Select a component using binary search on cumulative weights.
434    fn select_component(&mut self) -> usize {
435        let p: f64 = self.rng.random();
436        match self.cumulative_weights.binary_search_by(|w| {
437            w.partial_cmp(&p).unwrap_or_else(|| {
438                tracing::debug!("NaN detected in mixture weight comparison");
439                std::cmp::Ordering::Less
440            })
441        }) {
442            Ok(i) => i,
443            Err(i) => i.min(self.distributions.len() - 1),
444        }
445    }
446
447    /// Sample a value from the mixture.
448    pub fn sample(&mut self) -> f64 {
449        let component_idx = self.select_component();
450        let mut value = self.distributions[component_idx].sample(&mut self.rng);
451
452        // Apply constraints
453        value = value.max(self.config.min_value);
454        if let Some(max) = self.config.max_value {
455            value = value.min(max);
456        }
457
458        // Round to decimal places
459        (value * self.decimal_multiplier).round() / self.decimal_multiplier
460    }
461
462    /// Sample a value as Decimal.
463    pub fn sample_decimal(&mut self) -> Decimal {
464        let value = self.sample();
465        Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
466    }
467
468    /// Sample a value with component information.
469    pub fn sample_with_component(&mut self) -> SampleWithComponent {
470        let component_idx = self.select_component();
471        let mut value = self.distributions[component_idx].sample(&mut self.rng);
472
473        // Apply constraints
474        value = value.max(self.config.min_value);
475        if let Some(max) = self.config.max_value {
476            value = value.min(max);
477        }
478
479        // Round to decimal places
480        value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
481
482        SampleWithComponent {
483            value,
484            component_index: component_idx,
485            component_label: self.config.components[component_idx].label.clone(),
486        }
487    }
488
489    /// Sample multiple values.
490    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
491        (0..n).map(|_| self.sample()).collect()
492    }
493
494    /// Sample multiple values as Decimals.
495    pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
496        (0..n).map(|_| self.sample_decimal()).collect()
497    }
498
499    /// Reset the sampler with a new seed.
500    pub fn reset(&mut self, seed: u64) {
501        self.rng = ChaCha8Rng::seed_from_u64(seed);
502    }
503
504    /// Get the configuration.
505    pub fn config(&self) -> &LogNormalMixtureConfig {
506        &self.config
507    }
508
509    /// Get the expected value of the mixture.
510    pub fn expected_value(&self) -> f64 {
511        self.config
512            .components
513            .iter()
514            .map(|c| c.weight * c.expected_value())
515            .sum()
516    }
517}
518
519#[cfg(test)]
520#[allow(clippy::unwrap_used)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_gaussian_mixture_validation() {
526        // Valid config
527        let config = GaussianMixtureConfig::new(vec![
528            GaussianComponent::new(0.5, 0.0, 1.0),
529            GaussianComponent::new(0.5, 5.0, 2.0),
530        ]);
531        assert!(config.validate().is_ok());
532
533        // Invalid: weights don't sum to 1.0
534        let invalid_config = GaussianMixtureConfig::new(vec![
535            GaussianComponent::new(0.3, 0.0, 1.0),
536            GaussianComponent::new(0.3, 5.0, 2.0),
537        ]);
538        assert!(invalid_config.validate().is_err());
539
540        // Invalid: negative sigma
541        let invalid_config =
542            GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
543        assert!(invalid_config.validate().is_err());
544    }
545
546    #[test]
547    fn test_gaussian_mixture_sampling() {
548        let config = GaussianMixtureConfig::new(vec![
549            GaussianComponent::new(0.5, 0.0, 1.0),
550            GaussianComponent::new(0.5, 10.0, 1.0),
551        ]);
552        let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
553
554        let samples = sampler.sample_n(1000);
555        assert_eq!(samples.len(), 1000);
556
557        // Check that samples are distributed around both means
558        let low_count = samples.iter().filter(|&&x| x < 5.0).count();
559        let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
560
561        // Both should be roughly 50% (with some tolerance)
562        assert!(low_count > 350 && low_count < 650);
563        assert!(high_count > 350 && high_count < 650);
564    }
565
566    #[test]
567    fn test_gaussian_mixture_determinism() {
568        let config = GaussianMixtureConfig::new(vec![
569            GaussianComponent::new(0.5, 0.0, 1.0),
570            GaussianComponent::new(0.5, 10.0, 1.0),
571        ]);
572
573        let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
574        let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
575
576        for _ in 0..100 {
577            assert_eq!(sampler1.sample(), sampler2.sample());
578        }
579    }
580
581    #[test]
582    fn test_lognormal_mixture_validation() {
583        // Valid config
584        let config = LogNormalMixtureConfig::new(vec![
585            LogNormalComponent::new(0.6, 6.0, 1.5),
586            LogNormalComponent::new(0.4, 8.5, 1.0),
587        ]);
588        assert!(config.validate().is_ok());
589
590        // Invalid: weights don't sum to 1.0
591        let invalid_config = LogNormalMixtureConfig::new(vec![
592            LogNormalComponent::new(0.2, 6.0, 1.5),
593            LogNormalComponent::new(0.2, 8.5, 1.0),
594        ]);
595        assert!(invalid_config.validate().is_err());
596    }
597
598    #[test]
599    fn test_lognormal_mixture_sampling() {
600        let config = LogNormalMixtureConfig::typical_transactions();
601        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
602
603        let samples = sampler.sample_n(1000);
604        assert_eq!(samples.len(), 1000);
605
606        // All samples should be positive
607        assert!(samples.iter().all(|&x| x > 0.0));
608
609        // Check minimum value constraint
610        assert!(samples.iter().all(|&x| x >= 0.01));
611    }
612
613    #[test]
614    fn test_sample_with_component() {
615        let config = LogNormalMixtureConfig::new(vec![
616            LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
617            LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
618        ]);
619        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
620
621        let mut small_count = 0;
622        let mut large_count = 0;
623
624        for _ in 0..1000 {
625            let result = sampler.sample_with_component();
626            match result.component_label.as_deref() {
627                Some("small") => small_count += 1,
628                Some("large") => large_count += 1,
629                _ => panic!("Unexpected label"),
630            }
631        }
632
633        // Both components should be selected roughly equally
634        assert!(small_count > 400 && small_count < 600);
635        assert!(large_count > 400 && large_count < 600);
636    }
637
638    #[test]
639    fn test_lognormal_mixture_determinism() {
640        let config = LogNormalMixtureConfig::typical_transactions();
641
642        let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
643        let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
644
645        for _ in 0..100 {
646            assert_eq!(sampler1.sample(), sampler2.sample());
647        }
648    }
649
650    #[test]
651    fn test_lognormal_expected_value() {
652        let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
653        let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
654
655        // E[X] = exp(mu + sigma^2/2) = exp(7 + 0.5) = exp(7.5) ≈ 1808
656        let expected = sampler.expected_value();
657        assert!((expected - 1808.04).abs() < 1.0);
658    }
659
660    #[test]
661    fn test_component_label() {
662        let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
663        assert_eq!(component.label, Some("test_label".to_string()));
664
665        let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
666        assert_eq!(component_no_label.label, None);
667    }
668
669    #[test]
670    fn test_max_value_constraint() {
671        let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
672        config.max_value = Some(1000.0);
673
674        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
675        let samples = sampler.sample_n(1000);
676
677        // All samples should be at most 1000
678        assert!(samples.iter().all(|&x| x <= 1000.0));
679    }
680}