pub mod active_learning;
pub mod adaptive;
pub mod advanced;
pub mod basic;
pub mod batch;
pub mod core;
pub mod curriculum;
pub mod distributed;
pub mod importance;
pub mod stratified;
pub mod weighted;
pub use core::{utils, BatchSampler, Sampler, SamplerIterator};
pub use basic::{
random, random_subset, random_with_replacement, sequential, RandomSampler, SequentialSampler,
};
pub use batch::{batch, batch_drop_last, batch_keep_last, BatchSamplerIter, BatchingSampler};
pub use distributed::{distributed, distributed_sampler, DistributedSampler, DistributedWrapper};
pub use weighted::{
balanced_weighted, subset_random, weighted_random, SubsetRandomSampler, WeightedRandomSampler,
};
pub use stratified::{
balanced_stratified, stratified, stratified_train_test_split, StratifiedSampler,
};
pub use curriculum::{
anti_curriculum, exponential_curriculum, linear_curriculum, step_curriculum, CurriculumSampler,
CurriculumStats, CurriculumStrategy,
};
pub use active_learning::{
diversity_sampler, uncertainty_diversity_sampler, uncertainty_sampler, AcquisitionStrategy,
ActiveLearningSampler, ActiveLearningStats,
};
pub use adaptive::{
frequency_balanced_sampler, hard_adaptive_sampler, uncertainty_adaptive_sampler,
AdaptiveSampler, AdaptiveStats, AdaptiveStrategy,
};
pub use importance::{
class_balanced_importance_sampler, exponential_importance_sampler,
loss_based_importance_sampler, uniform_importance_sampler, ImportanceSampler, ImportanceStats,
};
pub use advanced::{
GroupedSampler, ImportanceSampler as AdvancedImportanceSampler,
StratifiedSampler as AdvancedStratifiedSampler,
WeightedRandomSampler as AdvancedWeightedRandomSampler,
};
pub type DefaultSampler = RandomSampler;
pub type DefaultBatchSampler = BatchingSampler<RandomSampler>;
pub fn default_sampler(dataset_size: usize, seed: Option<u64>) -> RandomSampler {
random(dataset_size, seed)
}
pub fn default_batch_sampler(
dataset_size: usize,
batch_size: usize,
drop_last: bool,
seed: Option<u64>,
) -> BatchingSampler<RandomSampler> {
random(dataset_size, seed).into_batch_sampler(batch_size, drop_last)
}
pub fn default_distributed_sampler(
dataset_size: usize,
num_replicas: usize,
rank: usize,
_seed: Option<u64>,
) -> DistributedSampler {
distributed_sampler(dataset_size, num_replicas, rank, true)
}
pub fn create_sampler(
sampler_type: &str,
dataset_size: usize,
config: &std::collections::HashMap<String, String>,
) -> Result<Box<dyn Sampler<Iter = Box<dyn Iterator<Item = usize> + Send>> + Send>, String> {
let seed = config.get("seed").and_then(|s| s.parse::<u64>().ok());
match sampler_type {
"sequential" => {
let sampler = sequential(dataset_size);
Ok(Box::new(SamplerWrapper::Sequential(sampler)))
}
"random" => {
let sampler = random(dataset_size, seed);
Ok(Box::new(SamplerWrapper::Random(sampler)))
}
"random_replacement" => {
let sampler = random_with_replacement(dataset_size, dataset_size, seed);
Ok(Box::new(SamplerWrapper::Random(sampler)))
}
"weighted" => {
let weights_str = config
.get("weights")
.ok_or("Weighted sampler requires 'weights' configuration")?;
let weights: Vec<f64> = weights_str
.split(',')
.map(|s| s.trim().parse::<f64>())
.collect::<Result<Vec<_>, _>>()
.map_err(|_| "Invalid weights format")?;
let replacement = config
.get("replacement")
.map(|s| s.parse::<bool>().unwrap_or(false))
.unwrap_or(false);
let num_samples = config
.get("num_samples")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(weights.len());
let sampler = weighted_random(weights, num_samples, replacement, seed);
Ok(Box::new(SamplerWrapper::Weighted(sampler)))
}
"distributed" => {
let num_replicas = config
.get("num_replicas")
.and_then(|s| s.parse::<usize>().ok())
.ok_or("Distributed sampler requires 'num_replicas' configuration")?;
let rank = config
.get("rank")
.and_then(|s| s.parse::<usize>().ok())
.ok_or("Distributed sampler requires 'rank' configuration")?;
let shuffle = config
.get("shuffle")
.map(|s| s.parse::<bool>().unwrap_or(true))
.unwrap_or(true);
let _drop_last = config
.get("drop_last")
.map(|s| s.parse::<bool>().unwrap_or(false))
.unwrap_or(false);
let sampler = DistributedSampler::new(dataset_size, num_replicas, rank, shuffle);
Ok(Box::new(SamplerWrapper::Distributed(sampler)))
}
_ => Err(format!("Unknown sampler type: {}", sampler_type)),
}
}
#[derive(Debug)]
enum SamplerWrapper {
Sequential(SequentialSampler),
Random(RandomSampler),
Weighted(WeightedRandomSampler),
Distributed(DistributedSampler),
}
impl Sampler for SamplerWrapper {
type Iter = Box<dyn Iterator<Item = usize> + Send>;
fn iter(&self) -> Self::Iter {
match self {
SamplerWrapper::Sequential(s) => Box::new(s.iter()),
SamplerWrapper::Random(s) => Box::new(s.iter()),
SamplerWrapper::Weighted(s) => Box::new(s.iter()),
SamplerWrapper::Distributed(s) => Box::new(s.iter()),
}
}
fn len(&self) -> usize {
match self {
SamplerWrapper::Sequential(s) => s.len(),
SamplerWrapper::Random(s) => s.len(),
SamplerWrapper::Weighted(s) => s.len(),
SamplerWrapper::Distributed(s) => s.len(),
}
}
fn into_batch_sampler(self, batch_size: usize, drop_last: bool) -> BatchingSampler<Self>
where
Self: Sized,
{
BatchingSampler::new(self, batch_size, drop_last)
}
fn into_distributed(self, num_replicas: usize, rank: usize) -> DistributedWrapper<Self>
where
Self: Sized,
{
DistributedWrapper::new(self, num_replicas, rank)
}
}
pub fn train_val_split(
dataset_size: usize,
val_ratio: f32,
seed: Option<u64>,
) -> (Vec<usize>, Vec<usize>) {
assert!(
val_ratio >= 0.0 && val_ratio <= 1.0,
"Validation ratio must be in [0, 1]"
);
let val_size = (dataset_size as f32 * val_ratio).round() as usize;
let train_size = dataset_size - val_size;
let mut indices: Vec<usize> = (0..dataset_size).collect();
if let Some(seed_val) = seed {
use scirs2_core::random::Random;
let mut rng = Random::seed(seed_val);
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
let train_indices = indices[..train_size].to_vec();
let val_indices = indices[train_size..].to_vec();
(train_indices, val_indices)
}
pub fn train_val_test_split(
dataset_size: usize,
train_ratio: f32,
val_ratio: f32,
seed: Option<u64>,
) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
assert!(
train_ratio >= 0.0 && train_ratio <= 1.0,
"Train ratio must be in [0, 1]"
);
assert!(
val_ratio >= 0.0 && val_ratio <= 1.0,
"Val ratio must be in [0, 1]"
);
assert!(
train_ratio + val_ratio <= 1.0,
"Train + val ratios must not exceed 1.0"
);
let train_size = (dataset_size as f32 * train_ratio).round() as usize;
let val_size = (dataset_size as f32 * val_ratio).round() as usize;
let _test_size = dataset_size - train_size - val_size;
let mut indices: Vec<usize> = (0..dataset_size).collect();
if let Some(seed_val) = seed {
use scirs2_core::random::Random;
let mut rng = Random::seed(seed_val);
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
let train_indices = indices[..train_size].to_vec();
let val_indices = indices[train_size..train_size + val_size].to_vec();
let test_indices = indices[train_size + val_size..].to_vec();
(train_indices, val_indices, test_indices)
}
pub fn kfold_splits(
dataset_size: usize,
k: usize,
seed: Option<u64>,
) -> Vec<(Vec<usize>, Vec<usize>)> {
assert!(k > 1, "Number of folds must be greater than 1");
assert!(
k <= dataset_size,
"Number of folds cannot exceed dataset size"
);
let mut indices: Vec<usize> = (0..dataset_size).collect();
if let Some(seed_val) = seed {
use scirs2_core::random::Random;
let mut rng = Random::seed(seed_val);
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
let fold_size = dataset_size / k;
let mut splits = Vec::with_capacity(k);
for i in 0..k {
let val_start = i * fold_size;
let val_end = if i == k - 1 {
dataset_size
} else {
(i + 1) * fold_size
};
let val_indices = indices[val_start..val_end].to_vec();
let mut train_indices = Vec::new();
train_indices.extend(&indices[..val_start]);
train_indices.extend(&indices[val_end..]);
splits.push((train_indices, val_indices));
}
splits
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_default_sampler() {
let sampler = default_sampler(100, Some(42));
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 100);
assert_eq!(indices.iter().collect::<HashSet<_>>().len(), 100);
let sampler2 = default_sampler(100, Some(42));
let indices2: Vec<usize> = sampler2.iter().collect();
assert_eq!(indices, indices2);
}
#[test]
fn test_default_batch_sampler() {
let sampler = default_batch_sampler(100, 32, true, Some(42));
let batches: Vec<Vec<usize>> = sampler.iter().collect();
assert_eq!(batches.len(), 3); assert_eq!(batches[0].len(), 32);
assert_eq!(batches[1].len(), 32);
assert_eq!(batches[2].len(), 32);
let sampler = default_batch_sampler(100, 32, false, Some(42));
let batches: Vec<Vec<usize>> = sampler.iter().collect();
assert_eq!(batches.len(), 4); assert_eq!(batches[3].len(), 4); }
#[test]
fn test_default_distributed_sampler() {
let sampler = default_distributed_sampler(1000, 4, 0, Some(42));
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 250);
assert!(indices.iter().all(|&i| i < 1000));
let sampler_rank1 = default_distributed_sampler(1000, 4, 1, Some(42));
let indices_rank1: Vec<usize> = sampler_rank1.iter().collect();
assert_eq!(indices_rank1.len(), 250);
let intersection: HashSet<_> = indices
.iter()
.filter(|&x| indices_rank1.contains(x))
.collect();
assert!(intersection.is_empty()); }
#[test]
fn test_sampler_factory() {
let mut config = std::collections::HashMap::new();
config.insert("seed".to_string(), "42".to_string());
let sampler =
create_sampler("sequential", 100, &config).expect("create sampler should succeed");
assert_eq!(sampler.len(), 100);
let sampler =
create_sampler("random", 100, &config).expect("create sampler should succeed");
assert_eq!(sampler.len(), 100);
config.insert("weights".to_string(), "0.1,0.3,0.6".to_string());
let sampler =
create_sampler("weighted", 3, &config).expect("create sampler should succeed");
assert_eq!(sampler.len(), 3);
config.insert("num_replicas".to_string(), "4".to_string());
config.insert("rank".to_string(), "0".to_string());
let sampler =
create_sampler("distributed", 1000, &config).expect("create sampler should succeed");
assert_eq!(sampler.len(), 250); }
#[test]
fn test_sampler_factory_errors() {
let config = std::collections::HashMap::new();
assert!(create_sampler("unknown", 100, &config).is_err());
assert!(create_sampler("weighted", 100, &config).is_err());
assert!(create_sampler("distributed", 100, &config).is_err());
}
#[test]
fn test_backward_compatibility() {
let _seq = SequentialSampler::new(100);
let _rand = RandomSampler::new(100, None, false).with_generator(42);
let _subset = SubsetRandomSampler::new(vec![0, 1, 2, 3, 4]).with_generator(42);
let _distributed = DistributedSampler::new(100, 4, 0, true).with_generator(42);
let (train, val) = train_val_split(1000, 0.2, Some(42));
assert_eq!(train.len(), 800);
assert_eq!(val.len(), 200);
let _default: DefaultSampler = RandomSampler::new(100, None, false).with_generator(42);
}
#[test]
fn test_modular_integration() {
let base_sampler = RandomSampler::new(1000, None, false).with_generator(42);
let batch_sampler = base_sampler.into_batch_sampler(32, true);
let distributed_sampler = batch_sampler.into_distributed(4, 0);
let batches: Vec<Vec<usize>> = distributed_sampler.iter().collect();
assert!(!batches.is_empty());
for batch in batches.iter().take(batches.len() - 1) {
assert_eq!(batch.len(), 32);
}
}
#[test]
fn test_comprehensive_api_coverage() {
let _seq = SequentialSampler::new(100);
let _rand = RandomSampler::new(100, None, false).with_generator(42);
let _subset = SubsetRandomSampler::new(vec![0, 1, 2]).with_generator(42);
let weights = vec![0.1, 0.3, 0.6];
let _weighted = WeightedRandomSampler::new(weights, 3, false).with_generator(42);
let labels = vec![0, 0, 0, 1, 1, 1]; let _stratified = StratifiedSampler::new(&labels, 6, false);
let _distributed = DistributedSampler::new(100, 4, 0, true);
let _splits = kfold_splits(100, 5, Some(42));
let (train, val, test) = train_val_test_split(1000, 0.6, 0.2, Some(42));
assert_eq!(train.len() + val.len() + test.len(), 1000);
}
}