use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, Pareto};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParetoConfig {
pub alpha: f64,
pub x_min: f64,
#[serde(default)]
pub max_value: Option<f64>,
#[serde(default = "default_decimal_places")]
pub decimal_places: u8,
}
fn default_decimal_places() -> u8 {
2
}
impl Default for ParetoConfig {
fn default() -> Self {
Self {
alpha: 2.0, x_min: 100.0, max_value: None,
decimal_places: 2,
}
}
}
impl ParetoConfig {
pub fn new(alpha: f64, x_min: f64) -> Self {
Self {
alpha,
x_min,
..Default::default()
}
}
pub fn capital_expenditure() -> Self {
Self {
alpha: 1.5, x_min: 10_000.0, max_value: Some(100_000_000.0),
decimal_places: 2,
}
}
pub fn maintenance_costs() -> Self {
Self {
alpha: 2.5, x_min: 500.0, max_value: Some(500_000.0),
decimal_places: 2,
}
}
pub fn vendor_spend() -> Self {
Self {
alpha: 1.8, x_min: 1_000.0,
max_value: Some(10_000_000.0),
decimal_places: 2,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.alpha <= 0.0 {
return Err("alpha must be positive".to_string());
}
if self.x_min <= 0.0 {
return Err("x_min must be positive".to_string());
}
if let Some(max) = self.max_value {
if max <= self.x_min {
return Err("max_value must be greater than x_min".to_string());
}
}
Ok(())
}
pub fn expected_value(&self) -> Option<f64> {
if self.alpha > 1.0 {
Some(self.alpha * self.x_min / (self.alpha - 1.0))
} else {
None }
}
pub fn variance(&self) -> Option<f64> {
if self.alpha > 2.0 {
let numerator = self.x_min.powi(2) * self.alpha;
let denominator = (self.alpha - 1.0).powi(2) * (self.alpha - 2.0);
Some(numerator / denominator)
} else {
None }
}
}
pub struct ParetoSampler {
rng: ChaCha8Rng,
config: ParetoConfig,
distribution: Pareto<f64>,
decimal_multiplier: f64,
}
impl ParetoSampler {
pub fn new(seed: u64, config: ParetoConfig) -> Result<Self, String> {
config.validate()?;
let distribution = Pareto::new(config.x_min, config.alpha)
.map_err(|e| format!("Invalid Pareto distribution: {e}"))?;
let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
Ok(Self {
rng: ChaCha8Rng::seed_from_u64(seed),
config,
distribution,
decimal_multiplier,
})
}
pub fn sample(&mut self) -> f64 {
let mut value = self.distribution.sample(&mut self.rng);
if let Some(max) = self.config.max_value {
value = value.min(max);
}
(value * self.decimal_multiplier).round() / self.decimal_multiplier
}
pub fn sample_decimal(&mut self) -> Decimal {
let value = self.sample();
Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
}
pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
(0..n).map(|_| self.sample()).collect()
}
pub fn reset(&mut self, seed: u64) {
self.rng = ChaCha8Rng::seed_from_u64(seed);
}
pub fn config(&self) -> &ParetoConfig {
&self.config
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_pareto_validation() {
let config = ParetoConfig::new(2.0, 100.0);
assert!(config.validate().is_ok());
let invalid_alpha = ParetoConfig::new(-1.0, 100.0);
assert!(invalid_alpha.validate().is_err());
let invalid_xmin = ParetoConfig::new(2.0, -100.0);
assert!(invalid_xmin.validate().is_err());
}
#[test]
fn test_pareto_sampling() {
let config = ParetoConfig::new(2.0, 100.0);
let mut sampler = ParetoSampler::new(42, config).unwrap();
let samples = sampler.sample_n(1000);
assert_eq!(samples.len(), 1000);
assert!(samples.iter().all(|&x| x >= 100.0));
}
#[test]
fn test_pareto_determinism() {
let config = ParetoConfig::new(2.0, 100.0);
let mut sampler1 = ParetoSampler::new(42, config.clone()).unwrap();
let mut sampler2 = ParetoSampler::new(42, config).unwrap();
for _ in 0..100 {
assert_eq!(sampler1.sample(), sampler2.sample());
}
}
#[test]
fn test_pareto_max_constraint() {
let mut config = ParetoConfig::new(2.0, 100.0);
config.max_value = Some(1000.0);
let mut sampler = ParetoSampler::new(42, config).unwrap();
let samples = sampler.sample_n(1000);
assert!(samples.iter().all(|&x| x <= 1000.0));
}
#[test]
fn test_pareto_expected_value() {
let config = ParetoConfig::new(2.0, 100.0);
assert_eq!(config.expected_value(), Some(200.0));
let heavy_tail = ParetoConfig::new(1.0, 100.0);
assert_eq!(heavy_tail.expected_value(), None);
}
#[test]
fn test_pareto_presets() {
let capex = ParetoConfig::capital_expenditure();
assert!(capex.validate().is_ok());
assert_eq!(capex.alpha, 1.5);
let maintenance = ParetoConfig::maintenance_costs();
assert!(maintenance.validate().is_ok());
let vendor = ParetoConfig::vendor_spend();
assert!(vendor.validate().is_ok());
}
#[test]
fn test_heavy_tail_behavior() {
let config = ParetoConfig::new(1.5, 100.0);
let mut sampler = ParetoSampler::new(42, config).unwrap();
let samples = sampler.sample_n(10000);
let large_values = samples.iter().filter(|&&x| x > 1000.0).count();
assert!(
large_values > 200 && large_values < 500,
"Expected ~316 values > 1000, got {}",
large_values
);
}
}