1use rand::prelude::*;
8use rand_chacha::ChaCha8Rng;
9use rand_distr::{Distribution, LogNormal, Normal};
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GaussianComponent {
16 pub weight: f64,
18 pub mu: f64,
20 pub sigma: f64,
22 #[serde(default)]
24 pub label: Option<String>,
25}
26
27impl GaussianComponent {
28 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
30 Self {
31 weight,
32 mu,
33 sigma,
34 label: None,
35 }
36 }
37
38 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
40 Self {
41 weight,
42 mu,
43 sigma,
44 label: Some(label.into()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct GaussianMixtureConfig {
52 pub components: Vec<GaussianComponent>,
54 #[serde(default = "default_true")]
56 pub allow_negative: bool,
57 #[serde(default)]
59 pub min_value: Option<f64>,
60 #[serde(default)]
62 pub max_value: Option<f64>,
63}
64
65fn default_true() -> bool {
66 true
67}
68
69impl Default for GaussianMixtureConfig {
70 fn default() -> Self {
71 Self {
72 components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
73 allow_negative: true,
74 min_value: None,
75 max_value: None,
76 }
77 }
78}
79
80impl GaussianMixtureConfig {
81 pub fn new(components: Vec<GaussianComponent>) -> Self {
83 Self {
84 components,
85 ..Default::default()
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.components.is_empty() {
92 return Err("At least one component is required".to_string());
93 }
94
95 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
96 if (weight_sum - 1.0).abs() > 0.01 {
97 return Err(format!(
98 "Component weights must sum to 1.0, got {}",
99 weight_sum
100 ));
101 }
102
103 for (i, component) in self.components.iter().enumerate() {
104 if component.weight < 0.0 || component.weight > 1.0 {
105 return Err(format!(
106 "Component {} weight must be between 0.0 and 1.0, got {}",
107 i, component.weight
108 ));
109 }
110 if component.sigma <= 0.0 {
111 return Err(format!(
112 "Component {} sigma must be positive, got {}",
113 i, component.sigma
114 ));
115 }
116 }
117
118 Ok(())
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct LogNormalComponent {
125 pub weight: f64,
127 pub mu: f64,
129 pub sigma: f64,
131 #[serde(default)]
133 pub label: Option<String>,
134}
135
136impl LogNormalComponent {
137 pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
139 Self {
140 weight,
141 mu,
142 sigma,
143 label: None,
144 }
145 }
146
147 pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
149 Self {
150 weight,
151 mu,
152 sigma,
153 label: Some(label.into()),
154 }
155 }
156
157 pub fn expected_value(&self) -> f64 {
159 (self.mu + self.sigma.powi(2) / 2.0).exp()
160 }
161
162 pub fn median(&self) -> f64 {
164 self.mu.exp()
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct LogNormalMixtureConfig {
171 pub components: Vec<LogNormalComponent>,
173 #[serde(default = "default_min_value")]
175 pub min_value: f64,
176 #[serde(default)]
178 pub max_value: Option<f64>,
179 #[serde(default = "default_decimal_places")]
181 pub decimal_places: u8,
182}
183
184fn default_min_value() -> f64 {
185 0.01
186}
187
188fn default_decimal_places() -> u8 {
189 2
190}
191
192impl Default for LogNormalMixtureConfig {
193 fn default() -> Self {
194 Self {
195 components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
196 min_value: 0.01,
197 max_value: None,
198 decimal_places: 2,
199 }
200 }
201}
202
203impl LogNormalMixtureConfig {
204 pub fn new(components: Vec<LogNormalComponent>) -> Self {
206 Self {
207 components,
208 ..Default::default()
209 }
210 }
211
212 pub fn typical_transactions() -> Self {
214 Self {
215 components: vec![
216 LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
217 LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
218 LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
219 ],
220 min_value: 0.01,
221 max_value: Some(100_000_000.0),
222 decimal_places: 2,
223 }
224 }
225
226 pub fn validate(&self) -> Result<(), String> {
228 if self.components.is_empty() {
229 return Err("At least one component is required".to_string());
230 }
231
232 let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
233 if (weight_sum - 1.0).abs() > 0.01 {
234 return Err(format!(
235 "Component weights must sum to 1.0, got {}",
236 weight_sum
237 ));
238 }
239
240 for (i, component) in self.components.iter().enumerate() {
241 if component.weight < 0.0 || component.weight > 1.0 {
242 return Err(format!(
243 "Component {} weight must be between 0.0 and 1.0, got {}",
244 i, component.weight
245 ));
246 }
247 if component.sigma <= 0.0 {
248 return Err(format!(
249 "Component {} sigma must be positive, got {}",
250 i, component.sigma
251 ));
252 }
253 }
254
255 if self.min_value < 0.0 {
256 return Err("min_value must be non-negative".to_string());
257 }
258
259 Ok(())
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct SampleWithComponent {
266 pub value: f64,
268 pub component_index: usize,
270 pub component_label: Option<String>,
272}
273
274pub struct GaussianMixtureSampler {
276 rng: ChaCha8Rng,
277 config: GaussianMixtureConfig,
278 cumulative_weights: Vec<f64>,
280 distributions: Vec<Normal<f64>>,
282}
283
284impl GaussianMixtureSampler {
285 pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
287 config.validate()?;
288
289 let mut cumulative_weights = Vec::with_capacity(config.components.len());
291 let mut cumulative = 0.0;
292 for component in &config.components {
293 cumulative += component.weight;
294 cumulative_weights.push(cumulative);
295 }
296
297 let distributions: Result<Vec<_>, _> = config
299 .components
300 .iter()
301 .map(|c| {
302 Normal::new(c.mu, c.sigma)
303 .map_err(|e| format!("Invalid normal distribution: {}", e))
304 })
305 .collect();
306
307 Ok(Self {
308 rng: ChaCha8Rng::seed_from_u64(seed),
309 config,
310 cumulative_weights,
311 distributions: distributions?,
312 })
313 }
314
315 fn select_component(&mut self) -> usize {
317 let p: f64 = self.rng.gen();
318 match self
319 .cumulative_weights
320 .binary_search_by(|w| w.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
321 {
322 Ok(i) => i,
323 Err(i) => i.min(self.distributions.len() - 1),
324 }
325 }
326
327 pub fn sample(&mut self) -> f64 {
329 let component_idx = self.select_component();
330 let mut value = self.distributions[component_idx].sample(&mut self.rng);
331
332 if !self.config.allow_negative {
334 value = value.abs();
335 }
336 if let Some(min) = self.config.min_value {
337 value = value.max(min);
338 }
339 if let Some(max) = self.config.max_value {
340 value = value.min(max);
341 }
342
343 value
344 }
345
346 pub fn sample_with_component(&mut self) -> SampleWithComponent {
348 let component_idx = self.select_component();
349 let mut value = self.distributions[component_idx].sample(&mut self.rng);
350
351 if !self.config.allow_negative {
353 value = value.abs();
354 }
355 if let Some(min) = self.config.min_value {
356 value = value.max(min);
357 }
358 if let Some(max) = self.config.max_value {
359 value = value.min(max);
360 }
361
362 SampleWithComponent {
363 value,
364 component_index: component_idx,
365 component_label: self.config.components[component_idx].label.clone(),
366 }
367 }
368
369 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
371 (0..n).map(|_| self.sample()).collect()
372 }
373
374 pub fn reset(&mut self, seed: u64) {
376 self.rng = ChaCha8Rng::seed_from_u64(seed);
377 }
378
379 pub fn config(&self) -> &GaussianMixtureConfig {
381 &self.config
382 }
383}
384
385pub struct LogNormalMixtureSampler {
387 rng: ChaCha8Rng,
388 config: LogNormalMixtureConfig,
389 cumulative_weights: Vec<f64>,
391 distributions: Vec<LogNormal<f64>>,
393 decimal_multiplier: f64,
395}
396
397impl LogNormalMixtureSampler {
398 pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
400 config.validate()?;
401
402 let mut cumulative_weights = Vec::with_capacity(config.components.len());
404 let mut cumulative = 0.0;
405 for component in &config.components {
406 cumulative += component.weight;
407 cumulative_weights.push(cumulative);
408 }
409
410 let distributions: Result<Vec<_>, _> = config
412 .components
413 .iter()
414 .map(|c| {
415 LogNormal::new(c.mu, c.sigma)
416 .map_err(|e| format!("Invalid log-normal distribution: {}", e))
417 })
418 .collect();
419
420 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
421
422 Ok(Self {
423 rng: ChaCha8Rng::seed_from_u64(seed),
424 config,
425 cumulative_weights,
426 distributions: distributions?,
427 decimal_multiplier,
428 })
429 }
430
431 fn select_component(&mut self) -> usize {
433 let p: f64 = self.rng.gen();
434 match self
435 .cumulative_weights
436 .binary_search_by(|w| w.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
437 {
438 Ok(i) => i,
439 Err(i) => i.min(self.distributions.len() - 1),
440 }
441 }
442
443 pub fn sample(&mut self) -> f64 {
445 let component_idx = self.select_component();
446 let mut value = self.distributions[component_idx].sample(&mut self.rng);
447
448 value = value.max(self.config.min_value);
450 if let Some(max) = self.config.max_value {
451 value = value.min(max);
452 }
453
454 (value * self.decimal_multiplier).round() / self.decimal_multiplier
456 }
457
458 pub fn sample_decimal(&mut self) -> Decimal {
460 let value = self.sample();
461 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
462 }
463
464 pub fn sample_with_component(&mut self) -> SampleWithComponent {
466 let component_idx = self.select_component();
467 let mut value = self.distributions[component_idx].sample(&mut self.rng);
468
469 value = value.max(self.config.min_value);
471 if let Some(max) = self.config.max_value {
472 value = value.min(max);
473 }
474
475 value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
477
478 SampleWithComponent {
479 value,
480 component_index: component_idx,
481 component_label: self.config.components[component_idx].label.clone(),
482 }
483 }
484
485 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
487 (0..n).map(|_| self.sample()).collect()
488 }
489
490 pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
492 (0..n).map(|_| self.sample_decimal()).collect()
493 }
494
495 pub fn reset(&mut self, seed: u64) {
497 self.rng = ChaCha8Rng::seed_from_u64(seed);
498 }
499
500 pub fn config(&self) -> &LogNormalMixtureConfig {
502 &self.config
503 }
504
505 pub fn expected_value(&self) -> f64 {
507 self.config
508 .components
509 .iter()
510 .map(|c| c.weight * c.expected_value())
511 .sum()
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_gaussian_mixture_validation() {
521 let config = GaussianMixtureConfig::new(vec![
523 GaussianComponent::new(0.5, 0.0, 1.0),
524 GaussianComponent::new(0.5, 5.0, 2.0),
525 ]);
526 assert!(config.validate().is_ok());
527
528 let invalid_config = GaussianMixtureConfig::new(vec![
530 GaussianComponent::new(0.3, 0.0, 1.0),
531 GaussianComponent::new(0.3, 5.0, 2.0),
532 ]);
533 assert!(invalid_config.validate().is_err());
534
535 let invalid_config =
537 GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
538 assert!(invalid_config.validate().is_err());
539 }
540
541 #[test]
542 fn test_gaussian_mixture_sampling() {
543 let config = GaussianMixtureConfig::new(vec![
544 GaussianComponent::new(0.5, 0.0, 1.0),
545 GaussianComponent::new(0.5, 10.0, 1.0),
546 ]);
547 let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
548
549 let samples = sampler.sample_n(1000);
550 assert_eq!(samples.len(), 1000);
551
552 let low_count = samples.iter().filter(|&&x| x < 5.0).count();
554 let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
555
556 assert!(low_count > 350 && low_count < 650);
558 assert!(high_count > 350 && high_count < 650);
559 }
560
561 #[test]
562 fn test_gaussian_mixture_determinism() {
563 let config = GaussianMixtureConfig::new(vec![
564 GaussianComponent::new(0.5, 0.0, 1.0),
565 GaussianComponent::new(0.5, 10.0, 1.0),
566 ]);
567
568 let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
569 let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
570
571 for _ in 0..100 {
572 assert_eq!(sampler1.sample(), sampler2.sample());
573 }
574 }
575
576 #[test]
577 fn test_lognormal_mixture_validation() {
578 let config = LogNormalMixtureConfig::new(vec![
580 LogNormalComponent::new(0.6, 6.0, 1.5),
581 LogNormalComponent::new(0.4, 8.5, 1.0),
582 ]);
583 assert!(config.validate().is_ok());
584
585 let invalid_config = LogNormalMixtureConfig::new(vec![
587 LogNormalComponent::new(0.2, 6.0, 1.5),
588 LogNormalComponent::new(0.2, 8.5, 1.0),
589 ]);
590 assert!(invalid_config.validate().is_err());
591 }
592
593 #[test]
594 fn test_lognormal_mixture_sampling() {
595 let config = LogNormalMixtureConfig::typical_transactions();
596 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
597
598 let samples = sampler.sample_n(1000);
599 assert_eq!(samples.len(), 1000);
600
601 assert!(samples.iter().all(|&x| x > 0.0));
603
604 assert!(samples.iter().all(|&x| x >= 0.01));
606 }
607
608 #[test]
609 fn test_sample_with_component() {
610 let config = LogNormalMixtureConfig::new(vec![
611 LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
612 LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
613 ]);
614 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
615
616 let mut small_count = 0;
617 let mut large_count = 0;
618
619 for _ in 0..1000 {
620 let result = sampler.sample_with_component();
621 match result.component_label.as_deref() {
622 Some("small") => small_count += 1,
623 Some("large") => large_count += 1,
624 _ => panic!("Unexpected label"),
625 }
626 }
627
628 assert!(small_count > 400 && small_count < 600);
630 assert!(large_count > 400 && large_count < 600);
631 }
632
633 #[test]
634 fn test_lognormal_mixture_determinism() {
635 let config = LogNormalMixtureConfig::typical_transactions();
636
637 let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
638 let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
639
640 for _ in 0..100 {
641 assert_eq!(sampler1.sample(), sampler2.sample());
642 }
643 }
644
645 #[test]
646 fn test_lognormal_expected_value() {
647 let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
648 let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
649
650 let expected = sampler.expected_value();
652 assert!((expected - 1808.04).abs() < 1.0);
653 }
654
655 #[test]
656 fn test_component_label() {
657 let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
658 assert_eq!(component.label, Some("test_label".to_string()));
659
660 let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
661 assert_eq!(component_no_label.label, None);
662 }
663
664 #[test]
665 fn test_max_value_constraint() {
666 let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
667 config.max_value = Some(1000.0);
668
669 let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
670 let samples = sampler.sample_n(1000);
671
672 assert!(samples.iter().all(|&x| x <= 1000.0));
674 }
675}