maybenot/
dist.rs

1//! Distributions sampled as part of a [`State`](crate::state).
2
3use rand_core::RngCore;
4use rand_distr::{
5    Beta, Binomial, Distribution, Gamma, Geometric, LogNormal, Normal, Pareto, Poisson, SkewNormal,
6    Weibull,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11use crate::Error;
12
13/// The minimum probability of a [`Dist`](crate::dist) with a probability
14/// parameter. This is set to prevent poor sampling performance for low
15/// probabilities. Set to 1e-9.
16pub const DIST_MIN_PROBABILITY: f64 = 0.000_000_001;
17
18/// DistType represents the type of a [`Dist`]. Supports a wide range of
19/// different distributions. Some are probably useless and some are probably
20/// missing. Uses the [`rand_distr`] crate for sampling.
21#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
22pub enum DistType {
23    /// Uniformly random [low, high). If low == high, constant.
24    Uniform {
25        /// The lower bound of the distribution.
26        low: f64,
27        /// The upper bound of the distribution.
28        high: f64,
29    },
30    /// Normal distribution with set mean and standard deviation. Useful for
31    /// real-valued quantities.
32    Normal {
33        /// The mean of the distribution.
34        mean: f64,
35        /// The standard deviation of the distribution.
36        stdev: f64,
37    },
38    /// SkewNormal distribution with set location, scale, and shape. Useful for
39    /// real-valued quantities.
40    SkewNormal {
41        /// The location of the distribution.
42        location: f64,
43        /// The scale of the distribution.
44        scale: f64,
45        /// The shape of the distribution.
46        shape: f64,
47    },
48    /// LogNormal distribution with set mu and sigma. Useful for real-valued
49    /// quantities.
50    LogNormal {
51        /// The mu of the distribution.
52        mu: f64,
53        /// The sigma of the distribution.
54        sigma: f64,
55    },
56    /// Binomial distribution with set trials and probability. Useful for yes/no
57    /// events.
58    Binomial {
59        /// The number of trials.
60        trials: u64,
61        /// The probability of success.
62        probability: f64,
63    },
64    /// Geometric distribution with set probability. Useful for yes/no events.
65    Geometric {
66        /// The probability of success.
67        probability: f64,
68    },
69    /// Pareto distribution with set scale and shape. Useful for occurrence of
70    /// independent events at a given rate.
71    Pareto {
72        /// The scale of the distribution.
73        scale: f64,
74        /// The shape of the distribution.
75        shape: f64,
76    },
77    /// Poisson distribution with set lambda. Useful for occurrence of
78    /// independent events at a given rate.
79    Poisson {
80        /// The lambda of the distribution.
81        lambda: f64,
82    },
83    /// Weibull distribution with set scale and shape. Useful for occurrence of
84    /// independent events at a given rate.
85    Weibull {
86        /// The scale of the distribution.
87        scale: f64,
88        /// The shape of the distribution.
89        shape: f64,
90    },
91    /// Gamma distribution with set scale and shape.
92    Gamma {
93        /// The scale of the distribution.
94        scale: f64,
95        /// The shape of the distribution.
96        shape: f64,
97    },
98    /// Beta distribution with set alpha and beta.
99    Beta {
100        /// The alpha of the distribution.
101        alpha: f64,
102        /// The beta of the distribution.
103        beta: f64,
104    },
105}
106
107impl fmt::Display for DistType {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        write!(f, "{self:?}")
110    }
111}
112
113/// A distribution used in a [`State`](crate::state). Can be sampled to get a
114/// value. The value is clamped to the range [start, max] if both are set.
115#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
116pub struct Dist {
117    /// The type of distribution.
118    pub dist: DistType,
119    /// The starting value that the sampled value is added to.
120    pub start: f64,
121    /// The maximum value that can be sampled (including starting value).
122    pub max: f64,
123}
124
125impl fmt::Display for Dist {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        let clamp;
128        if self.start > 0.0 && self.max > 0.0 {
129            clamp = format!(", start {}, clamped to [0.0, {}]", self.start, self.max);
130        } else if self.start > 0.0 {
131            clamp = format!(", start {}, clamped to [0.0, f64::MAX]", self.start);
132        } else if self.max > 0.0 {
133            clamp = format!(", clamped to [0.0, {}]", self.max);
134        } else {
135            clamp = ", clamped to [0.0, f64::MAX]".to_string();
136        }
137        write!(f, "{}{}", self.dist, clamp)
138    }
139}
140
141impl Default for Dist {
142    fn default() -> Self {
143        Self::new(
144            DistType::Uniform {
145                low: f64::MAX,
146                high: f64::MAX,
147            },
148            0.0,
149            0.0,
150        )
151    }
152}
153
154impl Dist {
155    /// Create a new [`Dist`].
156    pub fn new(dist: DistType, start: f64, max: f64) -> Self {
157        Dist { dist, start, max }
158    }
159
160    /// Validate that the parameters are valid for the set [`DistType`].
161    pub fn validate(&self) -> Result<(), Error> {
162        match self.dist {
163            DistType::Uniform { low, high } => {
164                if low.is_nan() || high.is_nan() {
165                    Err(Error::Machine(
166                        "for Uniform dist, got low or high as NaN".to_string(),
167                    ))?;
168                }
169                if low.is_infinite() || high.is_infinite() {
170                    Err(Error::Machine(
171                        "for Uniform dist, got low or high as infinite".to_string(),
172                    ))?;
173                }
174                if low > high {
175                    Err(Error::Machine(
176                        "for Uniform dist, got low > high".to_string(),
177                    ))?;
178                }
179                let range = high - low;
180                if range.is_infinite() {
181                    Err(Error::Machine(
182                        "for Uniform dist, range hig - low overflows".to_string(),
183                    ))?;
184                }
185            }
186            DistType::Normal { mean, stdev } => {
187                Normal::new(mean, stdev).map_err(|e| Error::Machine(e.to_string()))?;
188            }
189            DistType::SkewNormal {
190                location,
191                scale,
192                shape,
193            } => {
194                SkewNormal::new(location, scale, shape)
195                    .map_err(|e| Error::Machine(e.to_string()))?;
196            }
197            DistType::LogNormal { mu, sigma } => {
198                LogNormal::new(mu, sigma).map_err(|e| Error::Machine(e.to_string()))?;
199            }
200            DistType::Binomial {
201                trials,
202                probability,
203            } => {
204                if probability != 0.0 && probability < DIST_MIN_PROBABILITY {
205                    Err(Error::Machine(format!(
206                        "for Binomial dist, probability 0.0 > {probability:?} < DIST_MIN_PROBABILITY (1e-9), error due to too slow sampling"
207                    )))?;
208                }
209                if trials > 1_000_000_000 {
210                    Err(Error::Machine(format!(
211                        "for Binomial dist, {trials} trials > 1e9, error due to too slow sampling"
212                    )))?;
213                }
214                Binomial::new(trials, probability).map_err(|e| Error::Machine(e.to_string()))?;
215            }
216            DistType::Geometric { probability } => {
217                if probability != 0.0 && probability < DIST_MIN_PROBABILITY {
218                    Err(Error::Machine(format!(
219                        "for Geometric dist, probability 0.0 > {probability:?} < DIST_MIN_PROBABILITY (1e-9), error due to too slow sampling"
220                    )))?;
221                }
222                Geometric::new(probability).map_err(|e| Error::Machine(e.to_string()))?;
223            }
224            DistType::Pareto { scale, shape } => {
225                Pareto::new(scale, shape).map_err(|e| Error::Machine(e.to_string()))?;
226            }
227            DistType::Poisson { lambda } => {
228                if lambda > 1_000_000_000_000_000_000_000_000_000_000_000_000_000_000.0 {
229                    Err(Error::Machine(format!(
230                        "for Poisson dist, lambda {lambda} > 1e42, error due to too slow sampling"
231                    )))?;
232                }
233                Poisson::new(lambda).map_err(|e| Error::Machine(e.to_string()))?;
234            }
235            DistType::Weibull { scale, shape } => {
236                Weibull::new(scale, shape).map_err(|e| Error::Machine(e.to_string()))?;
237            }
238            DistType::Gamma { scale, shape } => {
239                // note order below in inverse from others for some reason in
240                // rand_distr
241                Gamma::new(shape, scale).map_err(|e| Error::Machine(e.to_string()))?;
242            }
243            DistType::Beta { alpha, beta } => {
244                Beta::new(alpha, beta).map_err(|e| Error::Machine(e.to_string()))?;
245            }
246        }
247
248        Ok(())
249    }
250
251    /// Sample the distribution. May panic if not valid (see [`Self::validate()`]).
252    pub fn sample<R: RngCore>(self, rng: &mut R) -> f64 {
253        let sampled = self.dist_sample(rng);
254        let mut r: f64 = 0.0;
255        let adjusted = sampled + self.start;
256
257        // Ensure the addition didn't produce NaN/inf (also catches NaN/inf from sampled)
258        if !adjusted.is_finite() {
259            return 0.0;
260        }
261
262        r = r.max(adjusted);
263        if self.max > 0.0 {
264            let clamped = r.min(self.max);
265            // Final safety check in case min() produced NaN
266            return if clamped.is_finite() { clamped } else { 0.0 };
267        }
268        r
269    }
270
271    fn dist_sample<R: RngCore>(self, rng: &mut R) -> f64 {
272        use rand::Rng;
273        match self.dist {
274            DistType::Uniform { low, high } => {
275                // special common case for handcrafted machines, also not
276                // supported by rand_dist::Uniform
277                if low == high {
278                    return low;
279                }
280                rng.random_range(low..high)
281            }
282            DistType::Normal { mean, stdev } => Normal::new(mean, stdev).unwrap().sample(rng),
283            DistType::SkewNormal {
284                location,
285                scale,
286                shape,
287            } => SkewNormal::new(location, scale, shape).unwrap().sample(rng),
288            DistType::LogNormal { mu, sigma } => LogNormal::new(mu, sigma).unwrap().sample(rng),
289            DistType::Binomial {
290                trials,
291                probability,
292            } => Binomial::new(trials, probability).unwrap().sample(rng) as f64,
293            DistType::Geometric { probability } => {
294                Geometric::new(probability).unwrap().sample(rng) as f64
295            }
296            DistType::Pareto { scale, shape } => Pareto::new(scale, shape).unwrap().sample(rng),
297            DistType::Poisson { lambda } => Poisson::new(lambda).unwrap().sample(rng),
298            DistType::Weibull { scale, shape } => Weibull::new(scale, shape).unwrap().sample(rng),
299            DistType::Gamma { scale, shape } => {
300                // note order below inverted from others for some reason in
301                // rand_distr
302                Gamma::new(shape, scale).unwrap().sample(rng)
303            }
304            DistType::Beta { alpha, beta } => Beta::new(alpha, beta).unwrap().sample(rng),
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn validate_uniform_dist() {
315        // valid dist
316        let d = Dist {
317            dist: DistType::Uniform {
318                low: 10.0,
319                high: 10.0,
320            },
321            start: 0.0,
322            max: 0.0,
323        };
324
325        let r = d.validate();
326        assert!(r.is_ok());
327
328        // dist with low > high
329        let d = Dist {
330            dist: DistType::Uniform {
331                low: 15.0,
332                high: 5.0,
333            },
334            start: 0.0,
335            max: 0.0,
336        };
337
338        let r = d.validate();
339        assert!(r.is_err());
340    }
341
342    #[test]
343    fn validate_normal_dist() {
344        // valid dist
345        let d = Dist {
346            dist: DistType::Normal {
347                mean: 100.0,
348                stdev: 15.0,
349            },
350            start: 0.0,
351            max: 0.0,
352        };
353
354        let r = d.validate();
355        assert!(r.is_ok());
356
357        // dist with infinite variance
358        let d = Dist {
359            dist: DistType::Normal {
360                mean: 100.0,
361                stdev: f64::INFINITY,
362            },
363            start: 0.0,
364            max: 0.0,
365        };
366
367        let r = d.validate();
368        assert!(r.is_err());
369    }
370
371    #[test]
372    fn validate_skewnormal_dist() {
373        // valid dist
374        let d = Dist {
375            dist: DistType::SkewNormal {
376                location: 100.0,
377                scale: 15.0,
378                shape: -3.0,
379            },
380            start: 0.0,
381            max: 0.0,
382        };
383
384        let r = d.validate();
385        assert!(r.is_ok());
386
387        // dist with infinite shape
388        let d = Dist {
389            dist: DistType::SkewNormal {
390                location: 100.0,
391                scale: 15.0,
392                shape: f64::INFINITY,
393            },
394            start: 0.0,
395            max: 0.0,
396        };
397
398        let r = d.validate();
399        assert!(r.is_err());
400    }
401
402    #[test]
403    fn validate_lognormal_dist() {
404        // valid dist
405        let d = Dist {
406            dist: DistType::LogNormal {
407                mu: 100.0,
408                sigma: 15.0,
409            },
410            start: 0.0,
411            max: 0.0,
412        };
413
414        let r = d.validate();
415        assert!(r.is_ok());
416
417        // dist with infinite variance
418        let d = Dist {
419            dist: DistType::LogNormal {
420                mu: 100.0,
421                sigma: f64::INFINITY,
422            },
423            start: 0.0,
424            max: 0.0,
425        };
426
427        let r = d.validate();
428        assert!(r.is_err());
429    }
430
431    #[test]
432    fn validate_binomial_dist() {
433        // valid dist
434        let d = Dist {
435            dist: DistType::Binomial {
436                trials: 10,
437                probability: 0.5,
438            },
439            start: 0.0,
440            max: 0.0,
441        };
442
443        let r = d.validate();
444        assert!(r.is_ok());
445
446        // dist with invalid probability
447        let d = Dist {
448            dist: DistType::Binomial {
449                trials: 10,
450                probability: 1.1,
451            },
452            start: 0.0,
453            max: 0.0,
454        };
455
456        let r = d.validate();
457        assert!(r.is_err());
458    }
459
460    #[test]
461    fn validate_geometric_dist() {
462        // valid dist
463        let d = Dist {
464            dist: DistType::Geometric { probability: 0.5 },
465            start: 0.0,
466            max: 0.0,
467        };
468
469        let r = d.validate();
470        assert!(r.is_ok());
471
472        // dist with invalid probability
473        let d = Dist {
474            dist: DistType::Geometric { probability: 1.1 },
475            start: 0.0,
476            max: 0.0,
477        };
478
479        let r = d.validate();
480        assert!(r.is_err());
481    }
482
483    #[test]
484    fn validate_pareto_dist() {
485        // valid dist
486        let d = Dist {
487            dist: DistType::Pareto {
488                scale: 1.0,
489                shape: 0.5,
490            },
491            start: 0.0,
492            max: 0.0,
493        };
494
495        let r = d.validate();
496        assert!(r.is_ok());
497
498        // dist with negative scale
499        let d = Dist {
500            dist: DistType::Pareto {
501                scale: -1.0,
502                shape: 0.5,
503            },
504            start: 0.0,
505            max: 0.0,
506        };
507
508        let r = d.validate();
509        assert!(r.is_err());
510    }
511
512    #[test]
513    fn validate_poisson_dist() {
514        // valid dist
515        let d = Dist {
516            dist: DistType::Poisson { lambda: 1.0 },
517            start: 0.0,
518            max: 0.0,
519        };
520
521        let r = d.validate();
522        assert!(r.is_ok());
523
524        // dist with negative lambda
525        let d = Dist {
526            dist: DistType::Poisson { lambda: -1.0 },
527            start: 0.0,
528            max: 0.0,
529        };
530
531        let r = d.validate();
532        assert!(r.is_err());
533    }
534
535    #[test]
536    fn validate_weibull_dist() {
537        // valid dist
538        let d = Dist {
539            dist: DistType::Weibull {
540                scale: 1.0,
541                shape: 0.5,
542            },
543            start: 0.0,
544            max: 0.0,
545        };
546
547        let r = d.validate();
548        assert!(r.is_ok());
549
550        // dist with negative shape
551        let d = Dist {
552            dist: DistType::Weibull {
553                scale: 1.0,
554                shape: -0.5,
555            },
556            start: 0.0,
557            max: 0.0,
558        };
559
560        let r = d.validate();
561        assert!(r.is_err());
562    }
563
564    #[test]
565    fn validate_gamma_dist() {
566        // valid dist
567        let d = Dist {
568            dist: DistType::Gamma {
569                scale: 1.0,
570                shape: 0.5,
571            },
572            start: 0.0,
573            max: 0.0,
574        };
575
576        let r = d.validate();
577        assert!(r.is_ok());
578
579        // dist with negative shape
580        let d = Dist {
581            dist: DistType::Gamma {
582                scale: 1.0,
583                shape: -0.5,
584            },
585            start: 0.0,
586            max: 0.0,
587        };
588
589        let r = d.validate();
590        assert!(r.is_err());
591    }
592
593    #[test]
594    fn validate_beta_dist() {
595        // valid dist
596        let d = Dist {
597            dist: DistType::Beta {
598                alpha: 1.0,
599                beta: 0.5,
600            },
601            start: 0.0,
602            max: 0.0,
603        };
604
605        let r = d.validate();
606        assert!(r.is_ok());
607
608        // dist with negative beta
609        let d = Dist {
610            dist: DistType::Beta {
611                alpha: 1.0,
612                beta: -0.5,
613            },
614            start: 0.0,
615            max: 0.0,
616        };
617
618        let r = d.validate();
619        assert!(r.is_err());
620    }
621
622    #[test]
623    fn sample_clamp() {
624        // make sure start and max are applied
625
626        // start: uniform 0, ensure sampled value is != 0
627        let d = Dist {
628            dist: DistType::Uniform {
629                low: 0.0,
630                high: 0.0,
631            },
632            start: 5.0,
633            max: 0.0,
634        };
635        assert_eq!(d.sample(&mut rand::rng()), 5.0);
636
637        // max: uniform 10, ensure sampled value is < 10
638        let d = Dist {
639            dist: DistType::Uniform {
640                low: 10.0,
641                high: 10.0,
642            },
643            start: 0.0,
644            max: 5.0,
645        };
646        assert_eq!(d.sample(&mut rand::rng()), 5.0);
647
648        // finally, make sure values < 0.0 cannot be sampled
649        let d = Dist {
650            dist: DistType::Uniform {
651                low: -20.0,
652                high: -10.0,
653            },
654            start: 0.0,
655            max: 0.0,
656        };
657        assert_eq!(d.sample(&mut rand::rng()), 0.0);
658    }
659
660    #[test]
661    fn sample_nan_inf_robustness() {
662        // Test handling of distributions that could potentially produce problematic values
663
664        // Test with extreme parameter combinations that might cause numerical issues
665        // Note: These would be caught by validate(), but we test the sampling robustness
666
667        // Test with a distribution that has valid parameters but might produce edge case values
668        let d = Dist {
669            dist: DistType::Normal {
670                mean: 0.0,
671                stdev: 1e300, // Very large standard deviation (still passes validation)
672            },
673            start: 0.0,
674            max: 0.0,
675        };
676
677        // Sample multiple times to increase chance of hitting edge cases
678        for _ in 0..100 {
679            let sampled = d.sample(&mut rand::rng());
680            assert!(
681                sampled.is_finite(),
682                "Normal distribution with large stdev should not produce non-finite values"
683            );
684            assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
685        }
686
687        // Test with Pareto distribution (known for heavy tails)
688        let d_pareto = Dist {
689            dist: DistType::Pareto {
690                scale: 1.0,
691                shape: 0.1, // Very small shape parameter creates heavy tail
692            },
693            start: 0.0,
694            max: 1000.0, // Clamp to prevent extreme values
695        };
696
697        for _ in 0..100 {
698            let sampled = d_pareto.sample(&mut rand::rng());
699            assert!(
700                sampled.is_finite(),
701                "Pareto distribution should not produce non-finite values"
702            );
703            assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
704            assert!(sampled <= 1000.0, "Sample should respect maximum bound");
705        }
706
707        // Test with extreme start value that could cause overflow
708        let d_extreme_start = Dist {
709            dist: DistType::Uniform {
710                low: 1e300,
711                high: 1e300,
712            },
713            start: 1e300, // Adding two very large numbers
714            max: 0.0,
715        };
716
717        let sampled = d_extreme_start.sample(&mut rand::rng());
718        assert!(
719            sampled.is_finite(),
720            "Large start value should not produce non-finite values"
721        );
722        assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
723
724        // Test with NaN-producing scenario (if we could construct one, but validation prevents this)
725        // Instead, test that our robustness handles the clamping correctly
726        let d_with_max = Dist {
727            dist: DistType::Uniform {
728                low: 100.0,
729                high: 200.0,
730            },
731            start: 0.0,
732            max: 50.0, // Max smaller than possible samples
733        };
734
735        for _ in 0..20 {
736            let sampled = d_with_max.sample(&mut rand::rng());
737            assert!(sampled.is_finite(), "Clamped sample should be finite");
738            assert!(sampled <= 50.0, "Sample should respect max bound");
739            assert!(sampled >= 0.0, "Sample should respect minimum bound of 0.0");
740        }
741    }
742}