1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {}",
99 weight_sum
100 ));
101 }
102
103 for (i, component) in self.components.iter().enumerate() {
104 if component.weight < 0.0 || component.weight > 1.0 {
105 return Err(format!(
106 "Component {} weight must be between 0.0 and 1.0, got {}",
107 i, component.weight
108 ));
109 }
110 if component.sigma <= 0.0 {
111 return Err(format!(
112 "Component {} sigma must be positive, got {}",
113 i, component.sigma
114 ));
115 }
116 }
117
118 Ok(())
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct LogNormalComponent {
125 pub weight: f64,
127 pub mu: f64,
129 pub sigma: f64,
131 #[serde(default)]
133 pub label: Option<String>,
134}
135
136impl LogNormalComponent {
137 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
139 Self {
140 weight,
141 mu,
142 sigma,
143 label: None,
144 }
145 }
146
147 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
149 Self {
150 weight,
151 mu,
152 sigma,
153 label: Some(label.into()),
154 }
155 }
156
157 pub fn expected_value(&self) -> f64 {
159 (self.mu + self.sigma.powi(2) / 2.0).exp()
160 }
161
162 pub fn median(&self) -> f64 {
164 self.mu.exp()
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct LogNormalMixtureConfig {
171 pub components: Vec<LogNormalComponent>,
173 #[serde(default = "default_min_value")]
175 pub min_value: f64,
176 #[serde(default)]
178 pub max_value: Option<f64>,
179 #[serde(default = "default_decimal_places")]
181 pub decimal_places: u8,
182}
183
184fn default_min_value() -> f64 {
185 0.01
186}
187
188fn default_decimal_places() -> u8 {
189 2
190}
191
192impl Default for LogNormalMixtureConfig {
193 fn default() -> Self {
194 Self {
195 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
196 min_value: 0.01,
197 max_value: None,
198 decimal_places: 2,
199 }
200 }
201}
202
203impl LogNormalMixtureConfig {
204 pub fn new(components: Vec<LogNormalComponent>) -> Self {
206 Self {
207 components,
208 ..Default::default()
209 }
210 }
211
212 pub fn typical_transactions() -> Self {
214 Self {
215 components: vec![
216 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
217 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
218 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
219 ],
220 min_value: 0.01,
221 max_value: Some(100_000_000.0),
222 decimal_places: 2,
223 }
224 }
225
226 pub fn validate(&self) -> Result<(), String> {
228 if self.components.is_empty() {
229 return Err("At least one component is required".to_string());
230 }
231
232 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
233 if (weight_sum - 1.0).abs() > 0.01 {
234 return Err(format!(
235 "Component weights must sum to 1.0, got {}",
236 weight_sum
237 ));
238 }
239
240 for (i, component) in self.components.iter().enumerate() {
241 if component.weight < 0.0 || component.weight > 1.0 {
242 return Err(format!(
243 "Component {} weight must be between 0.0 and 1.0, got {}",
244 i, component.weight
245 ));
246 }
247 if component.sigma <= 0.0 {
248 return Err(format!(
249 "Component {} sigma must be positive, got {}",
250 i, component.sigma
251 ));
252 }
253 }
254
255 if self.min_value < 0.0 {
256 return Err("min_value must be non-negative".to_string());
257 }
258
259 Ok(())
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct SampleWithComponent {
266 pub value: f64,
268 pub component_index: usize,
270 pub component_label: Option<String>,
272}
273
274pub struct GaussianMixtureSampler {
276 rng: ChaCha8Rng,
277 config: GaussianMixtureConfig,
278 cumulative_weights: Vec<f64>,
280 distributions: Vec<Normal<f64>>,
282}
283
284impl GaussianMixtureSampler {
285 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
287 config.validate()?;
288
289 let mut cumulative_weights = Vec::with_capacity(config.components.len());
291 let mut cumulative = 0.0;
292 for component in &config.components {
293 cumulative += component.weight;
294 cumulative_weights.push(cumulative);
295 }
296
297 let distributions: Result<Vec<_>, _> = config
299 .components
300 .iter()
301 .map(|c| {
302 Normal::new(c.mu, c.sigma)
303 .map_err(|e| format!("Invalid normal distribution: {}", e))
304 })
305 .collect();
306
307 Ok(Self {
308 rng: ChaCha8Rng::seed_from_u64(seed),
309 config,
310 cumulative_weights,
311 distributions: distributions?,
312 })
313 }
314
315 fn select_component(&mut self) -> usize {
317 let p: f64 = self.rng.random();
318 match self.cumulative_weights.binary_search_by(|w| {
319 w.partial_cmp(&p).unwrap_or_else(|| {
320 tracing::debug!("NaN detected in mixture weight comparison");
321 std::cmp::Ordering::Less
322 })
323 }) {
324 Ok(i) => i,
325 Err(i) => i.min(self.distributions.len() - 1),
326 }
327 }
328
329 pub fn sample(&mut self) -> f64 {
331 let component_idx = self.select_component();
332 let mut value = self.distributions[component_idx].sample(&mut self.rng);
333
334 if !self.config.allow_negative {
336 value = value.abs();
337 }
338 if let Some(min) = self.config.min_value {
339 value = value.max(min);
340 }
341 if let Some(max) = self.config.max_value {
342 value = value.min(max);
343 }
344
345 value
346 }
347
348 pub fn sample_with_component(&mut self) -> SampleWithComponent {
350 let component_idx = self.select_component();
351 let mut value = self.distributions[component_idx].sample(&mut self.rng);
352
353 if !self.config.allow_negative {
355 value = value.abs();
356 }
357 if let Some(min) = self.config.min_value {
358 value = value.max(min);
359 }
360 if let Some(max) = self.config.max_value {
361 value = value.min(max);
362 }
363
364 SampleWithComponent {
365 value,
366 component_index: component_idx,
367 component_label: self.config.components[component_idx].label.clone(),
368 }
369 }
370
371 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
373 (0..n).map(|_| self.sample()).collect()
374 }
375
376 pub fn reset(&mut self, seed: u64) {
378 self.rng = ChaCha8Rng::seed_from_u64(seed);
379 }
380
381 pub fn config(&self) -> &GaussianMixtureConfig {
383 &self.config
384 }
385}
386
387pub struct LogNormalMixtureSampler {
389 rng: ChaCha8Rng,
390 config: LogNormalMixtureConfig,
391 cumulative_weights: Vec<f64>,
393 distributions: Vec<LogNormal<f64>>,
395 decimal_multiplier: f64,
397}
398
399impl LogNormalMixtureSampler {
400 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
402 config.validate()?;
403
404 let mut cumulative_weights = Vec::with_capacity(config.components.len());
406 let mut cumulative = 0.0;
407 for component in &config.components {
408 cumulative += component.weight;
409 cumulative_weights.push(cumulative);
410 }
411
412 let distributions: Result<Vec<_>, _> = config
414 .components
415 .iter()
416 .map(|c| {
417 LogNormal::new(c.mu, c.sigma)
418 .map_err(|e| format!("Invalid log-normal distribution: {}", e))
419 })
420 .collect();
421
422 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
423
424 Ok(Self {
425 rng: ChaCha8Rng::seed_from_u64(seed),
426 config,
427 cumulative_weights,
428 distributions: distributions?,
429 decimal_multiplier,
430 })
431 }
432
433 fn select_component(&mut self) -> usize {
435 let p: f64 = self.rng.random();
436 match self.cumulative_weights.binary_search_by(|w| {
437 w.partial_cmp(&p).unwrap_or_else(|| {
438 tracing::debug!("NaN detected in mixture weight comparison");
439 std::cmp::Ordering::Less
440 })
441 }) {
442 Ok(i) => i,
443 Err(i) => i.min(self.distributions.len() - 1),
444 }
445 }
446
447 pub fn sample(&mut self) -> f64 {
449 let component_idx = self.select_component();
450 let mut value = self.distributions[component_idx].sample(&mut self.rng);
451
452 value = value.max(self.config.min_value);
454 if let Some(max) = self.config.max_value {
455 value = value.min(max);
456 }
457
458 (value * self.decimal_multiplier).round() / self.decimal_multiplier
460 }
461
462 pub fn sample_decimal(&mut self) -> Decimal {
464 let value = self.sample();
465 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
466 }
467
468 pub fn sample_with_component(&mut self) -> SampleWithComponent {
470 let component_idx = self.select_component();
471 let mut value = self.distributions[component_idx].sample(&mut self.rng);
472
473 value = value.max(self.config.min_value);
475 if let Some(max) = self.config.max_value {
476 value = value.min(max);
477 }
478
479 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
481
482 SampleWithComponent {
483 value,
484 component_index: component_idx,
485 component_label: self.config.components[component_idx].label.clone(),
486 }
487 }
488
489 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
491 (0..n).map(|_| self.sample()).collect()
492 }
493
494 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
496 (0..n).map(|_| self.sample_decimal()).collect()
497 }
498
499 pub fn reset(&mut self, seed: u64) {
501 self.rng = ChaCha8Rng::seed_from_u64(seed);
502 }
503
504 pub fn config(&self) -> &LogNormalMixtureConfig {
506 &self.config
507 }
508
509 pub fn expected_value(&self) -> f64 {
511 self.config
512 .components
513 .iter()
514 .map(|c| c.weight * c.expected_value())
515 .sum()
516 }
517}
518
519#[cfg(test)]
520#[allow(clippy::unwrap_used)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_gaussian_mixture_validation() {
526 let config = GaussianMixtureConfig::new(vec![
528 GaussianComponent::new(0.5, 0.0, 1.0),
529 GaussianComponent::new(0.5, 5.0, 2.0),
530 ]);
531 assert!(config.validate().is_ok());
532
533 let invalid_config = GaussianMixtureConfig::new(vec![
535 GaussianComponent::new(0.3, 0.0, 1.0),
536 GaussianComponent::new(0.3, 5.0, 2.0),
537 ]);
538 assert!(invalid_config.validate().is_err());
539
540 let invalid_config =
542 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
543 assert!(invalid_config.validate().is_err());
544 }
545
546 #[test]
547 fn test_gaussian_mixture_sampling() {
548 let config = GaussianMixtureConfig::new(vec![
549 GaussianComponent::new(0.5, 0.0, 1.0),
550 GaussianComponent::new(0.5, 10.0, 1.0),
551 ]);
552 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
553
554 let samples = sampler.sample_n(1000);
555 assert_eq!(samples.len(), 1000);
556
557 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
559 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
560
561 assert!(low_count > 350 && low_count < 650);
563 assert!(high_count > 350 && high_count < 650);
564 }
565
566 #[test]
567 fn test_gaussian_mixture_determinism() {
568 let config = GaussianMixtureConfig::new(vec![
569 GaussianComponent::new(0.5, 0.0, 1.0),
570 GaussianComponent::new(0.5, 10.0, 1.0),
571 ]);
572
573 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
574 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
575
576 for _ in 0..100 {
577 assert_eq!(sampler1.sample(), sampler2.sample());
578 }
579 }
580
581 #[test]
582 fn test_lognormal_mixture_validation() {
583 let config = LogNormalMixtureConfig::new(vec![
585 LogNormalComponent::new(0.6, 6.0, 1.5),
586 LogNormalComponent::new(0.4, 8.5, 1.0),
587 ]);
588 assert!(config.validate().is_ok());
589
590 let invalid_config = LogNormalMixtureConfig::new(vec![
592 LogNormalComponent::new(0.2, 6.0, 1.5),
593 LogNormalComponent::new(0.2, 8.5, 1.0),
594 ]);
595 assert!(invalid_config.validate().is_err());
596 }
597
598 #[test]
599 fn test_lognormal_mixture_sampling() {
600 let config = LogNormalMixtureConfig::typical_transactions();
601 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
602
603 let samples = sampler.sample_n(1000);
604 assert_eq!(samples.len(), 1000);
605
606 assert!(samples.iter().all(|&x| x > 0.0));
608
609 assert!(samples.iter().all(|&x| x >= 0.01));
611 }
612
613 #[test]
614 fn test_sample_with_component() {
615 let config = LogNormalMixtureConfig::new(vec![
616 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
617 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
618 ]);
619 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
620
621 let mut small_count = 0;
622 let mut large_count = 0;
623
624 for _ in 0..1000 {
625 let result = sampler.sample_with_component();
626 match result.component_label.as_deref() {
627 Some("small") => small_count += 1,
628 Some("large") => large_count += 1,
629 _ => panic!("Unexpected label"),
630 }
631 }
632
633 assert!(small_count > 400 && small_count < 600);
635 assert!(large_count > 400 && large_count < 600);
636 }
637
638 #[test]
639 fn test_lognormal_mixture_determinism() {
640 let config = LogNormalMixtureConfig::typical_transactions();
641
642 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
643 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
644
645 for _ in 0..100 {
646 assert_eq!(sampler1.sample(), sampler2.sample());
647 }
648 }
649
650 #[test]
651 fn test_lognormal_expected_value() {
652 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
653 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
654
655 let expected = sampler.expected_value();
657 assert!((expected - 1808.04).abs() < 1.0);
658 }
659
660 #[test]
661 fn test_component_label() {
662 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
663 assert_eq!(component.label, Some("test_label".to_string()));
664
665 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
666 assert_eq!(component_no_label.label, None);
667 }
668
669 #[test]
670 fn test_max_value_constraint() {
671 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
672 config.max_value = Some(1000.0);
673
674 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
675 let samples = sampler.sample_n(1000);
676
677 assert!(samples.iter().all(|&x| x <= 1000.0));
679 }
680}