use crate::{error_taxonomy::helpers as error_helpers, Dataset};
use std::collections::HashMap;
use std::sync::Arc;
use tenflowers_core::{Result, Tensor};
#[derive(Debug, Clone)]
pub struct ShardConfig {
pub world_size: usize,
pub rank: usize,
pub strategy: ShardStrategy,
pub seed: Option<u64>,
pub drop_last: bool,
pub num_replicas: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShardStrategy {
RoundRobin,
Contiguous,
ShuffledRoundRobin,
Stratified,
}
impl ShardConfig {
pub fn new(world_size: usize, rank: usize) -> Result<Self> {
if world_size == 0 {
return Err(error_helpers::invalid_configuration(
"ShardConfig::new",
"world_size",
"world_size must be > 0",
));
}
if rank >= world_size {
return Err(error_helpers::invalid_configuration(
"ShardConfig::new",
"rank",
format!("rank {} must be < world_size {}", rank, world_size),
));
}
Ok(Self {
world_size,
rank,
strategy: ShardStrategy::RoundRobin,
seed: None,
drop_last: false,
num_replicas: 1,
})
}
pub fn with_strategy(mut self, strategy: ShardStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn with_num_replicas(mut self, num_replicas: usize) -> Self {
self.num_replicas = num_replicas;
self
}
pub fn validate(&self) -> Result<()> {
if self.world_size == 0 {
return Err(error_helpers::invalid_configuration(
"ShardConfig::validate",
"world_size",
"world_size must be > 0",
));
}
if self.rank >= self.world_size {
return Err(error_helpers::invalid_configuration(
"ShardConfig::validate",
"rank",
format!(
"rank {} must be < world_size {}",
self.rank, self.world_size
),
));
}
if self.num_replicas == 0 {
return Err(error_helpers::invalid_configuration(
"ShardConfig::validate",
"num_replicas",
"num_replicas must be > 0",
));
}
Ok(())
}
}
pub trait ShardableDataset<T>: Dataset<T> {
fn get_shard_indices(&self, config: &ShardConfig) -> Result<Vec<usize>>;
fn num_shards(&self, config: &ShardConfig) -> usize {
config.world_size
}
fn shard_size(&self, config: &ShardConfig) -> usize {
let indices = self.get_shard_indices(config).unwrap_or_default();
indices.len()
}
}
pub struct ShardedDataset<T, D: Dataset<T>> {
dataset: Arc<D>,
config: ShardConfig,
indices: Vec<usize>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, D: Dataset<T>> ShardedDataset<T, D> {
pub fn new(dataset: D, config: ShardConfig) -> Result<Self> {
config.validate()?;
let dataset = Arc::new(dataset);
let indices = Self::compute_indices(&dataset, &config)?;
Ok(Self {
dataset,
config,
indices,
_phantom: std::marker::PhantomData,
})
}
pub fn new_stratified<F>(dataset: D, config: ShardConfig, label_extractor: F) -> Result<Self>
where
F: Fn(&Tensor<T>) -> Result<usize>,
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
config.validate()?;
let dataset = Arc::new(dataset);
let indices = Self::compute_stratified_indices(&dataset, &config, label_extractor)?;
Ok(Self {
dataset,
config,
indices,
_phantom: std::marker::PhantomData,
})
}
fn compute_indices(dataset: &D, config: &ShardConfig) -> Result<Vec<usize>> {
let total_size = dataset.len();
if total_size == 0 {
return Ok(Vec::new());
}
let mut all_indices: Vec<usize> = (0..total_size).collect();
match &config.strategy {
ShardStrategy::RoundRobin => {
}
ShardStrategy::Contiguous => {
}
ShardStrategy::ShuffledRoundRobin => {
if let Some(seed) = config.seed {
Self::deterministic_shuffle(&mut all_indices, seed);
}
}
ShardStrategy::Stratified => {
}
}
let shard_indices = match &config.strategy {
ShardStrategy::RoundRobin | ShardStrategy::ShuffledRoundRobin => {
all_indices
.iter()
.enumerate()
.filter(|(i, _)| i % config.world_size == config.rank)
.map(|(_, &idx)| idx)
.collect()
}
ShardStrategy::Contiguous => {
let samples_per_worker = total_size / config.world_size;
let extra_samples = total_size % config.world_size;
let start = if config.rank < extra_samples {
config.rank * (samples_per_worker + 1)
} else {
config.rank * samples_per_worker + extra_samples
};
let count = if config.rank < extra_samples {
samples_per_worker + 1
} else {
samples_per_worker
};
all_indices[start..start + count].to_vec()
}
ShardStrategy::Stratified => {
all_indices
.iter()
.enumerate()
.filter(|(i, _)| i % config.world_size == config.rank)
.map(|(_, &idx)| idx)
.collect()
}
};
Ok(shard_indices)
}
fn deterministic_shuffle(indices: &mut [usize], seed: u64) {
let mut rng_state = seed;
for i in (1..indices.len()).rev() {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
let j = (rng_state as usize) % (i + 1);
indices.swap(i, j);
}
}
fn compute_stratified_indices<F>(
dataset: &D,
config: &ShardConfig,
label_extractor: F,
) -> Result<Vec<usize>>
where
F: Fn(&Tensor<T>) -> Result<usize>,
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
let total_size = dataset.len();
if total_size == 0 {
return Ok(Vec::new());
}
let mut class_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..total_size {
let (_, label_tensor) = dataset.get(i)?;
let class = label_extractor(&label_tensor)?;
class_to_indices.entry(class).or_default().push(i);
}
let mut worker_indices: Vec<Vec<usize>> = vec![Vec::new(); config.world_size];
let mut classes: Vec<_> = class_to_indices.keys().cloned().collect();
classes.sort_unstable();
for class in classes {
let mut indices = class_to_indices
.remove(&class)
.expect("class should exist in map since we got it from keys()");
if let Some(seed) = config.seed {
Self::deterministic_shuffle(&mut indices, seed.wrapping_add(class as u64));
}
for (idx_pos, &global_idx) in indices.iter().enumerate() {
let worker_id = idx_pos % config.world_size;
worker_indices[worker_id].push(global_idx);
}
}
let mut shard_indices = worker_indices[config.rank].clone();
if let Some(seed) = config.seed {
Self::deterministic_shuffle(&mut shard_indices, seed.wrapping_add(config.rank as u64));
}
Ok(shard_indices)
}
pub fn inner(&self) -> &D {
&self.dataset
}
pub fn config(&self) -> &ShardConfig {
&self.config
}
pub fn indices(&self) -> &[usize] {
&self.indices
}
pub fn shard_stats(&self) -> ShardStatistics {
let total_size = self.dataset.len();
let shard_size = self.indices.len();
let min_shard_size = total_size / self.config.world_size;
let max_shard_size = (total_size + self.config.world_size - 1) / self.config.world_size;
ShardStatistics {
total_samples: total_size,
shard_size,
min_shard_size,
max_shard_size,
world_size: self.config.world_size,
rank: self.config.rank,
imbalance_ratio: if min_shard_size > 0 {
max_shard_size as f64 / min_shard_size as f64
} else {
0.0
},
}
}
}
impl<T, D: Dataset<T>> Dataset<T> for ShardedDataset<T, D> {
fn get(
&self,
index: usize,
) -> Result<(tenflowers_core::Tensor<T>, tenflowers_core::Tensor<T>)> {
if index >= self.indices.len() {
return Err(error_helpers::index_out_of_bounds(
"ShardedDataset::get",
index,
self.indices.len(),
));
}
let actual_index = self.indices[index];
self.dataset.get(actual_index)
}
fn len(&self) -> usize {
self.indices.len()
}
}
#[derive(Debug, Clone)]
pub struct ShardStatistics {
pub total_samples: usize,
pub shard_size: usize,
pub min_shard_size: usize,
pub max_shard_size: usize,
pub world_size: usize,
pub rank: usize,
pub imbalance_ratio: f64,
}
impl ShardStatistics {
pub fn is_balanced(&self) -> bool {
self.imbalance_ratio <= 1.1 }
pub fn report(&self) -> String {
format!(
"Shard Statistics:\n\
- Total samples: {}\n\
- World size: {} workers\n\
- Rank: {}\n\
- This shard size: {}\n\
- Min shard size: {}\n\
- Max shard size: {}\n\
- Imbalance ratio: {:.2}\n\
- Balanced: {}",
self.total_samples,
self.world_size,
self.rank,
self.shard_size,
self.min_shard_size,
self.max_shard_size,
self.imbalance_ratio,
if self.is_balanced() { "Yes" } else { "No" }
)
}
}
pub trait DatasetShardingExt<T>: Dataset<T> + Sized {
fn shard(self, config: ShardConfig) -> Result<ShardedDataset<T, Self>> {
ShardedDataset::new(self, config)
}
fn shard_round_robin(self, world_size: usize, rank: usize) -> Result<ShardedDataset<T, Self>> {
let config = ShardConfig::new(world_size, rank)?;
ShardedDataset::new(self, config)
}
fn shard_contiguous(self, world_size: usize, rank: usize) -> Result<ShardedDataset<T, Self>> {
let config = ShardConfig::new(world_size, rank)?.with_strategy(ShardStrategy::Contiguous);
ShardedDataset::new(self, config)
}
fn shard_shuffled(
self,
world_size: usize,
rank: usize,
seed: u64,
) -> Result<ShardedDataset<T, Self>> {
let config = ShardConfig::new(world_size, rank)?
.with_strategy(ShardStrategy::ShuffledRoundRobin)
.with_seed(seed);
ShardedDataset::new(self, config)
}
}
impl<T, D: Dataset<T>> DatasetShardingExt<T> for D {}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
use tenflowers_core::Tensor;
#[test]
fn test_shard_config_creation() {
let config = ShardConfig::new(4, 0).expect("config creation should succeed");
assert_eq!(config.world_size, 4);
assert_eq!(config.rank, 0);
assert_eq!(config.strategy, ShardStrategy::RoundRobin);
}
#[test]
fn test_shard_config_validation() {
assert!(ShardConfig::new(0, 0).is_err());
assert!(ShardConfig::new(4, 4).is_err());
assert!(ShardConfig::new(4, 5).is_err());
assert!(ShardConfig::new(4, 3).is_ok());
}
#[test]
fn test_round_robin_sharding() {
let features = Tensor::<f32>::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
&[10, 1],
)
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ShardConfig::new(3, 0).expect("config creation should succeed");
let sharded =
ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
assert_eq!(sharded.len(), 4);
assert_eq!(sharded.indices(), &[0, 3, 6, 9]);
}
#[test]
fn test_contiguous_sharding() {
let features = Tensor::<f32>::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
&[10, 1],
)
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ShardConfig::new(3, 1)
.expect("test: operation should succeed")
.with_strategy(ShardStrategy::Contiguous);
let sharded =
ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
assert_eq!(sharded.len(), 3);
assert_eq!(sharded.indices(), &[4, 5, 6]);
}
#[test]
fn test_shuffled_sharding_deterministic() {
let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
.expect("tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
.expect("tensor creation should succeed");
let dataset1 = TensorDataset::new(features.clone(), labels.clone());
let dataset2 = TensorDataset::new(features, labels);
let config1 = ShardConfig::new(4, 0)
.expect("config creation should succeed")
.with_strategy(ShardStrategy::ShuffledRoundRobin)
.with_seed(42);
let config2 = ShardConfig::new(4, 0)
.expect("config creation should succeed")
.with_strategy(ShardStrategy::ShuffledRoundRobin)
.with_seed(42);
let sharded1 = ShardedDataset::new(dataset1, config1)
.expect("sharded dataset creation should succeed");
let sharded2 = ShardedDataset::new(dataset2, config2)
.expect("sharded dataset creation should succeed");
assert_eq!(sharded1.indices(), sharded2.indices());
}
#[test]
fn test_shard_statistics() {
let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
.expect("tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
.expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ShardConfig::new(3, 0).expect("config creation should succeed");
let sharded =
ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
let stats = sharded.shard_stats();
assert_eq!(stats.total_samples, 100);
assert_eq!(stats.world_size, 3);
assert_eq!(stats.rank, 0);
assert!(stats.imbalance_ratio >= 1.0);
}
#[test]
fn test_extension_trait_round_robin() {
let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let sharded = dataset
.shard_round_robin(2, 0)
.expect("shard_round_robin should succeed");
assert_eq!(sharded.len(), 5);
}
#[test]
fn test_extension_trait_contiguous() {
let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let sharded = dataset
.shard_contiguous(2, 0)
.expect("shard_contiguous should succeed");
assert_eq!(sharded.len(), 5);
assert_eq!(sharded.indices(), &[0, 1, 2, 3, 4]);
}
#[test]
fn test_extension_trait_shuffled() {
let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let sharded = dataset
.shard_shuffled(2, 0, 42)
.expect("shard_shuffled should succeed");
assert_eq!(sharded.len(), 5);
}
#[test]
fn test_shard_access() {
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6, 1])
.expect("tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[6])
.expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ShardConfig::new(2, 0).expect("config creation should succeed");
let sharded =
ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
let (f0, l0) = sharded.get(0).expect("index should be in bounds");
let (f1, l1) = sharded.get(1).expect("index should be in bounds");
let (f2, l2) = sharded.get(2).expect("index should be in bounds");
assert!((f0.to_vec().expect("to_vec should succeed")[0] - 1.0).abs() < 1e-6);
assert!((l0.to_vec().expect("to_vec should succeed")[0] - 10.0).abs() < 1e-6);
assert!((f1.to_vec().expect("to_vec should succeed")[0] - 3.0).abs() < 1e-6);
assert!((l1.to_vec().expect("to_vec should succeed")[0] - 30.0).abs() < 1e-6);
assert!((f2.to_vec().expect("to_vec should succeed")[0] - 5.0).abs() < 1e-6);
assert!((l2.to_vec().expect("to_vec should succeed")[0] - 50.0).abs() < 1e-6);
}
#[test]
fn test_shard_out_of_bounds() {
let features =
Tensor::<f32>::from_vec(vec![1.0; 6], &[6, 1]).expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 6], &[6]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let sharded = dataset
.shard_round_robin(2, 0)
.expect("shard_round_robin should succeed");
assert_eq!(sharded.len(), 3);
assert!(sharded.get(3).is_err());
}
#[test]
fn test_empty_dataset_sharding() {
let features =
Tensor::<f32>::from_vec(vec![], &[0, 1]).expect("empty tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![], &[0]).expect("empty tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let sharded = dataset
.shard_round_robin(2, 0)
.expect("shard_round_robin should succeed");
assert_eq!(sharded.len(), 0);
}
#[test]
fn test_shard_statistics_balanced() {
let features = Tensor::<f32>::from_vec(vec![1.0; 12], &[12, 1])
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 12], &[12]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ShardConfig::new(3, 0).expect("config creation should succeed"); let sharded =
ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
let stats = sharded.shard_stats();
assert!(stats.is_balanced());
assert_eq!(stats.imbalance_ratio, 1.0);
}
#[test]
fn test_shard_statistics_report() {
let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
.expect("tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let config = ShardConfig::new(3, 0).expect("config creation should succeed");
let sharded =
ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
let report = sharded.shard_stats().report();
assert!(report.contains("Total samples: 10"));
assert!(report.contains("World size: 3"));
assert!(report.contains("Rank: 0"));
}
#[test]
fn test_stratified_sharding() {
let features = Tensor::<f32>::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&[12, 1],
)
.expect("tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(
vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
&[12],
)
.expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let label_extractor = |label_tensor: &Tensor<f32>| -> Result<usize> {
let data = label_tensor
.to_vec()
.map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
Ok(data[0] as usize)
};
let config = ShardConfig::new(2, 0)
.expect("config creation should succeed")
.with_strategy(ShardStrategy::Stratified)
.with_seed(42);
let sharded = ShardedDataset::new_stratified(dataset, config, label_extractor)
.expect("stratified sharding should succeed");
assert_eq!(sharded.len(), 6);
for i in 0..sharded.len() {
let (feature, label) = sharded.get(i).expect("get should succeed");
assert!(feature.to_vec().is_ok());
assert!(label.to_vec().is_ok());
}
}
#[test]
fn test_stratified_sharding_balanced_classes() {
let features = Tensor::<f32>::from_vec(vec![1.0; 60], &[60, 1])
.expect("tensor creation should succeed");
let mut label_data = vec![0.0; 20];
label_data.extend(vec![1.0; 20]);
label_data.extend(vec![2.0; 20]);
let labels =
Tensor::<f32>::from_vec(label_data, &[60]).expect("tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let label_extractor = |label_tensor: &Tensor<f32>| -> Result<usize> {
let data = label_tensor
.to_vec()
.map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
Ok(data[0] as usize)
};
let config = ShardConfig::new(3, 0)
.expect("config creation should succeed")
.with_strategy(ShardStrategy::Stratified)
.with_seed(123);
let sharded = ShardedDataset::new_stratified(dataset, config, label_extractor)
.expect("stratified sharding should succeed");
assert!(sharded.len() >= 18 && sharded.len() <= 21);
}
#[test]
fn test_stratified_sharding_deterministic() {
let features = Tensor::<f32>::from_vec(vec![1.0; 30], &[30, 1])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(
vec![
0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0,
1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0,
],
&[30],
)
.expect("tensor creation should succeed");
let dataset1 = TensorDataset::new(features.clone(), labels.clone());
let dataset2 = TensorDataset::new(features, labels);
let label_extractor1 = |label_tensor: &Tensor<f32>| -> Result<usize> {
let data = label_tensor
.to_vec()
.map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
Ok(data[0] as usize)
};
let label_extractor2 = |label_tensor: &Tensor<f32>| -> Result<usize> {
let data = label_tensor
.to_vec()
.map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
Ok(data[0] as usize)
};
let config1 = ShardConfig::new(2, 0)
.expect("config creation should succeed")
.with_strategy(ShardStrategy::Stratified)
.with_seed(999);
let config2 = ShardConfig::new(2, 0)
.expect("config creation should succeed")
.with_strategy(ShardStrategy::Stratified)
.with_seed(999);
let sharded1 = ShardedDataset::new_stratified(dataset1, config1, label_extractor1)
.expect("stratified sharding should succeed");
let sharded2 = ShardedDataset::new_stratified(dataset2, config2, label_extractor2)
.expect("stratified sharding should succeed");
assert_eq!(sharded1.indices(), sharded2.indices());
}
}