use core::fmt::Display;
use alloc::{collections::BTreeSet, vec::Vec};
use ndarray::{Array, Ix1};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha12Rng;
use crate::{generators::xor_iteration, pdf::PdfGenerator, quickselect, Schedule};
use super::{Generator, Trace};
#[derive(Clone, Copy, Debug)]
pub struct RandomSampling<G: PdfGenerator<Ix1>>(G, [u8; 32]);
impl<G: PdfGenerator<Ix1>> RandomSampling<G> {
pub const fn new(pdf: G, seed: [u8; 32]) -> RandomSampling<G> {
RandomSampling(pdf, seed)
}
}
#[derive(Clone, Copy, Debug)]
pub struct RandomSamplingTrace;
impl Display for RandomSamplingTrace {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "No trace information")
}
}
impl<G: PdfGenerator<Ix1>> Generator<Ix1> for RandomSampling<G> {
fn _generate_no_trace(&self, count: usize, dims: Ix1, iteration: u64) -> Schedule<Ix1> {
let pdf = self.0.get(dims).pop().unwrap();
let mut rng = ChaCha12Rng::from_seed(xor_iteration(self.1, iteration));
let mut values = pdf
.get_distribution()
.iter()
.enumerate()
.map(|(i, v)| (rng.random::<f64>().powf(v.recip()), i))
.collect::<Vec<_>>();
quickselect(&mut rng, &mut values, |a, b| a.0.total_cmp(&b.0), count);
let mut sched = alloc::vec![false; pdf.len()];
for i in 0..count {
sched[values[i].1] = true;
}
Schedule::new(Array::from_vec(sched))
}
fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
Trace::new(
self._generate_no_trace(count, dims, iteration),
AveragingTrace,
)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Averaging<G: PdfGenerator<Ix1>> {
avg_count: usize,
random: RandomSampling<G>,
}
impl<G: PdfGenerator<Ix1>> Averaging<G> {
pub const fn new(pdf: G, avg_count: usize, seed: [u8; 32]) -> Averaging<G> {
Averaging {
avg_count,
random: RandomSampling::new(pdf, seed),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct AveragingTrace;
impl Display for AveragingTrace {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "No trace information")
}
}
impl<G: PdfGenerator<Ix1>> Generator<Ix1> for Averaging<G> {
fn _generate_no_trace(&self, count: usize, dims: Ix1, iteration: u64) -> Schedule<Ix1> {
let mut sum = alloc::vec![0; count];
for i in 0..self.avg_count {
let sched = self
.random
.generate_with_iter(count, dims, iteration + i as u64);
let mut found = 0;
for (i, item) in sched.iter().enumerate() {
if *item {
sum[found] += i;
found += 1;
}
}
}
for value in sum.iter_mut() {
*value /= self.avg_count;
}
assert!(sum.iter().collect::<BTreeSet<_>>().len() == count);
let mut result = alloc::vec![false; dims[0]];
for value in sum {
result[value] = true;
}
Schedule::new(Array::from_vec(result))
}
fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
Trace::new(
self._generate_no_trace(count, dims, iteration),
AveragingTrace,
)
}
}