1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
//! Distribution utilities
use rand::distributions::{Bernoulli as BoolBernoulli, Distribution};
use rand::prelude::*;
use std::cmp::PartialOrd;
use std::convert::Infallible;

/// A (batch of) distribution(s) over multi-dimensional arrays.
pub trait ArrayDistribution<E, T> {
    /// The batch shape of distributions.
    fn batch_shape(&self) -> Vec<usize>;

    /// The shape of an element.
    fn element_shape(&self) -> Vec<usize>;

    /// Sample a batch of elements.
    ///
    /// # Returns
    /// An array of shape `[BATCH_SHAPE..., ELEMENT_SHAPE...]`
    fn sample(&self) -> E;

    /// Log probabilities of the given elements
    ///
    /// # Args
    /// * `elements` - Elements from the distribution domains. One per distribution.
    ///                An array with shape `[BATCH_SHAPE..., ELEMENT_SHAPE...]`.
    ///
    /// # Returns
    /// An array of log probabilities with shape `[BATCH_SHAPE...]`.
    fn log_probs(&self, elements: &E) -> T;

    /// Distribution entropies.
    ///
    /// # Returns
    /// An array of entropies with shape `[BATCH_SHAPE...]`.
    fn entropy(&self) -> T;

    /// The KL divergence (relative entropy) from another batch of distributions.
    ///
    /// `KL(self || other)`
    ///
    /// # Args
    /// * `other` - A batch of distributions with the same (or broadcastable) batch shape.
    ///
    /// # Returns
    /// An array of KL divergences `KL(self[i] || other[i])` with shape `[BATCH_SHAPE...]`.
    fn kl_divergence_from(&self, other: &Self) -> T;
}

/// Distributions that can be constructed from a mean.
pub trait FromMean<T>
where
    Self: Sized,
{
    type Error;

    #[allow(clippy::missing_errors_doc)]
    /// Construct a distribution having the given mean
    fn from_mean(mean: T) -> Result<Self, Self::Error>;
}

/// Bounds on a scalar value
pub trait Bounded<T: PartialOrd> {
    /// Minimum and maximum values (inclusive).
    ///
    /// Values x must satisfy min <= x && x <= max.
    /// If max < min then the interval is empty.
    fn bounds(&self) -> (T, T);
}

/// A determistic distribution.
///
/// Always produces the same value when sampled.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Deterministic<T>(T);

impl<T> Deterministic<T> {
    pub const fn new(value: T) -> Self {
        Self(value)
    }
}

impl<T: Copy> Distribution<T> for Deterministic<T> {
    fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> T {
        self.0
    }
}
impl Bounded<f64> for Deterministic<f64> {
    fn bounds(&self) -> (f64, f64) {
        (self.0, self.0)
    }
}
impl<T> FromMean<T> for Deterministic<T> {
    type Error = Infallible;

    fn from_mean(mean: T) -> Result<Self, Self::Error> {
        Ok(Self::new(mean))
    }
}

/// Bernoulli distribution that can sample floats
#[derive(Debug, Clone, Copy)]
pub struct Bernoulli(BoolBernoulli);
impl Bernoulli {
    /// Create a new `Bernoulli` instance.
    ///
    /// # Errors
    /// If `mean` is not in `[0, 1]`.
    pub fn new(mean: f64) -> Result<Self, rand::distributions::BernoulliError> {
        Ok(Self(BoolBernoulli::new(mean)?))
    }
}
impl Distribution<f64> for Bernoulli {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
        // TODO: Benchmark vs cast
        if self.0.sample(rng) {
            1.0
        } else {
            0.0
        }
    }
}
impl Bounded<f64> for Bernoulli {
    fn bounds(&self) -> (f64, f64) {
        (0.0, 1.0)
    }
}
impl FromMean<f64> for Bernoulli {
    type Error = rand::distributions::BernoulliError;

    fn from_mean(mean: f64) -> Result<Self, Self::Error> {
        Self::new(mean)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Asserts that distribution samples do not violate its reported bounds
    fn check_bounded<T: PartialOrd, D: Distribution<T> + Bounded<T>>(
        d: &D,
        num_samples: usize,
        seed: u64,
    ) {
        let (lower_bound, upper_bound) = d.bounds();
        let mut rng = StdRng::seed_from_u64(seed);
        for _ in 0..num_samples {
            let x: T = d.sample(&mut rng);
            assert!(x >= lower_bound);
            assert!(x <= upper_bound);
        }
    }

    /// Asserts that the distribution empirical mean is close to the given expected value.
    ///
    /// # Args
    /// * `d` - The distribution to test. d - E[d] must be sub-gaussian.
    /// * `expected_mean` - The expected value for the mean.
    /// * `num_samples` - Number of samples to generate.
    /// * `stddev_upper_bound` - An upper bound on the standard deviation of a single sample.
    /// * `seed` - Random seed.
    fn check_mean<D: Distribution<f64> + FromMean<f64>>(
        d: &D,
        expected_mean: f64,
        num_samples: usize,
        stddev_upper_bound: f64,
        seed: u64,
    ) {
        let rng = StdRng::seed_from_u64(seed);
        let empirical_mean = Distribution::<f64>::sample_iter(&d, rng)
            .take(num_samples)
            .sum::<f64>()
            / (num_samples as f64);
        // Want to be close enough to the mean to have false positive probability < 1e-5
        let false_positive_prob: f64 = 1e-5;
        let error_bound =
            stddev_upper_bound * (-2.0 / (num_samples as f64) * false_positive_prob.ln()).sqrt();
        // Make sure our error bound isn't huge
        assert!(
            error_bound < 0.1 || error_bound < 0.1 * expected_mean.abs(),
            "Use more samples"
        );
        assert!((empirical_mean - expected_mean) < error_bound);
    }

    #[test]
    fn deterministic_bounded() {
        check_bounded(&Deterministic::new(0.7), 1000, 1);
    }

    #[test]
    fn deterministic_mean() {
        check_mean(&Deterministic::new(0.7), 0.7, 100, 1e-6, 2);
    }

    #[test]
    fn bernoulli_bounded() {
        check_bounded(&Bernoulli::new(0.7).unwrap(), 1000, 1);
    }

    #[test]
    fn bernoulli_mean() {
        let p: f64 = 0.7;
        let stddev = (p * (1.0 - p)).sqrt();
        check_mean(&Bernoulli::new(p).unwrap(), p, 1000, stddev, 2);
    }
}