burn_tensor/tensor/
distribution.rs1use rand::{distributions::Standard, Rng, RngCore};
2
3use crate::{Element, ElementConversion};
4
5#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
7pub enum Distribution {
8 Default,
10
11 Bernoulli(f64),
13
14 Uniform(f64, f64),
16
17 Normal(f64, f64),
19}
20
21#[derive(new)]
23pub struct DistributionSampler<'a, E, R>
24where
25 Standard: rand::distributions::Distribution<E>,
26 E: rand::distributions::uniform::SampleUniform,
27 R: RngCore,
28{
29 kind: DistributionSamplerKind<E>,
30 rng: &'a mut R,
31}
32
33pub enum DistributionSamplerKind<E>
35where
36 Standard: rand::distributions::Distribution<E>,
37 E: rand::distributions::uniform::SampleUniform,
38{
39 Standard(rand::distributions::Standard),
41
42 Uniform(rand::distributions::Uniform<E>),
44
45 Bernoulli(rand::distributions::Bernoulli),
47
48 Normal(rand_distr::Normal<f64>),
50}
51
52impl<E, R> DistributionSampler<'_, E, R>
53where
54 Standard: rand::distributions::Distribution<E>,
55 E: rand::distributions::uniform::SampleUniform,
56 E: Element,
57 R: RngCore,
58{
59 pub fn sample(&mut self) -> E {
61 match &self.kind {
62 DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
63 DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
64 DistributionSamplerKind::Bernoulli(distribution) => {
65 if self.rng.sample(distribution) {
66 1.elem()
67 } else {
68 0.elem()
69 }
70 }
71 DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
72 }
73 }
74}
75
76impl Distribution {
77 pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
87 where
88 R: RngCore,
89 E: Element + rand::distributions::uniform::SampleUniform,
90 Standard: rand::distributions::Distribution<E>,
91 {
92 let kind = match self {
93 Distribution::Default => {
94 DistributionSamplerKind::Standard(rand::distributions::Standard {})
95 }
96 Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
97 rand::distributions::Uniform::new(low.elem::<E>(), high.elem::<E>()),
98 ),
99 Distribution::Bernoulli(prob) => DistributionSamplerKind::Bernoulli(
100 rand::distributions::Bernoulli::new(prob).unwrap(),
101 ),
102 Distribution::Normal(mean, std) => {
103 DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
104 }
105 };
106
107 DistributionSampler::new(kind, rng)
108 }
109}