1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {weight_sum}"
99 ));
100 }
101
102 for (i, component) in self.components.iter().enumerate() {
103 if component.weight < 0.0 || component.weight > 1.0 {
104 return Err(format!(
105 "Component {} weight must be between 0.0 and 1.0, got {}",
106 i, component.weight
107 ));
108 }
109 if component.sigma <= 0.0 {
110 return Err(format!(
111 "Component {} sigma must be positive, got {}",
112 i, component.sigma
113 ));
114 }
115 }
116
117 Ok(())
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct LogNormalComponent {
124 pub weight: f64,
126 pub mu: f64,
128 pub sigma: f64,
130 #[serde(default)]
132 pub label: Option<String>,
133}
134
135impl LogNormalComponent {
136 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
138 Self {
139 weight,
140 mu,
141 sigma,
142 label: None,
143 }
144 }
145
146 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
148 Self {
149 weight,
150 mu,
151 sigma,
152 label: Some(label.into()),
153 }
154 }
155
156 pub fn expected_value(&self) -> f64 {
158 (self.mu + self.sigma.powi(2) / 2.0).exp()
159 }
160
161 pub fn median(&self) -> f64 {
163 self.mu.exp()
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct LogNormalMixtureConfig {
170 pub components: Vec<LogNormalComponent>,
172 #[serde(default = "default_min_value")]
174 pub min_value: f64,
175 #[serde(default)]
177 pub max_value: Option<f64>,
178 #[serde(default = "default_decimal_places")]
180 pub decimal_places: u8,
181}
182
183fn default_min_value() -> f64 {
184 0.01
185}
186
187fn default_decimal_places() -> u8 {
188 2
189}
190
191impl Default for LogNormalMixtureConfig {
192 fn default() -> Self {
193 Self {
194 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
195 min_value: 0.01,
196 max_value: None,
197 decimal_places: 2,
198 }
199 }
200}
201
202impl LogNormalMixtureConfig {
203 pub fn new(components: Vec<LogNormalComponent>) -> Self {
205 Self {
206 components,
207 ..Default::default()
208 }
209 }
210
211 pub fn typical_transactions() -> Self {
213 Self {
214 components: vec![
215 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
216 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
217 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
218 ],
219 min_value: 0.01,
220 max_value: Some(100_000_000.0),
221 decimal_places: 2,
222 }
223 }
224
225 pub fn validate(&self) -> Result<(), String> {
227 if self.components.is_empty() {
228 return Err("At least one component is required".to_string());
229 }
230
231 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
232 if (weight_sum - 1.0).abs() > 0.01 {
233 return Err(format!(
234 "Component weights must sum to 1.0, got {weight_sum}"
235 ));
236 }
237
238 for (i, component) in self.components.iter().enumerate() {
239 if component.weight < 0.0 || component.weight > 1.0 {
240 return Err(format!(
241 "Component {} weight must be between 0.0 and 1.0, got {}",
242 i, component.weight
243 ));
244 }
245 if component.sigma <= 0.0 {
246 return Err(format!(
247 "Component {} sigma must be positive, got {}",
248 i, component.sigma
249 ));
250 }
251 }
252
253 if self.min_value < 0.0 {
254 return Err("min_value must be non-negative".to_string());
255 }
256
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct SampleWithComponent {
264 pub value: f64,
266 pub component_index: usize,
268 pub component_label: Option<String>,
270}
271
272#[derive(Clone)]
274pub struct GaussianMixtureSampler {
275 rng: ChaCha8Rng,
276 config: GaussianMixtureConfig,
277 cumulative_weights: Vec<f64>,
279 distributions: Vec<Normal<f64>>,
281}
282
283impl GaussianMixtureSampler {
284 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
286 config.validate()?;
287
288 let mut cumulative_weights = Vec::with_capacity(config.components.len());
290 let mut cumulative = 0.0;
291 for component in &config.components {
292 cumulative += component.weight;
293 cumulative_weights.push(cumulative);
294 }
295
296 let distributions: Result<Vec<_>, _> = config
298 .components
299 .iter()
300 .map(|c| {
301 Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
302 })
303 .collect();
304
305 Ok(Self {
306 rng: ChaCha8Rng::seed_from_u64(seed),
307 config,
308 cumulative_weights,
309 distributions: distributions?,
310 })
311 }
312
313 fn select_component(&mut self) -> usize {
315 let p: f64 = self.rng.random();
316 match self.cumulative_weights.binary_search_by(|w| {
317 w.partial_cmp(&p).unwrap_or_else(|| {
318 tracing::debug!("NaN detected in mixture weight comparison");
319 std::cmp::Ordering::Less
320 })
321 }) {
322 Ok(i) => i,
323 Err(i) => i.min(self.distributions.len() - 1),
324 }
325 }
326
327 pub fn sample(&mut self) -> f64 {
329 let component_idx = self.select_component();
330 let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332 if !self.config.allow_negative {
334 value = value.abs();
335 }
336 if let Some(min) = self.config.min_value {
337 value = value.max(min);
338 }
339 if let Some(max) = self.config.max_value {
340 value = value.min(max);
341 }
342
343 value
344 }
345
346 pub fn sample_with_component(&mut self) -> SampleWithComponent {
348 let component_idx = self.select_component();
349 let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351 if !self.config.allow_negative {
353 value = value.abs();
354 }
355 if let Some(min) = self.config.min_value {
356 value = value.max(min);
357 }
358 if let Some(max) = self.config.max_value {
359 value = value.min(max);
360 }
361
362 SampleWithComponent {
363 value,
364 component_index: component_idx,
365 component_label: self.config.components[component_idx].label.clone(),
366 }
367 }
368
369 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371 (0..n).map(|_| self.sample()).collect()
372 }
373
374 pub fn reset(&mut self, seed: u64) {
376 self.rng = ChaCha8Rng::seed_from_u64(seed);
377 }
378
379 pub fn config(&self) -> &GaussianMixtureConfig {
381 &self.config
382 }
383}
384
385#[derive(Clone)]
387pub struct LogNormalMixtureSampler {
388 rng: ChaCha8Rng,
389 config: LogNormalMixtureConfig,
390 cumulative_weights: Vec<f64>,
392 distributions: Vec<LogNormal<f64>>,
394 decimal_multiplier: f64,
396}
397
398impl LogNormalMixtureSampler {
399 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
401 config.validate()?;
402
403 let mut cumulative_weights = Vec::with_capacity(config.components.len());
405 let mut cumulative = 0.0;
406 for component in &config.components {
407 cumulative += component.weight;
408 cumulative_weights.push(cumulative);
409 }
410
411 let distributions: Result<Vec<_>, _> = config
413 .components
414 .iter()
415 .map(|c| {
416 LogNormal::new(c.mu, c.sigma)
417 .map_err(|e| format!("Invalid log-normal distribution: {e}"))
418 })
419 .collect();
420
421 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
422
423 Ok(Self {
424 rng: ChaCha8Rng::seed_from_u64(seed),
425 config,
426 cumulative_weights,
427 distributions: distributions?,
428 decimal_multiplier,
429 })
430 }
431
432 fn select_component(&mut self) -> usize {
434 let p: f64 = self.rng.random();
435 match self.cumulative_weights.binary_search_by(|w| {
436 w.partial_cmp(&p).unwrap_or_else(|| {
437 tracing::debug!("NaN detected in mixture weight comparison");
438 std::cmp::Ordering::Less
439 })
440 }) {
441 Ok(i) => i,
442 Err(i) => i.min(self.distributions.len() - 1),
443 }
444 }
445
446 pub fn sample(&mut self) -> f64 {
448 let component_idx = self.select_component();
449 let mut value = self.distributions[component_idx].sample(&mut self.rng);
450
451 value = value.max(self.config.min_value);
453 if let Some(max) = self.config.max_value {
454 value = value.min(max);
455 }
456
457 (value * self.decimal_multiplier).round() / self.decimal_multiplier
459 }
460
461 pub fn sample_decimal(&mut self) -> Decimal {
463 let value = self.sample();
464 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
465 }
466
467 pub fn sample_with_component(&mut self) -> SampleWithComponent {
469 let component_idx = self.select_component();
470 let mut value = self.distributions[component_idx].sample(&mut self.rng);
471
472 value = value.max(self.config.min_value);
474 if let Some(max) = self.config.max_value {
475 value = value.min(max);
476 }
477
478 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
480
481 SampleWithComponent {
482 value,
483 component_index: component_idx,
484 component_label: self.config.components[component_idx].label.clone(),
485 }
486 }
487
488 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
490 (0..n).map(|_| self.sample()).collect()
491 }
492
493 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
495 (0..n).map(|_| self.sample_decimal()).collect()
496 }
497
498 pub fn reset(&mut self, seed: u64) {
500 self.rng = ChaCha8Rng::seed_from_u64(seed);
501 }
502
503 pub fn config(&self) -> &LogNormalMixtureConfig {
505 &self.config
506 }
507
508 pub fn expected_value(&self) -> f64 {
510 self.config
511 .components
512 .iter()
513 .map(|c| c.weight * c.expected_value())
514 .sum()
515 }
516}
517
518#[cfg(test)]
519#[allow(clippy::unwrap_used)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_gaussian_mixture_validation() {
525 let config = GaussianMixtureConfig::new(vec![
527 GaussianComponent::new(0.5, 0.0, 1.0),
528 GaussianComponent::new(0.5, 5.0, 2.0),
529 ]);
530 assert!(config.validate().is_ok());
531
532 let invalid_config = GaussianMixtureConfig::new(vec![
534 GaussianComponent::new(0.3, 0.0, 1.0),
535 GaussianComponent::new(0.3, 5.0, 2.0),
536 ]);
537 assert!(invalid_config.validate().is_err());
538
539 let invalid_config =
541 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
542 assert!(invalid_config.validate().is_err());
543 }
544
545 #[test]
546 fn test_gaussian_mixture_sampling() {
547 let config = GaussianMixtureConfig::new(vec![
548 GaussianComponent::new(0.5, 0.0, 1.0),
549 GaussianComponent::new(0.5, 10.0, 1.0),
550 ]);
551 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
552
553 let samples = sampler.sample_n(1000);
554 assert_eq!(samples.len(), 1000);
555
556 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
558 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
559
560 assert!(low_count > 350 && low_count < 650);
562 assert!(high_count > 350 && high_count < 650);
563 }
564
565 #[test]
566 fn test_gaussian_mixture_determinism() {
567 let config = GaussianMixtureConfig::new(vec![
568 GaussianComponent::new(0.5, 0.0, 1.0),
569 GaussianComponent::new(0.5, 10.0, 1.0),
570 ]);
571
572 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
573 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
574
575 for _ in 0..100 {
576 assert_eq!(sampler1.sample(), sampler2.sample());
577 }
578 }
579
580 #[test]
581 fn test_lognormal_mixture_validation() {
582 let config = LogNormalMixtureConfig::new(vec![
584 LogNormalComponent::new(0.6, 6.0, 1.5),
585 LogNormalComponent::new(0.4, 8.5, 1.0),
586 ]);
587 assert!(config.validate().is_ok());
588
589 let invalid_config = LogNormalMixtureConfig::new(vec![
591 LogNormalComponent::new(0.2, 6.0, 1.5),
592 LogNormalComponent::new(0.2, 8.5, 1.0),
593 ]);
594 assert!(invalid_config.validate().is_err());
595 }
596
597 #[test]
598 fn test_lognormal_mixture_sampling() {
599 let config = LogNormalMixtureConfig::typical_transactions();
600 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
601
602 let samples = sampler.sample_n(1000);
603 assert_eq!(samples.len(), 1000);
604
605 assert!(samples.iter().all(|&x| x > 0.0));
607
608 assert!(samples.iter().all(|&x| x >= 0.01));
610 }
611
612 #[test]
613 fn test_sample_with_component() {
614 let config = LogNormalMixtureConfig::new(vec![
615 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
616 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
617 ]);
618 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
619
620 let mut small_count = 0;
621 let mut large_count = 0;
622
623 for _ in 0..1000 {
624 let result = sampler.sample_with_component();
625 match result.component_label.as_deref() {
626 Some("small") => small_count += 1,
627 Some("large") => large_count += 1,
628 _ => panic!("Unexpected label"),
629 }
630 }
631
632 assert!(small_count > 400 && small_count < 600);
634 assert!(large_count > 400 && large_count < 600);
635 }
636
637 #[test]
638 fn test_lognormal_mixture_determinism() {
639 let config = LogNormalMixtureConfig::typical_transactions();
640
641 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
642 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
643
644 for _ in 0..100 {
645 assert_eq!(sampler1.sample(), sampler2.sample());
646 }
647 }
648
649 #[test]
650 fn test_lognormal_expected_value() {
651 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
652 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
653
654 let expected = sampler.expected_value();
656 assert!((expected - 1808.04).abs() < 1.0);
657 }
658
659 #[test]
660 fn test_component_label() {
661 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
662 assert_eq!(component.label, Some("test_label".to_string()));
663
664 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
665 assert_eq!(component_no_label.label, None);
666 }
667
668 #[test]
669 fn test_max_value_constraint() {
670 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
671 config.max_value = Some(1000.0);
672
673 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
674 let samples = sampler.sample_n(1000);
675
676 assert!(samples.iter().all(|&x| x <= 1000.0));
678 }
679}