#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
use scirs2_core::random::Random;
pub(crate) mod rng_utils {
use super::*;
pub fn create_rng(seed: Option<u64>) -> Random<scirs2_core::rngs::StdRng> {
if let Some(seed) = seed {
Random::seed(seed)
} else {
Random::seed(42) }
}
pub fn shuffle_indices<T: Clone>(indices: &mut [T], seed: Option<u64>) {
let mut rng = create_rng(seed);
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
pub fn gen_range(
rng: &mut Random<scirs2_core::rngs::StdRng>,
range: std::ops::Range<usize>,
) -> usize {
rng.gen_range(range)
}
}
pub trait Sampler: Send {
type Iter: Iterator<Item = usize> + Send;
fn iter(&self) -> Self::Iter;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn into_batch_sampler(
self,
batch_size: usize,
drop_last: bool,
) -> super::batch::BatchingSampler<Self>
where
Self: Sized,
{
super::batch::BatchingSampler::new(self, batch_size, drop_last)
}
fn into_distributed(
self,
num_replicas: usize,
rank: usize,
) -> super::distributed::DistributedWrapper<Self>
where
Self: Sized,
{
super::distributed::DistributedWrapper::new(self, num_replicas, rank)
}
}
pub trait BatchSampler: Send {
type Iter: Iterator<Item = Vec<usize>> + Send;
fn iter(&self) -> Self::Iter;
fn num_batches(&self) -> usize;
fn len(&self) -> usize {
self.num_batches()
}
fn is_empty(&self) -> bool {
self.num_batches() == 0
}
}
pub struct SamplerIterator {
indices: Vec<usize>,
position: usize,
}
impl SamplerIterator {
pub fn new(indices: Vec<usize>) -> Self {
Self {
indices,
position: 0,
}
}
pub fn from_range(start: usize, end: usize) -> Self {
Self::new((start..end).collect())
}
pub fn shuffled(mut indices: Vec<usize>, seed: Option<u64>) -> Self {
let mut rng = match seed {
Some(s) => Random::seed(s),
None => Random::seed(42), };
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
Self::new(indices)
}
pub fn remaining(&self) -> usize {
self.indices.len() - self.position
}
}
impl Iterator for SamplerIterator {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.position < self.indices.len() {
let item = self.indices[self.position];
self.position += 1;
Some(item)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.remaining();
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for SamplerIterator {
fn len(&self) -> usize {
self.remaining()
}
}
pub mod utils {
use super::*;
pub fn random_indices(n: usize, k: usize, seed: Option<u64>) -> Vec<usize> {
assert!(k <= n, "Cannot sample more items than available");
let mut rng = match seed {
Some(s) => Random::seed(s),
None => Random::seed(42),
};
if k == n {
let mut indices: Vec<usize> = (0..n).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
indices
} else if k <= n / 2 {
let mut selected = std::collections::HashSet::new();
while selected.len() < k {
let idx = rng.gen_range(0..n);
selected.insert(idx);
}
let mut result: Vec<usize> = selected.into_iter().collect();
result.sort_unstable(); result
} else {
let mut excluded = std::collections::HashSet::new();
while excluded.len() < n - k {
let idx = rng.gen_range(0..n);
excluded.insert(idx);
}
let mut result: Vec<usize> = (0..n).filter(|&i| !excluded.contains(&i)).collect();
result.sort_unstable(); result
}
}
pub fn stratified_split(
indices: &[usize],
labels: &[usize],
test_ratio: f32,
seed: Option<u64>,
) -> (Vec<usize>, Vec<usize>) {
use std::collections::HashMap;
let mut label_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for &idx in indices {
if idx < labels.len() {
label_groups
.entry(labels[idx])
.or_insert_with(Vec::new)
.push(idx);
}
}
let mut rng = match seed {
Some(s) => Random::seed(s),
None => Random::seed(42),
};
let mut train_indices = Vec::new();
let mut test_indices = Vec::new();
for (_, mut group_indices) in label_groups {
for i in (1..group_indices.len()).rev() {
let j = rng.gen_range(0..=i);
group_indices.swap(i, j);
}
let test_size = ((group_indices.len() as f32) * test_ratio).round() as usize;
let test_size = test_size.min(group_indices.len());
test_indices.extend(group_indices.iter().take(test_size));
train_indices.extend(group_indices.iter().skip(test_size));
}
(train_indices, test_indices)
}
pub fn calculate_class_weights(labels: &[usize], num_classes: usize) -> Vec<f32> {
let mut class_counts = vec![0usize; num_classes];
for &label in labels {
if label < num_classes {
class_counts[label] += 1;
}
}
let total_samples = labels.len() as f32;
class_counts
.iter()
.map(|&count| {
if count > 0 {
total_samples / (num_classes as f32 * count as f32)
} else {
0.0
}
})
.collect()
}
pub fn validate_sampling_params(
dataset_size: usize,
num_samples: Option<usize>,
replacement: bool,
) -> Result<usize, String> {
let actual_num_samples = num_samples.unwrap_or(dataset_size);
if dataset_size == 0 {
if actual_num_samples == 0 {
return Ok(0);
} else {
return Err("Cannot sample from empty dataset".to_string());
}
}
if !replacement && actual_num_samples > dataset_size {
return Err(format!(
"Cannot sample {} items without replacement from dataset of size {}",
actual_num_samples, dataset_size
));
}
if actual_num_samples == 0 && !replacement {
return Err(
"Number of samples cannot be zero for non-empty dataset without replacement"
.to_string(),
);
}
Ok(actual_num_samples)
}
pub fn train_val_split(
dataset_size: usize,
val_ratio: f32,
seed: Option<u64>,
) -> (Vec<usize>, Vec<usize>) {
let val_size = (dataset_size as f32 * val_ratio).round() as usize;
let indices = random_indices(dataset_size, dataset_size, seed);
let (val_indices, train_indices) = indices.split_at(val_size);
(train_indices.to_vec(), val_indices.to_vec())
}
pub fn kfold_splits(
dataset_size: usize,
k: usize,
seed: Option<u64>,
) -> Vec<(Vec<usize>, Vec<usize>)> {
assert!(k > 1, "K must be greater than 1");
assert!(k <= dataset_size, "K cannot be larger than dataset size");
let indices = random_indices(dataset_size, dataset_size, seed);
let fold_size = dataset_size / k;
let mut splits = Vec::new();
for i in 0..k {
let start = i * fold_size;
let end = if i == k - 1 {
dataset_size } else {
(i + 1) * fold_size
};
let val_indices = indices[start..end].to_vec();
let train_indices = [&indices[..start], &indices[end..]].concat();
splits.push((train_indices, val_indices));
}
splits
}
pub fn train_val_test_split(
dataset_size: usize,
train_ratio: f32,
val_ratio: f32,
seed: Option<u64>,
) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
assert!(
train_ratio + val_ratio < 1.0,
"Train and val ratios must sum to less than 1.0"
);
assert!(
train_ratio > 0.0 && val_ratio > 0.0,
"Ratios must be positive"
);
let train_size = (dataset_size as f32 * train_ratio).round() as usize;
let val_size = (dataset_size as f32 * val_ratio).round() as usize;
let _test_size = dataset_size - train_size - val_size;
let indices = random_indices(dataset_size, dataset_size, seed);
let train_indices = indices[..train_size].to_vec();
let val_indices = indices[train_size..train_size + val_size].to_vec();
let test_indices = indices[train_size + val_size..].to_vec();
(train_indices, val_indices, test_indices)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sampler_iterator_basic() {
let indices = vec![0, 1, 2, 3, 4];
let iter = SamplerIterator::new(indices.clone());
assert_eq!(iter.len(), 5);
assert_eq!(iter.remaining(), 5);
let collected: Vec<usize> = iter.collect();
assert_eq!(collected, indices);
}
#[test]
fn test_sampler_iterator_from_range() {
let iter = SamplerIterator::from_range(0, 5);
let collected: Vec<usize> = iter.collect();
assert_eq!(collected, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_sampler_iterator_shuffled() {
let indices = vec![0, 1, 2, 3, 4];
let iter = SamplerIterator::shuffled(indices.clone(), Some(42));
let collected: Vec<usize> = iter.collect();
assert_eq!(collected.len(), indices.len());
for &idx in &indices {
assert!(collected.contains(&idx));
}
}
#[test]
fn test_utils_random_indices() {
let indices = utils::random_indices(10, 5, Some(42));
assert_eq!(indices.len(), 5);
let mut sorted_indices = indices.clone();
sorted_indices.sort();
sorted_indices.dedup();
assert_eq!(sorted_indices.len(), 5);
for &idx in &indices {
assert!(idx < 10);
}
}
#[test]
fn test_utils_random_indices_all() {
let indices = utils::random_indices(5, 5, Some(42));
assert_eq!(indices.len(), 5);
let mut sorted_indices = indices.clone();
sorted_indices.sort();
assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_utils_stratified_split() {
let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
let (train, test) = utils::stratified_split(&indices, &labels, 0.3, Some(42));
assert!(train.len() + test.len() == indices.len());
assert!(test.len() >= 2);
let mut all_indices = train.clone();
all_indices.extend(test.clone());
all_indices.sort();
assert_eq!(all_indices, indices);
}
#[test]
fn test_utils_calculate_class_weights() {
let labels = vec![0, 0, 1, 1, 1, 2]; let weights = utils::calculate_class_weights(&labels, 3);
assert_eq!(weights.len(), 3);
assert!(weights[2] > weights[1]);
assert!(weights[0] > weights[1]);
}
#[test]
fn test_utils_validate_sampling_params() {
assert!(utils::validate_sampling_params(10, Some(5), false).is_ok());
assert!(utils::validate_sampling_params(10, Some(15), true).is_ok());
assert!(utils::validate_sampling_params(10, None, false).is_ok());
assert!(utils::validate_sampling_params(0, Some(0), false).is_ok());
assert!(utils::validate_sampling_params(0, None, false).is_ok());
assert!(utils::validate_sampling_params(10, Some(0), true).is_ok());
assert!(utils::validate_sampling_params(0, Some(5), false).is_err()); assert!(utils::validate_sampling_params(10, Some(0), false).is_err()); assert!(utils::validate_sampling_params(10, Some(15), false).is_err()); }
#[test]
fn test_size_hints() {
let iter = SamplerIterator::new(vec![0, 1, 2]);
assert_eq!(iter.size_hint(), (3, Some(3)));
let mut iter = SamplerIterator::new(vec![0, 1, 2]);
iter.next();
assert_eq!(iter.size_hint(), (2, Some(2)));
}
}