use rust_decimal::Decimal;
use super::mixture::{
GaussianComponent, GaussianMixtureConfig, GaussianMixtureSampler, LogNormalComponent,
LogNormalMixtureConfig, LogNormalMixtureSampler,
};
use super::pareto::{ParetoConfig, ParetoSampler};
#[derive(Clone)]
pub enum AdvancedAmountSampler {
LogNormal(LogNormalMixtureSampler),
Gaussian(GaussianMixtureSampler),
Pareto(ParetoSampler),
}
impl AdvancedAmountSampler {
pub fn new_log_normal(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
Ok(Self::LogNormal(LogNormalMixtureSampler::new(seed, config)?))
}
pub fn new_gaussian(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
Ok(Self::Gaussian(GaussianMixtureSampler::new(seed, config)?))
}
pub fn new_pareto(seed: u64, config: ParetoConfig) -> Result<Self, String> {
Ok(Self::Pareto(ParetoSampler::new(seed, config)?))
}
pub fn sample_decimal(&mut self) -> Decimal {
match self {
Self::LogNormal(s) => s.sample_decimal(),
Self::Gaussian(s) => {
let value = s.sample().max(0.0);
Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
}
Self::Pareto(s) => s.sample_decimal(),
}
}
pub fn reset(&mut self, seed: u64) {
match self {
Self::LogNormal(s) => s.reset(seed),
Self::Gaussian(s) => s.reset(seed),
Self::Pareto(s) => s.reset(seed),
}
}
pub fn ppf_decimal(&self, u: f64) -> Decimal {
match self {
Self::LogNormal(s) => s.ppf_decimal(u),
Self::Pareto(s) => s.ppf_decimal(u),
Self::Gaussian(s) => {
let v = s.ppf(u).max(0.0);
Decimal::from_f64_retain(v).unwrap_or(Decimal::ZERO)
}
}
}
}
pub fn log_normal_config_from_components(
components: Vec<(f64, f64, f64, Option<String>)>,
min_value: f64,
max_value: Option<f64>,
decimal_places: u8,
) -> LogNormalMixtureConfig {
LogNormalMixtureConfig {
components: components
.into_iter()
.map(|(w, mu, sigma, label)| match label {
Some(l) => LogNormalComponent::with_label(w, mu, sigma, l),
None => LogNormalComponent::new(w, mu, sigma),
})
.collect(),
min_value,
max_value,
decimal_places,
}
}
pub fn gaussian_config_from_components(
components: Vec<(f64, f64, f64)>,
min_value: Option<f64>,
max_value: Option<f64>,
) -> GaussianMixtureConfig {
GaussianMixtureConfig {
components: components
.into_iter()
.map(|(w, mu, sigma)| GaussianComponent::new(w, mu, sigma))
.collect(),
allow_negative: true,
min_value,
max_value,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn log_normal_sampler_produces_positive_values() {
let cfg = log_normal_config_from_components(
vec![(1.0, 7.0, 1.0, Some("test".into()))],
0.01,
None,
2,
);
let mut sampler = AdvancedAmountSampler::new_log_normal(42, cfg).unwrap();
for _ in 0..1000 {
let v = sampler.sample_decimal();
assert!(v >= Decimal::ZERO);
}
}
#[test]
fn gaussian_sampler_clamps_to_non_negative() {
let cfg = gaussian_config_from_components(vec![(1.0, 0.0, 1.0)], None, None);
let mut sampler = AdvancedAmountSampler::new_gaussian(42, cfg).unwrap();
for _ in 0..1000 {
let v = sampler.sample_decimal();
assert!(v >= Decimal::ZERO);
}
}
#[test]
fn reset_restores_determinism() {
let cfg = log_normal_config_from_components(vec![(1.0, 5.0, 1.0, None)], 0.01, None, 2);
let mut a = AdvancedAmountSampler::new_log_normal(7, cfg.clone()).unwrap();
let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
a.reset(7);
let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
assert_eq!(first, second);
}
#[test]
fn pareto_sampler_produces_heavy_tail() {
let cfg = ParetoConfig {
alpha: 1.5,
x_min: 1000.0,
max_value: None,
decimal_places: 2,
};
let mut sampler = AdvancedAmountSampler::new_pareto(42, cfg).unwrap();
let samples: Vec<Decimal> = (0..10_000).map(|_| sampler.sample_decimal()).collect();
let min_sample = samples.iter().min().unwrap();
assert!(
*min_sample >= Decimal::from(1000),
"Pareto sample {min_sample} < x_min"
);
let extreme_count = samples
.iter()
.filter(|s| **s > Decimal::from(10_000))
.count();
assert!(
extreme_count > 100,
"expected heavy tail with >100/10000 extreme samples, got {extreme_count}"
);
}
#[test]
fn pareto_reset_is_deterministic() {
let cfg = ParetoConfig {
alpha: 2.0,
x_min: 100.0,
max_value: Some(1_000_000.0),
decimal_places: 2,
};
let mut a = AdvancedAmountSampler::new_pareto(9, cfg).unwrap();
let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
a.reset(9);
let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
assert_eq!(first, second);
}
}