1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {weight_sum}"
99 ));
100 }
101
102 for (i, component) in self.components.iter().enumerate() {
103 if component.weight < 0.0 || component.weight > 1.0 {
104 return Err(format!(
105 "Component {} weight must be between 0.0 and 1.0, got {}",
106 i, component.weight
107 ));
108 }
109 if component.sigma <= 0.0 {
110 return Err(format!(
111 "Component {} sigma must be positive, got {}",
112 i, component.sigma
113 ));
114 }
115 }
116
117 Ok(())
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124 pub weight: f64,
126 pub mu: f64,
128 pub sigma: f64,
130 #[serde(default)]
132 pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138 Self {
139 weight,
140 mu,
141 sigma,
142 label: None,
143 }
144 }
145
146 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148 Self {
149 weight,
150 mu,
151 sigma,
152 label: Some(label.into()),
153 }
154 }
155
156 pub fn expected_value(&self) -> f64 {
158 (self.mu + self.sigma.powi(2) / 2.0).exp()
159 }
160
161 pub fn median(&self) -> f64 {
163 self.mu.exp()
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170 pub components: Vec<LogNormalComponent>,
172 #[serde(default = "default_min_value")]
174 pub min_value: f64,
175 #[serde(default)]
177 pub max_value: Option<f64>,
178 #[serde(default = "default_decimal_places")]
180 pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184 0.01
185}
186
187fn default_decimal_places() -> u8 {
188 2
189}
190
191impl Default for LogNormalMixtureConfig {
192 fn default() -> Self {
193 Self {
194 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195 min_value: 0.01,
196 max_value: None,
197 decimal_places: 2,
198 }
199 }
200}
201
202impl LogNormalMixtureConfig {
203 pub fn new(components: Vec<LogNormalComponent>) -> Self {
205 Self {
206 components,
207 ..Default::default()
208 }
209 }
210
211 pub fn typical_transactions() -> Self {
213 Self {
214 components: vec![
215 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218 ],
219 min_value: 0.01,
220 max_value: Some(100_000_000.0),
221 decimal_places: 2,
222 }
223 }
224
225 pub fn validate(&self) -> Result<(), String> {
227 if self.components.is_empty() {
228 return Err("At least one component is required".to_string());
229 }
230
231 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232 if (weight_sum - 1.0).abs() > 0.01 {
233 return Err(format!(
234 "Component weights must sum to 1.0, got {weight_sum}"
235 ));
236 }
237
238 for (i, component) in self.components.iter().enumerate() {
239 if component.weight < 0.0 || component.weight > 1.0 {
240 return Err(format!(
241 "Component {} weight must be between 0.0 and 1.0, got {}",
242 i, component.weight
243 ));
244 }
245 if component.sigma <= 0.0 {
246 return Err(format!(
247 "Component {} sigma must be positive, got {}",
248 i, component.sigma
249 ));
250 }
251 }
252
253 if self.min_value < 0.0 {
254 return Err("min_value must be non-negative".to_string());
255 }
256
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264 pub value: f64,
266 pub component_index: usize,
268 pub component_label: Option<String>,
270}
271
272#[derive(Clone)]
274pub struct GaussianMixtureSampler {
275 rng: ChaCha8Rng,
276 config: GaussianMixtureConfig,
277 cumulative_weights: Vec<f64>,
279 distributions: Vec<Normal<f64>>,
281}
282
283impl GaussianMixtureSampler {
284 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
286 config.validate()?;
287
288 let mut cumulative_weights = Vec::with_capacity(config.components.len());
290 let mut cumulative = 0.0;
291 for component in &config.components {
292 cumulative += component.weight;
293 cumulative_weights.push(cumulative);
294 }
295
296 let distributions: Result<Vec<_>, _> = config
298 .components
299 .iter()
300 .map(|c| {
301 Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
302 })
303 .collect();
304
305 Ok(Self {
306 rng: ChaCha8Rng::seed_from_u64(seed),
307 config,
308 cumulative_weights,
309 distributions: distributions?,
310 })
311 }
312
313 fn select_component(&mut self) -> usize {
315 let p: f64 = self.rng.random();
316 match self.cumulative_weights.binary_search_by(|w| {
317 w.partial_cmp(&p).unwrap_or_else(|| {
318 tracing::debug!("NaN detected in mixture weight comparison");
319 std::cmp::Ordering::Less
320 })
321 }) {
322 Ok(i) => i,
323 Err(i) => i.min(self.distributions.len() - 1),
324 }
325 }
326
327 pub fn sample(&mut self) -> f64 {
329 let component_idx = self.select_component();
330 let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332 if !self.config.allow_negative {
334 value = value.abs();
335 }
336 if let Some(min) = self.config.min_value {
337 value = value.max(min);
338 }
339 if let Some(max) = self.config.max_value {
340 value = value.min(max);
341 }
342
343 value
344 }
345
346 pub fn sample_with_component(&mut self) -> SampleWithComponent {
348 let component_idx = self.select_component();
349 let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351 if !self.config.allow_negative {
353 value = value.abs();
354 }
355 if let Some(min) = self.config.min_value {
356 value = value.max(min);
357 }
358 if let Some(max) = self.config.max_value {
359 value = value.min(max);
360 }
361
362 SampleWithComponent {
363 value,
364 component_index: component_idx,
365 component_label: self.config.components[component_idx].label.clone(),
366 }
367 }
368
369 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371 (0..n).map(|_| self.sample()).collect()
372 }
373
374 pub fn reset(&mut self, seed: u64) {
376 self.rng = ChaCha8Rng::seed_from_u64(seed);
377 }
378
379 pub fn config(&self) -> &GaussianMixtureConfig {
381 &self.config
382 }
383
384 pub fn ppf(&self, u: f64) -> f64 {
387 let u = u.clamp(1e-9, 1.0 - 1e-9);
388 let min = self.config.min_value.unwrap_or(-1e15);
389 let max = self.config.max_value.unwrap_or(1e15);
390 let (mut lo, mut hi) = (min, max);
391 for _ in 0..64 {
392 let mid = (lo + hi) / 2.0;
393 let f_mid = mixture_gaussian_cdf(&self.config.components, mid);
394 if f_mid < u {
395 lo = mid;
396 } else {
397 hi = mid;
398 }
399 if hi - lo < 1e-6 * mid.abs().max(1.0) {
400 break;
401 }
402 }
403 ((lo + hi) / 2.0).clamp(min, max)
404 }
405}
406
407fn mixture_gaussian_cdf(components: &[GaussianComponent], x: f64) -> f64 {
409 components
410 .iter()
411 .map(|c| c.weight * standard_normal_cdf_gauss((x - c.mu) / c.sigma))
412 .sum()
413}
414
415fn standard_normal_cdf_gauss(x: f64) -> f64 {
416 0.5 * (1.0 + erf_gauss(x / std::f64::consts::SQRT_2))
417}
418
419fn erf_gauss(x: f64) -> f64 {
420 let sign = if x < 0.0 { -1.0 } else { 1.0 };
421 let x = x.abs();
422 let t = 1.0 / (1.0 + 0.3275911 * x);
423 let y = 1.0
424 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
425 + 0.254829592)
426 * t
427 * (-x * x).exp();
428 sign * y
429}
430
431#[derive(Clone)]
433pub struct LogNormalMixtureSampler {
434 rng: ChaCha8Rng,
435 config: LogNormalMixtureConfig,
436 cumulative_weights: Vec<f64>,
438 distributions: Vec<LogNormal<f64>>,
440 decimal_multiplier: f64,
442}
443
444impl LogNormalMixtureSampler {
445 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
447 config.validate()?;
448
449 let mut cumulative_weights = Vec::with_capacity(config.components.len());
451 let mut cumulative = 0.0;
452 for component in &config.components {
453 cumulative += component.weight;
454 cumulative_weights.push(cumulative);
455 }
456
457 let distributions: Result<Vec<_>, _> = config
459 .components
460 .iter()
461 .map(|c| {
462 LogNormal::new(c.mu, c.sigma)
463 .map_err(|e| format!("Invalid log-normal distribution: {e}"))
464 })
465 .collect();
466
467 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
468
469 Ok(Self {
470 rng: ChaCha8Rng::seed_from_u64(seed),
471 config,
472 cumulative_weights,
473 distributions: distributions?,
474 decimal_multiplier,
475 })
476 }
477
478 fn select_component(&mut self) -> usize {
480 let p: f64 = self.rng.random();
481 match self.cumulative_weights.binary_search_by(|w| {
482 w.partial_cmp(&p).unwrap_or_else(|| {
483 tracing::debug!("NaN detected in mixture weight comparison");
484 std::cmp::Ordering::Less
485 })
486 }) {
487 Ok(i) => i,
488 Err(i) => i.min(self.distributions.len() - 1),
489 }
490 }
491
492 pub fn sample(&mut self) -> f64 {
494 let component_idx = self.select_component();
495 let mut value = self.distributions[component_idx].sample(&mut self.rng);
496
497 value = value.max(self.config.min_value);
499 if let Some(max) = self.config.max_value {
500 value = value.min(max);
501 }
502
503 (value * self.decimal_multiplier).round() / self.decimal_multiplier
505 }
506
507 pub fn sample_decimal(&mut self) -> Decimal {
509 let value = self.sample();
510 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
511 }
512
513 pub fn sample_with_component(&mut self) -> SampleWithComponent {
515 let component_idx = self.select_component();
516 let mut value = self.distributions[component_idx].sample(&mut self.rng);
517
518 value = value.max(self.config.min_value);
520 if let Some(max) = self.config.max_value {
521 value = value.min(max);
522 }
523
524 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
526
527 SampleWithComponent {
528 value,
529 component_index: component_idx,
530 component_label: self.config.components[component_idx].label.clone(),
531 }
532 }
533
534 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
536 (0..n).map(|_| self.sample()).collect()
537 }
538
539 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
541 (0..n).map(|_| self.sample_decimal()).collect()
542 }
543
544 pub fn reset(&mut self, seed: u64) {
546 self.rng = ChaCha8Rng::seed_from_u64(seed);
547 }
548
549 pub fn config(&self) -> &LogNormalMixtureConfig {
551 &self.config
552 }
553
554 pub fn expected_value(&self) -> f64 {
556 self.config
557 .components
558 .iter()
559 .map(|c| c.weight * c.expected_value())
560 .sum()
561 }
562
563 pub fn ppf(&self, u: f64) -> f64 {
569 let u = u.clamp(1e-9, 1.0 - 1e-9);
570 let max = self.config.max_value.unwrap_or(1e15);
571 let min = self.config.min_value.max(1e-9);
572 let (mut lo, mut hi) = (min, max);
574 for _ in 0..64 {
575 let mid = (lo + hi) / 2.0;
576 let f_mid = mixture_log_normal_cdf(&self.config.components, mid);
577 if f_mid < u {
578 lo = mid;
579 } else {
580 hi = mid;
581 }
582 if hi - lo < 1e-6 * mid.abs().max(1.0) {
583 break;
584 }
585 }
586 let value = ((lo + hi) / 2.0).clamp(min, max);
587 (value * self.decimal_multiplier).round() / self.decimal_multiplier
588 }
589
590 pub fn ppf_decimal(&self, u: f64) -> Decimal {
592 Decimal::from_f64_retain(self.ppf(u)).unwrap_or(Decimal::ONE)
593 }
594}
595
596fn mixture_log_normal_cdf(components: &[LogNormalComponent], x: f64) -> f64 {
599 if x <= 0.0 {
600 return 0.0;
601 }
602 let log_x = x.ln();
603 components
604 .iter()
605 .map(|c| c.weight * standard_normal_cdf((log_x - c.mu) / c.sigma))
606 .sum()
607}
608
609fn standard_normal_cdf(x: f64) -> f64 {
612 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
613}
614
615fn erf(x: f64) -> f64 {
616 let sign = if x < 0.0 { -1.0 } else { 1.0 };
617 let x = x.abs();
618 let t = 1.0 / (1.0 + 0.3275911 * x);
619 let y = 1.0
620 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
621 + 0.254829592)
622 * t
623 * (-x * x).exp();
624 sign * y
625}
626
627#[cfg(test)]
628#[allow(clippy::unwrap_used)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_gaussian_mixture_validation() {
634 let config = GaussianMixtureConfig::new(vec![
636 GaussianComponent::new(0.5, 0.0, 1.0),
637 GaussianComponent::new(0.5, 5.0, 2.0),
638 ]);
639 assert!(config.validate().is_ok());
640
641 let invalid_config = GaussianMixtureConfig::new(vec![
643 GaussianComponent::new(0.3, 0.0, 1.0),
644 GaussianComponent::new(0.3, 5.0, 2.0),
645 ]);
646 assert!(invalid_config.validate().is_err());
647
648 let invalid_config =
650 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
651 assert!(invalid_config.validate().is_err());
652 }
653
654 #[test]
655 fn test_gaussian_mixture_sampling() {
656 let config = GaussianMixtureConfig::new(vec![
657 GaussianComponent::new(0.5, 0.0, 1.0),
658 GaussianComponent::new(0.5, 10.0, 1.0),
659 ]);
660 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
661
662 let samples = sampler.sample_n(1000);
663 assert_eq!(samples.len(), 1000);
664
665 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
667 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
668
669 assert!(low_count > 350 && low_count < 650);
671 assert!(high_count > 350 && high_count < 650);
672 }
673
674 #[test]
675 fn test_gaussian_mixture_determinism() {
676 let config = GaussianMixtureConfig::new(vec![
677 GaussianComponent::new(0.5, 0.0, 1.0),
678 GaussianComponent::new(0.5, 10.0, 1.0),
679 ]);
680
681 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
682 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
683
684 for _ in 0..100 {
685 assert_eq!(sampler1.sample(), sampler2.sample());
686 }
687 }
688
689 #[test]
690 fn test_lognormal_mixture_validation() {
691 let config = LogNormalMixtureConfig::new(vec![
693 LogNormalComponent::new(0.6, 6.0, 1.5),
694 LogNormalComponent::new(0.4, 8.5, 1.0),
695 ]);
696 assert!(config.validate().is_ok());
697
698 let invalid_config = LogNormalMixtureConfig::new(vec![
700 LogNormalComponent::new(0.2, 6.0, 1.5),
701 LogNormalComponent::new(0.2, 8.5, 1.0),
702 ]);
703 assert!(invalid_config.validate().is_err());
704 }
705
706 #[test]
707 fn test_lognormal_mixture_sampling() {
708 let config = LogNormalMixtureConfig::typical_transactions();
709 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
710
711 let samples = sampler.sample_n(1000);
712 assert_eq!(samples.len(), 1000);
713
714 assert!(samples.iter().all(|&x| x > 0.0));
716
717 assert!(samples.iter().all(|&x| x >= 0.01));
719 }
720
721 #[test]
722 fn test_sample_with_component() {
723 let config = LogNormalMixtureConfig::new(vec![
724 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
725 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
726 ]);
727 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
728
729 let mut small_count = 0;
730 let mut large_count = 0;
731
732 for _ in 0..1000 {
733 let result = sampler.sample_with_component();
734 match result.component_label.as_deref() {
735 Some("small") => small_count += 1,
736 Some("large") => large_count += 1,
737 _ => panic!("Unexpected label"),
738 }
739 }
740
741 assert!(small_count > 400 && small_count < 600);
743 assert!(large_count > 400 && large_count < 600);
744 }
745
746 #[test]
747 fn test_lognormal_mixture_determinism() {
748 let config = LogNormalMixtureConfig::typical_transactions();
749
750 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
751 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
752
753 for _ in 0..100 {
754 assert_eq!(sampler1.sample(), sampler2.sample());
755 }
756 }
757
758 #[test]
759 fn test_lognormal_expected_value() {
760 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
761 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
762
763 let expected = sampler.expected_value();
765 assert!((expected - 1808.04).abs() < 1.0);
766 }
767
768 #[test]
769 fn test_component_label() {
770 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
771 assert_eq!(component.label, Some("test_label".to_string()));
772
773 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
774 assert_eq!(component_no_label.label, None);
775 }
776
777 #[test]
778 fn test_max_value_constraint() {
779 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
780 config.max_value = Some(1000.0);
781
782 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
783 let samples = sampler.sample_n(1000);
784
785 assert!(samples.iter().all(|&x| x <= 1000.0));
787 }
788}