burn_backend/
distribution.rs

1//! Random value distributions used to initialize and populate tensor data.
2
3use rand::{Rng, RngCore, distr::StandardUniform};
4
5use super::element::{Element, ElementConversion};
6
7/// Distribution for random value of a tensor.
8#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
9pub enum Distribution {
10    /// Uniform distribution from 0 (inclusive) to 1 (exclusive).
11    #[default]
12    Default,
13
14    /// Bernoulli distribution with the given probability.
15    Bernoulli(f64),
16
17    /// Uniform distribution `[low, high)`.
18    Uniform(f64, f64),
19
20    /// Normal distribution with the given mean and standard deviation.
21    Normal(f64, f64),
22}
23
24/// Distribution sampler for random value of a tensor.
25#[derive(new)]
26pub struct DistributionSampler<'a, E, R>
27where
28    StandardUniform: rand::distr::Distribution<E>,
29    E: rand::distr::uniform::SampleUniform,
30    R: RngCore,
31{
32    kind: DistributionSamplerKind<E>,
33    rng: &'a mut R,
34}
35
36/// Distribution sampler kind for random value of a tensor.
37pub enum DistributionSamplerKind<E>
38where
39    StandardUniform: rand::distr::Distribution<E>,
40    E: rand::distr::uniform::SampleUniform,
41{
42    /// Standard distribution.
43    Standard(rand::distr::StandardUniform),
44
45    /// Uniform distribution.
46    Uniform(rand::distr::Uniform<E>),
47
48    /// Bernoulli distribution.
49    Bernoulli(rand::distr::Bernoulli),
50
51    /// Normal distribution.
52    Normal(rand_distr::Normal<f64>),
53}
54
55impl<E, R> DistributionSampler<'_, E, R>
56where
57    StandardUniform: rand::distr::Distribution<E>,
58    E: rand::distr::uniform::SampleUniform,
59    E: Element,
60    R: RngCore,
61{
62    /// Sames a random value from the distribution.
63    pub fn sample(&mut self) -> E {
64        match &self.kind {
65            DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
66            DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
67            DistributionSamplerKind::Bernoulli(distribution) => {
68                if self.rng.sample(distribution) {
69                    1.elem()
70                } else {
71                    0.elem()
72                }
73            }
74            DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
75        }
76    }
77}
78
79impl Distribution {
80    /// Creates a new distribution sampler.
81    ///
82    /// # Arguments
83    ///
84    /// * `rng` - The random number generator.
85    ///
86    /// # Returns
87    ///
88    /// The distribution sampler.
89    pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
90    where
91        R: RngCore,
92        E: Element + rand::distr::uniform::SampleUniform,
93        StandardUniform: rand::distr::Distribution<E>,
94    {
95        let kind = match self {
96            Distribution::Default => {
97                DistributionSamplerKind::Standard(rand::distr::StandardUniform {})
98            }
99            Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
100                rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),
101            ),
102            Distribution::Bernoulli(prob) => {
103                DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())
104            }
105            Distribution::Normal(mean, std) => {
106                DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
107            }
108        };
109
110        DistributionSampler::new(kind, rng)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_distribution_default() {
120        let dist: Distribution = Default::default();
121
122        assert_eq!(dist, Distribution::Default);
123        assert_eq!(Distribution::default(), Distribution::Default);
124    }
125}