1use rand_core::RngCore;
4use rand_distr::{
5 Beta, Binomial, Distribution, Gamma, Geometric, LogNormal, Normal, Pareto, Poisson, SkewNormal,
6 Weibull,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11use crate::Error;
12
13pub const DIST_MIN_PROBABILITY: f64 = 0.000_000_001;
17
18#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
22pub enum DistType {
23 Uniform {
25 low: f64,
27 high: f64,
29 },
30 Normal {
33 mean: f64,
35 stdev: f64,
37 },
38 SkewNormal {
41 location: f64,
43 scale: f64,
45 shape: f64,
47 },
48 LogNormal {
51 mu: f64,
53 sigma: f64,
55 },
56 Binomial {
59 trials: u64,
61 probability: f64,
63 },
64 Geometric {
66 probability: f64,
68 },
69 Pareto {
72 scale: f64,
74 shape: f64,
76 },
77 Poisson {
80 lambda: f64,
82 },
83 Weibull {
86 scale: f64,
88 shape: f64,
90 },
91 Gamma {
93 scale: f64,
95 shape: f64,
97 },
98 Beta {
100 alpha: f64,
102 beta: f64,
104 },
105}
106
107impl fmt::Display for DistType {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 write!(f, "{self:?}")
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
116pub struct Dist {
117 pub dist: DistType,
119 pub start: f64,
121 pub max: f64,
123}
124
125impl fmt::Display for Dist {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 let clamp;
128 if self.start > 0.0 && self.max > 0.0 {
129 clamp = format!(", start {}, clamped to [0.0, {}]", self.start, self.max);
130 } else if self.start > 0.0 {
131 clamp = format!(", start {}, clamped to [0.0, f64::MAX]", self.start);
132 } else if self.max > 0.0 {
133 clamp = format!(", clamped to [0.0, {}]", self.max);
134 } else {
135 clamp = ", clamped to [0.0, f64::MAX]".to_string();
136 }
137 write!(f, "{}{}", self.dist, clamp)
138 }
139}
140
141impl Default for Dist {
142 fn default() -> Self {
143 Self::new(
144 DistType::Uniform {
145 low: f64::MAX,
146 high: f64::MAX,
147 },
148 0.0,
149 0.0,
150 )
151 }
152}
153
154impl Dist {
155 pub fn new(dist: DistType, start: f64, max: f64) -> Self {
157 Dist { dist, start, max }
158 }
159
160 pub fn validate(&self) -> Result<(), Error> {
162 match self.dist {
163 DistType::Uniform { low, high } => {
164 if low.is_nan() || high.is_nan() {
165 Err(Error::Machine(
166 "for Uniform dist, got low or high as NaN".to_string(),
167 ))?;
168 }
169 if low.is_infinite() || high.is_infinite() {
170 Err(Error::Machine(
171 "for Uniform dist, got low or high as infinite".to_string(),
172 ))?;
173 }
174 if low > high {
175 Err(Error::Machine(
176 "for Uniform dist, got low > high".to_string(),
177 ))?;
178 }
179 let range = high - low;
180 if range.is_infinite() {
181 Err(Error::Machine(
182 "for Uniform dist, range hig - low overflows".to_string(),
183 ))?;
184 }
185 }
186 DistType::Normal { mean, stdev } => {
187 Normal::new(mean, stdev).map_err(|e| Error::Machine(e.to_string()))?;
188 }
189 DistType::SkewNormal {
190 location,
191 scale,
192 shape,
193 } => {
194 SkewNormal::new(location, scale, shape)
195 .map_err(|e| Error::Machine(e.to_string()))?;
196 }
197 DistType::LogNormal { mu, sigma } => {
198 LogNormal::new(mu, sigma).map_err(|e| Error::Machine(e.to_string()))?;
199 }
200 DistType::Binomial {
201 trials,
202 probability,
203 } => {
204 if probability != 0.0 && probability < DIST_MIN_PROBABILITY {
205 Err(Error::Machine(format!(
206 "for Binomial dist, probability 0.0 > {probability:?} < DIST_MIN_PROBABILITY (1e-9), error due to too slow sampling"
207 )))?;
208 }
209 if trials > 1_000_000_000 {
210 Err(Error::Machine(format!(
211 "for Binomial dist, {trials} trials > 1e9, error due to too slow sampling"
212 )))?;
213 }
214 Binomial::new(trials, probability).map_err(|e| Error::Machine(e.to_string()))?;
215 }
216 DistType::Geometric { probability } => {
217 if probability != 0.0 && probability < DIST_MIN_PROBABILITY {
218 Err(Error::Machine(format!(
219 "for Geometric dist, probability 0.0 > {probability:?} < DIST_MIN_PROBABILITY (1e-9), error due to too slow sampling"
220 )))?;
221 }
222 Geometric::new(probability).map_err(|e| Error::Machine(e.to_string()))?;
223 }
224 DistType::Pareto { scale, shape } => {
225 Pareto::new(scale, shape).map_err(|e| Error::Machine(e.to_string()))?;
226 }
227 DistType::Poisson { lambda } => {
228 if lambda > 1_000_000_000_000_000_000_000_000_000_000_000_000_000_000.0 {
229 Err(Error::Machine(format!(
230 "for Poisson dist, lambda {lambda} > 1e42, error due to too slow sampling"
231 )))?;
232 }
233 Poisson::new(lambda).map_err(|e| Error::Machine(e.to_string()))?;
234 }
235 DistType::Weibull { scale, shape } => {
236 Weibull::new(scale, shape).map_err(|e| Error::Machine(e.to_string()))?;
237 }
238 DistType::Gamma { scale, shape } => {
239 Gamma::new(shape, scale).map_err(|e| Error::Machine(e.to_string()))?;
242 }
243 DistType::Beta { alpha, beta } => {
244 Beta::new(alpha, beta).map_err(|e| Error::Machine(e.to_string()))?;
245 }
246 }
247
248 Ok(())
249 }
250
251 pub fn sample<R: RngCore>(self, rng: &mut R) -> f64 {
253 let sampled = self.dist_sample(rng);
254 let mut r: f64 = 0.0;
255 let adjusted = sampled + self.start;
256
257 if !adjusted.is_finite() {
259 return 0.0;
260 }
261
262 r = r.max(adjusted);
263 if self.max > 0.0 {
264 let clamped = r.min(self.max);
265 return if clamped.is_finite() { clamped } else { 0.0 };
267 }
268 r
269 }
270
271 fn dist_sample<R: RngCore>(self, rng: &mut R) -> f64 {
272 use rand::Rng;
273 match self.dist {
274 DistType::Uniform { low, high } => {
275 if low == high {
278 return low;
279 }
280 rng.random_range(low..high)
281 }
282 DistType::Normal { mean, stdev } => Normal::new(mean, stdev).unwrap().sample(rng),
283 DistType::SkewNormal {
284 location,
285 scale,
286 shape,
287 } => SkewNormal::new(location, scale, shape).unwrap().sample(rng),
288 DistType::LogNormal { mu, sigma } => LogNormal::new(mu, sigma).unwrap().sample(rng),
289 DistType::Binomial {
290 trials,
291 probability,
292 } => Binomial::new(trials, probability).unwrap().sample(rng) as f64,
293 DistType::Geometric { probability } => {
294 Geometric::new(probability).unwrap().sample(rng) as f64
295 }
296 DistType::Pareto { scale, shape } => Pareto::new(scale, shape).unwrap().sample(rng),
297 DistType::Poisson { lambda } => Poisson::new(lambda).unwrap().sample(rng),
298 DistType::Weibull { scale, shape } => Weibull::new(scale, shape).unwrap().sample(rng),
299 DistType::Gamma { scale, shape } => {
300 Gamma::new(shape, scale).unwrap().sample(rng)
303 }
304 DistType::Beta { alpha, beta } => Beta::new(alpha, beta).unwrap().sample(rng),
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn validate_uniform_dist() {
315 let d = Dist {
317 dist: DistType::Uniform {
318 low: 10.0,
319 high: 10.0,
320 },
321 start: 0.0,
322 max: 0.0,
323 };
324
325 let r = d.validate();
326 assert!(r.is_ok());
327
328 let d = Dist {
330 dist: DistType::Uniform {
331 low: 15.0,
332 high: 5.0,
333 },
334 start: 0.0,
335 max: 0.0,
336 };
337
338 let r = d.validate();
339 assert!(r.is_err());
340 }
341
342 #[test]
343 fn validate_normal_dist() {
344 let d = Dist {
346 dist: DistType::Normal {
347 mean: 100.0,
348 stdev: 15.0,
349 },
350 start: 0.0,
351 max: 0.0,
352 };
353
354 let r = d.validate();
355 assert!(r.is_ok());
356
357 let d = Dist {
359 dist: DistType::Normal {
360 mean: 100.0,
361 stdev: f64::INFINITY,
362 },
363 start: 0.0,
364 max: 0.0,
365 };
366
367 let r = d.validate();
368 assert!(r.is_err());
369 }
370
371 #[test]
372 fn validate_skewnormal_dist() {
373 let d = Dist {
375 dist: DistType::SkewNormal {
376 location: 100.0,
377 scale: 15.0,
378 shape: -3.0,
379 },
380 start: 0.0,
381 max: 0.0,
382 };
383
384 let r = d.validate();
385 assert!(r.is_ok());
386
387 let d = Dist {
389 dist: DistType::SkewNormal {
390 location: 100.0,
391 scale: 15.0,
392 shape: f64::INFINITY,
393 },
394 start: 0.0,
395 max: 0.0,
396 };
397
398 let r = d.validate();
399 assert!(r.is_err());
400 }
401
402 #[test]
403 fn validate_lognormal_dist() {
404 let d = Dist {
406 dist: DistType::LogNormal {
407 mu: 100.0,
408 sigma: 15.0,
409 },
410 start: 0.0,
411 max: 0.0,
412 };
413
414 let r = d.validate();
415 assert!(r.is_ok());
416
417 let d = Dist {
419 dist: DistType::LogNormal {
420 mu: 100.0,
421 sigma: f64::INFINITY,
422 },
423 start: 0.0,
424 max: 0.0,
425 };
426
427 let r = d.validate();
428 assert!(r.is_err());
429 }
430
431 #[test]
432 fn validate_binomial_dist() {
433 let d = Dist {
435 dist: DistType::Binomial {
436 trials: 10,
437 probability: 0.5,
438 },
439 start: 0.0,
440 max: 0.0,
441 };
442
443 let r = d.validate();
444 assert!(r.is_ok());
445
446 let d = Dist {
448 dist: DistType::Binomial {
449 trials: 10,
450 probability: 1.1,
451 },
452 start: 0.0,
453 max: 0.0,
454 };
455
456 let r = d.validate();
457 assert!(r.is_err());
458 }
459
460 #[test]
461 fn validate_geometric_dist() {
462 let d = Dist {
464 dist: DistType::Geometric { probability: 0.5 },
465 start: 0.0,
466 max: 0.0,
467 };
468
469 let r = d.validate();
470 assert!(r.is_ok());
471
472 let d = Dist {
474 dist: DistType::Geometric { probability: 1.1 },
475 start: 0.0,
476 max: 0.0,
477 };
478
479 let r = d.validate();
480 assert!(r.is_err());
481 }
482
483 #[test]
484 fn validate_pareto_dist() {
485 let d = Dist {
487 dist: DistType::Pareto {
488 scale: 1.0,
489 shape: 0.5,
490 },
491 start: 0.0,
492 max: 0.0,
493 };
494
495 let r = d.validate();
496 assert!(r.is_ok());
497
498 let d = Dist {
500 dist: DistType::Pareto {
501 scale: -1.0,
502 shape: 0.5,
503 },
504 start: 0.0,
505 max: 0.0,
506 };
507
508 let r = d.validate();
509 assert!(r.is_err());
510 }
511
512 #[test]
513 fn validate_poisson_dist() {
514 let d = Dist {
516 dist: DistType::Poisson { lambda: 1.0 },
517 start: 0.0,
518 max: 0.0,
519 };
520
521 let r = d.validate();
522 assert!(r.is_ok());
523
524 let d = Dist {
526 dist: DistType::Poisson { lambda: -1.0 },
527 start: 0.0,
528 max: 0.0,
529 };
530
531 let r = d.validate();
532 assert!(r.is_err());
533 }
534
535 #[test]
536 fn validate_weibull_dist() {
537 let d = Dist {
539 dist: DistType::Weibull {
540 scale: 1.0,
541 shape: 0.5,
542 },
543 start: 0.0,
544 max: 0.0,
545 };
546
547 let r = d.validate();
548 assert!(r.is_ok());
549
550 let d = Dist {
552 dist: DistType::Weibull {
553 scale: 1.0,
554 shape: -0.5,
555 },
556 start: 0.0,
557 max: 0.0,
558 };
559
560 let r = d.validate();
561 assert!(r.is_err());
562 }
563
564 #[test]
565 fn validate_gamma_dist() {
566 let d = Dist {
568 dist: DistType::Gamma {
569 scale: 1.0,
570 shape: 0.5,
571 },
572 start: 0.0,
573 max: 0.0,
574 };
575
576 let r = d.validate();
577 assert!(r.is_ok());
578
579 let d = Dist {
581 dist: DistType::Gamma {
582 scale: 1.0,
583 shape: -0.5,
584 },
585 start: 0.0,
586 max: 0.0,
587 };
588
589 let r = d.validate();
590 assert!(r.is_err());
591 }
592
593 #[test]
594 fn validate_beta_dist() {
595 let d = Dist {
597 dist: DistType::Beta {
598 alpha: 1.0,
599 beta: 0.5,
600 },
601 start: 0.0,
602 max: 0.0,
603 };
604
605 let r = d.validate();
606 assert!(r.is_ok());
607
608 let d = Dist {
610 dist: DistType::Beta {
611 alpha: 1.0,
612 beta: -0.5,
613 },
614 start: 0.0,
615 max: 0.0,
616 };
617
618 let r = d.validate();
619 assert!(r.is_err());
620 }
621
622 #[test]
623 fn sample_clamp() {
624 let d = Dist {
628 dist: DistType::Uniform {
629 low: 0.0,
630 high: 0.0,
631 },
632 start: 5.0,
633 max: 0.0,
634 };
635 assert_eq!(d.sample(&mut rand::rng()), 5.0);
636
637 let d = Dist {
639 dist: DistType::Uniform {
640 low: 10.0,
641 high: 10.0,
642 },
643 start: 0.0,
644 max: 5.0,
645 };
646 assert_eq!(d.sample(&mut rand::rng()), 5.0);
647
648 let d = Dist {
650 dist: DistType::Uniform {
651 low: -20.0,
652 high: -10.0,
653 },
654 start: 0.0,
655 max: 0.0,
656 };
657 assert_eq!(d.sample(&mut rand::rng()), 0.0);
658 }
659
660 #[test]
661 fn sample_nan_inf_robustness() {
662 let d = Dist {
669 dist: DistType::Normal {
670 mean: 0.0,
671 stdev: 1e300, },
673 start: 0.0,
674 max: 0.0,
675 };
676
677 for _ in 0..100 {
679 let sampled = d.sample(&mut rand::rng());
680 assert!(
681 sampled.is_finite(),
682 "Normal distribution with large stdev should not produce non-finite values"
683 );
684 assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
685 }
686
687 let d_pareto = Dist {
689 dist: DistType::Pareto {
690 scale: 1.0,
691 shape: 0.1, },
693 start: 0.0,
694 max: 1000.0, };
696
697 for _ in 0..100 {
698 let sampled = d_pareto.sample(&mut rand::rng());
699 assert!(
700 sampled.is_finite(),
701 "Pareto distribution should not produce non-finite values"
702 );
703 assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
704 assert!(sampled <= 1000.0, "Sample should respect maximum bound");
705 }
706
707 let d_extreme_start = Dist {
709 dist: DistType::Uniform {
710 low: 1e300,
711 high: 1e300,
712 },
713 start: 1e300, max: 0.0,
715 };
716
717 let sampled = d_extreme_start.sample(&mut rand::rng());
718 assert!(
719 sampled.is_finite(),
720 "Large start value should not produce non-finite values"
721 );
722 assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
723
724 let d_with_max = Dist {
727 dist: DistType::Uniform {
728 low: 100.0,
729 high: 200.0,
730 },
731 start: 0.0,
732 max: 50.0, };
734
735 for _ in 0..20 {
736 let sampled = d_with_max.sample(&mut rand::rng());
737 assert!(sampled.is_finite(), "Clamped sample should be finite");
738 assert!(sampled <= 50.0, "Sample should respect max bound");
739 assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
740 }
741 }
742}