burn_tensor/tensor/
distribution.rs

1use rand::{distributions::Standard, Rng, RngCore};
2
3use crate::{Element, ElementConversion};
4
5/// Distribution for random value of a tensor.
6#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
7pub enum Distribution {
8    /// Uniform distribution from 0 (inclusive) to 1 (exclusive).
9    Default,
10
11    /// Bernoulli distribution with the given probability.
12    Bernoulli(f64),
13
14    /// Uniform distribution. The range is inclusive.
15    Uniform(f64, f64),
16
17    /// Normal distribution with the given mean and standard deviation.
18    Normal(f64, f64),
19}
20
21/// Distribution sampler for random value of a tensor.
22#[derive(new)]
23pub struct DistributionSampler<'a, E, R>
24where
25    Standard: rand::distributions::Distribution<E>,
26    E: rand::distributions::uniform::SampleUniform,
27    R: RngCore,
28{
29    kind: DistributionSamplerKind<E>,
30    rng: &'a mut R,
31}
32
33/// Distribution sampler kind for random value of a tensor.
34pub enum DistributionSamplerKind<E>
35where
36    Standard: rand::distributions::Distribution<E>,
37    E: rand::distributions::uniform::SampleUniform,
38{
39    /// Standard distribution.
40    Standard(rand::distributions::Standard),
41
42    /// Uniform distribution.
43    Uniform(rand::distributions::Uniform<E>),
44
45    /// Bernoulli distribution.
46    Bernoulli(rand::distributions::Bernoulli),
47
48    /// Normal distribution.
49    Normal(rand_distr::Normal<f64>),
50}
51
52impl<E, R> DistributionSampler<'_, E, R>
53where
54    Standard: rand::distributions::Distribution<E>,
55    E: rand::distributions::uniform::SampleUniform,
56    E: Element,
57    R: RngCore,
58{
59    /// Sames a random value from the distribution.
60    pub fn sample(&mut self) -> E {
61        match &self.kind {
62            DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
63            DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
64            DistributionSamplerKind::Bernoulli(distribution) => {
65                if self.rng.sample(distribution) {
66                    1.elem()
67                } else {
68                    0.elem()
69                }
70            }
71            DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
72        }
73    }
74}
75
76impl Distribution {
77    /// Creates a new distribution sampler.
78    ///
79    /// # Arguments
80    ///
81    /// * `rng` - The random number generator.
82    ///
83    /// # Returns
84    ///
85    /// The distribution sampler.
86    pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
87    where
88        R: RngCore,
89        E: Element + rand::distributions::uniform::SampleUniform,
90        Standard: rand::distributions::Distribution<E>,
91    {
92        let kind = match self {
93            Distribution::Default => {
94                DistributionSamplerKind::Standard(rand::distributions::Standard {})
95            }
96            Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
97                rand::distributions::Uniform::new(low.elem::<E>(), high.elem::<E>()),
98            ),
99            Distribution::Bernoulli(prob) => DistributionSamplerKind::Bernoulli(
100                rand::distributions::Bernoulli::new(prob).unwrap(),
101            ),
102            Distribution::Normal(mean, std) => {
103                DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
104            }
105        };
106
107        DistributionSampler::new(kind, rng)
108    }
109}