oxirs_vec/
random_utils.rs1use scirs2_core::random::Rng;
7
8#[inline]
11pub fn sample_normal<R: Rng>(rng: &mut R, mean: f32, std: f32) -> f32 {
12 let u1: f32 = rng.random::<f32>();
13 let u2: f32 = rng.random::<f32>();
14 let z = (-2.0_f32 * u1.ln()).sqrt() * (2.0_f32 * std::f32::consts::PI * u2).cos();
15 mean + z * std
16}
17
18#[inline]
20pub fn sample_uniform<R: Rng>(rng: &mut R, low: f32, high: f32) -> f32 {
21 low + rng.random::<f32>() * (high - low)
22}
23
24pub struct NormalSampler {
26 mean: f32,
27 std: f32,
28}
29
30impl NormalSampler {
31 pub fn new(mean: f32, std: f32) -> Result<Self, String> {
32 if std <= 0.0 {
33 return Err(format!("Standard deviation must be positive, got {}", std));
34 }
35 Ok(Self { mean, std })
36 }
37
38 pub fn sample<R: Rng>(&self, rng: &mut R) -> f32 {
39 sample_normal(rng, self.mean, self.std)
40 }
41}
42
43pub struct UniformSampler {
45 low: f32,
46 high: f32,
47}
48
49impl UniformSampler {
50 pub fn new(low: f32, high: f32) -> Result<Self, String> {
51 if low >= high {
52 return Err(format!(
53 "Low must be less than high, got {} and {}",
54 low, high
55 ));
56 }
57 Ok(Self { low, high })
58 }
59
60 pub fn sample<R: Rng>(&self, rng: &mut R) -> f32 {
61 sample_uniform(rng, self.low, self.high)
62 }
63}