use ferray_core::{Array, FerrayError, Ix1, IxDyn};
use crate::bitgen::BitGenerator;
use crate::generator::Generator;
impl<B: BitGenerator> Generator<B> {
pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
where
T: ferray_core::Element,
{
let n = arr.shape()[0];
if n <= 1 {
return Ok(());
}
let slice = arr
.as_slice_mut()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
for i in (1..n).rev() {
let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
slice.swap(i, j);
}
Ok(())
}
pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
where
T: ferray_core::Element,
{
let mut copy = arr.clone();
self.shuffle(&mut copy)?;
Ok(copy)
}
pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
if n == 0 {
return Err(FerrayError::invalid_value("n must be > 0"));
}
let mut data: Vec<i64> = (0..n as i64).collect();
for i in (1..n).rev() {
let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
data.swap(i, j);
}
Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
}
pub fn permuted<T>(
&mut self,
arr: &Array<T, Ix1>,
_axis: usize,
) -> Result<Array<T, Ix1>, FerrayError>
where
T: ferray_core::Element,
{
self.permutation(arr)
}
pub fn shuffle_dyn<T>(
&mut self,
arr: &mut Array<T, IxDyn>,
axis: usize,
) -> Result<(), FerrayError>
where
T: ferray_core::Element,
{
let shape = arr.shape().to_vec();
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let n = shape[axis];
if n <= 1 {
return Ok(());
}
let inner_stride: usize = shape[axis + 1..].iter().product();
let block = n * inner_stride;
let outer_size: usize = shape[..axis].iter().product();
let slice = arr
.as_slice_mut()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
for i in (1..n).rev() {
let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
if i == j {
continue;
}
for o in 0..outer_size {
let base = o * block;
for k in 0..inner_stride {
slice.swap(base + i * inner_stride + k, base + j * inner_stride + k);
}
}
}
Ok(())
}
pub fn choice_dyn<T>(
&mut self,
arr: &Array<T, IxDyn>,
size: usize,
replace: bool,
p: Option<&[f64]>,
axis: usize,
shuffle: bool,
) -> Result<Array<T, IxDyn>, FerrayError>
where
T: ferray_core::Element,
{
let shape = arr.shape().to_vec();
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let axis_len = shape[axis];
if size == 0 {
let mut out_shape = shape;
out_shape[axis] = 0;
return Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), Vec::new());
}
if axis_len == 0 {
return Err(FerrayError::invalid_value(
"choice_dyn: source array has zero length along axis",
));
}
if !replace && size > axis_len {
return Err(FerrayError::invalid_value(format!(
"cannot choose {size} elements without replacement from axis of size {axis_len}"
)));
}
if let Some(probs) = p {
if probs.len() != axis_len {
return Err(FerrayError::invalid_value(format!(
"p must have length {axis_len} (size of axis {axis}), got {}",
probs.len()
)));
}
let psum: f64 = probs.iter().sum();
if (psum - 1.0).abs() > 1e-6 {
return Err(FerrayError::invalid_value(format!(
"p must sum to 1.0, got {psum}"
)));
}
for (i, &pi) in probs.iter().enumerate() {
if pi < 0.0 {
return Err(FerrayError::invalid_value(format!(
"p[{i}] = {pi} is negative"
)));
}
}
}
let src = arr
.as_slice()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous for choice_dyn"))?;
let mut indices = if let Some(probs) = p {
if replace {
weighted_sample_with_replacement(&mut self.bg, probs, size)
} else {
weighted_sample_without_replacement(&mut self.bg, probs, size)?
}
} else if replace {
(0..size)
.map(|_| self.bg.next_u64_bounded(axis_len as u64) as usize)
.collect()
} else {
sample_without_replacement(&mut self.bg, axis_len, size)
};
if !shuffle && !replace {
indices.sort_unstable();
}
let inner_stride: usize = shape[axis + 1..].iter().product();
let outer_size: usize = shape[..axis].iter().product();
let src_block = axis_len * inner_stride;
let out_block = size * inner_stride;
let total_out = outer_size * out_block;
let mut out_data: Vec<T> = Vec::with_capacity(total_out);
let filler = src[0].clone();
out_data.resize(total_out, filler);
for o in 0..outer_size {
let src_base = o * src_block;
let out_base = o * out_block;
for (i, &idx) in indices.iter().enumerate() {
let src_off = src_base + idx * inner_stride;
let out_off = out_base + i * inner_stride;
out_data[out_off..out_off + inner_stride]
.clone_from_slice(&src[src_off..src_off + inner_stride]);
}
}
let mut out_shape = shape;
out_shape[axis] = size;
Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)
}
pub fn permuted_dyn<T>(
&mut self,
arr: &Array<T, IxDyn>,
axis: usize,
) -> Result<Array<T, IxDyn>, FerrayError>
where
T: ferray_core::Element,
{
let shape = arr.shape().to_vec();
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let mut out = arr.clone();
let n = shape[axis];
if n <= 1 {
return Ok(out);
}
let inner_stride: usize = shape[axis + 1..].iter().product();
let block = n * inner_stride;
let outer_size: usize = shape[..axis].iter().product();
let slice = out
.as_slice_mut()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous for permuted"))?;
for o in 0..outer_size {
let base = o * block;
for k in 0..inner_stride {
for i in (1..n).rev() {
let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
slice.swap(base + i * inner_stride + k, base + j * inner_stride + k);
}
}
}
Ok(out)
}
pub fn choice<T>(
&mut self,
arr: &Array<T, Ix1>,
size: usize,
replace: bool,
p: Option<&[f64]>,
) -> Result<Array<T, Ix1>, FerrayError>
where
T: ferray_core::Element,
{
let n = arr.shape()[0];
if size == 0 {
return Array::from_vec(Ix1::new([0]), Vec::new());
}
if n == 0 {
return Err(FerrayError::invalid_value("source array must be non-empty"));
}
if !replace && size > n {
return Err(FerrayError::invalid_value(format!(
"cannot choose {size} elements without replacement from array of size {n}"
)));
}
if let Some(probs) = p {
if probs.len() != n {
return Err(FerrayError::invalid_value(format!(
"p must have same length as array ({n}), got {}",
probs.len()
)));
}
let psum: f64 = probs.iter().sum();
if (psum - 1.0).abs() > 1e-6 {
return Err(FerrayError::invalid_value(format!(
"p must sum to 1.0, got {psum}"
)));
}
for (i, &pi) in probs.iter().enumerate() {
if pi < 0.0 {
return Err(FerrayError::invalid_value(format!(
"p[{i}] = {pi} is negative"
)));
}
}
}
let src = arr
.as_slice()
.ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
let indices = if let Some(probs) = p {
if replace {
weighted_sample_with_replacement(&mut self.bg, probs, size)
} else {
weighted_sample_without_replacement(&mut self.bg, probs, size)?
}
} else if replace {
(0..size)
.map(|_| self.bg.next_u64_bounded(n as u64) as usize)
.collect()
} else {
sample_without_replacement(&mut self.bg, n, size)
};
let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
}
}
fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
let mut pool: Vec<usize> = (0..n).collect();
for i in 0..size {
let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
pool.swap(i, j);
}
pool[..size].to_vec()
}
fn weighted_sample_with_replacement<B: BitGenerator>(
bg: &mut B,
probs: &[f64],
size: usize,
) -> Vec<usize> {
let n = probs.len();
let total: f64 = probs.iter().sum();
let mut scaled: Vec<f64> = probs.iter().map(|&p| p * n as f64 / total).collect();
let mut prob = vec![0.0_f64; n];
let mut alias = vec![0_usize; n];
let mut small: Vec<usize> = Vec::with_capacity(n);
let mut large: Vec<usize> = Vec::with_capacity(n);
for (i, &m) in scaled.iter().enumerate() {
if m < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
while !small.is_empty() && !large.is_empty() {
let s = small.pop().unwrap();
let l = large.pop().unwrap();
prob[s] = scaled[s];
alias[s] = l;
scaled[l] = (scaled[l] + scaled[s]) - 1.0;
if scaled[l] < 1.0 {
small.push(l);
} else {
large.push(l);
}
}
for &i in large.iter().chain(small.iter()) {
prob[i] = 1.0;
}
(0..size)
.map(|_| {
let i = bg.next_u64_bounded(n as u64) as usize;
let u = bg.next_f64();
if u < prob[i] { i } else { alias[i] }
})
.collect()
}
fn weighted_sample_without_replacement<B: BitGenerator>(
bg: &mut B,
probs: &[f64],
size: usize,
) -> Result<Vec<usize>, FerrayError> {
let n = probs.len();
let mut weights: Vec<f64> = probs.to_vec();
let mut selected = Vec::with_capacity(size);
for _ in 0..size {
let total: f64 = weights.iter().sum();
if total <= 0.0 {
return Err(FerrayError::invalid_value(
"insufficient probability mass for sampling without replacement",
));
}
let u = bg.next_f64() * total;
let mut cumsum = 0.0;
let mut chosen = n - 1;
for (i, &w) in weights.iter().enumerate() {
cumsum += w;
if cumsum > u {
chosen = i;
break;
}
}
selected.push(chosen);
weights[chosen] = 0.0;
}
Ok(selected)
}
#[cfg(test)]
mod tests {
use crate::default_rng_seeded;
use ferray_core::{Array, Ix1};
#[test]
fn shuffle_preserves_elements() {
let mut rng = default_rng_seeded(42);
let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
rng.shuffle(&mut arr).unwrap();
let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
}
#[test]
fn permutation_preserves_elements() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
let perm = rng.permutation(&arr).unwrap();
let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
}
#[test]
fn permutation_range_covers_all() {
let mut rng = default_rng_seeded(42);
let perm = rng.permutation_range(10).unwrap();
let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
sorted.sort_unstable();
let expected: Vec<i64> = (0..10).collect();
assert_eq!(sorted, expected);
}
#[test]
fn shuffle_modifies_in_place() {
let mut rng = default_rng_seeded(42);
let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
rng.shuffle(&mut arr).unwrap();
let shuffled = arr.as_slice().unwrap().to_vec();
let mut sorted = shuffled;
sorted.sort_unstable();
assert_eq!(sorted, original);
}
#[test]
fn choice_with_replacement() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
let chosen = rng.choice(&arr, 10, true, None).unwrap();
assert_eq!(chosen.shape(), &[10]);
let src: Vec<i64> = vec![10, 20, 30, 40, 50];
for &v in chosen.as_slice().unwrap() {
assert!(src.contains(&v), "choice returned unexpected value {v}");
}
}
#[test]
fn choice_without_replacement_no_duplicates() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
let chosen = rng.choice(&arr, 5, false, None).unwrap();
let slice = chosen.as_slice().unwrap();
let mut seen = std::collections::HashSet::new();
for &v in slice {
assert!(
seen.insert(v),
"duplicate value {v} in choice without replacement"
);
}
}
#[test]
fn choice_without_replacement_too_many() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
assert!(rng.choice(&arr, 10, false, None).is_err());
}
#[test]
fn choice_with_weights() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
let p = [0.0, 0.0, 1.0]; let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
for &v in chosen.as_slice().unwrap() {
assert_eq!(v, 30);
}
}
#[test]
fn choice_without_replacement_with_weights() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
let p = [0.1, 0.2, 0.3, 0.2, 0.2];
let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
let slice = chosen.as_slice().unwrap();
let mut seen = std::collections::HashSet::new();
for &v in slice {
assert!(seen.insert(v), "duplicate value {v}");
}
}
#[test]
fn choice_bad_weights() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
}
#[test]
fn permuted_1d() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
let result = rng.permuted(&arr, 0).unwrap();
let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
}
#[test]
fn weighted_with_replacement_alias_distribution_recovers_probs() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
let p = [0.05, 0.15, 0.30, 0.40, 0.10];
let n = 100_000;
let chosen = rng.choice(&arr, n, true, Some(&p)).unwrap();
let mut counts = [0_usize; 5];
for &v in chosen.as_slice().unwrap() {
counts[v as usize] += 1;
}
for (i, &c) in counts.iter().enumerate() {
let observed = c as f64 / n as f64;
assert!(
(observed - p[i]).abs() < 0.015,
"bin {i}: observed {observed}, expected {}",
p[i]
);
}
}
#[test]
fn choice_dyn_axis0_picks_whole_rows() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(42);
let data: Vec<i64> = (0..5)
.flat_map(|i| (0..3).map(move |j| i * 100 + j))
.collect();
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[5, 3]), data).unwrap();
let chosen = rng.choice_dyn(&arr, 4, true, None, 0, true).unwrap();
assert_eq!(chosen.shape(), &[4, 3]);
let slice = chosen.as_slice().unwrap();
for row in 0..4 {
let v0 = slice[row * 3];
let id = v0 / 100;
assert!((0..5).contains(&id));
assert_eq!(slice[row * 3 + 1], id * 100 + 1);
assert_eq!(slice[row * 3 + 2], id * 100 + 2);
}
}
#[test]
fn choice_dyn_axis1_picks_whole_columns() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(7);
let data: Vec<i64> = (0..3)
.flat_map(|i| (0..6).map(move |j| i * 10 + j))
.collect();
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 6]), data).unwrap();
let chosen = rng.choice_dyn(&arr, 2, false, None, 1, true).unwrap();
assert_eq!(chosen.shape(), &[3, 2]);
let slice = chosen.as_slice().unwrap();
for col in 0..2 {
let v0 = slice[col];
let v1 = slice[2 + col];
let v2 = slice[4 + col];
assert!((0..6).contains(&v0));
assert_eq!(v1, v0 + 10);
assert_eq!(v2, v0 + 20);
}
}
#[test]
fn choice_dyn_without_replacement_no_duplicate_rows() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(1);
let data: Vec<i64> = (0..10)
.flat_map(|i| (0..2).map(move |j| i * 100 + j))
.collect();
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[10, 2]), data).unwrap();
let chosen = rng.choice_dyn(&arr, 5, false, None, 0, true).unwrap();
let slice = chosen.as_slice().unwrap();
let mut ids = std::collections::HashSet::new();
for row in 0..5 {
let id = slice[row * 2] / 100;
assert!(ids.insert(id), "row id {id} repeated under replace=false");
}
}
#[test]
fn choice_dyn_shuffle_false_returns_sorted_indices() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(3);
let data: Vec<i64> = (0..12)
.flat_map(|i| (0..2).map(move |j| if j == 0 { i as i64 } else { i as i64 * 10 }))
.collect();
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[12, 2]), data).unwrap();
let chosen = rng.choice_dyn(&arr, 6, false, None, 0, false).unwrap();
let slice = chosen.as_slice().unwrap();
let mut last = -1i64;
for row in 0..6 {
let id = slice[row * 2];
assert!(
id > last,
"shuffle=false output not ascending: {id} after {last}"
);
last = id;
}
}
#[test]
fn choice_dyn_weighted_concentrates_on_high_p() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(0);
let data: Vec<i64> = (0..4)
.flat_map(|i| (0..2).map(move |j| i * 100 + j))
.collect();
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[4, 2]), data).unwrap();
let p = [0.0, 0.0, 1.0, 0.0];
let chosen = rng.choice_dyn(&arr, 20, true, Some(&p), 0, true).unwrap();
let slice = chosen.as_slice().unwrap();
for row in 0..20 {
assert_eq!(slice[row * 2], 200, "weighted choice strayed from row 2");
}
}
#[test]
fn choice_dyn_size_zero_returns_empty_axis() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(11);
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), (0..12).collect()).unwrap();
let chosen = rng.choice_dyn(&arr, 0, true, None, 0, true).unwrap();
assert_eq!(chosen.shape(), &[0, 4]);
}
#[test]
fn choice_dyn_bad_axis() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(0);
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), (0..6).collect()).unwrap();
assert!(rng.choice_dyn(&arr, 1, true, None, 5, true).is_err());
}
#[test]
fn choice_dyn_too_many_no_replace_errors() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(0);
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 2]), (0..6).collect()).unwrap();
assert!(rng.choice_dyn(&arr, 5, false, None, 0, true).is_err());
}
#[test]
fn shuffle_dyn_axis0_swaps_whole_rows() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(42);
let data: Vec<i64> = (0..4)
.flat_map(|i| (0..3).map(move |j| i * 10 + j))
.collect();
let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[4, 3]), data).unwrap();
rng.shuffle_dyn(&mut arr, 0).unwrap();
let slice = arr.as_slice().unwrap();
let mut seen = std::collections::HashSet::new();
for row in 0..4 {
let row_first = slice[row * 3];
let id = row_first / 10;
assert!(
(0..4).contains(&id),
"row {row} starts with unexpected value {row_first}"
);
assert_eq!(slice[row * 3 + 1], id * 10 + 1);
assert_eq!(slice[row * 3 + 2], id * 10 + 2);
assert!(
seen.insert(id),
"row id {id} duplicated — shuffle broke a row"
);
}
}
#[test]
fn shuffle_dyn_axis1_swaps_whole_columns() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(7);
let data: Vec<i64> = (0..3)
.flat_map(|i| (0..4).map(move |j| i * 10 + j))
.collect();
let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), data).unwrap();
rng.shuffle_dyn(&mut arr, 1).unwrap();
let slice = arr.as_slice().unwrap();
let mut col_ids = Vec::new();
for col in 0..4 {
let v0 = slice[col];
let v1 = slice[4 + col];
let v2 = slice[8 + col];
assert!((0..4).contains(&v0));
assert_eq!(v1, v0 + 10);
assert_eq!(v2, v0 + 20);
col_ids.push(v0);
}
col_ids.sort_unstable();
assert_eq!(col_ids, vec![0, 1, 2, 3]);
}
#[test]
fn shuffle_dyn_axis_out_of_bounds() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(0);
let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0; 6]).unwrap();
assert!(rng.shuffle_dyn(&mut arr, 2).is_err());
}
#[test]
fn permuted_dyn_axis0_each_column_independent() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(99);
let n_rows = 5;
let n_cols = 4;
let data: Vec<i64> = (0..n_rows * n_cols).map(|x| x as i64).collect();
let arr =
Array::<i64, IxDyn>::from_vec(IxDyn::new(&[n_rows, n_cols]), data.clone()).unwrap();
let result = rng.permuted_dyn(&arr, 0).unwrap();
let slice = result.as_slice().unwrap();
for col in 0..n_cols {
let original_col: Vec<i64> = (0..n_rows).map(|r| (r * n_cols + col) as i64).collect();
let mut got_col: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols + col]).collect();
got_col.sort_unstable();
let mut want = original_col.clone();
want.sort_unstable();
assert_eq!(got_col, want, "col {col} lost values during permute");
}
}
#[test]
fn permuted_dyn_columns_can_diverge() {
use ferray_core::IxDyn;
let mut rng = default_rng_seeded(1234);
let n_rows = 5;
let n_cols = 4;
let data: Vec<i64> = (0..n_rows * n_cols)
.map(|x| x as i64 % n_rows as i64)
.collect();
let arr =
Array::<i64, IxDyn>::from_vec(IxDyn::new(&[n_rows, n_cols]), data.clone()).unwrap();
let result = rng.permuted_dyn(&arr, 0).unwrap();
let slice = result.as_slice().unwrap();
let col0: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols]).collect();
let mut any_diff = false;
for col in 1..n_cols {
let coln: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols + col]).collect();
if col0 != coln {
any_diff = true;
break;
}
}
assert!(
any_diff,
"all columns matched — permuted didn't independently shuffle"
);
}
#[test]
fn permuted_dyn_seed_reproducible() {
use ferray_core::IxDyn;
let mut a = default_rng_seeded(31);
let mut b = default_rng_seeded(31);
let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 3]), (0..9).collect()).unwrap();
let xa = a.permuted_dyn(&arr, 1).unwrap();
let xb = b.permuted_dyn(&arr, 1).unwrap();
assert_eq!(xa.as_slice().unwrap(), xb.as_slice().unwrap());
}
#[test]
fn weighted_with_replacement_unnormalized_probs() {
let mut rng = default_rng_seeded(42);
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![0, 1, 2]).unwrap();
let p = [0.2, 0.5, 0.3];
let n = 50_000;
let chosen = rng.choice(&arr, n, true, Some(&p)).unwrap();
let mut counts = [0_usize; 3];
for &v in chosen.as_slice().unwrap() {
counts[v as usize] += 1;
}
for (i, &c) in counts.iter().enumerate() {
let observed = c as f64 / n as f64;
assert!(
(observed - p[i]).abs() < 0.02,
"bin {i}: observed {observed}, expected {}",
p[i]
);
}
}
}