1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {weight_sum}"
99 ));
100 }
101
102 for (i, component) in self.components.iter().enumerate() {
103 if component.weight < 0.0 || component.weight > 1.0 {
104 return Err(format!(
105 "Component {} weight must be between 0.0 and 1.0, got {}",
106 i, component.weight
107 ));
108 }
109 if component.sigma <= 0.0 {
110 return Err(format!(
111 "Component {} sigma must be positive, got {}",
112 i, component.sigma
113 ));
114 }
115 }
116
117 Ok(())
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124 pub weight: f64,
126 pub mu: f64,
128 pub sigma: f64,
130 #[serde(default)]
132 pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138 Self {
139 weight,
140 mu,
141 sigma,
142 label: None,
143 }
144 }
145
146 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148 Self {
149 weight,
150 mu,
151 sigma,
152 label: Some(label.into()),
153 }
154 }
155
156 pub fn expected_value(&self) -> f64 {
158 (self.mu + self.sigma.powi(2) / 2.0).exp()
159 }
160
161 pub fn median(&self) -> f64 {
163 self.mu.exp()
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170 pub components: Vec<LogNormalComponent>,
172 #[serde(default = "default_min_value")]
174 pub min_value: f64,
175 #[serde(default)]
177 pub max_value: Option<f64>,
178 #[serde(default = "default_decimal_places")]
180 pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184 0.01
185}
186
187fn default_decimal_places() -> u8 {
188 2
189}
190
191impl Default for LogNormalMixtureConfig {
192 fn default() -> Self {
193 Self {
194 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195 min_value: 0.01,
196 max_value: None,
197 decimal_places: 2,
198 }
199 }
200}
201
202impl LogNormalMixtureConfig {
203 pub fn new(components: Vec<LogNormalComponent>) -> Self {
205 Self {
206 components,
207 ..Default::default()
208 }
209 }
210
211 pub fn typical_transactions() -> Self {
213 Self {
214 components: vec![
215 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218 ],
219 min_value: 0.01,
220 max_value: Some(100_000_000.0),
221 decimal_places: 2,
222 }
223 }
224
225 pub fn validate(&self) -> Result<(), String> {
227 if self.components.is_empty() {
228 return Err("At least one component is required".to_string());
229 }
230
231 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232 if (weight_sum - 1.0).abs() > 0.01 {
233 return Err(format!(
234 "Component weights must sum to 1.0, got {weight_sum}"
235 ));
236 }
237
238 for (i, component) in self.components.iter().enumerate() {
239 if component.weight < 0.0 || component.weight > 1.0 {
240 return Err(format!(
241 "Component {} weight must be between 0.0 and 1.0, got {}",
242 i, component.weight
243 ));
244 }
245 if component.sigma <= 0.0 {
246 return Err(format!(
247 "Component {} sigma must be positive, got {}",
248 i, component.sigma
249 ));
250 }
251 }
252
253 if self.min_value < 0.0 {
254 return Err("min_value must be non-negative".to_string());
255 }
256
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264 pub value: f64,
266 pub component_index: usize,
268 pub component_label: Option<String>,
270}
271
272pub struct GaussianMixtureSampler {
274 rng: ChaCha8Rng,
275 config: GaussianMixtureConfig,
276 cumulative_weights: Vec<f64>,
278 distributions: Vec<Normal<f64>>,
280}
281
282impl GaussianMixtureSampler {
283 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
285 config.validate()?;
286
287 let mut cumulative_weights = Vec::with_capacity(config.components.len());
289 let mut cumulative = 0.0;
290 for component in &config.components {
291 cumulative += component.weight;
292 cumulative_weights.push(cumulative);
293 }
294
295 let distributions: Result<Vec<_>, _> = config
297 .components
298 .iter()
299 .map(|c| {
300 Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
301 })
302 .collect();
303
304 Ok(Self {
305 rng: ChaCha8Rng::seed_from_u64(seed),
306 config,
307 cumulative_weights,
308 distributions: distributions?,
309 })
310 }
311
312 fn select_component(&mut self) -> usize {
314 let p: f64 = self.rng.random();
315 match self.cumulative_weights.binary_search_by(|w| {
316 w.partial_cmp(&p).unwrap_or_else(|| {
317 tracing::debug!("NaN detected in mixture weight comparison");
318 std::cmp::Ordering::Less
319 })
320 }) {
321 Ok(i) => i,
322 Err(i) => i.min(self.distributions.len() - 1),
323 }
324 }
325
326 pub fn sample(&mut self) -> f64 {
328 let component_idx = self.select_component();
329 let mut value = self.distributions[component_idx].sample(&mut self.rng);
330
331 if !self.config.allow_negative {
333 value = value.abs();
334 }
335 if let Some(min) = self.config.min_value {
336 value = value.max(min);
337 }
338 if let Some(max) = self.config.max_value {
339 value = value.min(max);
340 }
341
342 value
343 }
344
345 pub fn sample_with_component(&mut self) -> SampleWithComponent {
347 let component_idx = self.select_component();
348 let mut value = self.distributions[component_idx].sample(&mut self.rng);
349
350 if !self.config.allow_negative {
352 value = value.abs();
353 }
354 if let Some(min) = self.config.min_value {
355 value = value.max(min);
356 }
357 if let Some(max) = self.config.max_value {
358 value = value.min(max);
359 }
360
361 SampleWithComponent {
362 value,
363 component_index: component_idx,
364 component_label: self.config.components[component_idx].label.clone(),
365 }
366 }
367
368 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
370 (0..n).map(|_| self.sample()).collect()
371 }
372
373 pub fn reset(&mut self, seed: u64) {
375 self.rng = ChaCha8Rng::seed_from_u64(seed);
376 }
377
378 pub fn config(&self) -> &GaussianMixtureConfig {
380 &self.config
381 }
382}
383
384pub struct LogNormalMixtureSampler {
386 rng: ChaCha8Rng,
387 config: LogNormalMixtureConfig,
388 cumulative_weights: Vec<f64>,
390 distributions: Vec<LogNormal<f64>>,
392 decimal_multiplier: f64,
394}
395
396impl LogNormalMixtureSampler {
397 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
399 config.validate()?;
400
401 let mut cumulative_weights = Vec::with_capacity(config.components.len());
403 let mut cumulative = 0.0;
404 for component in &config.components {
405 cumulative += component.weight;
406 cumulative_weights.push(cumulative);
407 }
408
409 let distributions: Result<Vec<_>, _> = config
411 .components
412 .iter()
413 .map(|c| {
414 LogNormal::new(c.mu, c.sigma)
415 .map_err(|e| format!("Invalid log-normal distribution: {e}"))
416 })
417 .collect();
418
419 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
420
421 Ok(Self {
422 rng: ChaCha8Rng::seed_from_u64(seed),
423 config,
424 cumulative_weights,
425 distributions: distributions?,
426 decimal_multiplier,
427 })
428 }
429
430 fn select_component(&mut self) -> usize {
432 let p: f64 = self.rng.random();
433 match self.cumulative_weights.binary_search_by(|w| {
434 w.partial_cmp(&p).unwrap_or_else(|| {
435 tracing::debug!("NaN detected in mixture weight comparison");
436 std::cmp::Ordering::Less
437 })
438 }) {
439 Ok(i) => i,
440 Err(i) => i.min(self.distributions.len() - 1),
441 }
442 }
443
444 pub fn sample(&mut self) -> f64 {
446 let component_idx = self.select_component();
447 let mut value = self.distributions[component_idx].sample(&mut self.rng);
448
449 value = value.max(self.config.min_value);
451 if let Some(max) = self.config.max_value {
452 value = value.min(max);
453 }
454
455 (value * self.decimal_multiplier).round() / self.decimal_multiplier
457 }
458
459 pub fn sample_decimal(&mut self) -> Decimal {
461 let value = self.sample();
462 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
463 }
464
465 pub fn sample_with_component(&mut self) -> SampleWithComponent {
467 let component_idx = self.select_component();
468 let mut value = self.distributions[component_idx].sample(&mut self.rng);
469
470 value = value.max(self.config.min_value);
472 if let Some(max) = self.config.max_value {
473 value = value.min(max);
474 }
475
476 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
478
479 SampleWithComponent {
480 value,
481 component_index: component_idx,
482 component_label: self.config.components[component_idx].label.clone(),
483 }
484 }
485
486 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
488 (0..n).map(|_| self.sample()).collect()
489 }
490
491 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
493 (0..n).map(|_| self.sample_decimal()).collect()
494 }
495
496 pub fn reset(&mut self, seed: u64) {
498 self.rng = ChaCha8Rng::seed_from_u64(seed);
499 }
500
501 pub fn config(&self) -> &LogNormalMixtureConfig {
503 &self.config
504 }
505
506 pub fn expected_value(&self) -> f64 {
508 self.config
509 .components
510 .iter()
511 .map(|c| c.weight * c.expected_value())
512 .sum()
513 }
514}
515
516#[cfg(test)]
517#[allow(clippy::unwrap_used)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_gaussian_mixture_validation() {
523 let config = GaussianMixtureConfig::new(vec![
525 GaussianComponent::new(0.5, 0.0, 1.0),
526 GaussianComponent::new(0.5, 5.0, 2.0),
527 ]);
528 assert!(config.validate().is_ok());
529
530 let invalid_config = GaussianMixtureConfig::new(vec![
532 GaussianComponent::new(0.3, 0.0, 1.0),
533 GaussianComponent::new(0.3, 5.0, 2.0),
534 ]);
535 assert!(invalid_config.validate().is_err());
536
537 let invalid_config =
539 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
540 assert!(invalid_config.validate().is_err());
541 }
542
543 #[test]
544 fn test_gaussian_mixture_sampling() {
545 let config = GaussianMixtureConfig::new(vec![
546 GaussianComponent::new(0.5, 0.0, 1.0),
547 GaussianComponent::new(0.5, 10.0, 1.0),
548 ]);
549 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
550
551 let samples = sampler.sample_n(1000);
552 assert_eq!(samples.len(), 1000);
553
554 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
556 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
557
558 assert!(low_count > 350 && low_count < 650);
560 assert!(high_count > 350 && high_count < 650);
561 }
562
563 #[test]
564 fn test_gaussian_mixture_determinism() {
565 let config = GaussianMixtureConfig::new(vec![
566 GaussianComponent::new(0.5, 0.0, 1.0),
567 GaussianComponent::new(0.5, 10.0, 1.0),
568 ]);
569
570 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
571 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
572
573 for _ in 0..100 {
574 assert_eq!(sampler1.sample(), sampler2.sample());
575 }
576 }
577
578 #[test]
579 fn test_lognormal_mixture_validation() {
580 let config = LogNormalMixtureConfig::new(vec![
582 LogNormalComponent::new(0.6, 6.0, 1.5),
583 LogNormalComponent::new(0.4, 8.5, 1.0),
584 ]);
585 assert!(config.validate().is_ok());
586
587 let invalid_config = LogNormalMixtureConfig::new(vec![
589 LogNormalComponent::new(0.2, 6.0, 1.5),
590 LogNormalComponent::new(0.2, 8.5, 1.0),
591 ]);
592 assert!(invalid_config.validate().is_err());
593 }
594
595 #[test]
596 fn test_lognormal_mixture_sampling() {
597 let config = LogNormalMixtureConfig::typical_transactions();
598 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
599
600 let samples = sampler.sample_n(1000);
601 assert_eq!(samples.len(), 1000);
602
603 assert!(samples.iter().all(|&x| x > 0.0));
605
606 assert!(samples.iter().all(|&x| x >= 0.01));
608 }
609
610 #[test]
611 fn test_sample_with_component() {
612 let config = LogNormalMixtureConfig::new(vec![
613 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
614 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
615 ]);
616 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
617
618 let mut small_count = 0;
619 let mut large_count = 0;
620
621 for _ in 0..1000 {
622 let result = sampler.sample_with_component();
623 match result.component_label.as_deref() {
624 Some("small") => small_count += 1,
625 Some("large") => large_count += 1,
626 _ => panic!("Unexpected label"),
627 }
628 }
629
630 assert!(small_count > 400 && small_count < 600);
632 assert!(large_count > 400 && large_count < 600);
633 }
634
635 #[test]
636 fn test_lognormal_mixture_determinism() {
637 let config = LogNormalMixtureConfig::typical_transactions();
638
639 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
640 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
641
642 for _ in 0..100 {
643 assert_eq!(sampler1.sample(), sampler2.sample());
644 }
645 }
646
647 #[test]
648 fn test_lognormal_expected_value() {
649 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
650 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
651
652 let expected = sampler.expected_value();
654 assert!((expected - 1808.04).abs() < 1.0);
655 }
656
657 #[test]
658 fn test_component_label() {
659 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
660 assert_eq!(component.label, Some("test_label".to_string()));
661
662 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
663 assert_eq!(component_no_label.label, None);
664 }
665
666 #[test]
667 fn test_max_value_constraint() {
668 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
669 config.max_value = Some(1000.0);
670
671 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
672 let samples = sampler.sample_n(1000);
673
674 assert!(samples.iter().all(|&x| x <= 1000.0));
676 }
677}