1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {weight_sum}"
99 ));
100 }
101
102 for (i, component) in self.components.iter().enumerate() {
103 if component.weight < 0.0 || component.weight > 1.0 {
104 return Err(format!(
105 "Component {} weight must be between 0.0 and 1.0, got {}",
106 i, component.weight
107 ));
108 }
109 if component.sigma <= 0.0 {
110 return Err(format!(
111 "Component {} sigma must be positive, got {}",
112 i, component.sigma
113 ));
114 }
115 }
116
117 Ok(())
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124 pub weight: f64,
126 pub mu: f64,
128 pub sigma: f64,
130 #[serde(default)]
132 pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138 Self {
139 weight,
140 mu,
141 sigma,
142 label: None,
143 }
144 }
145
146 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148 Self {
149 weight,
150 mu,
151 sigma,
152 label: Some(label.into()),
153 }
154 }
155
156 pub fn expected_value(&self) -> f64 {
158 (self.mu + self.sigma.powi(2) / 2.0).exp()
159 }
160
161 pub fn median(&self) -> f64 {
163 self.mu.exp()
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170 pub components: Vec<LogNormalComponent>,
172 #[serde(default = "default_min_value")]
174 pub min_value: f64,
175 #[serde(default)]
177 pub max_value: Option<f64>,
178 #[serde(default = "default_decimal_places")]
180 pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184 0.01
185}
186
187fn default_decimal_places() -> u8 {
188 2
189}
190
191impl Default for LogNormalMixtureConfig {
192 fn default() -> Self {
193 Self {
194 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195 min_value: 0.01,
196 max_value: None,
197 decimal_places: 2,
198 }
199 }
200}
201
202impl LogNormalMixtureConfig {
203 pub fn new(components: Vec<LogNormalComponent>) -> Self {
205 Self {
206 components,
207 ..Default::default()
208 }
209 }
210
211 pub fn typical_transactions() -> Self {
213 Self {
214 components: vec![
215 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218 ],
219 min_value: 0.01,
220 max_value: Some(100_000_000.0),
221 decimal_places: 2,
222 }
223 }
224
225 pub fn validate(&self) -> Result<(), String> {
227 if self.components.is_empty() {
228 return Err("At least one component is required".to_string());
229 }
230
231 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232 if (weight_sum - 1.0).abs() > 0.01 {
233 return Err(format!(
234 "Component weights must sum to 1.0, got {weight_sum}"
235 ));
236 }
237
238 for (i, component) in self.components.iter().enumerate() {
239 if component.weight < 0.0 || component.weight > 1.0 {
240 return Err(format!(
241 "Component {} weight must be between 0.0 and 1.0, got {}",
242 i, component.weight
243 ));
244 }
245 if component.sigma <= 0.0 {
246 return Err(format!(
247 "Component {} sigma must be positive, got {}",
248 i, component.sigma
249 ));
250 }
251 }
252
253 if self.min_value < 0.0 {
254 return Err("min_value must be non-negative".to_string());
255 }
256
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264 pub value: f64,
266 pub component_index: usize,
268 pub component_label: Option<String>,
270}
271
272#[derive(Clone)]
274pub struct GaussianMixtureSampler {
275 rng: ChaCha8Rng,
276 config: GaussianMixtureConfig,
277 cumulative_weights: Vec<f64>,
279 distributions: Vec<Normal<f64>>,
281}
282
283impl GaussianMixtureSampler {
284 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
286 config.validate()?;
287
288 let mut cumulative_weights = Vec::with_capacity(config.components.len());
290 let mut cumulative = 0.0;
291 for component in &config.components {
292 cumulative += component.weight;
293 cumulative_weights.push(cumulative);
294 }
295
296 let distributions: Result<Vec<_>, _> = config
298 .components
299 .iter()
300 .map(|c| {
301 Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
302 })
303 .collect();
304
305 Ok(Self {
306 rng: ChaCha8Rng::seed_from_u64(seed),
307 config,
308 cumulative_weights,
309 distributions: distributions?,
310 })
311 }
312
313 fn select_component(&mut self) -> usize {
315 let p: f64 = self.rng.random();
316 match self.cumulative_weights.binary_search_by(|w| {
317 w.partial_cmp(&p).unwrap_or_else(|| {
318 tracing::debug!("NaN detected in mixture weight comparison");
319 std::cmp::Ordering::Less
320 })
321 }) {
322 Ok(i) => i,
323 Err(i) => i.min(self.distributions.len() - 1),
324 }
325 }
326
327 pub fn sample(&mut self) -> f64 {
329 let component_idx = self.select_component();
330 let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332 if !self.config.allow_negative {
334 value = value.abs();
335 }
336 if let Some(min) = self.config.min_value {
337 value = value.max(min);
338 }
339 if let Some(max) = self.config.max_value {
340 value = value.min(max);
341 }
342
343 value
344 }
345
346 pub fn sample_with_component(&mut self) -> SampleWithComponent {
348 let component_idx = self.select_component();
349 let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351 if !self.config.allow_negative {
353 value = value.abs();
354 }
355 if let Some(min) = self.config.min_value {
356 value = value.max(min);
357 }
358 if let Some(max) = self.config.max_value {
359 value = value.min(max);
360 }
361
362 SampleWithComponent {
363 value,
364 component_index: component_idx,
365 component_label: self.config.components[component_idx].label.clone(),
366 }
367 }
368
369 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371 (0..n).map(|_| self.sample()).collect()
372 }
373
374 pub fn reset(&mut self, seed: u64) {
376 self.rng = ChaCha8Rng::seed_from_u64(seed);
377 }
378
379 pub fn config(&self) -> &GaussianMixtureConfig {
381 &self.config
382 }
383
384 pub fn ppf(&self, u: f64) -> f64 {
387 let u = u.clamp(1e-9, 1.0 - 1e-9);
388 let min = self.config.min_value.unwrap_or(-1e15);
389 let max = self.config.max_value.unwrap_or(1e15);
390 let (mut lo, mut hi) = (min, max);
391 for _ in 0..64 {
392 let mid = (lo + hi) / 2.0;
393 let f_mid = mixture_gaussian_cdf(&self.config.components, mid);
394 if f_mid < u {
395 lo = mid;
396 } else {
397 hi = mid;
398 }
399 if hi - lo < 1e-6 * mid.abs().max(1.0) {
400 break;
401 }
402 }
403 ((lo + hi) / 2.0).clamp(min, max)
404 }
405}
406
407fn mixture_gaussian_cdf(components: &[GaussianComponent], x: f64) -> f64 {
409 components
410 .iter()
411 .map(|c| c.weight * standard_normal_cdf_gauss((x - c.mu) / c.sigma))
412 .sum()
413}
414
415fn standard_normal_cdf_gauss(x: f64) -> f64 {
416 0.5 * (1.0 + erf_gauss(x / std::f64::consts::SQRT_2))
417}
418
419fn erf_gauss(x: f64) -> f64 {
420 let sign = if x < 0.0 { -1.0 } else { 1.0 };
421 let x = x.abs();
422 let t = 1.0 / (1.0 + 0.3275911 * x);
423 let y = 1.0
424 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
425 + 0.254829592)
426 * t
427 * (-x * x).exp();
428 sign * y
429}
430
431#[derive(Clone)]
433pub struct LogNormalMixtureSampler {
434 rng: ChaCha8Rng,
435 config: LogNormalMixtureConfig,
436 cumulative_weights: Vec<f64>,
438 distributions: Vec<LogNormal<f64>>,
440 decimal_multiplier: f64,
442}
443
444impl LogNormalMixtureSampler {
445 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
447 config.validate()?;
448
449 let mut cumulative_weights = Vec::with_capacity(config.components.len());
451 let mut cumulative = 0.0;
452 for component in &config.components {
453 cumulative += component.weight;
454 cumulative_weights.push(cumulative);
455 }
456
457 let distributions: Result<Vec<_>, _> = config
459 .components
460 .iter()
461 .map(|c| {
462 LogNormal::new(c.mu, c.sigma)
463 .map_err(|e| format!("Invalid log-normal distribution: {e}"))
464 })
465 .collect();
466
467 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
468
469 Ok(Self {
470 rng: ChaCha8Rng::seed_from_u64(seed),
471 config,
472 cumulative_weights,
473 distributions: distributions?,
474 decimal_multiplier,
475 })
476 }
477
478 fn select_component(&mut self) -> usize {
480 let p: f64 = self.rng.random();
481 match self.cumulative_weights.binary_search_by(|w| {
482 w.partial_cmp(&p).unwrap_or_else(|| {
483 tracing::debug!("NaN detected in mixture weight comparison");
484 std::cmp::Ordering::Less
485 })
486 }) {
487 Ok(i) => i,
488 Err(i) => i.min(self.distributions.len() - 1),
489 }
490 }
491
492 pub fn sample(&mut self) -> f64 {
494 let component_idx = self.select_component();
495 let mut value = self.distributions[component_idx].sample(&mut self.rng);
496
497 value = value.max(self.config.min_value);
499 if let Some(max) = self.config.max_value {
500 value = value.min(max);
501 }
502
503 (value * self.decimal_multiplier).round() / self.decimal_multiplier
505 }
506
507 pub fn sample_decimal(&mut self) -> Decimal {
509 let value = self.sample();
510 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
511 }
512
513 pub fn sample_with_component(&mut self) -> SampleWithComponent {
515 let component_idx = self.select_component();
516 let mut value = self.distributions[component_idx].sample(&mut self.rng);
517
518 value = value.max(self.config.min_value);
520 if let Some(max) = self.config.max_value {
521 value = value.min(max);
522 }
523
524 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
526
527 SampleWithComponent {
528 value,
529 component_index: component_idx,
530 component_label: self.config.components[component_idx].label.clone(),
531 }
532 }
533
534 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
536 (0..n).map(|_| self.sample()).collect()
537 }
538
539 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
541 (0..n).map(|_| self.sample_decimal()).collect()
542 }
543
544 pub fn reset(&mut self, seed: u64) {
546 self.rng = ChaCha8Rng::seed_from_u64(seed);
547 }
548
549 pub fn config(&self) -> &LogNormalMixtureConfig {
551 &self.config
552 }
553
554 pub fn expected_value(&self) -> f64 {
556 self.config
557 .components
558 .iter()
559 .map(|c| c.weight * c.expected_value())
560 .sum()
561 }
562
563 pub fn ppf(&self, u: f64) -> f64 {
569 let u = u.clamp(1e-9, 1.0 - 1e-9);
570 let max = self.config.max_value.unwrap_or(1e15);
571 let min = self.config.min_value.max(1e-9);
572 let (mut lo, mut hi) = (min, max);
574 for _ in 0..64 {
575 let mid = (lo + hi) / 2.0;
576 let f_mid = mixture_log_normal_cdf(&self.config.components, mid);
577 if f_mid < u {
578 lo = mid;
579 } else {
580 hi = mid;
581 }
582 if hi - lo < 1e-6 * mid.abs().max(1.0) {
583 break;
584 }
585 }
586 let value = ((lo + hi) / 2.0).clamp(min, max);
587 (value * self.decimal_multiplier).round() / self.decimal_multiplier
588 }
589
590 pub fn ppf_decimal(&self, u: f64) -> Decimal {
592 Decimal::from_f64_retain(self.ppf(u)).unwrap_or(Decimal::ONE)
593 }
594}
595
596fn mixture_log_normal_cdf(components: &[LogNormalComponent], x: f64) -> f64 {
599 if x <= 0.0 {
600 return 0.0;
601 }
602 let log_x = x.ln();
603 components
604 .iter()
605 .map(|c| c.weight * standard_normal_cdf((log_x - c.mu) / c.sigma))
606 .sum()
607}
608
609fn standard_normal_cdf(x: f64) -> f64 {
612 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
613}
614
615fn erf(x: f64) -> f64 {
616 let sign = if x < 0.0 { -1.0 } else { 1.0 };
617 let x = x.abs();
618 let t = 1.0 / (1.0 + 0.3275911 * x);
619 let y = 1.0
620 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
621 + 0.254829592)
622 * t
623 * (-x * x).exp();
624 sign * y
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn test_gaussian_mixture_validation() {
633 let config = GaussianMixtureConfig::new(vec![
635 GaussianComponent::new(0.5, 0.0, 1.0),
636 GaussianComponent::new(0.5, 5.0, 2.0),
637 ]);
638 assert!(config.validate().is_ok());
639
640 let invalid_config = GaussianMixtureConfig::new(vec![
642 GaussianComponent::new(0.3, 0.0, 1.0),
643 GaussianComponent::new(0.3, 5.0, 2.0),
644 ]);
645 assert!(invalid_config.validate().is_err());
646
647 let invalid_config =
649 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
650 assert!(invalid_config.validate().is_err());
651 }
652
653 #[test]
654 fn test_gaussian_mixture_sampling() {
655 let config = GaussianMixtureConfig::new(vec![
656 GaussianComponent::new(0.5, 0.0, 1.0),
657 GaussianComponent::new(0.5, 10.0, 1.0),
658 ]);
659 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
660
661 let samples = sampler.sample_n(1000);
662 assert_eq!(samples.len(), 1000);
663
664 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
666 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
667
668 assert!(low_count > 350 && low_count < 650);
670 assert!(high_count > 350 && high_count < 650);
671 }
672
673 #[test]
674 fn test_gaussian_mixture_determinism() {
675 let config = GaussianMixtureConfig::new(vec![
676 GaussianComponent::new(0.5, 0.0, 1.0),
677 GaussianComponent::new(0.5, 10.0, 1.0),
678 ]);
679
680 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
681 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
682
683 for _ in 0..100 {
684 assert_eq!(sampler1.sample(), sampler2.sample());
685 }
686 }
687
688 #[test]
689 fn test_lognormal_mixture_validation() {
690 let config = LogNormalMixtureConfig::new(vec![
692 LogNormalComponent::new(0.6, 6.0, 1.5),
693 LogNormalComponent::new(0.4, 8.5, 1.0),
694 ]);
695 assert!(config.validate().is_ok());
696
697 let invalid_config = LogNormalMixtureConfig::new(vec![
699 LogNormalComponent::new(0.2, 6.0, 1.5),
700 LogNormalComponent::new(0.2, 8.5, 1.0),
701 ]);
702 assert!(invalid_config.validate().is_err());
703 }
704
705 #[test]
706 fn test_lognormal_mixture_sampling() {
707 let config = LogNormalMixtureConfig::typical_transactions();
708 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
709
710 let samples = sampler.sample_n(1000);
711 assert_eq!(samples.len(), 1000);
712
713 assert!(samples.iter().all(|&x| x > 0.0));
715
716 assert!(samples.iter().all(|&x| x >= 0.01));
718 }
719
720 #[test]
721 fn test_sample_with_component() {
722 let config = LogNormalMixtureConfig::new(vec![
723 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
724 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
725 ]);
726 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
727
728 let mut small_count = 0;
729 let mut large_count = 0;
730
731 for _ in 0..1000 {
732 let result = sampler.sample_with_component();
733 match result.component_label.as_deref() {
734 Some("small") => small_count += 1,
735 Some("large") => large_count += 1,
736 _ => panic!("Unexpected label"),
737 }
738 }
739
740 assert!(small_count > 400 && small_count < 600);
742 assert!(large_count > 400 && large_count < 600);
743 }
744
745 #[test]
746 fn test_lognormal_mixture_determinism() {
747 let config = LogNormalMixtureConfig::typical_transactions();
748
749 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
750 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
751
752 for _ in 0..100 {
753 assert_eq!(sampler1.sample(), sampler2.sample());
754 }
755 }
756
757 #[test]
758 fn test_lognormal_expected_value() {
759 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
760 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
761
762 let expected = sampler.expected_value();
764 assert!((expected - 1808.04).abs() < 1.0);
765 }
766
767 #[test]
768 fn test_component_label() {
769 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
770 assert_eq!(component.label, Some("test_label".to_string()));
771
772 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
773 assert_eq!(component_no_label.label, None);
774 }
775
776 #[test]
777 fn test_max_value_constraint() {
778 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
779 config.max_value = Some(1000.0);
780
781 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
782 let samples = sampler.sample_n(1000);
783
784 assert!(samples.iter().all(|&x| x <= 1000.0));
786 }
787}