Skip to main content

datasynth_core/distributions/
mixture.rs

1//! Mixture model distributions for multi-modal data generation.
2//!
3//! Provides Gaussian and Log-Normal mixture models that can generate
4//! realistic multi-modal distributions commonly observed in accounting data
5//! (e.g., routine vs. significant vs. major transactions).
6
7use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13/// Configuration for a single Gaussian component in a mixture.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
17    pub weight: f64,
18    /// Mean (mu) of the Gaussian distribution
19    pub mu: f64,
20    /// Standard deviation (sigma) of the Gaussian distribution
21    pub sigma: f64,
22    /// Optional label for this component (e.g., "routine", "significant")
23    #[serde(default)]
24    pub label: Option<String>,
25}
26
27impl GaussianComponent {
28    /// Create a new Gaussian component.
29    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30        Self {
31            weight,
32            mu,
33            sigma,
34            label: None,
35        }
36    }
37
38    /// Create a labeled Gaussian component.
39    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40        Self {
41            weight,
42            mu,
43            sigma,
44            label: Some(label.into()),
45        }
46    }
47}
48
49/// Configuration for a Gaussian Mixture Model.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52    /// Components of the mixture
53    pub components: Vec<GaussianComponent>,
54    /// Whether to allow negative values (default: true)
55    #[serde(default = "default_true")]
56    pub allow_negative: bool,
57    /// Minimum value (if not allowing negative)
58    #[serde(default)]
59    pub min_value: Option<f64>,
60    /// Maximum value (clamps output)
61    #[serde(default)]
62    pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66    true
67}
68
69impl Default for GaussianMixtureConfig {
70    fn default() -> Self {
71        Self {
72            components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73            allow_negative: true,
74            min_value: None,
75            max_value: None,
76        }
77    }
78}
79
80impl GaussianMixtureConfig {
81    /// Create a new Gaussian mixture configuration.
82    pub fn new(components: Vec<GaussianComponent>) -> Self {
83        Self {
84            components,
85            ..Default::default()
86        }
87    }
88
89    /// Validate the configuration.
90    pub fn validate(&self) -> Result<(), String> {
91        if self.components.is_empty() {
92            return Err("At least one component is required".to_string());
93        }
94
95        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96        if (weight_sum - 1.0).abs() > 0.01 {
97            return Err(format!(
98                "Component weights must sum to 1.0, got {weight_sum}"
99            ));
100        }
101
102        for (i, component) in self.components.iter().enumerate() {
103            if component.weight < 0.0 || component.weight > 1.0 {
104                return Err(format!(
105                    "Component {} weight must be between 0.0 and 1.0, got {}",
106                    i, component.weight
107                ));
108            }
109            if component.sigma <= 0.0 {
110                return Err(format!(
111                    "Component {} sigma must be positive, got {}",
112                    i, component.sigma
113                ));
114            }
115        }
116
117        Ok(())
118    }
119}
120
121/// Configuration for a single Log-Normal component in a mixture.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124    /// Weight of this component (0.0-1.0, all weights should sum to 1.0)
125    pub weight: f64,
126    /// Mu parameter (location) of the log-normal distribution
127    pub mu: f64,
128    /// Sigma parameter (scale) of the log-normal distribution
129    pub sigma: f64,
130    /// Optional label for this component
131    #[serde(default)]
132    pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136    /// Create a new Log-Normal component.
137    pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138        Self {
139            weight,
140            mu,
141            sigma,
142            label: None,
143        }
144    }
145
146    /// Create a labeled Log-Normal component.
147    pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148        Self {
149            weight,
150            mu,
151            sigma,
152            label: Some(label.into()),
153        }
154    }
155
156    /// Get the expected value (mean) of this component.
157    pub fn expected_value(&self) -> f64 {
158        (self.mu + self.sigma.powi(2) / 2.0).exp()
159    }
160
161    /// Get the median of this component.
162    pub fn median(&self) -> f64 {
163        self.mu.exp()
164    }
165}
166
167/// Configuration for a Log-Normal Mixture Model.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170    /// Components of the mixture
171    pub components: Vec<LogNormalComponent>,
172    /// Minimum value (default: 0.01)
173    #[serde(default = "default_min_value")]
174    pub min_value: f64,
175    /// Maximum value (clamps output)
176    #[serde(default)]
177    pub max_value: Option<f64>,
178    /// Number of decimal places for rounding
179    #[serde(default = "default_decimal_places")]
180    pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184    0.01
185}
186
187fn default_decimal_places() -> u8 {
188    2
189}
190
191impl Default for LogNormalMixtureConfig {
192    fn default() -> Self {
193        Self {
194            components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195            min_value: 0.01,
196            max_value: None,
197            decimal_places: 2,
198        }
199    }
200}
201
202impl LogNormalMixtureConfig {
203    /// Create a new Log-Normal mixture configuration.
204    pub fn new(components: Vec<LogNormalComponent>) -> Self {
205        Self {
206            components,
207            ..Default::default()
208        }
209    }
210
211    /// Create a typical transaction amount mixture (routine/significant/major).
212    pub fn typical_transactions() -> Self {
213        Self {
214            components: vec![
215                LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216                LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217                LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218            ],
219            min_value: 0.01,
220            max_value: Some(100_000_000.0),
221            decimal_places: 2,
222        }
223    }
224
225    /// Validate the configuration.
226    pub fn validate(&self) -> Result<(), String> {
227        if self.components.is_empty() {
228            return Err("At least one component is required".to_string());
229        }
230
231        let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232        if (weight_sum - 1.0).abs() > 0.01 {
233            return Err(format!(
234                "Component weights must sum to 1.0, got {weight_sum}"
235            ));
236        }
237
238        for (i, component) in self.components.iter().enumerate() {
239            if component.weight < 0.0 || component.weight > 1.0 {
240                return Err(format!(
241                    "Component {} weight must be between 0.0 and 1.0, got {}",
242                    i, component.weight
243                ));
244            }
245            if component.sigma <= 0.0 {
246                return Err(format!(
247                    "Component {} sigma must be positive, got {}",
248                    i, component.sigma
249                ));
250            }
251        }
252
253        if self.min_value < 0.0 {
254            return Err("min_value must be non-negative".to_string());
255        }
256
257        Ok(())
258    }
259}
260
261/// Result of sampling with component information.
262#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264    /// The sampled value
265    pub value: f64,
266    /// Index of the component that generated this sample
267    pub component_index: usize,
268    /// Label of the component (if available)
269    pub component_label: Option<String>,
270}
271
272/// Gaussian Mixture Model sampler.
273#[derive(Clone)]
274pub struct GaussianMixtureSampler {
275    rng: ChaCha8Rng,
276    config: GaussianMixtureConfig,
277    /// Pre-computed cumulative weights for O(log n) component selection
278    cumulative_weights: Vec<f64>,
279    /// Normal distributions for each component
280    distributions: Vec<Normal<f64>>,
281}
282
283impl GaussianMixtureSampler {
284    /// Create a new Gaussian mixture sampler.
285    pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
286        config.validate()?;
287
288        // Pre-compute cumulative weights
289        let mut cumulative_weights = Vec::with_capacity(config.components.len());
290        let mut cumulative = 0.0;
291        for component in &config.components {
292            cumulative += component.weight;
293            cumulative_weights.push(cumulative);
294        }
295
296        // Create distributions
297        let distributions: Result<Vec<_>, _> = config
298            .components
299            .iter()
300            .map(|c| {
301                Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
302            })
303            .collect();
304
305        Ok(Self {
306            rng: ChaCha8Rng::seed_from_u64(seed),
307            config,
308            cumulative_weights,
309            distributions: distributions?,
310        })
311    }
312
313    /// Select a component using binary search on cumulative weights.
314    fn select_component(&mut self) -> usize {
315        let p: f64 = self.rng.random();
316        match self.cumulative_weights.binary_search_by(|w| {
317            w.partial_cmp(&p).unwrap_or_else(|| {
318                tracing::debug!("NaN detected in mixture weight comparison");
319                std::cmp::Ordering::Less
320            })
321        }) {
322            Ok(i) => i,
323            Err(i) => i.min(self.distributions.len() - 1),
324        }
325    }
326
327    /// Sample a value from the mixture.
328    pub fn sample(&mut self) -> f64 {
329        let component_idx = self.select_component();
330        let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332        // Apply constraints
333        if !self.config.allow_negative {
334            value = value.abs();
335        }
336        if let Some(min) = self.config.min_value {
337            value = value.max(min);
338        }
339        if let Some(max) = self.config.max_value {
340            value = value.min(max);
341        }
342
343        value
344    }
345
346    /// Sample a value with component information.
347    pub fn sample_with_component(&mut self) -> SampleWithComponent {
348        let component_idx = self.select_component();
349        let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351        // Apply constraints
352        if !self.config.allow_negative {
353            value = value.abs();
354        }
355        if let Some(min) = self.config.min_value {
356            value = value.max(min);
357        }
358        if let Some(max) = self.config.max_value {
359            value = value.min(max);
360        }
361
362        SampleWithComponent {
363            value,
364            component_index: component_idx,
365            component_label: self.config.components[component_idx].label.clone(),
366        }
367    }
368
369    /// Sample multiple values.
370    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371        (0..n).map(|_| self.sample()).collect()
372    }
373
374    /// Reset the sampler with a new seed.
375    pub fn reset(&mut self, seed: u64) {
376        self.rng = ChaCha8Rng::seed_from_u64(seed);
377    }
378
379    /// Get the configuration.
380    pub fn config(&self) -> &GaussianMixtureConfig {
381        &self.config
382    }
383
384    /// v4.1.6+: inverse CDF (quantile) for the Gaussian mixture via
385    /// bisection. Result clamped to `[min_value, max_value]`.
386    pub fn ppf(&self, u: f64) -> f64 {
387        let u = u.clamp(1e-9, 1.0 - 1e-9);
388        let min = self.config.min_value.unwrap_or(-1e15);
389        let max = self.config.max_value.unwrap_or(1e15);
390        let (mut lo, mut hi) = (min, max);
391        for _ in 0..64 {
392            let mid = (lo + hi) / 2.0;
393            let f_mid = mixture_gaussian_cdf(&self.config.components, mid);
394            if f_mid < u {
395                lo = mid;
396            } else {
397                hi = mid;
398            }
399            if hi - lo < 1e-6 * mid.abs().max(1.0) {
400                break;
401            }
402        }
403        ((lo + hi) / 2.0).clamp(min, max)
404    }
405}
406
407/// CDF of a Gaussian mixture at `x`.
408fn mixture_gaussian_cdf(components: &[GaussianComponent], x: f64) -> f64 {
409    components
410        .iter()
411        .map(|c| c.weight * standard_normal_cdf_gauss((x - c.mu) / c.sigma))
412        .sum()
413}
414
415fn standard_normal_cdf_gauss(x: f64) -> f64 {
416    0.5 * (1.0 + erf_gauss(x / std::f64::consts::SQRT_2))
417}
418
419fn erf_gauss(x: f64) -> f64 {
420    let sign = if x < 0.0 { -1.0 } else { 1.0 };
421    let x = x.abs();
422    let t = 1.0 / (1.0 + 0.3275911 * x);
423    let y = 1.0
424        - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
425            + 0.254829592)
426            * t
427            * (-x * x).exp();
428    sign * y
429}
430
431/// Log-Normal Mixture Model sampler for positive-only distributions.
432#[derive(Clone)]
433pub struct LogNormalMixtureSampler {
434    rng: ChaCha8Rng,
435    config: LogNormalMixtureConfig,
436    /// Pre-computed cumulative weights for O(log n) component selection
437    cumulative_weights: Vec<f64>,
438    /// Log-normal distributions for each component
439    distributions: Vec<LogNormal<f64>>,
440    /// Decimal multiplier for rounding
441    decimal_multiplier: f64,
442}
443
444impl LogNormalMixtureSampler {
445    /// Create a new Log-Normal mixture sampler.
446    pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
447        config.validate()?;
448
449        // Pre-compute cumulative weights
450        let mut cumulative_weights = Vec::with_capacity(config.components.len());
451        let mut cumulative = 0.0;
452        for component in &config.components {
453            cumulative += component.weight;
454            cumulative_weights.push(cumulative);
455        }
456
457        // Create distributions
458        let distributions: Result<Vec<_>, _> = config
459            .components
460            .iter()
461            .map(|c| {
462                LogNormal::new(c.mu, c.sigma)
463                    .map_err(|e| format!("Invalid log-normal distribution: {e}"))
464            })
465            .collect();
466
467        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
468
469        Ok(Self {
470            rng: ChaCha8Rng::seed_from_u64(seed),
471            config,
472            cumulative_weights,
473            distributions: distributions?,
474            decimal_multiplier,
475        })
476    }
477
478    /// Select a component using binary search on cumulative weights.
479    fn select_component(&mut self) -> usize {
480        let p: f64 = self.rng.random();
481        match self.cumulative_weights.binary_search_by(|w| {
482            w.partial_cmp(&p).unwrap_or_else(|| {
483                tracing::debug!("NaN detected in mixture weight comparison");
484                std::cmp::Ordering::Less
485            })
486        }) {
487            Ok(i) => i,
488            Err(i) => i.min(self.distributions.len() - 1),
489        }
490    }
491
492    /// Sample a value from the mixture.
493    pub fn sample(&mut self) -> f64 {
494        let component_idx = self.select_component();
495        let mut value = self.distributions[component_idx].sample(&mut self.rng);
496
497        // Apply constraints
498        value = value.max(self.config.min_value);
499        if let Some(max) = self.config.max_value {
500            value = value.min(max);
501        }
502
503        // Round to decimal places
504        (value * self.decimal_multiplier).round() / self.decimal_multiplier
505    }
506
507    /// Sample a value as Decimal.
508    pub fn sample_decimal(&mut self) -> Decimal {
509        let value = self.sample();
510        Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
511    }
512
513    /// Sample a value with component information.
514    pub fn sample_with_component(&mut self) -> SampleWithComponent {
515        let component_idx = self.select_component();
516        let mut value = self.distributions[component_idx].sample(&mut self.rng);
517
518        // Apply constraints
519        value = value.max(self.config.min_value);
520        if let Some(max) = self.config.max_value {
521            value = value.min(max);
522        }
523
524        // Round to decimal places
525        value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
526
527        SampleWithComponent {
528            value,
529            component_index: component_idx,
530            component_label: self.config.components[component_idx].label.clone(),
531        }
532    }
533
534    /// Sample multiple values.
535    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
536        (0..n).map(|_| self.sample()).collect()
537    }
538
539    /// Sample multiple values as Decimals.
540    pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
541        (0..n).map(|_| self.sample_decimal()).collect()
542    }
543
544    /// Reset the sampler with a new seed.
545    pub fn reset(&mut self, seed: u64) {
546        self.rng = ChaCha8Rng::seed_from_u64(seed);
547    }
548
549    /// Get the configuration.
550    pub fn config(&self) -> &LogNormalMixtureConfig {
551        &self.config
552    }
553
554    /// Get the expected value of the mixture.
555    pub fn expected_value(&self) -> f64 {
556        self.config
557            .components
558            .iter()
559            .map(|c| c.weight * c.expected_value())
560            .sum()
561    }
562
563    /// v4.1.6+: inverse CDF (quantile) for the mixture, computed via
564    /// bisection.  Given `u ∈ (0, 1)` returns the value `x` such that
565    /// `F(x) = u`, where `F` is the mixture CDF (weighted sum of the
566    /// component log-normal CDFs). Result is clamped to
567    /// `[min_value, max_value]` and rounded to `decimal_places`.
568    pub fn ppf(&self, u: f64) -> f64 {
569        let u = u.clamp(1e-9, 1.0 - 1e-9);
570        let max = self.config.max_value.unwrap_or(1e15);
571        let min = self.config.min_value.max(1e-9);
572        // Bisection — the mixture CDF is monotone.
573        let (mut lo, mut hi) = (min, max);
574        for _ in 0..64 {
575            let mid = (lo + hi) / 2.0;
576            let f_mid = mixture_log_normal_cdf(&self.config.components, mid);
577            if f_mid < u {
578                lo = mid;
579            } else {
580                hi = mid;
581            }
582            if hi - lo < 1e-6 * mid.abs().max(1.0) {
583                break;
584            }
585        }
586        let value = ((lo + hi) / 2.0).clamp(min, max);
587        (value * self.decimal_multiplier).round() / self.decimal_multiplier
588    }
589
590    /// v4.1.6+: inverse CDF as Decimal.
591    pub fn ppf_decimal(&self, u: f64) -> Decimal {
592        Decimal::from_f64_retain(self.ppf(u)).unwrap_or(Decimal::ONE)
593    }
594}
595
596/// CDF of a log-normal mixture at `x` (standard normal CDF applied to
597/// `(ln(x) - μ) / σ` for each component, weighted by component weights).
598fn mixture_log_normal_cdf(components: &[LogNormalComponent], x: f64) -> f64 {
599    if x <= 0.0 {
600        return 0.0;
601    }
602    let log_x = x.ln();
603    components
604        .iter()
605        .map(|c| c.weight * standard_normal_cdf((log_x - c.mu) / c.sigma))
606        .sum()
607}
608
609/// Standard normal CDF via an erf approximation. Matches
610/// `crate::distributions::validation::erf` to 7 digits.
611fn standard_normal_cdf(x: f64) -> f64 {
612    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
613}
614
615fn erf(x: f64) -> f64 {
616    let sign = if x < 0.0 { -1.0 } else { 1.0 };
617    let x = x.abs();
618    let t = 1.0 / (1.0 + 0.3275911 * x);
619    let y = 1.0
620        - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
621            + 0.254829592)
622            * t
623            * (-x * x).exp();
624    sign * y
625}
626
627#[cfg(test)]
628#[allow(clippy::unwrap_used)]
629mod tests {
630    use super::*;
631
632    #[test]
633    fn test_gaussian_mixture_validation() {
634        // Valid config
635        let config = GaussianMixtureConfig::new(vec![
636            GaussianComponent::new(0.5, 0.0, 1.0),
637            GaussianComponent::new(0.5, 5.0, 2.0),
638        ]);
639        assert!(config.validate().is_ok());
640
641        // Invalid: weights don't sum to 1.0
642        let invalid_config = GaussianMixtureConfig::new(vec![
643            GaussianComponent::new(0.3, 0.0, 1.0),
644            GaussianComponent::new(0.3, 5.0, 2.0),
645        ]);
646        assert!(invalid_config.validate().is_err());
647
648        // Invalid: negative sigma
649        let invalid_config =
650            GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
651        assert!(invalid_config.validate().is_err());
652    }
653
654    #[test]
655    fn test_gaussian_mixture_sampling() {
656        let config = GaussianMixtureConfig::new(vec![
657            GaussianComponent::new(0.5, 0.0, 1.0),
658            GaussianComponent::new(0.5, 10.0, 1.0),
659        ]);
660        let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
661
662        let samples = sampler.sample_n(1000);
663        assert_eq!(samples.len(), 1000);
664
665        // Check that samples are distributed around both means
666        let low_count = samples.iter().filter(|&&x| x < 5.0).count();
667        let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
668
669        // Both should be roughly 50% (with some tolerance)
670        assert!(low_count > 350 && low_count < 650);
671        assert!(high_count > 350 && high_count < 650);
672    }
673
674    #[test]
675    fn test_gaussian_mixture_determinism() {
676        let config = GaussianMixtureConfig::new(vec![
677            GaussianComponent::new(0.5, 0.0, 1.0),
678            GaussianComponent::new(0.5, 10.0, 1.0),
679        ]);
680
681        let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
682        let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
683
684        for _ in 0..100 {
685            assert_eq!(sampler1.sample(), sampler2.sample());
686        }
687    }
688
689    #[test]
690    fn test_lognormal_mixture_validation() {
691        // Valid config
692        let config = LogNormalMixtureConfig::new(vec![
693            LogNormalComponent::new(0.6, 6.0, 1.5),
694            LogNormalComponent::new(0.4, 8.5, 1.0),
695        ]);
696        assert!(config.validate().is_ok());
697
698        // Invalid: weights don't sum to 1.0
699        let invalid_config = LogNormalMixtureConfig::new(vec![
700            LogNormalComponent::new(0.2, 6.0, 1.5),
701            LogNormalComponent::new(0.2, 8.5, 1.0),
702        ]);
703        assert!(invalid_config.validate().is_err());
704    }
705
706    #[test]
707    fn test_lognormal_mixture_sampling() {
708        let config = LogNormalMixtureConfig::typical_transactions();
709        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
710
711        let samples = sampler.sample_n(1000);
712        assert_eq!(samples.len(), 1000);
713
714        // All samples should be positive
715        assert!(samples.iter().all(|&x| x > 0.0));
716
717        // Check minimum value constraint
718        assert!(samples.iter().all(|&x| x >= 0.01));
719    }
720
721    #[test]
722    fn test_sample_with_component() {
723        let config = LogNormalMixtureConfig::new(vec![
724            LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
725            LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
726        ]);
727        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
728
729        let mut small_count = 0;
730        let mut large_count = 0;
731
732        for _ in 0..1000 {
733            let result = sampler.sample_with_component();
734            match result.component_label.as_deref() {
735                Some("small") => small_count += 1,
736                Some("large") => large_count += 1,
737                _ => panic!("Unexpected label"),
738            }
739        }
740
741        // Both components should be selected roughly equally
742        assert!(small_count > 400 && small_count < 600);
743        assert!(large_count > 400 && large_count < 600);
744    }
745
746    #[test]
747    fn test_lognormal_mixture_determinism() {
748        let config = LogNormalMixtureConfig::typical_transactions();
749
750        let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
751        let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
752
753        for _ in 0..100 {
754            assert_eq!(sampler1.sample(), sampler2.sample());
755        }
756    }
757
758    #[test]
759    fn test_lognormal_expected_value() {
760        let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
761        let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
762
763        // E[X] = exp(mu + sigma^2/2) = exp(7 + 0.5) = exp(7.5) ≈ 1808
764        let expected = sampler.expected_value();
765        assert!((expected - 1808.04).abs() < 1.0);
766    }
767
768    #[test]
769    fn test_component_label() {
770        let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
771        assert_eq!(component.label, Some("test_label".to_string()));
772
773        let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
774        assert_eq!(component_no_label.label, None);
775    }
776
777    #[test]
778    fn test_max_value_constraint() {
779        let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
780        config.max_value = Some(1000.0);
781
782        let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
783        let samples = sampler.sample_n(1000);
784
785        // All samples should be at most 1000
786        assert!(samples.iter().all(|&x| x <= 1000.0));
787    }
788}