1use crate::error::{DatasetsError, Result};
8use crate::streaming::{DataChunk, StreamConfig};
9use crate::utils::Dataset;
10use scirs2_core::ndarray::{Array1, Array2};
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14
15#[derive(Debug, Clone)]
17pub struct DistributedConfig {
18 pub world_size: usize,
20 pub rank: usize,
22 pub num_shards: usize,
24 pub shuffle_shards: bool,
26 pub seed: Option<u64>,
28 pub drop_last: bool,
30 pub enable_distributed_cache: bool,
32}
33
34impl DistributedConfig {
35 pub fn new(world_size: usize, rank: usize) -> Result<Self> {
44 if rank >= world_size {
45 return Err(DatasetsError::InvalidFormat(format!(
46 "Rank {} must be less than world_size {}",
47 rank, world_size
48 )));
49 }
50
51 Ok(Self {
52 world_size,
53 rank,
54 num_shards: world_size,
55 shuffle_shards: false,
56 seed: None,
57 drop_last: false,
58 enable_distributed_cache: true,
59 })
60 }
61
62 pub fn with_shards(mut self, num_shards: usize) -> Self {
64 self.num_shards = num_shards.max(1);
65 self
66 }
67
68 pub fn with_shuffle(mut self, shuffle: bool, seed: Option<u64>) -> Self {
70 self.shuffle_shards = shuffle;
71 self.seed = seed;
72 self
73 }
74
75 pub fn with_drop_last(mut self, drop_last: bool) -> Self {
77 self.drop_last = drop_last;
78 self
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct Shard {
85 pub index: usize,
87 pub start: usize,
89 pub end: usize,
91 pub size: usize,
93}
94
95impl Shard {
96 pub fn new(index: usize, start: usize, end: usize) -> Self {
98 Self {
99 index,
100 start,
101 end,
102 size: end - start,
103 }
104 }
105
106 pub fn contains(&self, idx: usize) -> bool {
108 idx >= self.start && idx < self.end
109 }
110}
111
112pub struct DistributedLoader {
114 config: DistributedConfig,
115 total_samples: usize,
116 shards: Vec<Shard>,
117 assigned_shards: Vec<usize>,
118}
119
120impl DistributedLoader {
121 pub fn new(config: DistributedConfig, total_samples: usize) -> Result<Self> {
131 if total_samples == 0 {
132 return Err(DatasetsError::InvalidFormat(
133 "Dataset must have at least one sample".to_string(),
134 ));
135 }
136
137 let shards = Self::create_shards(total_samples, config.num_shards, config.drop_last)?;
139
140 let assigned_shards = Self::assign_shards_to_rank(&shards, &config);
142
143 Ok(Self {
144 config,
145 total_samples,
146 shards,
147 assigned_shards,
148 })
149 }
150
151 fn create_shards(
153 total_samples: usize,
154 num_shards: usize,
155 drop_last: bool,
156 ) -> Result<Vec<Shard>> {
157 let mut shards = Vec::new();
158 let base_shard_size = total_samples / num_shards;
159 let remainder = total_samples % num_shards;
160
161 let mut start = 0;
162 for i in 0..num_shards {
163 let shard_size = if i < remainder {
165 base_shard_size + 1
166 } else {
167 base_shard_size
168 };
169
170 if shard_size == 0 && drop_last {
171 break;
172 }
173
174 let end = start + shard_size;
175 shards.push(Shard::new(i, start, end));
176 start = end;
177 }
178
179 Ok(shards)
180 }
181
182 fn assign_shards_to_rank(shards: &[Shard], config: &DistributedConfig) -> Vec<usize> {
184 let mut assigned = Vec::new();
185
186 for (idx, _shard) in shards.iter().enumerate() {
188 if idx % config.world_size == config.rank {
189 assigned.push(idx);
190 }
191 }
192
193 assigned
194 }
195
196 pub fn get_assigned_shards(&self) -> Vec<&Shard> {
198 self.assigned_shards
199 .iter()
200 .filter_map(|&idx| self.shards.get(idx))
201 .collect()
202 }
203
204 pub fn samples_for_rank(&self) -> usize {
206 self.get_assigned_shards().iter().map(|s| s.size).sum()
207 }
208
209 pub fn get_sample_indices(&self) -> Vec<usize> {
211 let mut indices = Vec::new();
212 for shard in self.get_assigned_shards() {
213 indices.extend(shard.start..shard.end);
214 }
215 indices
216 }
217
218 pub fn partition_dataset(&self, dataset: &Dataset) -> Result<Dataset> {
220 let indices = self.get_sample_indices();
221
222 if indices.is_empty() {
223 return Err(DatasetsError::InvalidFormat(
224 "No samples assigned to this rank".to_string(),
225 ));
226 }
227
228 let n_features = dataset.n_features();
230 let mut data_rows = Vec::new();
231 let mut target_values = Vec::new();
232
233 for &idx in &indices {
234 if idx >= dataset.n_samples() {
235 return Err(DatasetsError::InvalidFormat(format!(
236 "Index {} out of bounds for dataset with {} samples",
237 idx,
238 dataset.n_samples()
239 )));
240 }
241
242 let row = dataset.data.row(idx);
244 data_rows.extend(row.iter().copied());
245
246 if let Some(ref target) = dataset.target {
248 if idx < target.len() {
249 target_values.push(target[idx]);
250 }
251 }
252 }
253
254 let data = Array2::from_shape_vec((indices.len(), n_features), data_rows)
256 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to create array: {}", e)))?;
257
258 let target = if !target_values.is_empty() {
259 Some(Array1::from_vec(target_values))
260 } else {
261 None
262 };
263
264 Ok(Dataset {
265 data,
266 target,
267 targetnames: dataset.targetnames.clone(),
268 featurenames: dataset.featurenames.clone(),
269 feature_descriptions: dataset.feature_descriptions.clone(),
270 description: dataset.description.clone(),
271 metadata: dataset.metadata.clone(),
272 })
273 }
274
275 pub fn config(&self) -> &DistributedConfig {
277 &self.config
278 }
279
280 pub fn total_samples(&self) -> usize {
282 self.total_samples
283 }
284}
285
286pub struct DistributedCache {
288 cache: Arc<Mutex<HashMap<String, Vec<u8>>>>,
289 config: DistributedConfig,
290}
291
292impl DistributedCache {
293 pub fn new(config: DistributedConfig) -> Self {
295 Self {
296 cache: Arc::new(Mutex::new(HashMap::new())),
297 config,
298 }
299 }
300
301 pub fn put(&self, key: String, data: Vec<u8>) -> Result<()> {
303 let mut cache = self
304 .cache
305 .lock()
306 .map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
307 cache.insert(key, data);
308 Ok(())
309 }
310
311 pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
313 let cache = self
314 .cache
315 .lock()
316 .map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
317 Ok(cache.get(key).cloned())
318 }
319
320 pub fn contains(&self, key: &str) -> bool {
322 self.cache
323 .lock()
324 .map(|cache| cache.contains_key(key))
325 .unwrap_or(false)
326 }
327
328 pub fn clear(&self) -> Result<()> {
330 let mut cache = self
331 .cache
332 .lock()
333 .map_err(|e| DatasetsError::CacheError(format!("Lock error: {}", e)))?;
334 cache.clear();
335 Ok(())
336 }
337
338 pub fn size(&self) -> usize {
340 self.cache.lock().map(|c| c.len()).unwrap_or(0)
341 }
342}
343
344pub fn create_loader(
355 world_size: usize,
356 rank: usize,
357 total_samples: usize,
358) -> Result<DistributedLoader> {
359 let config = DistributedConfig::new(world_size, rank)?;
360 DistributedLoader::new(config, total_samples)
361}
362
363pub fn create_loader_with_config(
365 config: DistributedConfig,
366 total_samples: usize,
367) -> Result<DistributedLoader> {
368 DistributedLoader::new(config, total_samples)
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_distributed_config() -> Result<()> {
377 let config = DistributedConfig::new(4, 2)?;
378 assert_eq!(config.world_size, 4);
379 assert_eq!(config.rank, 2);
380 assert_eq!(config.num_shards, 4);
381
382 assert!(DistributedConfig::new(4, 4).is_err());
384
385 Ok(())
386 }
387
388 #[test]
389 fn test_shard_creation() -> Result<()> {
390 let shards = DistributedLoader::create_shards(100, 4, false)?;
392 assert_eq!(shards.len(), 4);
393 assert_eq!(shards[0].size, 25);
394 assert_eq!(shards[3].end, 100);
395
396 let shards = DistributedLoader::create_shards(103, 4, false)?;
398 assert_eq!(shards.len(), 4);
399 assert_eq!(shards[0].size, 26); assert_eq!(shards[1].size, 26);
401 assert_eq!(shards[2].size, 26);
402 assert_eq!(shards[3].size, 25);
403
404 Ok(())
405 }
406
407 #[test]
408 fn test_distributed_loader() -> Result<()> {
409 let config = DistributedConfig::new(4, 1)?;
410 let loader = DistributedLoader::new(config, 100)?;
411
412 assert_eq!(loader.total_samples(), 100);
413 let assigned = loader.get_assigned_shards();
414 assert!(!assigned.is_empty());
415
416 let indices = loader.get_sample_indices();
418 assert!(!indices.is_empty());
419
420 Ok(())
421 }
422
423 #[test]
424 fn test_partition_dataset() -> Result<()> {
425 let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect())
427 .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
428
429 let target = Some(Array1::from_vec((0..10).map(|x| x as f64).collect()));
430
431 let dataset = Dataset {
432 data,
433 target,
434 targetnames: None,
435 featurenames: None,
436 feature_descriptions: None,
437 description: None,
438 metadata: Default::default(),
439 };
440
441 let config = DistributedConfig::new(2, 0)?;
443 let loader = DistributedLoader::new(config, 10)?;
444 let partitioned = loader.partition_dataset(&dataset)?;
445
446 assert_eq!(partitioned.n_samples(), 5); assert_eq!(partitioned.n_features(), 3);
448
449 Ok(())
450 }
451
452 #[test]
453 fn test_distributed_cache() -> Result<()> {
454 let config = DistributedConfig::new(2, 0)?;
455 let cache = DistributedCache::new(config);
456
457 cache.put("test".to_string(), vec![1, 2, 3, 4])?;
459 assert!(cache.contains("test"));
460
461 let data = cache.get("test")?;
462 assert_eq!(data, Some(vec![1, 2, 3, 4]));
463
464 cache.clear()?;
466 assert!(!cache.contains("test"));
467
468 Ok(())
469 }
470
471 #[test]
472 fn test_shard_contains() {
473 let shard = Shard::new(0, 10, 20);
474 assert!(shard.contains(10));
475 assert!(shard.contains(15));
476 assert!(shard.contains(19));
477 assert!(!shard.contains(9));
478 assert!(!shard.contains(20));
479 }
480
481 #[test]
482 fn test_round_robin_assignment() -> Result<()> {
483 let config = DistributedConfig::new(3, 1)?; let loader = DistributedLoader::new(config, 90)?; let indices = loader.get_sample_indices();
487
488 assert_eq!(indices.len(), 30);
492 assert_eq!(indices[0], 30);
493 assert_eq!(indices[29], 59);
494
495 Ok(())
496 }
497}