use pounce_common::types::Number;
use rand::distributions::{Distribution, Standard};
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub struct Sampler {
rng: ChaCha8Rng,
sobol_seed: u32,
sobol_index: u32,
sobol: bool,
}
impl Sampler {
pub fn new(seed: u64, sobol: bool) -> Self {
Self {
rng: ChaCha8Rng::seed_from_u64(seed),
sobol_seed: (seed ^ (seed >> 32)) as u32,
sobol_index: 0,
sobol,
}
}
pub fn uniform(&mut self) -> f64 {
Standard.sample(&mut self.rng)
}
pub fn standard_normal(&mut self) -> f64 {
let u1: f64 = 1.0 - self.uniform();
let u2: f64 = self.uniform();
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
}
pub fn sample(
&mut self,
x0: &[Number],
lo: &[Number],
hi: &[Number],
has_box: bool,
jitter: f64,
) -> Vec<Number> {
let n = x0.len();
if has_box {
let idx = self.sobol_index;
self.sobol_index += 1;
(0..n)
.map(|j| {
let u = if self.sobol {
sobol_burley::sample(idx, j as u32, self.sobol_seed) as f64
} else {
self.uniform()
};
lo[j] + (hi[j] - lo[j]) * u
})
.collect()
} else {
(0..n)
.map(|j| x0[j] + jitter * self.standard_normal())
.collect()
}
}
pub fn perturb(&mut self, anchor: &[Number], scale: &[Number]) -> Vec<Number> {
anchor
.iter()
.enumerate()
.map(|(j, &a)| {
let s = if scale.len() == 1 { scale[0] } else { scale[j] };
a + s * self.standard_normal()
})
.collect()
}
}
pub fn clip(x: &mut [Number], lo: &[Number], hi: &[Number], has_box: bool) {
if !has_box {
return;
}
for j in 0..x.len() {
x[j] = x[j].clamp(lo[j], hi[j]);
}
}