Skip to main content

datasynth_core/distributions/
mixture.rs

1//! Mixture model distributions for multi-modal data generation.
2//!
3//! Provides Gaussian and Log-Normal mixture models that can generate
4//! realistic multi-modal distributions commonly observed in accounting data
5//! (e.g., routine vs. significant vs. major transactions).
6
7use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13/// Configuration for a single Gaussian component in a mixture.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
17    pub weight: f64,
18    /// Mean (mu) of the Gaussian distribution
19    pub mu: f64,
20    /// Standard deviation (sigma) of the Gaussian distribution
21    pub sigma: f64,
22    /// Optional label for this component (e.g., "routine", "significant")
23    #[serde(default)]
24    pub label: Option<String>,
25}
26
27impl GaussianComponent {
28    /// Create a new Gaussian component.
29    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30        Self {
31            weight,
32            mu,
33            sigma,
34            label: None,
35        }
36    }
37
38    /// Create a labeled Gaussian component.
39    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40        Self {
41            weight,
42            mu,
43            sigma,
44            label: Some(label.into()),
45        }
46    }
47}
48
49/// Configuration for a Gaussian Mixture Model.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52    /// Components of the mixture
53    pub components: Vec<GaussianComponent>,
54    /// Whether to allow negative values (default: true)
55    #[serde(default = "default_true")]
56    pub allow_negative: bool,
57    /// Minimum value (if not allowing negative)
58    #[serde(default)]
59    pub min_value: Option<f64>,
60    /// Maximum value (clamps output)
61    #[serde(default)]
62    pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66    true
67}
68
69impl Default for GaussianMixtureConfig {
70    fn default() -> Self {
71        Self {
72            components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73            allow_negative: true,
74            min_value: None,
75            max_value: None,
76        }
77    }
78}
79
80impl GaussianMixtureConfig {
81    /// Create a new Gaussian mixture configuration.
82    pub fn new(components: Vec<GaussianComponent>) -> Self {
83        Self {
84            components,
85            ..Default::default()
86        }
87    }
88
89    /// Validate the configuration.
90    pub fn validate(&self) -> Result<(), String> {
91        if self.components.is_empty() {
92            return Err("At least one component is required".to_string());
93        }
94
95        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96        if (weight_sum - 1.0).abs() > 0.01 {
97            return Err(format!(
98                "Component weights must sum to 1.0, got {weight_sum}"
99            ));
100        }
101
102        for (i, component) in self.components.iter().enumerate() {
103            if component.weight < 0.0 || component.weight > 1.0 {
104                return Err(format!(
105                    "Component {} weight must be between 0.0 and 1.0, got {}",
106                    i, component.weight
107                ));
108            }
109            if component.sigma <= 0.0 {
110                return Err(format!(
111                    "Component {} sigma must be positive, got {}",
112                    i, component.sigma
113                ));
114            }
115        }
116
117        Ok(())
118    }
119}
120
121/// Configuration for a single Log-Normal component in a mixture.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
125    pub weight: f64,
126    /// Mu parameter (location) of the log-normal distribution
127    pub mu: f64,
128    /// Sigma parameter (scale) of the log-normal distribution
129    pub sigma: f64,
130    /// Optional label for this component
131    #[serde(default)]
132    pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136    /// Create a new Log-Normal component.
137    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138        Self {
139            weight,
140            mu,
141            sigma,
142            label: None,
143        }
144    }
145
146    /// Create a labeled Log-Normal component.
147    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148        Self {
149            weight,
150            mu,
151            sigma,
152            label: Some(label.into()),
153        }
154    }
155
156    /// Get the expected value (mean) of this component.
157    pub fn expected_value(&self) -> f64 {
158        (self.mu + self.sigma.powi(2) / 2.0).exp()
159    }
160
161    /// Get the median of this component.
162    pub fn median(&self) -> f64 {
163        self.mu.exp()
164    }
165}
166
167/// Configuration for a Log-Normal Mixture Model.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170    /// Components of the mixture
171    pub components: Vec<LogNormalComponent>,
172    /// Minimum value (default: 0.01)
173    #[serde(default = "default_min_value")]
174    pub min_value: f64,
175    /// Maximum value (clamps output)
176    #[serde(default)]
177    pub max_value: Option<f64>,
178    /// Number of decimal places for rounding
179    #[serde(default = "default_decimal_places")]
180    pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184    0.01
185}
186
187fn default_decimal_places() -> u8 {
188    2
189}
190
191impl Default for LogNormalMixtureConfig {
192    fn default() -> Self {
193        Self {
194            components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195            min_value: 0.01,
196            max_value: None,
197            decimal_places: 2,
198        }
199    }
200}
201
202impl LogNormalMixtureConfig {
203    /// Create a new Log-Normal mixture configuration.
204    pub fn new(components: Vec<LogNormalComponent>) -> Self {
205        Self {
206            components,
207            ..Default::default()
208        }
209    }
210
211    /// Create a typical transaction amount mixture (routine/significant/major).
212    pub fn typical_transactions() -> Self {
213        Self {
214            components: vec![
215                LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216                LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217                LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218            ],
219            min_value: 0.01,
220            max_value: Some(100_000_000.0),
221            decimal_places: 2,
222        }
223    }
224
225    /// Validate the configuration.
226    pub fn validate(&self) -> Result<(), String> {
227        if self.components.is_empty() {
228            return Err("At least one component is required".to_string());
229        }
230
231        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232        if (weight_sum - 1.0).abs() > 0.01 {
233            return Err(format!(
234                "Component weights must sum to 1.0, got {weight_sum}"
235            ));
236        }
237
238        for (i, component) in self.components.iter().enumerate() {
239            if component.weight < 0.0 || component.weight > 1.0 {
240                return Err(format!(
241                    "Component {} weight must be between 0.0 and 1.0, got {}",
242                    i, component.weight
243                ));
244            }
245            if component.sigma <= 0.0 {
246                return Err(format!(
247                    "Component {} sigma must be positive, got {}",
248                    i, component.sigma
249                ));
250            }
251        }
252
253        if self.min_value < 0.0 {
254            return Err("min_value must be non-negative".to_string());
255        }
256
257        Ok(())
258    }
259}
260
261/// Result of sampling with component information.
262#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264    /// The sampled value
265    pub value: f64,
266    /// Index of the component that generated this sample
267    pub component_index: usize,
268    /// Label of the component (if available)
269    pub component_label: Option<String>,
270}
271
272/// Gaussian Mixture Model sampler.
273pub struct GaussianMixtureSampler {
274    rng: ChaCha8Rng,
275    config: GaussianMixtureConfig,
276    /// Pre-computed cumulative weights for O(log n) component selection
277    cumulative_weights: Vec<f64>,
278    /// Normal distributions for each component
279    distributions: Vec<Normal<f64>>,
280}
281
282impl GaussianMixtureSampler {
283    /// Create a new Gaussian mixture sampler.
284    pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
285        config.validate()?;
286
287        // Pre-compute cumulative weights
288        let mut cumulative_weights = Vec::with_capacity(config.components.len());
289        let mut cumulative = 0.0;
290        for component in &config.components {
291            cumulative += component.weight;
292            cumulative_weights.push(cumulative);
293        }
294
295        // Create distributions
296        let distributions: Result<Vec<_>, _> = config
297            .components
298            .iter()
299            .map(|c| {
300                Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
301            })
302            .collect();
303
304        Ok(Self {
305            rng: ChaCha8Rng::seed_from_u64(seed),
306            config,
307            cumulative_weights,
308            distributions: distributions?,
309        })
310    }
311
312    /// Select a component using binary search on cumulative weights.
313    fn select_component(&mut self) -> usize {
314        let p: f64 = self.rng.random();
315        match self.cumulative_weights.binary_search_by(|w| {
316            w.partial_cmp(&p).unwrap_or_else(|| {
317                tracing::debug!("NaN detected in mixture weight comparison");
318                std::cmp::Ordering::Less
319            })
320        }) {
321            Ok(i) => i,
322            Err(i) => i.min(self.distributions.len() - 1),
323        }
324    }
325
326    /// Sample a value from the mixture.
327    pub fn sample(&mut self) -> f64 {
328        let component_idx = self.select_component();
329        let mut value = self.distributions[component_idx].sample(&mut self.rng);
330
331        // Apply constraints
332        if !self.config.allow_negative {
333            value = value.abs();
334        }
335        if let Some(min) = self.config.min_value {
336            value = value.max(min);
337        }
338        if let Some(max) = self.config.max_value {
339            value = value.min(max);
340        }
341
342        value
343    }
344
345    /// Sample a value with component information.
346    pub fn sample_with_component(&mut self) -> SampleWithComponent {
347        let component_idx = self.select_component();
348        let mut value = self.distributions[component_idx].sample(&mut self.rng);
349
350        // Apply constraints
351        if !self.config.allow_negative {
352            value = value.abs();
353        }
354        if let Some(min) = self.config.min_value {
355            value = value.max(min);
356        }
357        if let Some(max) = self.config.max_value {
358            value = value.min(max);
359        }
360
361        SampleWithComponent {
362            value,
363            component_index: component_idx,
364            component_label: self.config.components[component_idx].label.clone(),
365        }
366    }
367
368    /// Sample multiple values.
369    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
370        (0..n).map(|_| self.sample()).collect()
371    }
372
373    /// Reset the sampler with a new seed.
374    pub fn reset(&mut self, seed: u64) {
375        self.rng = ChaCha8Rng::seed_from_u64(seed);
376    }
377
378    /// Get the configuration.
379    pub fn config(&self) -> &GaussianMixtureConfig {
380        &self.config
381    }
382}
383
384/// Log-Normal Mixture Model sampler for positive-only distributions.
385pub struct LogNormalMixtureSampler {
386    rng: ChaCha8Rng,
387    config: LogNormalMixtureConfig,
388    /// Pre-computed cumulative weights for O(log n) component selection
389    cumulative_weights: Vec<f64>,
390    /// Log-normal distributions for each component
391    distributions: Vec<LogNormal<f64>>,
392    /// Decimal multiplier for rounding
393    decimal_multiplier: f64,
394}
395
396impl LogNormalMixtureSampler {
397    /// Create a new Log-Normal mixture sampler.
398    pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
399        config.validate()?;
400
401        // Pre-compute cumulative weights
402        let mut cumulative_weights = Vec::with_capacity(config.components.len());
403        let mut cumulative = 0.0;
404        for component in &config.components {
405            cumulative += component.weight;
406            cumulative_weights.push(cumulative);
407        }
408
409        // Create distributions
410        let distributions: Result<Vec<_>, _> = config
411            .components
412            .iter()
413            .map(|c| {
414                LogNormal::new(c.mu, c.sigma)
415                    .map_err(|e| format!("Invalid log-normal distribution: {e}"))
416            })
417            .collect();
418
419        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
420
421        Ok(Self {
422            rng: ChaCha8Rng::seed_from_u64(seed),
423            config,
424            cumulative_weights,
425            distributions: distributions?,
426            decimal_multiplier,
427        })
428    }
429
430    /// Select a component using binary search on cumulative weights.
431    fn select_component(&mut self) -> usize {
432        let p: f64 = self.rng.random();
433        match self.cumulative_weights.binary_search_by(|w| {
434            w.partial_cmp(&p).unwrap_or_else(|| {
435                tracing::debug!("NaN detected in mixture weight comparison");
436                std::cmp::Ordering::Less
437            })
438        }) {
439            Ok(i) => i,
440            Err(i) => i.min(self.distributions.len() - 1),
441        }
442    }
443
444    /// Sample a value from the mixture.
445    pub fn sample(&mut self) -> f64 {
446        let component_idx = self.select_component();
447        let mut value = self.distributions[component_idx].sample(&mut self.rng);
448
449        // Apply constraints
450        value = value.max(self.config.min_value);
451        if let Some(max) = self.config.max_value {
452            value = value.min(max);
453        }
454
455        // Round to decimal places
456        (value * self.decimal_multiplier).round() / self.decimal_multiplier
457    }
458
459    /// Sample a value as Decimal.
460    pub fn sample_decimal(&mut self) -> Decimal {
461        let value = self.sample();
462        Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
463    }
464
465    /// Sample a value with component information.
466    pub fn sample_with_component(&mut self) -> SampleWithComponent {
467        let component_idx = self.select_component();
468        let mut value = self.distributions[component_idx].sample(&mut self.rng);
469
470        // Apply constraints
471        value = value.max(self.config.min_value);
472        if let Some(max) = self.config.max_value {
473            value = value.min(max);
474        }
475
476        // Round to decimal places
477        value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
478
479        SampleWithComponent {
480            value,
481            component_index: component_idx,
482            component_label: self.config.components[component_idx].label.clone(),
483        }
484    }
485
486    /// Sample multiple values.
487    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
488        (0..n).map(|_| self.sample()).collect()
489    }
490
491    /// Sample multiple values as Decimals.
492    pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
493        (0..n).map(|_| self.sample_decimal()).collect()
494    }
495
496    /// Reset the sampler with a new seed.
497    pub fn reset(&mut self, seed: u64) {
498        self.rng = ChaCha8Rng::seed_from_u64(seed);
499    }
500
501    /// Get the configuration.
502    pub fn config(&self) -> &LogNormalMixtureConfig {
503        &self.config
504    }
505
506    /// Get the expected value of the mixture.
507    pub fn expected_value(&self) -> f64 {
508        self.config
509            .components
510            .iter()
511            .map(|c| c.weight * c.expected_value())
512            .sum()
513    }
514}
515
516#[cfg(test)]
517#[allow(clippy::unwrap_used)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_gaussian_mixture_validation() {
523        // Valid config
524        let config = GaussianMixtureConfig::new(vec![
525            GaussianComponent::new(0.5, 0.0, 1.0),
526            GaussianComponent::new(0.5, 5.0, 2.0),
527        ]);
528        assert!(config.validate().is_ok());
529
530        // Invalid: weights don't sum to 1.0
531        let invalid_config = GaussianMixtureConfig::new(vec![
532            GaussianComponent::new(0.3, 0.0, 1.0),
533            GaussianComponent::new(0.3, 5.0, 2.0),
534        ]);
535        assert!(invalid_config.validate().is_err());
536
537        // Invalid: negative sigma
538        let invalid_config =
539            GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
540        assert!(invalid_config.validate().is_err());
541    }
542
543    #[test]
544    fn test_gaussian_mixture_sampling() {
545        let config = GaussianMixtureConfig::new(vec![
546            GaussianComponent::new(0.5, 0.0, 1.0),
547            GaussianComponent::new(0.5, 10.0, 1.0),
548        ]);
549        let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
550
551        let samples = sampler.sample_n(1000);
552        assert_eq!(samples.len(), 1000);
553
554        // Check that samples are distributed around both means
555        let low_count = samples.iter().filter(|&&x| x < 5.0).count();
556        let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
557
558        // Both should be roughly 50% (with some tolerance)
559        assert!(low_count > 350 && low_count < 650);
560        assert!(high_count > 350 && high_count < 650);
561    }
562
563    #[test]
564    fn test_gaussian_mixture_determinism() {
565        let config = GaussianMixtureConfig::new(vec![
566            GaussianComponent::new(0.5, 0.0, 1.0),
567            GaussianComponent::new(0.5, 10.0, 1.0),
568        ]);
569
570        let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
571        let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
572
573        for _ in 0..100 {
574            assert_eq!(sampler1.sample(), sampler2.sample());
575        }
576    }
577
578    #[test]
579    fn test_lognormal_mixture_validation() {
580        // Valid config
581        let config = LogNormalMixtureConfig::new(vec![
582            LogNormalComponent::new(0.6, 6.0, 1.5),
583            LogNormalComponent::new(0.4, 8.5, 1.0),
584        ]);
585        assert!(config.validate().is_ok());
586
587        // Invalid: weights don't sum to 1.0
588        let invalid_config = LogNormalMixtureConfig::new(vec![
589            LogNormalComponent::new(0.2, 6.0, 1.5),
590            LogNormalComponent::new(0.2, 8.5, 1.0),
591        ]);
592        assert!(invalid_config.validate().is_err());
593    }
594
595    #[test]
596    fn test_lognormal_mixture_sampling() {
597        let config = LogNormalMixtureConfig::typical_transactions();
598        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
599
600        let samples = sampler.sample_n(1000);
601        assert_eq!(samples.len(), 1000);
602
603        // All samples should be positive
604        assert!(samples.iter().all(|&x| x > 0.0));
605
606        // Check minimum value constraint
607        assert!(samples.iter().all(|&x| x >= 0.01));
608    }
609
610    #[test]
611    fn test_sample_with_component() {
612        let config = LogNormalMixtureConfig::new(vec![
613            LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
614            LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
615        ]);
616        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
617
618        let mut small_count = 0;
619        let mut large_count = 0;
620
621        for _ in 0..1000 {
622            let result = sampler.sample_with_component();
623            match result.component_label.as_deref() {
624                Some("small") => small_count += 1,
625                Some("large") => large_count += 1,
626                _ => panic!("Unexpected label"),
627            }
628        }
629
630        // Both components should be selected roughly equally
631        assert!(small_count > 400 && small_count < 600);
632        assert!(large_count > 400 && large_count < 600);
633    }
634
635    #[test]
636    fn test_lognormal_mixture_determinism() {
637        let config = LogNormalMixtureConfig::typical_transactions();
638
639        let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
640        let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
641
642        for _ in 0..100 {
643            assert_eq!(sampler1.sample(), sampler2.sample());
644        }
645    }
646
647    #[test]
648    fn test_lognormal_expected_value() {
649        let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
650        let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
651
652        // E[X] = exp(mu + sigma^2/2) = exp(7 + 0.5) = exp(7.5) ≈ 1808
653        let expected = sampler.expected_value();
654        assert!((expected - 1808.04).abs() < 1.0);
655    }
656
657    #[test]
658    fn test_component_label() {
659        let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
660        assert_eq!(component.label, Some("test_label".to_string()));
661
662        let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
663        assert_eq!(component_no_label.label, None);
664    }
665
666    #[test]
667    fn test_max_value_constraint() {
668        let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
669        config.max_value = Some(1000.0);
670
671        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
672        let samples = sampler.sample_n(1000);
673
674        // All samples should be at most 1000
675        assert!(samples.iter().all(|&x| x <= 1000.0));
676    }
677}