use rand::Rng;
use crate::Sample;
use super::Re;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SamplingMode {
WithoutReplacement,
WithReplacement,
}
#[derive(Clone)]
pub struct Subsample<R: Rng, F = fn(usize) -> usize>
where
F: Fn(usize) -> usize + Clone,
{
pub rng: R,
pub policy: F,
pub mode: SamplingMode,
}
impl<R: Rng> Subsample<R, fn(usize) -> usize> {
pub fn new(rng: R) -> Self {
Self {
rng,
policy: |n| (n as f64).sqrt() as usize,
mode: SamplingMode::WithoutReplacement,
}
}
}
impl<R: Rng, F> Subsample<R, F>
where
F: Fn(usize) -> usize + Clone,
{
pub fn with_policy(rng: R, policy: F) -> Self {
Self {
rng,
policy,
mode: SamplingMode::WithoutReplacement,
}
}
pub fn with_mode(mut self, mode: SamplingMode) -> Self {
self.mode = mode;
self
}
}
impl<T: Copy, R: Rng + Clone, F> Re<Sample<T>> for Subsample<R, F>
where
F: Fn(usize) -> usize + Clone,
{
type Item = Sample<T>;
fn re(&self, sample: &Sample<T>) -> impl Iterator<Item = Self::Item> {
let subsample_size = (self.policy)(sample.data.len()).min(sample.data.len());
Box::new(SubsampleIter::new(
&sample.data,
self.rng.clone(),
subsample_size,
self.mode,
))
}
}
pub struct SubsampleIter<'a, T, R: Rng> {
data: &'a [T],
rng: R,
buffer: Vec<T>,
subsample_size: usize,
mode: SamplingMode,
reservoir_idx: usize,
}
impl<'a, T: Copy, R: Rng> SubsampleIter<'a, T, R> {
fn new(data: &'a [T], rng: R, subsample_size: usize, mode: SamplingMode) -> Self {
Self {
buffer: Vec::with_capacity(subsample_size),
data,
rng,
subsample_size,
mode,
reservoir_idx: 0,
}
}
#[inline(always)]
fn sample_without_replacement(&mut self) -> Sample<T> {
let n = self.data.len();
let k = self.subsample_size;
if k == 0 {
return Sample::new(Vec::new());
}
if k < n / 4 {
self.buffer.clear();
self.buffer.reserve_exact(k);
unsafe {
self.buffer.set_len(k);
std::ptr::copy_nonoverlapping(
self.data.as_ptr(),
self.buffer.as_mut_ptr(),
k.min(n),
);
}
let mut i = k;
while i < n {
let j = self.rng.gen_range(0..=i);
if j < k {
unsafe {
*self.buffer.get_unchecked_mut(j) = *self.data.get_unchecked(i);
}
}
i += 1;
}
Sample::new(std::mem::take(&mut self.buffer))
} else {
if self.buffer.capacity() < n {
self.buffer.reserve_exact(n);
}
unsafe {
self.buffer.set_len(n);
std::ptr::copy_nonoverlapping(
self.data.as_ptr(),
self.buffer.as_mut_ptr(),
n,
);
let ptr = self.buffer.as_mut_ptr();
for i in (n - k..n).rev() {
let j = self.rng.gen_range(0..=i);
let tmp = *ptr.add(i);
*ptr.add(i) = *ptr.add(j);
*ptr.add(j) = tmp;
}
self.buffer.set_len(k);
}
Sample::new(std::mem::take(&mut self.buffer))
}
}
#[inline(always)]
fn sample_with_replacement(&mut self) -> Sample<T> {
let n = self.data.len();
let k = self.subsample_size;
self.buffer.clear();
self.buffer.reserve_exact(k);
unsafe {
self.buffer.set_len(k);
let out_ptr = self.buffer.as_mut_ptr();
let data_ptr = self.data.as_ptr();
let mut i = 0;
while i + 3 < k {
let idx0 = self.rng.gen_range(0..n);
let idx1 = self.rng.gen_range(0..n);
let idx2 = self.rng.gen_range(0..n);
let idx3 = self.rng.gen_range(0..n);
*out_ptr.add(i) = *data_ptr.add(idx0);
*out_ptr.add(i + 1) = *data_ptr.add(idx1);
*out_ptr.add(i + 2) = *data_ptr.add(idx2);
*out_ptr.add(i + 3) = *data_ptr.add(idx3);
i += 4;
}
while i < k {
let idx = self.rng.gen_range(0..n);
*out_ptr.add(i) = *data_ptr.add(idx);
i += 1;
}
}
Sample::new(std::mem::take(&mut self.buffer))
}
}
impl<'a, T: Copy, R: Rng> Iterator for SubsampleIter<'a, T, R> {
type Item = Sample<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.data.is_empty() || self.subsample_size == 0 {
return Some(Sample::new(Vec::new()));
}
Some(match self.mode {
SamplingMode::WithoutReplacement => self.sample_without_replacement(),
SamplingMode::WithReplacement => self.sample_with_replacement(),
})
}
}
pub fn sqrt_policy(n: usize) -> usize {
(n as f64).sqrt() as usize
}
pub fn log_policy(n: usize) -> usize {
(n as f64).ln().max(1.0) as usize
}
pub fn fixed_ratio_policy(ratio: f64) -> impl Fn(usize) -> usize {
move |n| ((n as f64) * ratio).max(1.0) as usize
}
pub fn fixed_size_policy(size: usize) -> impl Fn(usize) -> usize {
move |_| size
}