use ferray_core::{Array, FerrayError, Ix1};
use crate::bitgen::BitGenerator;
use crate::generator::Generator;
impl<B: BitGenerator> Generator<B> {
pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
where
T: ferray_core::Element,
{
let n = arr.shape()[0];
if n <= 1 {
return Ok(());
}
let slice = arr
.as_slice_mut()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
for i in (1..n).rev() {
let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
slice.swap(i, j);
}
Ok(())
}
pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
where
T: ferray_core::Element,
{
let mut copy = arr.clone();
self.shuffle(&mut copy)?;
Ok(copy)
}
pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
if n == 0 {
return Err(FerrayError::invalid_value("n must be > 0"));
}
let mut data: Vec<i64> = (0..n as i64).collect();
for i in (1..n).rev() {
let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
data.swap(i, j);
}
Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
}
pub fn permuted<T>(
&mut self,
arr: &Array<T, Ix1>,
_axis: usize,
) -> Result<Array<T, Ix1>, FerrayError>
where
T: ferray_core::Element,
{
self.permutation(arr)
}
pub fn choice<T>(
&mut self,
arr: &Array<T, Ix1>,
size: usize,
replace: bool,
p: Option<&[f64]>,
) -> Result<Array<T, Ix1>, FerrayError>
where
T: ferray_core::Element,
{
let n = arr.shape()[0];
if size == 0 {
return Array::from_vec(Ix1::new([0]), Vec::new());
}
if n == 0 {
return Err(FerrayError::invalid_value("source array must be non-empty"));
}
if !replace && size > n {
return Err(FerrayError::invalid_value(format!(
"cannot choose {size} elements without replacement from array of size {n}"
)));
}
if let Some(probs) = p {
if probs.len() != n {
return Err(FerrayError::invalid_value(format!(
"p must have same length as array ({n}), got {}",
probs.len()
)));
}
let psum: f64 = probs.iter().sum();
if (psum - 1.0).abs() > 1e-6 {
return Err(FerrayError::invalid_value(format!(
"p must sum to 1.0, got {psum}"
)));
}
for (i, &pi) in probs.iter().enumerate() {
if pi < 0.0 {
return Err(FerrayError::invalid_value(format!(
"p[{i}] = {pi} is negative"
)));
}
}
}
let src = arr
.as_slice()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
let indices = if let Some(probs) = p {
if replace {
weighted_sample_with_replacement(&mut self.bg, probs, size)
} else {
weighted_sample_without_replacement(&mut self.bg, probs, size)?
}
} else if replace {
(0..size)
.map(|_| self.bg.next_u64_bounded(n as u64) as usize)
.collect()
} else {
sample_without_replacement(&mut self.bg, n, size)
};
let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
}
}
fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
let mut pool: Vec<usize> = (0..n).collect();
for i in 0..size {
let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
pool.swap(i, j);
}
pool[..size].to_vec()
}
fn weighted_sample_with_replacement<B: BitGenerator>(
bg: &mut B,
probs: &[f64],
size: usize,
) -> Vec<usize> {
let mut cdf = Vec::with_capacity(probs.len());
let mut cumsum = 0.0;
for &p in probs {
cumsum += p;
cdf.push(cumsum);
}
(0..size)
.map(|_| {
let u = bg.next_f64();
match cdf.binary_search_by(|c| c.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)) {
Ok(i) => i,
Err(i) => i.min(probs.len() - 1),
}
})
.collect()
}
fn weighted_sample_without_replacement<B: BitGenerator>(
bg: &mut B,
probs: &[f64],
size: usize,
) -> Result<Vec<usize>, FerrayError> {
let n = probs.len();
let mut weights: Vec<f64> = probs.to_vec();
let mut selected = Vec::with_capacity(size);
for _ in 0..size {
let total: f64 = weights.iter().sum();
if total <= 0.0 {
return Err(FerrayError::invalid_value(
"insufficient probability mass for sampling without replacement",
));
}
let u = bg.next_f64() * total;
let mut cumsum = 0.0;
let mut chosen = n - 1;
for (i, &w) in weights.iter().enumerate() {
cumsum += w;
if cumsum > u {
chosen = i;
break;
}
}
selected.push(chosen);
weights[chosen] = 0.0;
}
Ok(selected)
}
#[cfg(test)]
mod tests {
use crate::default_rng_seeded;
use ferray_core::{Array, Ix1};
#[test]
fn shuffle_preserves_elements() {
let mut rng = default_rng_seeded(42);
let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
rng.shuffle(&mut arr).unwrap();
let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
}
#[test]
fn permutation_preserves_elements() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
let perm = rng.permutation(&arr).unwrap();
let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
}
#[test]
fn permutation_range_covers_all() {
let mut rng = default_rng_seeded(42);
let perm = rng.permutation_range(10).unwrap();
let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
sorted.sort_unstable();
let expected: Vec<i64> = (0..10).collect();
assert_eq!(sorted, expected);
}
#[test]
fn shuffle_modifies_in_place() {
let mut rng = default_rng_seeded(42);
let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
rng.shuffle(&mut arr).unwrap();
let shuffled = arr.as_slice().unwrap().to_vec();
let mut sorted = shuffled;
sorted.sort_unstable();
assert_eq!(sorted, original);
}
#[test]
fn choice_with_replacement() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
let chosen = rng.choice(&arr, 10, true, None).unwrap();
assert_eq!(chosen.shape(), &[10]);
let src: Vec<i64> = vec![10, 20, 30, 40, 50];
for &v in chosen.as_slice().unwrap() {
assert!(src.contains(&v), "choice returned unexpected value {v}");
}
}
#[test]
fn choice_without_replacement_no_duplicates() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
let chosen = rng.choice(&arr, 5, false, None).unwrap();
let slice = chosen.as_slice().unwrap();
let mut seen = std::collections::HashSet::new();
for &v in slice {
assert!(
seen.insert(v),
"duplicate value {v} in choice without replacement"
);
}
}
#[test]
fn choice_without_replacement_too_many() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
assert!(rng.choice(&arr, 10, false, None).is_err());
}
#[test]
fn choice_with_weights() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
let p = [0.0, 0.0, 1.0]; let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
for &v in chosen.as_slice().unwrap() {
assert_eq!(v, 30);
}
}
#[test]
fn choice_without_replacement_with_weights() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
let p = [0.1, 0.2, 0.3, 0.2, 0.2];
let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
let slice = chosen.as_slice().unwrap();
let mut seen = std::collections::HashSet::new();
for &v in slice {
assert!(seen.insert(v), "duplicate value {v}");
}
}
#[test]
fn choice_bad_weights() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
}
#[test]
fn permuted_1d() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
let result = rng.permuted(&arr, 0).unwrap();
let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
}
}