1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
use crate::*;

use rand::rngs::ThreadRng;
use num_bigint::RandBigInt;
use num_integer::Integer;
use num_traits::Zero;

/// A probability distribution over values of type `T`.
pub trait Distribution<T> {
    /// samples a value from the distribution.
    fn sample(&self, rng: &mut ThreadRng) -> T;
}

/// Uniformly samples a random non-negative `Int` ...
pub struct IntDistribution {
    /// ... satisfying `_ >= start` for a non-negative `start`
    pub start: Int,
    /// ... and `_ < end`
    pub end: Int,
    /// ... and `_ % divisor == 0` for a positive `divisor`
    pub divisor: Int,
}

impl Distribution<Int> for IntDistribution {
    fn sample(&self, rng: &mut ThreadRng) -> Int {
        let start = self.start.ext();
        let end = self.end.ext();
        let divisor = self.divisor.ext();

        assert!(start >= ExtInt::zero());
        assert!(divisor > ExtInt::zero());

        let start = start.div_ceil(&divisor);
        let end = end.div_ceil(&divisor);

        assert!(start < end);

        let out = rng.gen_bigint_range(&start, &end);
        let out = out * divisor;

        Int::wrap(out)
    }
}

#[test]
fn test_int_distr() {
    let mut rng = rand::thread_rng();
    for (start, end, divisor) in [(0, 8, 4), (2, 5, 4), (0, 3, 3), (1, 4, 3)] {
        let distr = IntDistribution {
            start: start.into(),
            end: end.into(),
            divisor: divisor.into(),
        };
        for _ in 0..20 {
            let v = distr.sample(&mut rng);
            assert!(v >= distr.start);
            assert!(v < distr.end);
            assert!(v % distr.divisor == 0);
        }
    }
}