Skip to main content

scirs2_datasets/
distributed_loading.rs

1//! Distributed dataset loading
2//!
3//! This module provides shard-aware dataset loading for distributed training,
4//! with multi-node coordination, distributed caching, and rank-aware data
5//! partitioning to ensure each worker processes a unique subset of the data.
6
7use crate::error::{DatasetsError, Result};
8use crate::streaming::{DataChunk, StreamConfig};
9use crate::utils::Dataset;
10use scirs2_core::ndarray::{Array1, Array2};
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14
15/// Configuration for distributed loading
16#[derive(Debug, Clone)]
17pub struct DistributedConfig {
18    /// Total number of workers/processes
19    pub world_size: usize,
20    /// Rank of this worker (0 to world_size-1)
21    pub rank: usize,
22    /// Number of shards to create
23    pub num_shards: usize,
24    /// Whether to shuffle shards
25    pub shuffle_shards: bool,
26    /// Random seed for shuffling
27    pub seed: Option<u64>,
28    /// Whether to drop last incomplete batch
29    pub drop_last: bool,
30    /// Whether to use distributed caching
31    pub enable_distributed_cache: bool,
32}
33
34impl DistributedConfig {
35    /// Create a new distributed configuration
36    ///
37    /// # Arguments
38    /// * `world_size` - Total number of workers
39    /// * `rank` - Rank of this worker
40    ///
41    /// # Returns
42    /// * `DistributedConfig` - Configuration instance
43    pub fn new(world_size: usize, rank: usize) -> Result<Self> {
44        if rank >= world_size {
45            return Err(DatasetsError::InvalidFormat(format!(
46                "Rank {} must be less than world_size {}",
47                rank, world_size
48            )));
49        }
50
51        Ok(Self {
52            world_size,
53            rank,
54            num_shards: world_size,
55            shuffle_shards: false,
56            seed: None,
57            drop_last: false,
58            enable_distributed_cache: true,
59        })
60    }
61
62    /// Set number of shards
63    pub fn with_shards(mut self, num_shards: usize) -> Self {
64        self.num_shards = num_shards.max(1);
65        self
66    }
67
68    /// Enable shard shuffling
69    pub fn with_shuffle(mut self, shuffle: bool, seed: Option<u64>) -> Self {
70        self.shuffle_shards = shuffle;
71        self.seed = seed;
72        self
73    }
74
75    /// Set drop_last behavior
76    pub fn with_drop_last(mut self, drop_last: bool) -> Self {
77        self.drop_last = drop_last;
78        self
79    }
80}
81
82/// Shard information
83#[derive(Debug, Clone)]
84pub struct Shard {
85    /// Shard index
86    pub index: usize,
87    /// Starting sample index
88    pub start: usize,
89    /// Ending sample index (exclusive)
90    pub end: usize,
91    /// Number of samples in this shard
92    pub size: usize,
93}
94
95impl Shard {
96    /// Create a new shard
97    pub fn new(index: usize, start: usize, end: usize) -> Self {
98        Self {
99            index,
100            start,
101            end,
102            size: end - start,
103        }
104    }
105
106    /// Check if a sample index belongs to this shard
107    pub fn contains(&self, idx: usize) -> bool {
108        idx >= self.start && idx < self.end
109    }
110}
111
112/// Distributed dataset loader
113pub struct DistributedLoader {
114    config: DistributedConfig,
115    total_samples: usize,
116    shards: Vec<Shard>,
117    assigned_shards: Vec<usize>,
118}
119
120impl DistributedLoader {
121    /// Create a new distributed loader
122    ///
123    /// # Arguments
124    /// * `config` - Distributed configuration
125    /// * `total_samples` - Total number of samples in the dataset
126    ///
127    /// # Returns
128    /// * `Ok(DistributedLoader)` - The loader instance
129    /// * `Err(DatasetsError)` - If configuration is invalid
130    pub fn new(config: DistributedConfig, total_samples: usize) -> Result<Self> {
131        if total_samples == 0 {
132            return Err(DatasetsError::InvalidFormat(
133                "Dataset must have at least one sample".to_string(),
134            ));
135        }
136
137        // Create shards
138        let shards = Self::create_shards(total_samples, config.num_shards, config.drop_last)?;
139
140        // Assign shards to this rank
141        let assigned_shards = Self::assign_shards_to_rank(&shards, &config);
142
143        Ok(Self {
144            config,
145            total_samples,
146            shards,
147            assigned_shards,
148        })
149    }
150
151    /// Create shards from the total dataset
152    fn create_shards(
153        total_samples: usize,
154        num_shards: usize,
155        drop_last: bool,
156    ) -> Result<Vec<Shard>> {
157        let mut shards = Vec::new();
158        let base_shard_size = total_samples / num_shards;
159        let remainder = total_samples % num_shards;
160
161        let mut start = 0;
162        for i in 0..num_shards {
163            // Distribute remainder samples across first shards
164            let shard_size = if i < remainder {
165                base_shard_size + 1
166            } else {
167                base_shard_size
168            };
169
170            if shard_size == 0 && drop_last {
171                break;
172            }
173
174            let end = start + shard_size;
175            shards.push(Shard::new(i, start, end));
176            start = end;
177        }
178
179        Ok(shards)
180    }
181
182    /// Assign shards to a specific rank
183    fn assign_shards_to_rank(shards: &[Shard], config: &DistributedConfig) -> Vec<usize> {
184        let mut assigned = Vec::new();
185
186        // Round-robin assignment
187        for (idx, _shard) in shards.iter().enumerate() {
188            if idx % config.world_size == config.rank {
189                assigned.push(idx);
190            }
191        }
192
193        assigned
194    }
195
196    /// Get the shards assigned to this rank
197    pub fn get_assigned_shards(&self) -> Vec<&Shard> {
198        self.assigned_shards
199            .iter()
200            .filter_map(|&idx| self.shards.get(idx))
201            .collect()
202    }
203
204    /// Get the total number of samples for this rank
205    pub fn samples_for_rank(&self) -> usize {
206        self.get_assigned_shards().iter().map(|s| s.size).sum()
207    }
208
209    /// Get the sample indices for this rank
210    pub fn get_sample_indices(&self) -> Vec<usize> {
211        let mut indices = Vec::new();
212        for shard in self.get_assigned_shards() {
213            indices.extend(shard.start..shard.end);
214        }
215        indices
216    }
217
218    /// Partition a dataset according to this rank's assignment
219    pub fn partition_dataset(&self, dataset: &Dataset) -> Result<Dataset> {
220        let indices = self.get_sample_indices();
221
222        if indices.is_empty() {
223            return Err(DatasetsError::InvalidFormat(
224                "No samples assigned to this rank".to_string(),
225            ));
226        }
227
228        // Extract rows corresponding to this rank's indices
229        let n_features = dataset.n_features();
230        let mut data_rows = Vec::new();
231        let mut target_values = Vec::new();
232
233        for &idx in &indices {
234            if idx >= dataset.n_samples() {
235                return Err(DatasetsError::InvalidFormat(format!(
236                    "Index {} out of bounds for dataset with {} samples",
237                    idx,
238                    dataset.n_samples()
239                )));
240            }
241
242            // Extract data row
243            let row = dataset.data.row(idx);
244            data_rows.extend(row.iter().copied());
245
246            // Extract target if present
247            if let Some(ref target) = dataset.target {
248                if idx < target.len() {
249                    target_values.push(target[idx]);
250                }
251            }
252        }
253
254        // Create new arrays
255        let data = Array2::from_shape_vec((indices.len(), n_features), data_rows)
256            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to create array: {}", e)))?;
257
258        let target = if !target_values.is_empty() {
259            Some(Array1::from_vec(target_values))
260        } else {
261            None
262        };
263
264        Ok(Dataset {
265            data,
266            target,
267            targetnames: dataset.targetnames.clone(),
268            featurenames: dataset.featurenames.clone(),
269            feature_descriptions: dataset.feature_descriptions.clone(),
270            description: dataset.description.clone(),
271            metadata: dataset.metadata.clone(),
272        })
273    }
274
275    /// Get configuration
276    pub fn config(&self) -> &DistributedConfig {
277        &self.config
278    }
279
280    /// Get total number of samples
281    pub fn total_samples(&self) -> usize {
282        self.total_samples
283    }
284}
285
286/// Distributed cache for sharing data across nodes
287pub struct DistributedCache {
288    cache: Arc<Mutex<HashMap<String, Vec<u8>>>>,
289    config: DistributedConfig,
290}
291
292impl DistributedCache {
293    /// Create a new distributed cache
294    pub fn new(config: DistributedConfig) -> Self {
295        Self {
296            cache: Arc::new(Mutex::new(HashMap::new())),
297            config,
298        }
299    }
300
301    /// Store data in cache
302    pub fn put(&self, key: String, data: Vec<u8>) -> Result<()> {
303        let mut cache = self
304            .cache
305            .lock()
306            .map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
307        cache.insert(key, data);
308        Ok(())
309    }
310
311    /// Retrieve data from cache
312    pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
313        let cache = self
314            .cache
315            .lock()
316            .map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
317        Ok(cache.get(key).cloned())
318    }
319
320    /// Check if key exists in cache
321    pub fn contains(&self, key: &str) -> bool {
322        self.cache
323            .lock()
324            .map(|cache| cache.contains_key(key))
325            .unwrap_or(false)
326    }
327
328    /// Clear the cache
329    pub fn clear(&self) -> Result<()> {
330        let mut cache = self
331            .cache
332            .lock()
333            .map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
334        cache.clear();
335        Ok(())
336    }
337
338    /// Get cache size (number of entries)
339    pub fn size(&self) -> usize {
340        self.cache.lock().map(|c| c.len()).unwrap_or(0)
341    }
342}
343
344/// Create a distributed loader for a dataset
345///
346/// # Arguments
347/// * `world_size` - Total number of workers
348/// * `rank` - Rank of this worker
349/// * `total_samples` - Total samples in the dataset
350///
351/// # Returns
352/// * `Ok(DistributedLoader)` - The loader
353/// * `Err(DatasetsError)` - If creation fails
354pub fn create_loader(
355    world_size: usize,
356    rank: usize,
357    total_samples: usize,
358) -> Result<DistributedLoader> {
359    let config = DistributedConfig::new(world_size, rank)?;
360    DistributedLoader::new(config, total_samples)
361}
362
363/// Create a distributed loader with custom configuration
364pub fn create_loader_with_config(
365    config: DistributedConfig,
366    total_samples: usize,
367) -> Result<DistributedLoader> {
368    DistributedLoader::new(config, total_samples)
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_distributed_config() -> Result<()> {
377        let config = DistributedConfig::new(4, 2)?;
378        assert_eq!(config.world_size, 4);
379        assert_eq!(config.rank, 2);
380        assert_eq!(config.num_shards, 4);
381
382        // Invalid rank should fail
383        assert!(DistributedConfig::new(4, 4).is_err());
384
385        Ok(())
386    }
387
388    #[test]
389    fn test_shard_creation() -> Result<()> {
390        // 100 samples, 4 shards
391        let shards = DistributedLoader::create_shards(100, 4, false)?;
392        assert_eq!(shards.len(), 4);
393        assert_eq!(shards[0].size, 25);
394        assert_eq!(shards[3].end, 100);
395
396        // With remainder
397        let shards = DistributedLoader::create_shards(103, 4, false)?;
398        assert_eq!(shards.len(), 4);
399        assert_eq!(shards[0].size, 26); // Gets extra sample
400        assert_eq!(shards[1].size, 26);
401        assert_eq!(shards[2].size, 26);
402        assert_eq!(shards[3].size, 25);
403
404        Ok(())
405    }
406
407    #[test]
408    fn test_distributed_loader() -> Result<()> {
409        let config = DistributedConfig::new(4, 1)?;
410        let loader = DistributedLoader::new(config, 100)?;
411
412        assert_eq!(loader.total_samples(), 100);
413        let assigned = loader.get_assigned_shards();
414        assert!(!assigned.is_empty());
415
416        // Check sample indices
417        let indices = loader.get_sample_indices();
418        assert!(!indices.is_empty());
419
420        Ok(())
421    }
422
423    #[test]
424    fn test_partition_dataset() -> Result<()> {
425        // Create a test dataset
426        let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect())
427            .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
428
429        let target = Some(Array1::from_vec((0..10).map(|x| x as f64).collect()));
430
431        let dataset = Dataset {
432            data,
433            target,
434            targetnames: None,
435            featurenames: None,
436            feature_descriptions: None,
437            description: None,
438            metadata: Default::default(),
439        };
440
441        // Partition for rank 0 of 2 workers
442        let config = DistributedConfig::new(2, 0)?;
443        let loader = DistributedLoader::new(config, 10)?;
444        let partitioned = loader.partition_dataset(&dataset)?;
445
446        assert_eq!(partitioned.n_samples(), 5); // Should get half
447        assert_eq!(partitioned.n_features(), 3);
448
449        Ok(())
450    }
451
452    #[test]
453    fn test_distributed_cache() -> Result<()> {
454        let config = DistributedConfig::new(2, 0)?;
455        let cache = DistributedCache::new(config);
456
457        // Store and retrieve
458        cache.put("test".to_string(), vec![1, 2, 3, 4])?;
459        assert!(cache.contains("test"));
460
461        let data = cache.get("test")?;
462        assert_eq!(data, Some(vec![1, 2, 3, 4]));
463
464        // Clear
465        cache.clear()?;
466        assert!(!cache.contains("test"));
467
468        Ok(())
469    }
470
471    #[test]
472    fn test_shard_contains() {
473        let shard = Shard::new(0, 10, 20);
474        assert!(shard.contains(10));
475        assert!(shard.contains(15));
476        assert!(shard.contains(19));
477        assert!(!shard.contains(9));
478        assert!(!shard.contains(20));
479    }
480
481    #[test]
482    fn test_round_robin_assignment() -> Result<()> {
483        let config = DistributedConfig::new(3, 1)?; // Rank 1 of 3
484        let loader = DistributedLoader::new(config, 90)?; // 90 samples
485
486        let indices = loader.get_sample_indices();
487
488        // Rank 1 should get shards 1
489        // With 3 shards: [0-30), [30-60), [60-90)
490        // Rank 1 gets shard 1: [30-60)
491        assert_eq!(indices.len(), 30);
492        assert_eq!(indices[0], 30);
493        assert_eq!(indices[29], 59);
494
495        Ok(())
496    }
497}