Skip to main content

datasynth_core/distributions/
mixture.rs

1//! Mixture model distributions for multi-modal data generation.
2//!
3//! Provides Gaussian and Log-Normal mixture models that can generate
4//! realistic multi-modal distributions commonly observed in accounting data
5//! (e.g., routine vs. significant vs. major transactions).
6
7use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13/// Configuration for a single Gaussian component in a mixture.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
17    pub weight: f64,
18    /// Mean (mu) of the Gaussian distribution
19    pub mu: f64,
20    /// Standard deviation (sigma) of the Gaussian distribution
21    pub sigma: f64,
22    /// Optional label for this component (e.g., "routine", "significant")
23    #[serde(default)]
24    pub label: Option<String>,
25}
26
27impl GaussianComponent {
28    /// Create a new Gaussian component.
29    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30        Self {
31            weight,
32            mu,
33            sigma,
34            label: None,
35        }
36    }
37
38    /// Create a labeled Gaussian component.
39    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40        Self {
41            weight,
42            mu,
43            sigma,
44            label: Some(label.into()),
45        }
46    }
47}
48
49/// Configuration for a Gaussian Mixture Model.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52    /// Components of the mixture
53    pub components: Vec<GaussianComponent>,
54    /// Whether to allow negative values (default: true)
55    #[serde(default = "default_true")]
56    pub allow_negative: bool,
57    /// Minimum value (if not allowing negative)
58    #[serde(default)]
59    pub min_value: Option<f64>,
60    /// Maximum value (clamps output)
61    #[serde(default)]
62    pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66    true
67}
68
69impl Default for GaussianMixtureConfig {
70    fn default() -> Self {
71        Self {
72            components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73            allow_negative: true,
74            min_value: None,
75            max_value: None,
76        }
77    }
78}
79
80impl GaussianMixtureConfig {
81    /// Create a new Gaussian mixture configuration.
82    pub fn new(components: Vec<GaussianComponent>) -> Self {
83        Self {
84            components,
85            ..Default::default()
86        }
87    }
88
89    /// Validate the configuration.
90    pub fn validate(&self) -> Result<(), String> {
91        if self.components.is_empty() {
92            return Err("At least one component is required".to_string());
93        }
94
95        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96        if (weight_sum - 1.0).abs() > 0.01 {
97            return Err(format!(
98                "Component weights must sum to 1.0, got {}",
99                weight_sum
100            ));
101        }
102
103        for (i, component) in self.components.iter().enumerate() {
104            if component.weight < 0.0 || component.weight > 1.0 {
105                return Err(format!(
106                    "Component {} weight must be between 0.0 and 1.0, got {}",
107                    i, component.weight
108                ));
109            }
110            if component.sigma <= 0.0 {
111                return Err(format!(
112                    "Component {} sigma must be positive, got {}",
113                    i, component.sigma
114                ));
115            }
116        }
117
118        Ok(())
119    }
120}
121
122/// Configuration for a single Log-Normal component in a mixture.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct LogNormalComponent {
125    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
126    pub weight: f64,
127    /// Mu parameter (location) of the log-normal distribution
128    pub mu: f64,
129    /// Sigma parameter (scale) of the log-normal distribution
130    pub sigma: f64,
131    /// Optional label for this component
132    #[serde(default)]
133    pub label: Option<String>,
134}
135
136impl LogNormalComponent {
137    /// Create a new Log-Normal component.
138    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
139        Self {
140            weight,
141            mu,
142            sigma,
143            label: None,
144        }
145    }
146
147    /// Create a labeled Log-Normal component.
148    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
149        Self {
150            weight,
151            mu,
152            sigma,
153            label: Some(label.into()),
154        }
155    }
156
157    /// Get the expected value (mean) of this component.
158    pub fn expected_value(&self) -> f64 {
159        (self.mu + self.sigma.powi(2) / 2.0).exp()
160    }
161
162    /// Get the median of this component.
163    pub fn median(&self) -> f64 {
164        self.mu.exp()
165    }
166}
167
168/// Configuration for a Log-Normal Mixture Model.
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct LogNormalMixtureConfig {
171    /// Components of the mixture
172    pub components: Vec<LogNormalComponent>,
173    /// Minimum value (default: 0.01)
174    #[serde(default = "default_min_value")]
175    pub min_value: f64,
176    /// Maximum value (clamps output)
177    #[serde(default)]
178    pub max_value: Option<f64>,
179    /// Number of decimal places for rounding
180    #[serde(default = "default_decimal_places")]
181    pub decimal_places: u8,
182}
183
184fn default_min_value() -> f64 {
185    0.01
186}
187
188fn default_decimal_places() -> u8 {
189    2
190}
191
192impl Default for LogNormalMixtureConfig {
193    fn default() -> Self {
194        Self {
195            components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
196            min_value: 0.01,
197            max_value: None,
198            decimal_places: 2,
199        }
200    }
201}
202
203impl LogNormalMixtureConfig {
204    /// Create a new Log-Normal mixture configuration.
205    pub fn new(components: Vec<LogNormalComponent>) -> Self {
206        Self {
207            components,
208            ..Default::default()
209        }
210    }
211
212    /// Create a typical transaction amount mixture (routine/significant/major).
213    pub fn typical_transactions() -> Self {
214        Self {
215            components: vec![
216                LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
217                LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
218                LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
219            ],
220            min_value: 0.01,
221            max_value: Some(100_000_000.0),
222            decimal_places: 2,
223        }
224    }
225
226    /// Validate the configuration.
227    pub fn validate(&self) -> Result<(), String> {
228        if self.components.is_empty() {
229            return Err("At least one component is required".to_string());
230        }
231
232        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
233        if (weight_sum - 1.0).abs() > 0.01 {
234            return Err(format!(
235                "Component weights must sum to 1.0, got {}",
236                weight_sum
237            ));
238        }
239
240        for (i, component) in self.components.iter().enumerate() {
241            if component.weight < 0.0 || component.weight > 1.0 {
242                return Err(format!(
243                    "Component {} weight must be between 0.0 and 1.0, got {}",
244                    i, component.weight
245                ));
246            }
247            if component.sigma <= 0.0 {
248                return Err(format!(
249                    "Component {} sigma must be positive, got {}",
250                    i, component.sigma
251                ));
252            }
253        }
254
255        if self.min_value < 0.0 {
256            return Err("min_value must be non-negative".to_string());
257        }
258
259        Ok(())
260    }
261}
262
263/// Result of sampling with component information.
264#[derive(Debug, Clone)]
265pub struct SampleWithComponent {
266    /// The sampled value
267    pub value: f64,
268    /// Index of the component that generated this sample
269    pub component_index: usize,
270    /// Label of the component (if available)
271    pub component_label: Option<String>,
272}
273
274/// Gaussian Mixture Model sampler.
275pub struct GaussianMixtureSampler {
276    rng: ChaCha8Rng,
277    config: GaussianMixtureConfig,
278    /// Pre-computed cumulative weights for O(log n) component selection
279    cumulative_weights: Vec<f64>,
280    /// Normal distributions for each component
281    distributions: Vec<Normal<f64>>,
282}
283
284impl GaussianMixtureSampler {
285    /// Create a new Gaussian mixture sampler.
286    pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
287        config.validate()?;
288
289        // Pre-compute cumulative weights
290        let mut cumulative_weights = Vec::with_capacity(config.components.len());
291        let mut cumulative = 0.0;
292        for component in &config.components {
293            cumulative += component.weight;
294            cumulative_weights.push(cumulative);
295        }
296
297        // Create distributions
298        let distributions: Result<Vec<_>, _> = config
299            .components
300            .iter()
301            .map(|c| {
302                Normal::new(c.mu, c.sigma)
303                    .map_err(|e| format!("Invalid normal distribution: {}", e))
304            })
305            .collect();
306
307        Ok(Self {
308            rng: ChaCha8Rng::seed_from_u64(seed),
309            config,
310            cumulative_weights,
311            distributions: distributions?,
312        })
313    }
314
315    /// Select a component using binary search on cumulative weights.
316    fn select_component(&mut self) -> usize {
317        let p: f64 = self.rng.gen();
318        match self
319            .cumulative_weights
320            .binary_search_by(|w| w.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
321        {
322            Ok(i) => i,
323            Err(i) => i.min(self.distributions.len() - 1),
324        }
325    }
326
327    /// Sample a value from the mixture.
328    pub fn sample(&mut self) -> f64 {
329        let component_idx = self.select_component();
330        let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332        // Apply constraints
333        if !self.config.allow_negative {
334            value = value.abs();
335        }
336        if let Some(min) = self.config.min_value {
337            value = value.max(min);
338        }
339        if let Some(max) = self.config.max_value {
340            value = value.min(max);
341        }
342
343        value
344    }
345
346    /// Sample a value with component information.
347    pub fn sample_with_component(&mut self) -> SampleWithComponent {
348        let component_idx = self.select_component();
349        let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351        // Apply constraints
352        if !self.config.allow_negative {
353            value = value.abs();
354        }
355        if let Some(min) = self.config.min_value {
356            value = value.max(min);
357        }
358        if let Some(max) = self.config.max_value {
359            value = value.min(max);
360        }
361
362        SampleWithComponent {
363            value,
364            component_index: component_idx,
365            component_label: self.config.components[component_idx].label.clone(),
366        }
367    }
368
369    /// Sample multiple values.
370    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371        (0..n).map(|_| self.sample()).collect()
372    }
373
374    /// Reset the sampler with a new seed.
375    pub fn reset(&mut self, seed: u64) {
376        self.rng = ChaCha8Rng::seed_from_u64(seed);
377    }
378
379    /// Get the configuration.
380    pub fn config(&self) -> &GaussianMixtureConfig {
381        &self.config
382    }
383}
384
385/// Log-Normal Mixture Model sampler for positive-only distributions.
386pub struct LogNormalMixtureSampler {
387    rng: ChaCha8Rng,
388    config: LogNormalMixtureConfig,
389    /// Pre-computed cumulative weights for O(log n) component selection
390    cumulative_weights: Vec<f64>,
391    /// Log-normal distributions for each component
392    distributions: Vec<LogNormal<f64>>,
393    /// Decimal multiplier for rounding
394    decimal_multiplier: f64,
395}
396
397impl LogNormalMixtureSampler {
398    /// Create a new Log-Normal mixture sampler.
399    pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
400        config.validate()?;
401
402        // Pre-compute cumulative weights
403        let mut cumulative_weights = Vec::with_capacity(config.components.len());
404        let mut cumulative = 0.0;
405        for component in &config.components {
406            cumulative += component.weight;
407            cumulative_weights.push(cumulative);
408        }
409
410        // Create distributions
411        let distributions: Result<Vec<_>, _> = config
412            .components
413            .iter()
414            .map(|c| {
415                LogNormal::new(c.mu, c.sigma)
416                    .map_err(|e| format!("Invalid log-normal distribution: {}", e))
417            })
418            .collect();
419
420        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
421
422        Ok(Self {
423            rng: ChaCha8Rng::seed_from_u64(seed),
424            config,
425            cumulative_weights,
426            distributions: distributions?,
427            decimal_multiplier,
428        })
429    }
430
431    /// Select a component using binary search on cumulative weights.
432    fn select_component(&mut self) -> usize {
433        let p: f64 = self.rng.gen();
434        match self
435            .cumulative_weights
436            .binary_search_by(|w| w.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
437        {
438            Ok(i) => i,
439            Err(i) => i.min(self.distributions.len() - 1),
440        }
441    }
442
443    /// Sample a value from the mixture.
444    pub fn sample(&mut self) -> f64 {
445        let component_idx = self.select_component();
446        let mut value = self.distributions[component_idx].sample(&mut self.rng);
447
448        // Apply constraints
449        value = value.max(self.config.min_value);
450        if let Some(max) = self.config.max_value {
451            value = value.min(max);
452        }
453
454        // Round to decimal places
455        (value * self.decimal_multiplier).round() / self.decimal_multiplier
456    }
457
458    /// Sample a value as Decimal.
459    pub fn sample_decimal(&mut self) -> Decimal {
460        let value = self.sample();
461        Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
462    }
463
464    /// Sample a value with component information.
465    pub fn sample_with_component(&mut self) -> SampleWithComponent {
466        let component_idx = self.select_component();
467        let mut value = self.distributions[component_idx].sample(&mut self.rng);
468
469        // Apply constraints
470        value = value.max(self.config.min_value);
471        if let Some(max) = self.config.max_value {
472            value = value.min(max);
473        }
474
475        // Round to decimal places
476        value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
477
478        SampleWithComponent {
479            value,
480            component_index: component_idx,
481            component_label: self.config.components[component_idx].label.clone(),
482        }
483    }
484
485    /// Sample multiple values.
486    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
487        (0..n).map(|_| self.sample()).collect()
488    }
489
490    /// Sample multiple values as Decimals.
491    pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
492        (0..n).map(|_| self.sample_decimal()).collect()
493    }
494
495    /// Reset the sampler with a new seed.
496    pub fn reset(&mut self, seed: u64) {
497        self.rng = ChaCha8Rng::seed_from_u64(seed);
498    }
499
500    /// Get the configuration.
501    pub fn config(&self) -> &LogNormalMixtureConfig {
502        &self.config
503    }
504
505    /// Get the expected value of the mixture.
506    pub fn expected_value(&self) -> f64 {
507        self.config
508            .components
509            .iter()
510            .map(|c| c.weight * c.expected_value())
511            .sum()
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_gaussian_mixture_validation() {
521        // Valid config
522        let config = GaussianMixtureConfig::new(vec![
523            GaussianComponent::new(0.5, 0.0, 1.0),
524            GaussianComponent::new(0.5, 5.0, 2.0),
525        ]);
526        assert!(config.validate().is_ok());
527
528        // Invalid: weights don't sum to 1.0
529        let invalid_config = GaussianMixtureConfig::new(vec![
530            GaussianComponent::new(0.3, 0.0, 1.0),
531            GaussianComponent::new(0.3, 5.0, 2.0),
532        ]);
533        assert!(invalid_config.validate().is_err());
534
535        // Invalid: negative sigma
536        let invalid_config =
537            GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
538        assert!(invalid_config.validate().is_err());
539    }
540
541    #[test]
542    fn test_gaussian_mixture_sampling() {
543        let config = GaussianMixtureConfig::new(vec![
544            GaussianComponent::new(0.5, 0.0, 1.0),
545            GaussianComponent::new(0.5, 10.0, 1.0),
546        ]);
547        let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
548
549        let samples = sampler.sample_n(1000);
550        assert_eq!(samples.len(), 1000);
551
552        // Check that samples are distributed around both means
553        let low_count = samples.iter().filter(|&&x| x < 5.0).count();
554        let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
555
556        // Both should be roughly 50% (with some tolerance)
557        assert!(low_count > 350 && low_count < 650);
558        assert!(high_count > 350 && high_count < 650);
559    }
560
561    #[test]
562    fn test_gaussian_mixture_determinism() {
563        let config = GaussianMixtureConfig::new(vec![
564            GaussianComponent::new(0.5, 0.0, 1.0),
565            GaussianComponent::new(0.5, 10.0, 1.0),
566        ]);
567
568        let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
569        let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
570
571        for _ in 0..100 {
572            assert_eq!(sampler1.sample(), sampler2.sample());
573        }
574    }
575
576    #[test]
577    fn test_lognormal_mixture_validation() {
578        // Valid config
579        let config = LogNormalMixtureConfig::new(vec![
580            LogNormalComponent::new(0.6, 6.0, 1.5),
581            LogNormalComponent::new(0.4, 8.5, 1.0),
582        ]);
583        assert!(config.validate().is_ok());
584
585        // Invalid: weights don't sum to 1.0
586        let invalid_config = LogNormalMixtureConfig::new(vec![
587            LogNormalComponent::new(0.2, 6.0, 1.5),
588            LogNormalComponent::new(0.2, 8.5, 1.0),
589        ]);
590        assert!(invalid_config.validate().is_err());
591    }
592
593    #[test]
594    fn test_lognormal_mixture_sampling() {
595        let config = LogNormalMixtureConfig::typical_transactions();
596        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
597
598        let samples = sampler.sample_n(1000);
599        assert_eq!(samples.len(), 1000);
600
601        // All samples should be positive
602        assert!(samples.iter().all(|&x| x > 0.0));
603
604        // Check minimum value constraint
605        assert!(samples.iter().all(|&x| x >= 0.01));
606    }
607
608    #[test]
609    fn test_sample_with_component() {
610        let config = LogNormalMixtureConfig::new(vec![
611            LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
612            LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
613        ]);
614        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
615
616        let mut small_count = 0;
617        let mut large_count = 0;
618
619        for _ in 0..1000 {
620            let result = sampler.sample_with_component();
621            match result.component_label.as_deref() {
622                Some("small") => small_count += 1,
623                Some("large") => large_count += 1,
624                _ => panic!("Unexpected label"),
625            }
626        }
627
628        // Both components should be selected roughly equally
629        assert!(small_count > 400 && small_count < 600);
630        assert!(large_count > 400 && large_count < 600);
631    }
632
633    #[test]
634    fn test_lognormal_mixture_determinism() {
635        let config = LogNormalMixtureConfig::typical_transactions();
636
637        let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
638        let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
639
640        for _ in 0..100 {
641            assert_eq!(sampler1.sample(), sampler2.sample());
642        }
643    }
644
645    #[test]
646    fn test_lognormal_expected_value() {
647        let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
648        let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
649
650        // E[X] = exp(mu + sigma^2/2) = exp(7 + 0.5) = exp(7.5) ≈ 1808
651        let expected = sampler.expected_value();
652        assert!((expected - 1808.04).abs() < 1.0);
653    }
654
655    #[test]
656    fn test_component_label() {
657        let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
658        assert_eq!(component.label, Some("test_label".to_string()));
659
660        let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
661        assert_eq!(component_no_label.label, None);
662    }
663
664    #[test]
665    fn test_max_value_constraint() {
666        let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
667        config.max_value = Some(1000.0);
668
669        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
670        let samples = sampler.sample_n(1000);
671
672        // All samples should be at most 1000
673        assert!(samples.iter().all(|&x| x <= 1000.0));
674    }
675}