burn_tensor/tensor/
distribution.rs1use rand::{Rng, RngCore, distr::StandardUniform};
2
3use crate::{Element, ElementConversion};
4
5#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
7pub enum Distribution {
8 #[default]
10 Default,
11
12 Bernoulli(f64),
14
15 Uniform(f64, f64),
17
18 Normal(f64, f64),
20}
21
22#[derive(new)]
24pub struct DistributionSampler<'a, E, R>
25where
26 StandardUniform: rand::distr::Distribution<E>,
27 E: rand::distr::uniform::SampleUniform,
28 R: RngCore,
29{
30 kind: DistributionSamplerKind<E>,
31 rng: &'a mut R,
32}
33
34pub enum DistributionSamplerKind<E>
36where
37 StandardUniform: rand::distr::Distribution<E>,
38 E: rand::distr::uniform::SampleUniform,
39{
40 Standard(rand::distr::StandardUniform),
42
43 Uniform(rand::distr::Uniform<E>),
45
46 Bernoulli(rand::distr::Bernoulli),
48
49 Normal(rand_distr::Normal<f64>),
51}
52
53impl<E, R> DistributionSampler<'_, E, R>
54where
55 StandardUniform: rand::distr::Distribution<E>,
56 E: rand::distr::uniform::SampleUniform,
57 E: Element,
58 R: RngCore,
59{
60 pub fn sample(&mut self) -> E {
62 match &self.kind {
63 DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
64 DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
65 DistributionSamplerKind::Bernoulli(distribution) => {
66 if self.rng.sample(distribution) {
67 1.elem()
68 } else {
69 0.elem()
70 }
71 }
72 DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
73 }
74 }
75}
76
77impl Distribution {
78 pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
88 where
89 R: RngCore,
90 E: Element + rand::distr::uniform::SampleUniform,
91 StandardUniform: rand::distr::Distribution<E>,
92 {
93 let kind = match self {
94 Distribution::Default => {
95 DistributionSamplerKind::Standard(rand::distr::StandardUniform {})
96 }
97 Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
98 rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),
99 ),
100 Distribution::Bernoulli(prob) => {
101 DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())
102 }
103 Distribution::Normal(mean, std) => {
104 DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
105 }
106 };
107
108 DistributionSampler::new(kind, rng)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_distribution_default() {
118 let dist: Distribution = Default::default();
119
120 assert_eq!(dist, Distribution::Default);
121 assert_eq!(Distribution::default(), Distribution::Default);
122 }
123}