burn_tensor/tensor/
distribution.rs

1use rand::{Rng, RngCore, distr::StandardUniform};
2
3use crate::{Element, ElementConversion};
4
5/// Distribution for random value of a tensor.
6#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
7pub enum Distribution {
8    /// Uniform distribution from 0 (inclusive) to 1 (exclusive).
9    #[default]
10    Default,
11
12    /// Bernoulli distribution with the given probability.
13    Bernoulli(f64),
14
15    /// Uniform distribution `[low, high)`.
16    Uniform(f64, f64),
17
18    /// Normal distribution with the given mean and standard deviation.
19    Normal(f64, f64),
20}
21
22/// Distribution sampler for random value of a tensor.
23#[derive(new)]
24pub struct DistributionSampler<'a, E, R>
25where
26    StandardUniform: rand::distr::Distribution<E>,
27    E: rand::distr::uniform::SampleUniform,
28    R: RngCore,
29{
30    kind: DistributionSamplerKind<E>,
31    rng: &'a mut R,
32}
33
34/// Distribution sampler kind for random value of a tensor.
35pub enum DistributionSamplerKind<E>
36where
37    StandardUniform: rand::distr::Distribution<E>,
38    E: rand::distr::uniform::SampleUniform,
39{
40    /// Standard distribution.
41    Standard(rand::distr::StandardUniform),
42
43    /// Uniform distribution.
44    Uniform(rand::distr::Uniform<E>),
45
46    /// Bernoulli distribution.
47    Bernoulli(rand::distr::Bernoulli),
48
49    /// Normal distribution.
50    Normal(rand_distr::Normal<f64>),
51}
52
53impl<E, R> DistributionSampler<'_, E, R>
54where
55    StandardUniform: rand::distr::Distribution<E>,
56    E: rand::distr::uniform::SampleUniform,
57    E: Element,
58    R: RngCore,
59{
60    /// Sames a random value from the distribution.
61    pub fn sample(&mut self) -> E {
62        match &self.kind {
63            DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
64            DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
65            DistributionSamplerKind::Bernoulli(distribution) => {
66                if self.rng.sample(distribution) {
67                    1.elem()
68                } else {
69                    0.elem()
70                }
71            }
72            DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
73        }
74    }
75}
76
77impl Distribution {
78    /// Creates a new distribution sampler.
79    ///
80    /// # Arguments
81    ///
82    /// * `rng` - The random number generator.
83    ///
84    /// # Returns
85    ///
86    /// The distribution sampler.
87    pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
88    where
89        R: RngCore,
90        E: Element + rand::distr::uniform::SampleUniform,
91        StandardUniform: rand::distr::Distribution<E>,
92    {
93        let kind = match self {
94            Distribution::Default => {
95                DistributionSamplerKind::Standard(rand::distr::StandardUniform {})
96            }
97            Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
98                rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),
99            ),
100            Distribution::Bernoulli(prob) => {
101                DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())
102            }
103            Distribution::Normal(mean, std) => {
104                DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
105            }
106        };
107
108        DistributionSampler::new(kind, rng)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_distribution_default() {
118        let dist: Distribution = Default::default();
119
120        assert_eq!(dist, Distribution::Default);
121        assert_eq!(Distribution::default(), Distribution::Default);
122    }
123}