use rand::Rng;
pub fn create_rng(seed: u64) -> rand::rngs::SmallRng {
use rand::SeedableRng;
rand::rngs::SmallRng::seed_from_u64(seed)
}
pub fn shuffle<T, R: Rng>(slice: &mut [T], rng: &mut R) {
let n = slice.len();
if n <= 1 {
return;
}
for i in (1..n).rev() {
let j = rng.random_range(0..=i);
slice.swap(i, j);
}
}
pub fn shuffled_indices<R: Rng>(n: usize, rng: &mut R) -> Vec<usize> {
let mut indices: Vec<usize> = (0..n).collect();
shuffle(&mut indices, rng);
indices
}
pub fn weighted_choose<R: Rng>(weights: &[f64], rng: &mut R) -> Option<usize> {
if weights.is_empty() {
return None;
}
let total: f64 = weights.iter().filter(|w| **w > 0.0).sum();
if total <= 0.0 {
return None;
}
let threshold = rng.random_range(0.0..total);
let mut cumulative = 0.0;
for (i, &w) in weights.iter().enumerate() {
if w > 0.0 {
cumulative += w;
if cumulative > threshold {
return Some(i);
}
}
}
Some(weights.len() - 1)
}
pub struct WeightedSampler {
cumulative: Vec<f64>,
total: f64,
}
impl WeightedSampler {
pub fn new(weights: &[f64]) -> Option<Self> {
if weights.is_empty() {
return None;
}
let mut cumulative = Vec::with_capacity(weights.len());
let mut total = 0.0;
for &w in weights {
if w > 0.0 {
total += w;
}
cumulative.push(total);
}
if total <= 0.0 {
return None;
}
Some(Self { cumulative, total })
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> usize {
let threshold = rng.random_range(0.0..self.total);
match self.cumulative.binary_search_by(|c| {
c.partial_cmp(&threshold)
.expect("cumulative values are finite")
}) {
Ok(i) => i,
Err(i) => i.min(self.cumulative.len() - 1),
}
}
pub fn len(&self) -> usize {
self.cumulative.len()
}
pub fn is_empty(&self) -> bool {
self.cumulative.is_empty()
}
pub fn total_weight(&self) -> f64 {
self.total
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_rng_deterministic() {
let mut rng1 = create_rng(42);
let mut rng2 = create_rng(42);
let vals1: Vec<f64> = (0..10).map(|_| rng1.random()).collect();
let vals2: Vec<f64> = (0..10).map(|_| rng2.random()).collect();
assert_eq!(vals1, vals2);
}
#[test]
fn test_shuffle_preserves_elements() {
let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut rng = create_rng(123);
shuffle(&mut v, &mut rng);
v.sort();
assert_eq!(v, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
#[test]
fn test_shuffle_empty() {
let mut v: Vec<i32> = vec![];
let mut rng = create_rng(0);
shuffle(&mut v, &mut rng); }
#[test]
fn test_shuffle_single() {
let mut v = vec![42];
let mut rng = create_rng(0);
shuffle(&mut v, &mut rng);
assert_eq!(v, vec![42]);
}
#[test]
fn test_shuffle_actually_shuffles() {
let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut v = original.clone();
let mut rng = create_rng(42);
shuffle(&mut v, &mut rng);
assert_ne!(v, original, "shuffle should change order (probabilistic)");
}
#[test]
fn test_shuffled_indices() {
let mut rng = create_rng(42);
let indices = shuffled_indices(10, &mut rng);
assert_eq!(indices.len(), 10);
let mut sorted = indices.clone();
sorted.sort();
assert_eq!(sorted, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_weighted_choose_basic() {
let mut rng = create_rng(42);
let weights = [0.0, 0.0, 1.0]; for _ in 0..100 {
assert_eq!(weighted_choose(&weights, &mut rng), Some(2));
}
}
#[test]
fn test_weighted_choose_empty() {
let mut rng = create_rng(42);
assert_eq!(weighted_choose(&[], &mut rng), None);
}
#[test]
fn test_weighted_choose_all_zero() {
let mut rng = create_rng(42);
assert_eq!(weighted_choose(&[0.0, 0.0], &mut rng), None);
}
#[test]
fn test_weighted_choose_distribution() {
let mut rng = create_rng(42);
let weights = [1.0, 3.0]; let mut counts = [0u32; 2];
let n = 10000;
for _ in 0..n {
let idx = weighted_choose(&weights, &mut rng).unwrap();
counts[idx] += 1;
}
let ratio = counts[1] as f64 / counts[0] as f64;
assert!(
(ratio - 3.0).abs() < 0.5,
"expected ratio ~3.0, got {ratio}"
);
}
#[test]
fn test_weighted_sampler_basic() {
let sampler = WeightedSampler::new(&[1.0, 2.0, 3.0]).unwrap();
assert_eq!(sampler.len(), 3);
assert!(!sampler.is_empty());
assert!((sampler.total_weight() - 6.0).abs() < 1e-15);
}
#[test]
fn test_weighted_sampler_deterministic_weight() {
let sampler = WeightedSampler::new(&[0.0, 0.0, 1.0]).unwrap();
let mut rng = create_rng(42);
for _ in 0..100 {
assert_eq!(sampler.sample(&mut rng), 2);
}
}
#[test]
fn test_weighted_sampler_distribution() {
let sampler = WeightedSampler::new(&[1.0, 3.0]).unwrap();
let mut rng = create_rng(42);
let mut counts = [0u32; 2];
let n = 10000;
for _ in 0..n {
counts[sampler.sample(&mut rng)] += 1;
}
let ratio = counts[1] as f64 / counts[0] as f64;
assert!(
(ratio - 3.0).abs() < 0.5,
"expected ratio ~3.0, got {ratio}"
);
}
#[test]
fn test_weighted_sampler_empty() {
assert!(WeightedSampler::new(&[]).is_none());
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(300))]
#[test]
fn shuffle_is_permutation(
seed in 0_u64..10000,
data in proptest::collection::vec(0_i32..1000, 0..50),
) {
let mut shuffled = data.clone();
let mut rng = create_rng(seed);
shuffle(&mut shuffled, &mut rng);
let mut sorted_orig = data.clone();
let mut sorted_shuf = shuffled;
sorted_orig.sort();
sorted_shuf.sort();
prop_assert_eq!(sorted_orig, sorted_shuf);
}
#[test]
fn weighted_choose_returns_valid_index(
seed in 0_u64..10000,
weights in proptest::collection::vec(0.0_f64..10.0, 1..20),
) {
let has_positive = weights.iter().any(|&w| w > 0.0);
let mut rng = create_rng(seed);
let result = weighted_choose(&weights, &mut rng);
if has_positive {
let idx = result.unwrap();
prop_assert!(idx < weights.len());
}
}
}
}