1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {}",
99 weight_sum
100 ));
101 }
102
103 for (i, component) in self.components.iter().enumerate() {
104 if component.weight < 0.0 || component.weight > 1.0 {
105 return Err(format!(
106 "Component {} weight must be between 0.0 and 1.0, got {}",
107 i, component.weight
108 ));
109 }
110 if component.sigma <= 0.0 {
111 return Err(format!(
112 "Component {} sigma must be positive, got {}",
113 i, component.sigma
114 ));
115 }
116 }
117
118 Ok(())
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct LogNormalComponent {
125 pub weight: f64,
127 pub mu: f64,
129 pub sigma: f64,
131 #[serde(default)]
133 pub label: Option<String>,
134}
135
136impl LogNormalComponent {
137 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
139 Self {
140 weight,
141 mu,
142 sigma,
143 label: None,
144 }
145 }
146
147 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
149 Self {
150 weight,
151 mu,
152 sigma,
153 label: Some(label.into()),
154 }
155 }
156
157 pub fn expected_value(&self) -> f64 {
159 (self.mu + self.sigma.powi(2) / 2.0).exp()
160 }
161
162 pub fn median(&self) -> f64 {
164 self.mu.exp()
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct LogNormalMixtureConfig {
171 pub components: Vec<LogNormalComponent>,
173 #[serde(default = "default_min_value")]
175 pub min_value: f64,
176 #[serde(default)]
178 pub max_value: Option<f64>,
179 #[serde(default = "default_decimal_places")]
181 pub decimal_places: u8,
182}
183
184fn default_min_value() -> f64 {
185 0.01
186}
187
188fn default_decimal_places() -> u8 {
189 2
190}
191
192impl Default for LogNormalMixtureConfig {
193 fn default() -> Self {
194 Self {
195 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
196 min_value: 0.01,
197 max_value: None,
198 decimal_places: 2,
199 }
200 }
201}
202
203impl LogNormalMixtureConfig {
204 pub fn new(components: Vec<LogNormalComponent>) -> Self {
206 Self {
207 components,
208 ..Default::default()
209 }
210 }
211
212 pub fn typical_transactions() -> Self {
214 Self {
215 components: vec![
216 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
217 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
218 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
219 ],
220 min_value: 0.01,
221 max_value: Some(100_000_000.0),
222 decimal_places: 2,
223 }
224 }
225
226 pub fn validate(&self) -> Result<(), String> {
228 if self.components.is_empty() {
229 return Err("At least one component is required".to_string());
230 }
231
232 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
233 if (weight_sum - 1.0).abs() > 0.01 {
234 return Err(format!(
235 "Component weights must sum to 1.0, got {}",
236 weight_sum
237 ));
238 }
239
240 for (i, component) in self.components.iter().enumerate() {
241 if component.weight < 0.0 || component.weight > 1.0 {
242 return Err(format!(
243 "Component {} weight must be between 0.0 and 1.0, got {}",
244 i, component.weight
245 ));
246 }
247 if component.sigma <= 0.0 {
248 return Err(format!(
249 "Component {} sigma must be positive, got {}",
250 i, component.sigma
251 ));
252 }
253 }
254
255 if self.min_value < 0.0 {
256 return Err("min_value must be non-negative".to_string());
257 }
258
259 Ok(())
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct SampleWithComponent {
266 pub value: f64,
268 pub component_index: usize,
270 pub component_label: Option<String>,
272}
273
274pub struct GaussianMixtureSampler {
276 rng: ChaCha8Rng,
277 config: GaussianMixtureConfig,
278 cumulative_weights: Vec<f64>,
280 distributions: Vec<Normal<f64>>,
282}
283
284impl GaussianMixtureSampler {
285 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
287 config.validate()?;
288
289 let mut cumulative_weights = Vec::with_capacity(config.components.len());
291 let mut cumulative = 0.0;
292 for component in &config.components {
293 cumulative += component.weight;
294 cumulative_weights.push(cumulative);
295 }
296
297 let distributions: Result<Vec<_>, _> = config
299 .components
300 .iter()
301 .map(|c| {
302 Normal::new(c.mu, c.sigma)
303 .map_err(|e| format!("Invalid normal distribution: {}", e))
304 })
305 .collect();
306
307 Ok(Self {
308 rng: ChaCha8Rng::seed_from_u64(seed),
309 config,
310 cumulative_weights,
311 distributions: distributions?,
312 })
313 }
314
315 fn select_component(&mut self) -> usize {
317 let p: f64 = self.rng.gen();
318 match self
319 .cumulative_weights
320 .binary_search_by(|w| w.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
321 {
322 Ok(i) => i,
323 Err(i) => i.min(self.distributions.len() - 1),
324 }
325 }
326
327 pub fn sample(&mut self) -> f64 {
329 let component_idx = self.select_component();
330 let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332 if !self.config.allow_negative {
334 value = value.abs();
335 }
336 if let Some(min) = self.config.min_value {
337 value = value.max(min);
338 }
339 if let Some(max) = self.config.max_value {
340 value = value.min(max);
341 }
342
343 value
344 }
345
346 pub fn sample_with_component(&mut self) -> SampleWithComponent {
348 let component_idx = self.select_component();
349 let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351 if !self.config.allow_negative {
353 value = value.abs();
354 }
355 if let Some(min) = self.config.min_value {
356 value = value.max(min);
357 }
358 if let Some(max) = self.config.max_value {
359 value = value.min(max);
360 }
361
362 SampleWithComponent {
363 value,
364 component_index: component_idx,
365 component_label: self.config.components[component_idx].label.clone(),
366 }
367 }
368
369 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371 (0..n).map(|_| self.sample()).collect()
372 }
373
374 pub fn reset(&mut self, seed: u64) {
376 self.rng = ChaCha8Rng::seed_from_u64(seed);
377 }
378
379 pub fn config(&self) -> &GaussianMixtureConfig {
381 &self.config
382 }
383}
384
385pub struct LogNormalMixtureSampler {
387 rng: ChaCha8Rng,
388 config: LogNormalMixtureConfig,
389 cumulative_weights: Vec<f64>,
391 distributions: Vec<LogNormal<f64>>,
393 decimal_multiplier: f64,
395}
396
397impl LogNormalMixtureSampler {
398 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
400 config.validate()?;
401
402 let mut cumulative_weights = Vec::with_capacity(config.components.len());
404 let mut cumulative = 0.0;
405 for component in &config.components {
406 cumulative += component.weight;
407 cumulative_weights.push(cumulative);
408 }
409
410 let distributions: Result<Vec<_>, _> = config
412 .components
413 .iter()
414 .map(|c| {
415 LogNormal::new(c.mu, c.sigma)
416 .map_err(|e| format!("Invalid log-normal distribution: {}", e))
417 })
418 .collect();
419
420 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
421
422 Ok(Self {
423 rng: ChaCha8Rng::seed_from_u64(seed),
424 config,
425 cumulative_weights,
426 distributions: distributions?,
427 decimal_multiplier,
428 })
429 }
430
431 fn select_component(&mut self) -> usize {
433 let p: f64 = self.rng.gen();
434 match self
435 .cumulative_weights
436 .binary_search_by(|w| w.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
437 {
438 Ok(i) => i,
439 Err(i) => i.min(self.distributions.len() - 1),
440 }
441 }
442
443 pub fn sample(&mut self) -> f64 {
445 let component_idx = self.select_component();
446 let mut value = self.distributions[component_idx].sample(&mut self.rng);
447
448 value = value.max(self.config.min_value);
450 if let Some(max) = self.config.max_value {
451 value = value.min(max);
452 }
453
454 (value * self.decimal_multiplier).round() / self.decimal_multiplier
456 }
457
458 pub fn sample_decimal(&mut self) -> Decimal {
460 let value = self.sample();
461 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
462 }
463
464 pub fn sample_with_component(&mut self) -> SampleWithComponent {
466 let component_idx = self.select_component();
467 let mut value = self.distributions[component_idx].sample(&mut self.rng);
468
469 value = value.max(self.config.min_value);
471 if let Some(max) = self.config.max_value {
472 value = value.min(max);
473 }
474
475 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
477
478 SampleWithComponent {
479 value,
480 component_index: component_idx,
481 component_label: self.config.components[component_idx].label.clone(),
482 }
483 }
484
485 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
487 (0..n).map(|_| self.sample()).collect()
488 }
489
490 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
492 (0..n).map(|_| self.sample_decimal()).collect()
493 }
494
495 pub fn reset(&mut self, seed: u64) {
497 self.rng = ChaCha8Rng::seed_from_u64(seed);
498 }
499
500 pub fn config(&self) -> &LogNormalMixtureConfig {
502 &self.config
503 }
504
505 pub fn expected_value(&self) -> f64 {
507 self.config
508 .components
509 .iter()
510 .map(|c| c.weight * c.expected_value())
511 .sum()
512 }
513}
514
515#[cfg(test)]
516#[allow(clippy::unwrap_used)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_gaussian_mixture_validation() {
522 let config = GaussianMixtureConfig::new(vec![
524 GaussianComponent::new(0.5, 0.0, 1.0),
525 GaussianComponent::new(0.5, 5.0, 2.0),
526 ]);
527 assert!(config.validate().is_ok());
528
529 let invalid_config = GaussianMixtureConfig::new(vec![
531 GaussianComponent::new(0.3, 0.0, 1.0),
532 GaussianComponent::new(0.3, 5.0, 2.0),
533 ]);
534 assert!(invalid_config.validate().is_err());
535
536 let invalid_config =
538 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
539 assert!(invalid_config.validate().is_err());
540 }
541
542 #[test]
543 fn test_gaussian_mixture_sampling() {
544 let config = GaussianMixtureConfig::new(vec![
545 GaussianComponent::new(0.5, 0.0, 1.0),
546 GaussianComponent::new(0.5, 10.0, 1.0),
547 ]);
548 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
549
550 let samples = sampler.sample_n(1000);
551 assert_eq!(samples.len(), 1000);
552
553 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
555 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
556
557 assert!(low_count > 350 && low_count < 650);
559 assert!(high_count > 350 && high_count < 650);
560 }
561
562 #[test]
563 fn test_gaussian_mixture_determinism() {
564 let config = GaussianMixtureConfig::new(vec![
565 GaussianComponent::new(0.5, 0.0, 1.0),
566 GaussianComponent::new(0.5, 10.0, 1.0),
567 ]);
568
569 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
570 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
571
572 for _ in 0..100 {
573 assert_eq!(sampler1.sample(), sampler2.sample());
574 }
575 }
576
577 #[test]
578 fn test_lognormal_mixture_validation() {
579 let config = LogNormalMixtureConfig::new(vec![
581 LogNormalComponent::new(0.6, 6.0, 1.5),
582 LogNormalComponent::new(0.4, 8.5, 1.0),
583 ]);
584 assert!(config.validate().is_ok());
585
586 let invalid_config = LogNormalMixtureConfig::new(vec![
588 LogNormalComponent::new(0.2, 6.0, 1.5),
589 LogNormalComponent::new(0.2, 8.5, 1.0),
590 ]);
591 assert!(invalid_config.validate().is_err());
592 }
593
594 #[test]
595 fn test_lognormal_mixture_sampling() {
596 let config = LogNormalMixtureConfig::typical_transactions();
597 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
598
599 let samples = sampler.sample_n(1000);
600 assert_eq!(samples.len(), 1000);
601
602 assert!(samples.iter().all(|&x| x > 0.0));
604
605 assert!(samples.iter().all(|&x| x >= 0.01));
607 }
608
609 #[test]
610 fn test_sample_with_component() {
611 let config = LogNormalMixtureConfig::new(vec![
612 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
613 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
614 ]);
615 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
616
617 let mut small_count = 0;
618 let mut large_count = 0;
619
620 for _ in 0..1000 {
621 let result = sampler.sample_with_component();
622 match result.component_label.as_deref() {
623 Some("small") => small_count += 1,
624 Some("large") => large_count += 1,
625 _ => panic!("Unexpected label"),
626 }
627 }
628
629 assert!(small_count > 400 && small_count < 600);
631 assert!(large_count > 400 && large_count < 600);
632 }
633
634 #[test]
635 fn test_lognormal_mixture_determinism() {
636 let config = LogNormalMixtureConfig::typical_transactions();
637
638 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
639 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
640
641 for _ in 0..100 {
642 assert_eq!(sampler1.sample(), sampler2.sample());
643 }
644 }
645
646 #[test]
647 fn test_lognormal_expected_value() {
648 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
649 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
650
651 let expected = sampler.expected_value();
653 assert!((expected - 1808.04).abs() < 1.0);
654 }
655
656 #[test]
657 fn test_component_label() {
658 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
659 assert_eq!(component.label, Some("test_label".to_string()));
660
661 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
662 assert_eq!(component_no_label.label, None);
663 }
664
665 #[test]
666 fn test_max_value_constraint() {
667 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
668 config.max_value = Some(1000.0);
669
670 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
671 let samples = sampler.sample_n(1000);
672
673 assert!(samples.iter().all(|&x| x <= 1000.0));
675 }
676}