use core::num::NonZeroUsize;
use crate::distr::Distribution;
use crate::distr::uniform::{UniformSampler, UniformUsize};
#[cfg(feature = "alloc")]
use alloc::string::String;
#[derive(Debug, Clone, Copy)]
pub struct Choose<'a, T> {
slice: &'a [T],
range: UniformUsize,
num_choices: NonZeroUsize,
}
impl<'a, T> Choose<'a, T> {
pub fn new(slice: &'a [T]) -> Result<Self, Empty> {
let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?;
Ok(Self {
slice,
range: UniformUsize::new(0, num_choices.get()).unwrap(),
num_choices,
})
}
pub fn num_choices(&self) -> NonZeroUsize {
self.num_choices
}
}
impl<'a, T> Distribution<&'a T> for Choose<'a, T> {
fn sample<R: crate::Rng + ?Sized>(&self, rng: &mut R) -> &'a T {
let idx = self.range.sample(rng);
debug_assert!(
idx < self.slice.len(),
"Uniform::new(0, {}) somehow returned {}",
self.slice.len(),
idx
);
unsafe { self.slice.get_unchecked(idx) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct Empty;
impl core::fmt::Display for Empty {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Tried to create a `rand::distr::slice::Choose` with an empty slice"
)
}
}
#[cfg(feature = "std")]
impl std::error::Error for Empty {}
#[cfg(feature = "alloc")]
impl super::SampleString for Choose<'_, char> {
fn append_string<R: crate::Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) {
let max_char_len = if self.slice.len() < 200 {
self.slice
.iter()
.try_fold(1, |max_len, char| {
Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
})
.unwrap_or(4)
} else {
4
};
let mut extend_len = if max_char_len == 1 || len < 100 {
len
} else {
len / 4
};
let mut remain_len = len;
while extend_len > 0 {
string.reserve(max_char_len * extend_len);
string.extend(self.sample_iter(&mut *rng).take(extend_len));
remain_len -= extend_len;
extend_len = extend_len.min(remain_len);
}
}
}
#[cfg(test)]
mod test {
use super::*;
use core::iter;
#[test]
fn value_stability() {
let rng = crate::test::rng(651);
let slice = Choose::new(b"escaped emus explore extensively").unwrap();
let expected = b"eaxee";
assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b));
}
}