datasynth_core/distributions/
pareto.rs1use rand::prelude::*;
9use rand_chacha::ChaCha8Rng;
10use rand_distr::{Distribution, Pareto};
11use rust_decimal::Decimal;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ParetoConfig {
17 pub alpha: f64,
21 pub x_min: f64,
24 #[serde(default)]
26 pub max_value: Option<f64>,
27 #[serde(default = "default_decimal_places")]
29 pub decimal_places: u8,
30}
31
32fn default_decimal_places() -> u8 {
33 2
34}
35
36impl Default for ParetoConfig {
37 fn default() -> Self {
38 Self {
39 alpha: 2.0, x_min: 100.0, max_value: None,
42 decimal_places: 2,
43 }
44 }
45}
46
47impl ParetoConfig {
48 pub fn new(alpha: f64, x_min: f64) -> Self {
50 Self {
51 alpha,
52 x_min,
53 ..Default::default()
54 }
55 }
56
57 pub fn capital_expenditure() -> Self {
59 Self {
60 alpha: 1.5, x_min: 10_000.0, max_value: Some(100_000_000.0),
63 decimal_places: 2,
64 }
65 }
66
67 pub fn maintenance_costs() -> Self {
69 Self {
70 alpha: 2.5, x_min: 500.0, max_value: Some(500_000.0),
73 decimal_places: 2,
74 }
75 }
76
77 pub fn vendor_spend() -> Self {
79 Self {
80 alpha: 1.8, x_min: 1_000.0,
82 max_value: Some(10_000_000.0),
83 decimal_places: 2,
84 }
85 }
86
87 pub fn validate(&self) -> Result<(), String> {
89 if self.alpha <= 0.0 {
90 return Err("alpha must be positive".to_string());
91 }
92 if self.x_min <= 0.0 {
93 return Err("x_min must be positive".to_string());
94 }
95 if let Some(max) = self.max_value {
96 if max <= self.x_min {
97 return Err("max_value must be greater than x_min".to_string());
98 }
99 }
100 Ok(())
101 }
102
103 pub fn expected_value(&self) -> Option<f64> {
106 if self.alpha > 1.0 {
107 Some(self.alpha * self.x_min / (self.alpha - 1.0))
108 } else {
109 None }
111 }
112
113 pub fn variance(&self) -> Option<f64> {
116 if self.alpha > 2.0 {
117 let numerator = self.x_min.powi(2) * self.alpha;
118 let denominator = (self.alpha - 1.0).powi(2) * (self.alpha - 2.0);
119 Some(numerator / denominator)
120 } else {
121 None }
123 }
124}
125
126#[derive(Clone)]
128pub struct ParetoSampler {
129 rng: ChaCha8Rng,
130 config: ParetoConfig,
131 distribution: Pareto<f64>,
132 decimal_multiplier: f64,
133}
134
135impl ParetoSampler {
136 pub fn new(seed: u64, config: ParetoConfig) -> Result<Self, String> {
138 config.validate()?;
139
140 let distribution = Pareto::new(config.x_min, config.alpha)
141 .map_err(|e| format!("Invalid Pareto distribution: {e}"))?;
142
143 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
144
145 Ok(Self {
146 rng: ChaCha8Rng::seed_from_u64(seed),
147 config,
148 distribution,
149 decimal_multiplier,
150 })
151 }
152
153 pub fn sample(&mut self) -> f64 {
155 let mut value = self.distribution.sample(&mut self.rng);
156
157 if let Some(max) = self.config.max_value {
159 value = value.min(max);
160 }
161
162 (value * self.decimal_multiplier).round() / self.decimal_multiplier
164 }
165
166 pub fn sample_decimal(&mut self) -> Decimal {
168 let value = self.sample();
169 Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
170 }
171
172 pub fn ppf(&self, u: f64) -> f64 {
178 let u = u.clamp(1e-12, 1.0 - 1e-12);
179 let mut value = self.config.x_min * (1.0 - u).powf(-1.0 / self.config.alpha);
180 if let Some(max) = self.config.max_value {
181 value = value.min(max);
182 }
183 (value * self.decimal_multiplier).round() / self.decimal_multiplier
184 }
185
186 pub fn ppf_decimal(&self, u: f64) -> Decimal {
188 Decimal::from_f64_retain(self.ppf(u)).unwrap_or(Decimal::ONE)
189 }
190
191 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
193 (0..n).map(|_| self.sample()).collect()
194 }
195
196 pub fn reset(&mut self, seed: u64) {
198 self.rng = ChaCha8Rng::seed_from_u64(seed);
199 }
200
201 pub fn config(&self) -> &ParetoConfig {
203 &self.config
204 }
205}
206
207#[cfg(test)]
208#[allow(clippy::unwrap_used)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_pareto_validation() {
214 let config = ParetoConfig::new(2.0, 100.0);
215 assert!(config.validate().is_ok());
216
217 let invalid_alpha = ParetoConfig::new(-1.0, 100.0);
218 assert!(invalid_alpha.validate().is_err());
219
220 let invalid_xmin = ParetoConfig::new(2.0, -100.0);
221 assert!(invalid_xmin.validate().is_err());
222 }
223
224 #[test]
225 fn test_pareto_sampling() {
226 let config = ParetoConfig::new(2.0, 100.0);
227 let mut sampler = ParetoSampler::new(42, config).unwrap();
228
229 let samples = sampler.sample_n(1000);
230 assert_eq!(samples.len(), 1000);
231
232 assert!(samples.iter().all(|&x| x >= 100.0));
234 }
235
236 #[test]
237 fn test_pareto_determinism() {
238 let config = ParetoConfig::new(2.0, 100.0);
239
240 let mut sampler1 = ParetoSampler::new(42, config.clone()).unwrap();
241 let mut sampler2 = ParetoSampler::new(42, config).unwrap();
242
243 for _ in 0..100 {
244 assert_eq!(sampler1.sample(), sampler2.sample());
245 }
246 }
247
248 #[test]
249 fn test_pareto_max_constraint() {
250 let mut config = ParetoConfig::new(2.0, 100.0);
251 config.max_value = Some(1000.0);
252
253 let mut sampler = ParetoSampler::new(42, config).unwrap();
254 let samples = sampler.sample_n(1000);
255
256 assert!(samples.iter().all(|&x| x <= 1000.0));
257 }
258
259 #[test]
260 fn test_pareto_expected_value() {
261 let config = ParetoConfig::new(2.0, 100.0);
262 assert_eq!(config.expected_value(), Some(200.0));
264
265 let heavy_tail = ParetoConfig::new(1.0, 100.0);
267 assert_eq!(heavy_tail.expected_value(), None);
268 }
269
270 #[test]
271 fn test_pareto_presets() {
272 let capex = ParetoConfig::capital_expenditure();
273 assert!(capex.validate().is_ok());
274 assert_eq!(capex.alpha, 1.5);
275
276 let maintenance = ParetoConfig::maintenance_costs();
277 assert!(maintenance.validate().is_ok());
278
279 let vendor = ParetoConfig::vendor_spend();
280 assert!(vendor.validate().is_ok());
281 }
282
283 #[test]
284 fn test_heavy_tail_behavior() {
285 let config = ParetoConfig::new(1.5, 100.0);
289 let mut sampler = ParetoSampler::new(42, config).unwrap();
290
291 let samples = sampler.sample_n(10000);
292 let large_values = samples.iter().filter(|&&x| x > 1000.0).count();
293
294 assert!(
297 large_values > 200 && large_values < 500,
298 "Expected ~316 values > 1000, got {}",
299 large_values
300 );
301 }
302}