use crate::rng::Rng32;
use crate::sampler::Sampler32;
#[derive(Debug)]
pub struct Bst32<'a, R: Rng32 + 'a> {
rng: &'a mut R,
cumulative: Vec<f32>,
}
impl<'a, R: Rng32 + 'a> Bst32<'a, R> {
pub fn new(rng: &'a mut R, weights: &[f32]) -> Self {
Self {
rng,
cumulative: Self::build_cumulative(weights),
}
}
fn build_cumulative(weights: &[f32]) -> Vec<f32> {
let mut cumulative = Vec::with_capacity(weights.len());
let mut sum = 0.0f32;
for &w in weights {
sum += w;
cumulative.push(sum);
}
cumulative
}
}
impl<'a, R: Rng32 + 'a> Sampler32<'a, R> for Bst32<'a, R> {
fn weights(&mut self, weights: &[f32]) {
self.cumulative = Self::build_cumulative(weights);
}
fn sample(&mut self) -> usize {
let total = *self.cumulative.last().unwrap_or(&0.0);
let r = self.rng.randf(0.0, total);
match self
.cumulative
.binary_search_by(|&c| c.partial_cmp(&r).unwrap())
{
Ok(i) | Err(i) => i.min(self.cumulative.len() - 1),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rng32::Mt19937;
use crate::sampler::Sampler32;
#[test]
fn bst32_works() {
let mut rng = Mt19937::new(1);
let mut sampler = Bst32::new(&mut rng, &[1.0f32, 9.0f32]);
assert_eq!(sampler.sample(), 1);
assert_eq!(sampler.sample(), 1);
assert_eq!(sampler.sample(), 0);
assert_eq!(sampler.sample(), 1);
assert_eq!(sampler.sample(), 1);
}
}