use std::collections::HashMap;
use tenflowers_core::{Result, TensorError};
pub trait Sampler: Send + Sync {
fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send>;
fn is_random(&self) -> bool;
fn set_seed(&mut self, _seed: Option<u64>) {}
}
#[derive(Debug, Clone)]
pub struct SequentialSampler {
start: usize,
end: Option<usize>,
}
impl SequentialSampler {
pub fn new() -> Self {
Self {
start: 0,
end: None,
}
}
pub fn with_range(start: usize, end: usize) -> Self {
Self {
start,
end: Some(end),
}
}
}
impl Default for SequentialSampler {
fn default() -> Self {
Self::new()
}
}
impl Sampler for SequentialSampler {
fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send> {
let end = self.end.unwrap_or(len).min(len);
Box::new(self.start..end)
}
fn is_random(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct RandomSampler {
replacement: bool,
seed: Option<u64>,
}
impl RandomSampler {
pub fn new() -> Self {
Self {
replacement: false,
seed: None,
}
}
pub fn with_replacement() -> Self {
Self {
replacement: true,
seed: None,
}
}
pub fn with_seed(seed: u64) -> Self {
Self {
replacement: false,
seed: Some(seed),
}
}
}
impl Default for RandomSampler {
fn default() -> Self {
Self::new()
}
}
impl Sampler for RandomSampler {
fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send> {
let seed = self.seed.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_secs()
});
if self.replacement {
let mut indices = Vec::with_capacity(len);
let mut state = seed;
for _ in 0..len {
state = state.wrapping_mul(1103515245).wrapping_add(12345);
indices.push((state as usize) % len);
}
Box::new(indices.into_iter())
} else {
let mut indices: Vec<usize> = (0..len).collect();
let mut state = seed;
for i in (1..indices.len()).rev() {
state = state.wrapping_mul(1103515245).wrapping_add(12345);
let j = (state as usize) % (i + 1);
indices.swap(i, j);
}
Box::new(indices.into_iter())
}
}
fn is_random(&self) -> bool {
true
}
fn set_seed(&mut self, seed: Option<u64>) {
self.seed = seed;
}
}
#[derive(Debug, Clone)]
pub struct DistributedSampler {
num_replicas: usize,
rank: usize,
epoch: usize,
shuffle: bool,
seed: Option<u64>,
drop_last: bool,
}
impl DistributedSampler {
pub fn new(num_replicas: usize, rank: usize) -> Result<Self> {
if rank >= num_replicas {
return Err(TensorError::invalid_argument(format!(
"Rank {rank} must be less than num_replicas {num_replicas}"
)));
}
Ok(Self {
num_replicas,
rank,
epoch: 0,
shuffle: true,
seed: None,
drop_last: false,
})
}
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn set_epoch(&mut self, epoch: usize) {
self.epoch = epoch;
}
pub fn epoch(&self) -> usize {
self.epoch
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn num_replicas(&self) -> usize {
self.num_replicas
}
fn samples_per_replica(&self, total_size: usize) -> usize {
if self.drop_last {
total_size / self.num_replicas
} else {
(total_size + self.num_replicas - 1) / self.num_replicas
}
}
fn padded_size(&self, total_size: usize) -> usize {
if self.drop_last {
(total_size / self.num_replicas) * self.num_replicas
} else {
self.samples_per_replica(total_size) * self.num_replicas
}
}
}
impl Sampler for DistributedSampler {
fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send> {
let mut indices: Vec<usize> = (0..len).collect();
if self.shuffle {
let seed = self.seed.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_secs()
});
let effective_seed = seed.wrapping_add(self.epoch as u64);
let mut state = effective_seed;
for i in (1..indices.len()).rev() {
state = state.wrapping_mul(1103515245).wrapping_add(12345);
let j = (state as usize) % (i + 1);
indices.swap(i, j);
}
}
let samples_per_replica = self.samples_per_replica(len);
let padded_size = self.padded_size(len);
if !self.drop_last && padded_size > len {
let padding_needed = padded_size - len;
for i in 0..padding_needed {
indices.push(indices[i % len]);
}
}
let start_idx = self.rank * samples_per_replica;
let end_idx = ((self.rank + 1) * samples_per_replica).min(indices.len());
let rank_indices = if start_idx < indices.len() {
indices[start_idx..end_idx].to_vec()
} else {
Vec::new()
};
Box::new(rank_indices.into_iter())
}
fn is_random(&self) -> bool {
self.shuffle
}
fn set_seed(&mut self, seed: Option<u64>) {
self.seed = seed;
}
}
#[derive(Debug, Clone)]
pub struct StratifiedSampler {
class_labels: Vec<usize>,
samples_per_class: Option<usize>,
replacement: bool,
seed: Option<u64>,
shuffle: bool,
}
impl StratifiedSampler {
pub fn new(class_labels: Vec<usize>) -> Self {
Self {
class_labels,
samples_per_class: None,
replacement: false,
seed: None,
shuffle: true,
}
}
pub fn with_samples_per_class(mut self, samples_per_class: usize) -> Self {
self.samples_per_class = Some(samples_per_class);
self
}
pub fn with_replacement(mut self) -> Self {
self.replacement = true;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn class_distribution(&self) -> HashMap<usize, usize> {
let mut counts = HashMap::new();
for &label in &self.class_labels {
*counts.entry(label).or_insert(0) += 1;
}
counts
}
pub fn num_classes(&self) -> usize {
self.class_distribution().len()
}
}
impl Sampler for StratifiedSampler {
fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send> {
if self.class_labels.len() != len {
return Box::new((0..len).collect::<Vec<_>>().into_iter());
}
let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
for (idx, &class_label) in self.class_labels.iter().enumerate() {
class_indices.entry(class_label).or_default().push(idx);
}
let seed = self.seed.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_secs()
});
let mut result_indices = Vec::new();
let mut rng_state = seed;
let samples_per_class = if let Some(spc) = self.samples_per_class {
spc
} else {
class_indices
.values()
.map(|indices| indices.len())
.min()
.unwrap_or(0)
};
for (_, mut indices) in class_indices {
if self.shuffle {
for i in (1..indices.len()).rev() {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
let j = (rng_state as usize) % (i + 1);
indices.swap(i, j);
}
}
if self.replacement {
for _ in 0..samples_per_class {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
let idx = (rng_state as usize) % indices.len();
result_indices.push(indices[idx]);
}
} else {
let sample_count = samples_per_class.min(indices.len());
result_indices.extend_from_slice(&indices[..sample_count]);
}
}
if self.shuffle {
for i in (1..result_indices.len()).rev() {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
let j = (rng_state as usize) % (i + 1);
result_indices.swap(i, j);
}
}
Box::new(result_indices.into_iter())
}
fn is_random(&self) -> bool {
self.shuffle
}
fn set_seed(&mut self, seed: Option<u64>) {
self.seed = seed;
}
}
#[derive(Debug, Clone)]
pub struct ImportanceSampler {
weights: Vec<f64>,
normalize: bool,
seed: Option<u64>,
temperature: f64,
}
impl ImportanceSampler {
pub fn new(dataset_size: usize) -> Self {
Self {
weights: vec![1.0; dataset_size],
normalize: true,
seed: None,
temperature: 1.0,
}
}
pub fn with_weights(weights: Vec<f64>) -> Self {
Self {
weights,
normalize: true,
seed: None,
temperature: 1.0,
}
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
pub fn update_weight(&mut self, index: usize, weight: f64) {
if index < self.weights.len() {
self.weights[index] = weight;
}
}
pub fn update_weights(&mut self, updates: &[(usize, f64)]) {
for &(index, weight) in updates {
self.update_weight(index, weight);
}
}
pub fn weights(&self) -> &[f64] {
&self.weights
}
fn compute_probabilities(&self) -> Vec<f64> {
let mut probs: Vec<f64> = self
.weights
.iter()
.map(|&w| (w / self.temperature).exp())
.collect();
if self.normalize {
let sum: f64 = probs.iter().sum();
if sum > 0.0 {
for p in &mut probs {
*p /= sum;
}
} else {
let uniform_prob = 1.0 / probs.len() as f64;
probs.fill(uniform_prob);
}
}
probs
}
}
impl Sampler for ImportanceSampler {
fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send> {
if self.weights.len() != len {
return Box::new((0..len).collect::<Vec<_>>().into_iter());
}
let probabilities = self.compute_probabilities();
let mut cumulative = Vec::with_capacity(probabilities.len());
let mut sum = 0.0;
for &prob in &probabilities {
sum += prob;
cumulative.push(sum);
}
let seed = self.seed.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_secs()
});
let mut indices = Vec::with_capacity(len);
let mut rng_state = seed;
for _ in 0..len {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
let random_val = (rng_state as f64) / (u64::MAX as f64);
let index = cumulative
.binary_search_by(|&x| {
if x < random_val {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
})
.unwrap_or_else(|i| i);
indices.push(index.min(len - 1));
}
Box::new(indices.into_iter())
}
fn is_random(&self) -> bool {
true
}
fn set_seed(&mut self, seed: Option<u64>) {
self.seed = seed;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sequential_sampler() {
let sampler = SequentialSampler::new();
let indices: Vec<usize> = sampler.sample_indices(5).collect();
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
assert!(!sampler.is_random());
}
#[test]
fn test_sequential_sampler_with_range() {
let sampler = SequentialSampler::with_range(2, 5);
let indices: Vec<usize> = sampler.sample_indices(10).collect();
assert_eq!(indices, vec![2, 3, 4]);
}
#[test]
fn test_random_sampler() {
let sampler = RandomSampler::with_seed(42);
let indices: Vec<usize> = sampler.sample_indices(5).collect();
assert_eq!(indices.len(), 5);
assert!(sampler.is_random());
let sampler2 = RandomSampler::with_seed(42);
let indices2: Vec<usize> = sampler2.sample_indices(5).collect();
assert_eq!(indices, indices2);
}
#[test]
fn test_random_sampler_with_replacement() {
let sampler = RandomSampler::with_replacement();
let indices: Vec<usize> = sampler.sample_indices(3).collect();
assert_eq!(indices.len(), 3);
}
#[test]
fn test_distributed_sampler() {
let sampler = DistributedSampler::new(2, 0).expect("test: operation should succeed");
let indices: Vec<usize> = sampler.sample_indices(10).collect();
assert!(indices.len() >= 4 && indices.len() <= 6);
}
#[test]
fn test_distributed_sampler_invalid_rank() {
let result = DistributedSampler::new(2, 2);
assert!(result.is_err());
}
#[test]
fn test_stratified_sampler() {
let class_labels = vec![0, 0, 1, 1, 2, 2];
let sampler = StratifiedSampler::new(class_labels.clone());
assert_eq!(sampler.num_classes(), 3);
let distribution = sampler.class_distribution();
assert_eq!(distribution[&0], 2);
assert_eq!(distribution[&1], 2);
assert_eq!(distribution[&2], 2);
}
#[test]
fn test_stratified_sampler_with_samples_per_class() {
let class_labels = vec![0, 0, 0, 1, 1, 1];
let sampler = StratifiedSampler::new(class_labels)
.with_samples_per_class(1)
.with_seed(42);
let indices: Vec<usize> = sampler.sample_indices(6).collect();
assert_eq!(indices.len(), 2);
}
#[test]
fn test_importance_sampler() {
let sampler = ImportanceSampler::new(5);
assert_eq!(sampler.weights().len(), 5);
assert!(sampler.weights().iter().all(|&w| w == 1.0));
}
#[test]
fn test_importance_sampler_with_weights() {
let weights = vec![1.0, 2.0, 3.0];
let sampler = ImportanceSampler::with_weights(weights.clone());
assert_eq!(sampler.weights(), &weights);
}
#[test]
fn test_importance_sampler_update_weight() {
let mut sampler = ImportanceSampler::new(3);
sampler.update_weight(1, 5.0);
assert_eq!(sampler.weights()[1], 5.0);
}
#[test]
fn test_importance_sampler_update_weights() {
let mut sampler = ImportanceSampler::new(3);
sampler.update_weights(&[(0, 2.0), (2, 4.0)]);
assert_eq!(sampler.weights()[0], 2.0);
assert_eq!(sampler.weights()[2], 4.0);
}
}