burn_backend/
distribution.rs1use rand::{Rng, RngCore, distr::StandardUniform};
4
5use super::element::{Element, ElementConversion};
6
7#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
9pub enum Distribution {
10 #[default]
12 Default,
13
14 Bernoulli(f64),
16
17 Uniform(f64, f64),
19
20 Normal(f64, f64),
22}
23
24#[derive(new)]
26pub struct DistributionSampler<'a, E, R>
27where
28 StandardUniform: rand::distr::Distribution<E>,
29 E: rand::distr::uniform::SampleUniform,
30 R: RngCore,
31{
32 kind: DistributionSamplerKind<E>,
33 rng: &'a mut R,
34}
35
36pub enum DistributionSamplerKind<E>
38where
39 StandardUniform: rand::distr::Distribution<E>,
40 E: rand::distr::uniform::SampleUniform,
41{
42 Standard(rand::distr::StandardUniform),
44
45 Uniform(rand::distr::Uniform<E>),
47
48 Bernoulli(rand::distr::Bernoulli),
50
51 Normal(rand_distr::Normal<f64>),
53}
54
55impl<E, R> DistributionSampler<'_, E, R>
56where
57 StandardUniform: rand::distr::Distribution<E>,
58 E: rand::distr::uniform::SampleUniform,
59 E: Element,
60 R: RngCore,
61{
62 pub fn sample(&mut self) -> E {
64 match &self.kind {
65 DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
66 DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
67 DistributionSamplerKind::Bernoulli(distribution) => {
68 if self.rng.sample(distribution) {
69 1.elem()
70 } else {
71 0.elem()
72 }
73 }
74 DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
75 }
76 }
77}
78
79impl Distribution {
80 pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
90 where
91 R: RngCore,
92 E: Element + rand::distr::uniform::SampleUniform,
93 StandardUniform: rand::distr::Distribution<E>,
94 {
95 let kind = match self {
96 Distribution::Default => {
97 DistributionSamplerKind::Standard(rand::distr::StandardUniform {})
98 }
99 Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
100 rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),
101 ),
102 Distribution::Bernoulli(prob) => {
103 DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())
104 }
105 Distribution::Normal(mean, std) => {
106 DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
107 }
108 };
109
110 DistributionSampler::new(kind, rng)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_distribution_default() {
120 let dist: Distribution = Default::default();
121
122 assert_eq!(dist, Distribution::Default);
123 assert_eq!(Distribution::default(), Distribution::Default);
124 }
125}