use crate::error::{DatasetsError, Result};
use crate::streaming::{DataChunk, StreamConfig};
use crate::utils::Dataset;
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub world_size: usize,
pub rank: usize,
pub num_shards: usize,
pub shuffle_shards: bool,
pub seed: Option<u64>,
pub drop_last: bool,
pub enable_distributed_cache: bool,
}
impl DistributedConfig {
pub fn new(world_size: usize, rank: usize) -> Result<Self> {
if rank >= world_size {
return Err(DatasetsError::InvalidFormat(format!(
"Rank {} must be less than world_size {}",
rank, world_size
)));
}
Ok(Self {
world_size,
rank,
num_shards: world_size,
shuffle_shards: false,
seed: None,
drop_last: false,
enable_distributed_cache: true,
})
}
pub fn with_shards(mut self, num_shards: usize) -> Self {
self.num_shards = num_shards.max(1);
self
}
pub fn with_shuffle(mut self, shuffle: bool, seed: Option<u64>) -> Self {
self.shuffle_shards = shuffle;
self.seed = seed;
self
}
pub fn with_drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
}
#[derive(Debug, Clone)]
pub struct Shard {
pub index: usize,
pub start: usize,
pub end: usize,
pub size: usize,
}
impl Shard {
pub fn new(index: usize, start: usize, end: usize) -> Self {
Self {
index,
start,
end,
size: end - start,
}
}
pub fn contains(&self, idx: usize) -> bool {
idx >= self.start && idx < self.end
}
}
pub struct DistributedLoader {
config: DistributedConfig,
total_samples: usize,
shards: Vec<Shard>,
assigned_shards: Vec<usize>,
}
impl DistributedLoader {
pub fn new(config: DistributedConfig, total_samples: usize) -> Result<Self> {
if total_samples == 0 {
return Err(DatasetsError::InvalidFormat(
"Dataset must have at least one sample".to_string(),
));
}
let shards = Self::create_shards(total_samples, config.num_shards, config.drop_last)?;
let assigned_shards = Self::assign_shards_to_rank(&shards, &config);
Ok(Self {
config,
total_samples,
shards,
assigned_shards,
})
}
fn create_shards(
total_samples: usize,
num_shards: usize,
drop_last: bool,
) -> Result<Vec<Shard>> {
let mut shards = Vec::new();
let base_shard_size = total_samples / num_shards;
let remainder = total_samples % num_shards;
let mut start = 0;
for i in 0..num_shards {
let shard_size = if i < remainder {
base_shard_size + 1
} else {
base_shard_size
};
if shard_size == 0 && drop_last {
break;
}
let end = start + shard_size;
shards.push(Shard::new(i, start, end));
start = end;
}
Ok(shards)
}
fn assign_shards_to_rank(shards: &[Shard], config: &DistributedConfig) -> Vec<usize> {
let mut assigned = Vec::new();
for (idx, _shard) in shards.iter().enumerate() {
if idx % config.world_size == config.rank {
assigned.push(idx);
}
}
assigned
}
pub fn get_assigned_shards(&self) -> Vec<&Shard> {
self.assigned_shards
.iter()
.filter_map(|&idx| self.shards.get(idx))
.collect()
}
pub fn samples_for_rank(&self) -> usize {
self.get_assigned_shards().iter().map(|s| s.size).sum()
}
pub fn get_sample_indices(&self) -> Vec<usize> {
let mut indices = Vec::new();
for shard in self.get_assigned_shards() {
indices.extend(shard.start..shard.end);
}
indices
}
pub fn partition_dataset(&self, dataset: &Dataset) -> Result<Dataset> {
let indices = self.get_sample_indices();
if indices.is_empty() {
return Err(DatasetsError::InvalidFormat(
"No samples assigned to this rank".to_string(),
));
}
let n_features = dataset.n_features();
let mut data_rows = Vec::new();
let mut target_values = Vec::new();
for &idx in &indices {
if idx >= dataset.n_samples() {
return Err(DatasetsError::InvalidFormat(format!(
"Index {} out of bounds for dataset with {} samples",
idx,
dataset.n_samples()
)));
}
let row = dataset.data.row(idx);
data_rows.extend(row.iter().copied());
if let Some(ref target) = dataset.target {
if idx < target.len() {
target_values.push(target[idx]);
}
}
}
let data = Array2::from_shape_vec((indices.len(), n_features), data_rows)
.map_err(|e| DatasetsError::InvalidFormat(format!("Failed to create array: {}", e)))?;
let target = if !target_values.is_empty() {
Some(Array1::from_vec(target_values))
} else {
None
};
Ok(Dataset {
data,
target,
targetnames: dataset.targetnames.clone(),
featurenames: dataset.featurenames.clone(),
feature_descriptions: dataset.feature_descriptions.clone(),
description: dataset.description.clone(),
metadata: dataset.metadata.clone(),
})
}
pub fn config(&self) -> &DistributedConfig {
&self.config
}
pub fn total_samples(&self) -> usize {
self.total_samples
}
}
pub struct DistributedCache {
cache: Arc<Mutex<HashMap<String, Vec<u8>>>>,
config: DistributedConfig,
}
impl DistributedCache {
pub fn new(config: DistributedConfig) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
config,
}
}
pub fn put(&self, key: String, data: Vec<u8>) -> Result<()> {
let mut cache = self
.cache
.lock()
.map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
cache.insert(key, data);
Ok(())
}
pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
let cache = self
.cache
.lock()
.map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
Ok(cache.get(key).cloned())
}
pub fn contains(&self, key: &str) -> bool {
self.cache
.lock()
.map(|cache| cache.contains_key(key))
.unwrap_or(false)
}
pub fn clear(&self) -> Result<()> {
let mut cache = self
.cache
.lock()
.map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
cache.clear();
Ok(())
}
pub fn size(&self) -> usize {
self.cache.lock().map(|c| c.len()).unwrap_or(0)
}
}
pub fn create_loader(
world_size: usize,
rank: usize,
total_samples: usize,
) -> Result<DistributedLoader> {
let config = DistributedConfig::new(world_size, rank)?;
DistributedLoader::new(config, total_samples)
}
pub fn create_loader_with_config(
config: DistributedConfig,
total_samples: usize,
) -> Result<DistributedLoader> {
DistributedLoader::new(config, total_samples)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distributed_config() -> Result<()> {
let config = DistributedConfig::new(4, 2)?;
assert_eq!(config.world_size, 4);
assert_eq!(config.rank, 2);
assert_eq!(config.num_shards, 4);
assert!(DistributedConfig::new(4, 4).is_err());
Ok(())
}
#[test]
fn test_shard_creation() -> Result<()> {
let shards = DistributedLoader::create_shards(100, 4, false)?;
assert_eq!(shards.len(), 4);
assert_eq!(shards[0].size, 25);
assert_eq!(shards[3].end, 100);
let shards = DistributedLoader::create_shards(103, 4, false)?;
assert_eq!(shards.len(), 4);
assert_eq!(shards[0].size, 26); assert_eq!(shards[1].size, 26);
assert_eq!(shards[2].size, 26);
assert_eq!(shards[3].size, 25);
Ok(())
}
#[test]
fn test_distributed_loader() -> Result<()> {
let config = DistributedConfig::new(4, 1)?;
let loader = DistributedLoader::new(config, 100)?;
assert_eq!(loader.total_samples(), 100);
let assigned = loader.get_assigned_shards();
assert!(!assigned.is_empty());
let indices = loader.get_sample_indices();
assert!(!indices.is_empty());
Ok(())
}
#[test]
fn test_partition_dataset() -> Result<()> {
let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect())
.map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
let target = Some(Array1::from_vec((0..10).map(|x| x as f64).collect()));
let dataset = Dataset {
data,
target,
targetnames: None,
featurenames: None,
feature_descriptions: None,
description: None,
metadata: Default::default(),
};
let config = DistributedConfig::new(2, 0)?;
let loader = DistributedLoader::new(config, 10)?;
let partitioned = loader.partition_dataset(&dataset)?;
assert_eq!(partitioned.n_samples(), 5); assert_eq!(partitioned.n_features(), 3);
Ok(())
}
#[test]
fn test_distributed_cache() -> Result<()> {
let config = DistributedConfig::new(2, 0)?;
let cache = DistributedCache::new(config);
cache.put("test".to_string(), vec![1, 2, 3, 4])?;
assert!(cache.contains("test"));
let data = cache.get("test")?;
assert_eq!(data, Some(vec![1, 2, 3, 4]));
cache.clear()?;
assert!(!cache.contains("test"));
Ok(())
}
#[test]
fn test_shard_contains() {
let shard = Shard::new(0, 10, 20);
assert!(shard.contains(10));
assert!(shard.contains(15));
assert!(shard.contains(19));
assert!(!shard.contains(9));
assert!(!shard.contains(20));
}
#[test]
fn test_round_robin_assignment() -> Result<()> {
let config = DistributedConfig::new(3, 1)?; let loader = DistributedLoader::new(config, 90)?;
let indices = loader.get_sample_indices();
assert_eq!(indices.len(), 30);
assert_eq!(indices[0], 30);
assert_eq!(indices[29], 59);
Ok(())
}
}