Skip to main content

datasynth_core/distributions/
mixture.rs

1//! Mixture model distributions for multi-modal data generation.
2//!
3//! Provides Gaussian and Log-Normal mixture models that can generate
4//! realistic multi-modal distributions commonly observed in accounting data
5//! (e.g., routine vs. significant vs. major transactions).
6
7use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13/// Configuration for a single Gaussian component in a mixture.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
17    pub weight: f64,
18    /// Mean (mu) of the Gaussian distribution
19    pub mu: f64,
20    /// Standard deviation (sigma) of the Gaussian distribution
21    pub sigma: f64,
22    /// Optional label for this component (e.g., "routine", "significant")
23    #[serde(default)]
24    pub label: Option<String>,
25}
26
27impl GaussianComponent {
28    /// Create a new Gaussian component.
29    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30        Self {
31            weight,
32            mu,
33            sigma,
34            label: None,
35        }
36    }
37
38    /// Create a labeled Gaussian component.
39    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40        Self {
41            weight,
42            mu,
43            sigma,
44            label: Some(label.into()),
45        }
46    }
47}
48
49/// Configuration for a Gaussian Mixture Model.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52    /// Components of the mixture
53    pub components: Vec<GaussianComponent>,
54    /// Whether to allow negative values (default: true)
55    #[serde(default = "default_true")]
56    pub allow_negative: bool,
57    /// Minimum value (if not allowing negative)
58    #[serde(default)]
59    pub min_value: Option<f64>,
60    /// Maximum value (clamps output)
61    #[serde(default)]
62    pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66    true
67}
68
69impl Default for GaussianMixtureConfig {
70    fn default() -> Self {
71        Self {
72            components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73            allow_negative: true,
74            min_value: None,
75            max_value: None,
76        }
77    }
78}
79
80impl GaussianMixtureConfig {
81    /// Create a new Gaussian mixture configuration.
82    pub fn new(components: Vec<GaussianComponent>) -> Self {
83        Self {
84            components,
85            ..Default::default()
86        }
87    }
88
89    /// Validate the configuration.
90    pub fn validate(&self) -> Result<(), String> {
91        if self.components.is_empty() {
92            return Err("At least one component is required".to_string());
93        }
94
95        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96        if (weight_sum - 1.0).abs() > 0.01 {
97            return Err(format!(
98                "Component weights must sum to 1.0, got {weight_sum}"
99            ));
100        }
101
102        for (i, component) in self.components.iter().enumerate() {
103            if component.weight < 0.0 || component.weight > 1.0 {
104                return Err(format!(
105                    "Component {} weight must be between 0.0 and 1.0, got {}",
106                    i, component.weight
107                ));
108            }
109            if component.sigma <= 0.0 {
110                return Err(format!(
111                    "Component {} sigma must be positive, got {}",
112                    i, component.sigma
113                ));
114            }
115        }
116
117        Ok(())
118    }
119}
120
121/// Configuration for a single Log-Normal component in a mixture.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
125    pub weight: f64,
126    /// Mu parameter (location) of the log-normal distribution
127    pub mu: f64,
128    /// Sigma parameter (scale) of the log-normal distribution
129    pub sigma: f64,
130    /// Optional label for this component
131    #[serde(default)]
132    pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136    /// Create a new Log-Normal component.
137    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138        Self {
139            weight,
140            mu,
141            sigma,
142            label: None,
143        }
144    }
145
146    /// Create a labeled Log-Normal component.
147    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148        Self {
149            weight,
150            mu,
151            sigma,
152            label: Some(label.into()),
153        }
154    }
155
156    /// Get the expected value (mean) of this component.
157    pub fn expected_value(&self) -> f64 {
158        (self.mu + self.sigma.powi(2) / 2.0).exp()
159    }
160
161    /// Get the median of this component.
162    pub fn median(&self) -> f64 {
163        self.mu.exp()
164    }
165}
166
167/// Configuration for a Log-Normal Mixture Model.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170    /// Components of the mixture
171    pub components: Vec<LogNormalComponent>,
172    /// Minimum value (default: 0.01)
173    #[serde(default = "default_min_value")]
174    pub min_value: f64,
175    /// Maximum value (clamps output)
176    #[serde(default)]
177    pub max_value: Option<f64>,
178    /// Number of decimal places for rounding
179    #[serde(default = "default_decimal_places")]
180    pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184    0.01
185}
186
187fn default_decimal_places() -> u8 {
188    2
189}
190
191impl Default for LogNormalMixtureConfig {
192    fn default() -> Self {
193        Self {
194            components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195            min_value: 0.01,
196            max_value: None,
197            decimal_places: 2,
198        }
199    }
200}
201
202impl LogNormalMixtureConfig {
203    /// Create a new Log-Normal mixture configuration.
204    pub fn new(components: Vec<LogNormalComponent>) -> Self {
205        Self {
206            components,
207            ..Default::default()
208        }
209    }
210
211    /// Create a typical transaction amount mixture (routine/significant/major).
212    pub fn typical_transactions() -> Self {
213        Self {
214            components: vec![
215                LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216                LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217                LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218            ],
219            min_value: 0.01,
220            max_value: Some(100_000_000.0),
221            decimal_places: 2,
222        }
223    }
224
225    /// Validate the configuration.
226    pub fn validate(&self) -> Result<(), String> {
227        if self.components.is_empty() {
228            return Err("At least one component is required".to_string());
229        }
230
231        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232        if (weight_sum - 1.0).abs() > 0.01 {
233            return Err(format!(
234                "Component weights must sum to 1.0, got {weight_sum}"
235            ));
236        }
237
238        for (i, component) in self.components.iter().enumerate() {
239            if component.weight < 0.0 || component.weight > 1.0 {
240                return Err(format!(
241                    "Component {} weight must be between 0.0 and 1.0, got {}",
242                    i, component.weight
243                ));
244            }
245            if component.sigma <= 0.0 {
246                return Err(format!(
247                    "Component {} sigma must be positive, got {}",
248                    i, component.sigma
249                ));
250            }
251        }
252
253        if self.min_value < 0.0 {
254            return Err("min_value must be non-negative".to_string());
255        }
256
257        Ok(())
258    }
259}
260
261/// Result of sampling with component information.
262#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264    /// The sampled value
265    pub value: f64,
266    /// Index of the component that generated this sample
267    pub component_index: usize,
268    /// Label of the component (if available)
269    pub component_label: Option<String>,
270}
271
272/// Gaussian Mixture Model sampler.
273#[derive(Clone)]
274pub struct GaussianMixtureSampler {
275    rng: ChaCha8Rng,
276    config: GaussianMixtureConfig,
277    /// Pre-computed cumulative weights for O(log n) component selection
278    cumulative_weights: Vec<f64>,
279    /// Normal distributions for each component
280    distributions: Vec<Normal<f64>>,
281}
282
283impl GaussianMixtureSampler {
284    /// Create a new Gaussian mixture sampler.
285    pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
286        config.validate()?;
287
288        // Pre-compute cumulative weights
289        let mut cumulative_weights = Vec::with_capacity(config.components.len());
290        let mut cumulative = 0.0;
291        for component in &config.components {
292            cumulative += component.weight;
293            cumulative_weights.push(cumulative);
294        }
295
296        // Create distributions
297        let distributions: Result<Vec<_>, _> = config
298            .components
299            .iter()
300            .map(|c| {
301                Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
302            })
303            .collect();
304
305        Ok(Self {
306            rng: ChaCha8Rng::seed_from_u64(seed),
307            config,
308            cumulative_weights,
309            distributions: distributions?,
310        })
311    }
312
313    /// Select a component using binary search on cumulative weights.
314    fn select_component(&mut self) -> usize {
315        let p: f64 = self.rng.random();
316        match self.cumulative_weights.binary_search_by(|w| {
317            w.partial_cmp(&p).unwrap_or_else(|| {
318                tracing::debug!("NaN detected in mixture weight comparison");
319                std::cmp::Ordering::Less
320            })
321        }) {
322            Ok(i) => i,
323            Err(i) => i.min(self.distributions.len() - 1),
324        }
325    }
326
327    /// Sample a value from the mixture.
328    pub fn sample(&mut self) -> f64 {
329        let component_idx = self.select_component();
330        let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332        // Apply constraints
333        if !self.config.allow_negative {
334            value = value.abs();
335        }
336        if let Some(min) = self.config.min_value {
337            value = value.max(min);
338        }
339        if let Some(max) = self.config.max_value {
340            value = value.min(max);
341        }
342
343        value
344    }
345
346    /// Sample a value with component information.
347    pub fn sample_with_component(&mut self) -> SampleWithComponent {
348        let component_idx = self.select_component();
349        let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351        // Apply constraints
352        if !self.config.allow_negative {
353            value = value.abs();
354        }
355        if let Some(min) = self.config.min_value {
356            value = value.max(min);
357        }
358        if let Some(max) = self.config.max_value {
359            value = value.min(max);
360        }
361
362        SampleWithComponent {
363            value,
364            component_index: component_idx,
365            component_label: self.config.components[component_idx].label.clone(),
366        }
367    }
368
369    /// Sample multiple values.
370    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371        (0..n).map(|_| self.sample()).collect()
372    }
373
374    /// Reset the sampler with a new seed.
375    pub fn reset(&mut self, seed: u64) {
376        self.rng = ChaCha8Rng::seed_from_u64(seed);
377    }
378
379    /// Get the configuration.
380    pub fn config(&self) -> &GaussianMixtureConfig {
381        &self.config
382    }
383}
384
385/// Log-Normal Mixture Model sampler for positive-only distributions.
386#[derive(Clone)]
387pub struct LogNormalMixtureSampler {
388    rng: ChaCha8Rng,
389    config: LogNormalMixtureConfig,
390    /// Pre-computed cumulative weights for O(log n) component selection
391    cumulative_weights: Vec<f64>,
392    /// Log-normal distributions for each component
393    distributions: Vec<LogNormal<f64>>,
394    /// Decimal multiplier for rounding
395    decimal_multiplier: f64,
396}
397
398impl LogNormalMixtureSampler {
399    /// Create a new Log-Normal mixture sampler.
400    pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
401        config.validate()?;
402
403        // Pre-compute cumulative weights
404        let mut cumulative_weights = Vec::with_capacity(config.components.len());
405        let mut cumulative = 0.0;
406        for component in &config.components {
407            cumulative += component.weight;
408            cumulative_weights.push(cumulative);
409        }
410
411        // Create distributions
412        let distributions: Result<Vec<_>, _> = config
413            .components
414            .iter()
415            .map(|c| {
416                LogNormal::new(c.mu, c.sigma)
417                    .map_err(|e| format!("Invalid log-normal distribution: {e}"))
418            })
419            .collect();
420
421        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
422
423        Ok(Self {
424            rng: ChaCha8Rng::seed_from_u64(seed),
425            config,
426            cumulative_weights,
427            distributions: distributions?,
428            decimal_multiplier,
429        })
430    }
431
432    /// Select a component using binary search on cumulative weights.
433    fn select_component(&mut self) -> usize {
434        let p: f64 = self.rng.random();
435        match self.cumulative_weights.binary_search_by(|w| {
436            w.partial_cmp(&p).unwrap_or_else(|| {
437                tracing::debug!("NaN detected in mixture weight comparison");
438                std::cmp::Ordering::Less
439            })
440        }) {
441            Ok(i) => i,
442            Err(i) => i.min(self.distributions.len() - 1),
443        }
444    }
445
446    /// Sample a value from the mixture.
447    pub fn sample(&mut self) -> f64 {
448        let component_idx = self.select_component();
449        let mut value = self.distributions[component_idx].sample(&mut self.rng);
450
451        // Apply constraints
452        value = value.max(self.config.min_value);
453        if let Some(max) = self.config.max_value {
454            value = value.min(max);
455        }
456
457        // Round to decimal places
458        (value * self.decimal_multiplier).round() / self.decimal_multiplier
459    }
460
461    /// Sample a value as Decimal.
462    pub fn sample_decimal(&mut self) -> Decimal {
463        let value = self.sample();
464        Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
465    }
466
467    /// Sample a value with component information.
468    pub fn sample_with_component(&mut self) -> SampleWithComponent {
469        let component_idx = self.select_component();
470        let mut value = self.distributions[component_idx].sample(&mut self.rng);
471
472        // Apply constraints
473        value = value.max(self.config.min_value);
474        if let Some(max) = self.config.max_value {
475            value = value.min(max);
476        }
477
478        // Round to decimal places
479        value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
480
481        SampleWithComponent {
482            value,
483            component_index: component_idx,
484            component_label: self.config.components[component_idx].label.clone(),
485        }
486    }
487
488    /// Sample multiple values.
489    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
490        (0..n).map(|_| self.sample()).collect()
491    }
492
493    /// Sample multiple values as Decimals.
494    pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
495        (0..n).map(|_| self.sample_decimal()).collect()
496    }
497
498    /// Reset the sampler with a new seed.
499    pub fn reset(&mut self, seed: u64) {
500        self.rng = ChaCha8Rng::seed_from_u64(seed);
501    }
502
503    /// Get the configuration.
504    pub fn config(&self) -> &LogNormalMixtureConfig {
505        &self.config
506    }
507
508    /// Get the expected value of the mixture.
509    pub fn expected_value(&self) -> f64 {
510        self.config
511            .components
512            .iter()
513            .map(|c| c.weight * c.expected_value())
514            .sum()
515    }
516}
517
518#[cfg(test)]
519#[allow(clippy::unwrap_used)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_gaussian_mixture_validation() {
525        // Valid config
526        let config = GaussianMixtureConfig::new(vec![
527            GaussianComponent::new(0.5, 0.0, 1.0),
528            GaussianComponent::new(0.5, 5.0, 2.0),
529        ]);
530        assert!(config.validate().is_ok());
531
532        // Invalid: weights don't sum to 1.0
533        let invalid_config = GaussianMixtureConfig::new(vec![
534            GaussianComponent::new(0.3, 0.0, 1.0),
535            GaussianComponent::new(0.3, 5.0, 2.0),
536        ]);
537        assert!(invalid_config.validate().is_err());
538
539        // Invalid: negative sigma
540        let invalid_config =
541            GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
542        assert!(invalid_config.validate().is_err());
543    }
544
545    #[test]
546    fn test_gaussian_mixture_sampling() {
547        let config = GaussianMixtureConfig::new(vec![
548            GaussianComponent::new(0.5, 0.0, 1.0),
549            GaussianComponent::new(0.5, 10.0, 1.0),
550        ]);
551        let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
552
553        let samples = sampler.sample_n(1000);
554        assert_eq!(samples.len(), 1000);
555
556        // Check that samples are distributed around both means
557        let low_count = samples.iter().filter(|&&x| x < 5.0).count();
558        let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
559
560        // Both should be roughly 50% (with some tolerance)
561        assert!(low_count > 350 && low_count < 650);
562        assert!(high_count > 350 && high_count < 650);
563    }
564
565    #[test]
566    fn test_gaussian_mixture_determinism() {
567        let config = GaussianMixtureConfig::new(vec![
568            GaussianComponent::new(0.5, 0.0, 1.0),
569            GaussianComponent::new(0.5, 10.0, 1.0),
570        ]);
571
572        let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
573        let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
574
575        for _ in 0..100 {
576            assert_eq!(sampler1.sample(), sampler2.sample());
577        }
578    }
579
580    #[test]
581    fn test_lognormal_mixture_validation() {
582        // Valid config
583        let config = LogNormalMixtureConfig::new(vec![
584            LogNormalComponent::new(0.6, 6.0, 1.5),
585            LogNormalComponent::new(0.4, 8.5, 1.0),
586        ]);
587        assert!(config.validate().is_ok());
588
589        // Invalid: weights don't sum to 1.0
590        let invalid_config = LogNormalMixtureConfig::new(vec![
591            LogNormalComponent::new(0.2, 6.0, 1.5),
592            LogNormalComponent::new(0.2, 8.5, 1.0),
593        ]);
594        assert!(invalid_config.validate().is_err());
595    }
596
597    #[test]
598    fn test_lognormal_mixture_sampling() {
599        let config = LogNormalMixtureConfig::typical_transactions();
600        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
601
602        let samples = sampler.sample_n(1000);
603        assert_eq!(samples.len(), 1000);
604
605        // All samples should be positive
606        assert!(samples.iter().all(|&x| x > 0.0));
607
608        // Check minimum value constraint
609        assert!(samples.iter().all(|&x| x >= 0.01));
610    }
611
612    #[test]
613    fn test_sample_with_component() {
614        let config = LogNormalMixtureConfig::new(vec![
615            LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
616            LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
617        ]);
618        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
619
620        let mut small_count = 0;
621        let mut large_count = 0;
622
623        for _ in 0..1000 {
624            let result = sampler.sample_with_component();
625            match result.component_label.as_deref() {
626                Some("small") => small_count += 1,
627                Some("large") => large_count += 1,
628                _ => panic!("Unexpected label"),
629            }
630        }
631
632        // Both components should be selected roughly equally
633        assert!(small_count > 400 && small_count < 600);
634        assert!(large_count > 400 && large_count < 600);
635    }
636
637    #[test]
638    fn test_lognormal_mixture_determinism() {
639        let config = LogNormalMixtureConfig::typical_transactions();
640
641        let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
642        let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
643
644        for _ in 0..100 {
645            assert_eq!(sampler1.sample(), sampler2.sample());
646        }
647    }
648
649    #[test]
650    fn test_lognormal_expected_value() {
651        let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
652        let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
653
654        // E[X] = exp(mu + sigma^2/2) = exp(7 + 0.5) = exp(7.5) ≈ 1808
655        let expected = sampler.expected_value();
656        assert!((expected - 1808.04).abs() < 1.0);
657    }
658
659    #[test]
660    fn test_component_label() {
661        let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
662        assert_eq!(component.label, Some("test_label".to_string()));
663
664        let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
665        assert_eq!(component_no_label.label, None);
666    }
667
668    #[test]
669    fn test_max_value_constraint() {
670        let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
671        config.max_value = Some(1000.0);
672
673        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
674        let samples = sampler.sample_n(1000);
675
676        // All samples should be at most 1000
677        assert!(samples.iter().all(|&x| x <= 1000.0));
678    }
679}