1#![allow(unused_imports, unused_variables, dead_code)]
8
9use crate::{
10 dataloader::{BatchResult, DistributedSampler, Sampler},
11 DataLoader, DataLoaderConfig, Dataset,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::io::{BufReader, BufWriter, Read, Write};
16use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream};
17use std::sync::{Arc, Mutex, RwLock};
18use std::thread;
19use std::time::{Duration, Instant};
20use tenflowers_core::{Device, Result, Tensor, TensorError};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct DistributedLoadingConfig {
25 pub world_size: usize,
27 pub rank: usize,
29 pub master_addr: String,
31 pub master_port: u16,
33 pub enable_rdma: bool,
35 pub rdma_device: Option<String>,
37 pub network_timeout: Duration,
39 pub enable_compression: bool,
41 pub collective_batch_size: usize,
43 pub network_workers: usize,
45 pub enable_remote_prefetch: bool,
47 pub remote_prefetch_size: usize,
49}
50
51impl Default for DistributedLoadingConfig {
52 fn default() -> Self {
53 Self {
54 world_size: 1,
55 rank: 0,
56 master_addr: "127.0.0.1".to_string(),
57 master_port: 29500,
58 enable_rdma: false,
59 rdma_device: None,
60 network_timeout: Duration::from_secs(30),
61 enable_compression: false,
62 collective_batch_size: 32,
63 network_workers: 4,
64 enable_remote_prefetch: true,
65 remote_prefetch_size: 64,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct NodeInfo {
73 pub rank: usize,
74 pub addr: SocketAddr,
75 pub device_capabilities: Vec<String>, pub rdma_enabled: bool,
77 pub rdma_device: Option<String>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum DistributedMessage {
83 Handshake { node_info: NodeInfo },
85 DataRequest {
87 indices: Vec<usize>,
88 requestor_rank: usize,
89 request_id: u64,
90 },
91 DataResponse {
93 data: Vec<u8>, request_id: u64,
95 compressed: bool,
96 },
97 CollectiveOp {
99 op_type: CollectiveOpType,
100 op_id: u64,
101 data: Option<Vec<u8>>,
102 },
103 Heartbeat { timestamp: u64 },
105 Error { message: String },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum CollectiveOpType {
112 EpochSync { epoch: usize },
114 ShuffleSync { seed: u64 },
116 StatisticsGather,
118 ConfigBroadcast,
120 Barrier,
122 Broadcast,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct DistributedLoadingStats {
129 pub local_samples_loaded: u64,
130 pub remote_samples_loaded: u64,
131 pub network_bytes_sent: u64,
132 pub network_bytes_received: u64,
133 pub average_network_latency_ms: u64, pub cache_hit_rate: f64,
135 pub rdma_transfers: u64,
136 pub collective_operations: u64,
137}
138
139impl Default for DistributedLoadingStats {
140 fn default() -> Self {
141 Self {
142 local_samples_loaded: 0,
143 remote_samples_loaded: 0,
144 network_bytes_sent: 0,
145 network_bytes_received: 0,
146 average_network_latency_ms: 0,
147 cache_hit_rate: 0.0,
148 rdma_transfers: 0,
149 collective_operations: 0,
150 }
151 }
152}
153
154pub struct EnhancedDistributedSampler {
156 base_sampler: DistributedSampler,
158 config: DistributedLoadingConfig,
160 comm_manager: Arc<Mutex<CommunicationManager>>,
162 stats: Arc<RwLock<DistributedLoadingStats>>,
164 sample_cache: Arc<Mutex<HashMap<usize, CachedSample>>>,
166 rdma_context: Option<Arc<Mutex<RdmaContext>>>,
168}
169
170#[derive(Debug, Clone)]
172struct CachedSample {
173 data: Vec<u8>,
174 timestamp: Instant,
175 access_count: u64,
176}
177
178#[derive(Debug)]
180struct RdmaContext {
181 device_name: String,
182 initialized: bool,
185 memory_regions: HashMap<String, RdmaMemoryRegion>,
186}
187
188#[derive(Debug)]
190struct RdmaMemoryRegion {
191 addr: usize, size: usize,
193 }
195
196pub struct CommunicationManager {
198 node_info: NodeInfo,
199 cluster_nodes: HashMap<usize, NodeInfo>,
200 connections: HashMap<usize, TcpStream>,
201 listener: Option<TcpListener>,
202 config: DistributedLoadingConfig,
203 #[allow(clippy::type_complexity)]
204 message_handlers: HashMap<
205 String,
206 Box<dyn Fn(&DistributedMessage) -> Result<Option<DistributedMessage>> + Send + Sync>,
207 >,
208}
209
210impl EnhancedDistributedSampler {
211 pub fn new(num_replicas: usize, rank: usize, config: DistributedLoadingConfig) -> Result<Self> {
213 let base_sampler = DistributedSampler::new(num_replicas, rank)?;
214
215 let node_info = NodeInfo {
217 rank,
218 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), device_capabilities: Self::detect_devices_as_strings(),
220 rdma_enabled: config.enable_rdma,
221 rdma_device: config.rdma_device.clone(),
222 };
223
224 let comm_manager = Arc::new(Mutex::new(CommunicationManager::new(
225 node_info,
226 config.clone(),
227 )?));
228
229 let rdma_context = if config.enable_rdma {
231 Some(Arc::new(Mutex::new(RdmaContext::new(
232 config.rdma_device.as_ref(),
233 )?)))
234 } else {
235 None
236 };
237
238 Ok(Self {
239 base_sampler,
240 config,
241 comm_manager,
242 stats: Arc::new(RwLock::new(DistributedLoadingStats::default())),
243 sample_cache: Arc::new(Mutex::new(HashMap::new())),
244 rdma_context,
245 })
246 }
247
248 pub fn initialize(&mut self) -> Result<()> {
250 self.register_with_master()?;
252
253 self.discover_cluster_nodes()?;
255
256 self.establish_connections()?;
258
259 if let Some(rdma_context) = &self.rdma_context {
261 let mut ctx = rdma_context.lock().expect("lock should not be poisoned");
262 ctx.initialize()?;
263 }
264
265 self.start_network_workers()?;
267
268 Ok(())
269 }
270
271 pub fn sample_indices_distributed(
273 &self,
274 dataset_len: usize,
275 ) -> Result<Box<dyn Iterator<Item = usize> + Send>> {
276 let mut base_indices: Vec<usize> = self.base_sampler.sample_indices(dataset_len).collect();
278
279 if self.base_sampler.is_random() {
281 self.coordinate_shuffle(&mut base_indices)?;
282 }
283
284 let enhanced_indices = self.apply_load_balancing(base_indices)?;
286
287 Ok(Box::new(enhanced_indices.into_iter()))
288 }
289
290 pub fn load_batch_distributed<T, D>(
292 &self,
293 dataset: &D,
294 indices: &[usize],
295 ) -> Result<BatchResult<T>>
296 where
297 T: Clone
298 + Default
299 + Send
300 + Sync
301 + 'static
302 + bytemuck::Pod
303 + bytemuck::Zeroable
304 + serde::Serialize
305 + for<'de> serde::Deserialize<'de>
306 + scirs2_core::numeric::Zero,
307 D: Dataset<T> + Send + Sync,
308 {
309 let mut local_indices = Vec::new();
310 let mut remote_requests = HashMap::new();
311
312 for &index in indices {
314 if self.is_local_index(index, dataset.len()) {
315 local_indices.push(index);
316 } else {
317 let owner_rank = self.get_index_owner(index, dataset.len());
318 remote_requests
319 .entry(owner_rank)
320 .or_insert_with(Vec::new)
321 .push(index);
322 }
323 }
324
325 let mut batch_data = Vec::new();
327 for &index in &local_indices {
328 let (features, labels) = dataset.get(index)?;
329 batch_data.push((features, labels));
330 }
331
332 for (remote_rank, remote_indices) in remote_requests {
334 let remote_data = self.fetch_remote_data_sync::<T>(remote_rank, &remote_indices)?;
337 batch_data.extend(remote_data);
338 }
339
340 {
342 let mut stats = self
343 .stats
344 .write()
345 .expect("write lock should not be poisoned");
346 stats.local_samples_loaded += local_indices.len() as u64;
347 stats.remote_samples_loaded += (indices.len() - local_indices.len()) as u64;
348 }
349
350 Ok(BatchResult::Samples(batch_data))
351 }
352
353 pub fn collective_operation(
355 &self,
356 op_type: CollectiveOpType,
357 data: Option<Vec<u8>>,
358 ) -> Result<Option<Vec<u8>>> {
359 let op_id = self.generate_operation_id();
360 let message = DistributedMessage::CollectiveOp {
361 op_type: op_type.clone(),
362 op_id,
363 data,
364 };
365
366 let comm_manager = self
368 .comm_manager
369 .lock()
370 .expect("lock should not be poisoned");
371 let results = comm_manager.broadcast_message(&message)?;
372
373 match op_type {
375 CollectiveOpType::EpochSync { epoch } => {
376 self.synchronize_epoch(epoch)?;
378 Ok(None)
379 }
380 CollectiveOpType::ShuffleSync { seed } => {
381 self.coordinate_shuffle_seed(seed)?;
383 Ok(None)
384 }
385 CollectiveOpType::StatisticsGather => {
386 let aggregated_stats = self.aggregate_statistics(results)?;
388 let serialized =
389 oxicode::serde::encode_to_vec(&aggregated_stats, oxicode::config::standard())
390 .map_err(|e| {
391 TensorError::invalid_argument(format!("Serialization error: {e}"))
392 })?;
393 Ok(Some(serialized))
394 }
395 CollectiveOpType::ConfigBroadcast => {
396 Ok(None)
398 }
399 CollectiveOpType::Barrier => {
400 Ok(None)
402 }
403 CollectiveOpType::Broadcast => {
404 Ok(None)
406 }
407 }
408 }
409
410 pub fn get_statistics(&self) -> DistributedLoadingStats {
412 self.stats
413 .read()
414 .expect("read lock should not be poisoned")
415 .clone()
416 }
417
418 pub fn shutdown(&mut self) -> Result<()> {
420 {
422 let mut comm_manager = self
423 .comm_manager
424 .lock()
425 .expect("lock should not be poisoned");
426 comm_manager.shutdown()?;
427 }
428
429 if let Some(rdma_context) = &self.rdma_context {
431 let mut ctx = rdma_context.lock().expect("lock should not be poisoned");
432 ctx.cleanup()?;
433 }
434
435 {
437 let mut cache = self
438 .sample_cache
439 .lock()
440 .expect("lock should not be poisoned");
441 cache.clear();
442 }
443
444 Ok(())
445 }
446
447 fn detect_devices() -> Vec<Device> {
450 #[cfg_attr(not(feature = "gpu"), allow(unused_mut))]
451 let mut devices = vec![Device::Cpu];
452
453 #[cfg(feature = "gpu")]
454 {
455 #[cfg(feature = "gpu")]
458 if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
459 for i in 0..4 {
460 if let Ok(gpu_device) = Device::from_str(&format!("gpu:{i}")) {
462 devices.push(gpu_device);
463 }
464 }
465 }
466 }
467
468 devices
469 }
470
471 fn detect_devices_as_strings() -> Vec<String> {
472 Self::detect_devices()
473 .iter()
474 .map(|d| format!("{d:?}"))
475 .collect()
476 }
477
478 fn register_with_master(&self) -> Result<()> {
479 let master_addr = format!("{}:{}", self.config.master_addr, self.config.master_port);
481
482 println!("Registering with master at {master_addr}");
484
485 Ok(())
486 }
487
488 fn discover_cluster_nodes(&self) -> Result<()> {
489 Ok(())
492 }
493
494 fn establish_connections(&self) -> Result<()> {
495 Ok(())
498 }
499
500 fn start_network_workers(&self) -> Result<()> {
501 Ok(())
504 }
505
506 fn coordinate_shuffle(&self, indices: &mut [usize]) -> Result<()> {
507 let seed = if self.config.rank == 0 {
509 std::time::SystemTime::now()
511 .duration_since(std::time::UNIX_EPOCH)
512 .map(|d| d.as_secs())
513 .unwrap_or(0)
514 } else {
515 let collective_msg = DistributedMessage::CollectiveOp {
517 op_type: CollectiveOpType::Broadcast,
518 op_id: std::time::SystemTime::now()
519 .duration_since(std::time::UNIX_EPOCH)
520 .map(|d| d.as_nanos() as u64)
521 .unwrap_or(0),
522 data: None,
523 };
524
525 let res = {
527 let comm_manager = self
528 .comm_manager
529 .lock()
530 .expect("lock should not be poisoned");
531 comm_manager.send_request(0, &collective_msg)
532 };
533 match res {
534 Ok(DistributedMessage::CollectiveOp {
535 data: Some(seed_data),
536 ..
537 }) => {
538 match oxicode::serde::decode_owned_from_slice::<u64, _>(
540 &seed_data,
541 oxicode::config::standard(),
542 )
543 .map(|(v, _)| v)
544 {
545 Ok(received_seed) => received_seed,
546 Err(_) => {
547 std::time::SystemTime::now()
549 .duration_since(std::time::UNIX_EPOCH)
550 .map(|d| d.as_secs())
551 .unwrap_or(0)
552 }
553 }
554 }
555 _ => {
556 std::time::SystemTime::now()
558 .duration_since(std::time::UNIX_EPOCH)
559 .map(|d| d.as_secs())
560 .unwrap_or(0)
561 }
562 }
563 };
564
565 self.coordinate_shuffle_seed(seed)?;
566
567 let mut rng_state = seed;
569 for i in (1..indices.len()).rev() {
570 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
571 let j = (rng_state as usize) % (i + 1);
572 indices.swap(i, j);
573 }
574
575 Ok(())
576 }
577
578 fn apply_load_balancing(&self, indices: Vec<usize>) -> Result<Vec<usize>> {
579 Ok(indices)
582 }
583
584 fn is_local_index(&self, index: usize, dataset_len: usize) -> bool {
585 let samples_per_replica =
587 (dataset_len + self.config.world_size - 1) / self.config.world_size;
588 let start_idx = self.config.rank * samples_per_replica;
589 let end_idx = ((self.config.rank + 1) * samples_per_replica).min(dataset_len);
590
591 index >= start_idx && index < end_idx
592 }
593
594 fn get_index_owner(&self, index: usize, dataset_len: usize) -> usize {
595 let samples_per_replica =
597 (dataset_len + self.config.world_size - 1) / self.config.world_size;
598 index / samples_per_replica
599 }
600
601 fn fetch_remote_data_sync<T>(
602 &self,
603 remote_rank: usize,
604 indices: &[usize],
605 ) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
606 where
607 T: Clone
608 + Default
609 + Send
610 + Sync
611 + 'static
612 + bytemuck::Pod
613 + bytemuck::Zeroable
614 + serde::Serialize
615 + for<'de> serde::Deserialize<'de>
616 + scirs2_core::numeric::Zero,
617 {
618 let cached_data = self.check_cache::<T>(indices);
620 if !cached_data.is_empty() {
621 return Ok(cached_data);
622 }
623
624 let request_id = self.generate_request_id();
626 let request = DistributedMessage::DataRequest {
627 indices: indices.to_vec(),
628 requestor_rank: self.config.rank,
629 request_id,
630 };
631
632 let comm_manager = self
633 .comm_manager
634 .lock()
635 .expect("lock should not be poisoned");
636 let response = comm_manager.send_request(remote_rank, &request)?;
637
638 match response {
639 DistributedMessage::DataResponse {
640 data, compressed, ..
641 } => {
642 let data_len = data.len(); let decompressed_data = if compressed {
644 self.decompress_data(&data)?
645 } else {
646 data
647 };
648
649 let samples: Vec<(Tensor<T>, Tensor<T>)> =
651 match oxicode::serde::decode_owned_from_slice::<
652 Vec<(Vec<T>, Vec<usize>, Vec<T>, Vec<usize>)>,
653 _,
654 >(&decompressed_data, oxicode::config::standard())
655 .map(|(v, _)| v)
656 {
657 Ok(tensor_data) => {
658 tensor_data
660 .into_iter()
661 .map(|(input_data, input_shape, target_data, target_shape)| {
662 let input_tensor =
664 match Tensor::from_vec(input_data, &input_shape) {
665 Ok(tensor) => tensor,
666 Err(_) => {
667 Tensor::zeros(&[1])
669 }
670 };
671
672 let target_tensor =
674 match Tensor::from_vec(target_data, &target_shape) {
675 Ok(tensor) => tensor,
676 Err(_) => {
677 Tensor::zeros(&[1])
679 }
680 };
681
682 (input_tensor, target_tensor)
683 })
684 .collect()
685 }
686 Err(_) => {
687 indices
689 .iter()
690 .map(|_| {
691 let input_data = vec![T::default(); 1];
692 let target_data = vec![T::default(); 1];
693 let input_tensor = Tensor::from_vec(input_data, &[1])
694 .unwrap_or_else(|_| Tensor::zeros(&[1]));
695 let target_tensor = Tensor::from_vec(target_data, &[1])
696 .unwrap_or_else(|_| Tensor::zeros(&[1]));
697 (input_tensor, target_tensor)
698 })
699 .collect()
700 }
701 };
702
703 self.cache_samples(indices, &decompressed_data);
705
706 {
708 let mut stats = self
709 .stats
710 .write()
711 .expect("write lock should not be poisoned");
712 stats.network_bytes_received += data_len as u64;
713 }
714
715 Ok(samples)
716 }
717 _ => Err(TensorError::invalid_argument(
718 "Invalid response from remote node".to_string(),
719 )),
720 }
721 }
722
723 fn check_cache<T>(&self, indices: &[usize]) -> Vec<(Tensor<T>, Tensor<T>)>
724 where
725 T: Clone + Default + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
726 {
727 Vec::new()
730 }
731
732 fn cache_samples(&self, indices: &[usize], data: &[u8]) {
733 let mut cache = self
734 .sample_cache
735 .lock()
736 .expect("lock should not be poisoned");
737 let timestamp = Instant::now();
738
739 for &index in indices {
740 let cached_sample = CachedSample {
741 data: data.to_vec(),
742 timestamp,
743 access_count: 1,
744 };
745 cache.insert(index, cached_sample);
746 }
747
748 if cache.len() > 1000 {
750 self.evict_old_cache_entries(&mut cache);
752 }
753 }
754
755 fn evict_old_cache_entries(&self, cache: &mut HashMap<usize, CachedSample>) {
756 let cutoff_time = Instant::now() - Duration::from_secs(300); cache.retain(|_, sample| sample.timestamp > cutoff_time);
759 }
760
761 fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
762 Ok(data.to_vec())
765 }
766
767 fn generate_operation_id(&self) -> u64 {
768 std::time::SystemTime::now()
769 .duration_since(std::time::UNIX_EPOCH)
770 .map(|d| d.as_nanos() as u64)
771 .unwrap_or(0)
772 }
773
774 fn generate_request_id(&self) -> u64 {
775 self.generate_operation_id()
776 }
777
778 fn synchronize_epoch(&self, epoch: usize) -> Result<()> {
779 Ok(())
782 }
783
784 fn coordinate_shuffle_seed(&self, seed: u64) -> Result<()> {
785 if self.config.rank == 0 {
787 let seed_data = oxicode::serde::encode_to_vec(&seed, oxicode::config::standard())
789 .map_err(|e| {
790 TensorError::invalid_operation_simple(format!("Seed serialization error: {e}"))
791 })?;
792
793 let broadcast_msg = DistributedMessage::CollectiveOp {
794 op_type: CollectiveOpType::Broadcast,
795 op_id: std::time::SystemTime::now()
796 .duration_since(std::time::UNIX_EPOCH)
797 .map(|d| d.as_nanos() as u64)
798 .unwrap_or(0),
799 data: Some(seed_data),
800 };
801
802 for rank in 1..self.config.world_size {
804 if let Err(e) = {
805 let comm_manager = self
806 .comm_manager
807 .lock()
808 .expect("lock should not be poisoned");
809 comm_manager.send_request(rank, &broadcast_msg)
810 } {
811 return Err(TensorError::invalid_operation_simple(format!(
812 "Failed to send seed to rank {rank}: {e}"
813 )));
814 }
815 }
816 }
817 Ok(())
819 }
820
821 fn aggregate_statistics(
822 &self,
823 results: Vec<DistributedMessage>,
824 ) -> Result<DistributedLoadingStats> {
825 Ok(DistributedLoadingStats::default())
828 }
829}
830
831impl Sampler for EnhancedDistributedSampler {
833 fn sample_indices(&self, len: usize) -> Box<dyn Iterator<Item = usize> + Send> {
834 self.base_sampler.sample_indices(len)
836 }
837
838 fn is_random(&self) -> bool {
839 self.base_sampler.is_random()
840 }
841
842 fn set_seed(&mut self, seed: Option<u64>) {
843 }
846}
847
848impl CommunicationManager {
849 fn new(node_info: NodeInfo, config: DistributedLoadingConfig) -> Result<Self> {
850 Ok(Self {
851 node_info,
852 cluster_nodes: HashMap::new(),
853 connections: HashMap::new(),
854 listener: None,
855 config,
856 message_handlers: HashMap::new(),
857 })
858 }
859
860 fn broadcast_message(&self, message: &DistributedMessage) -> Result<Vec<DistributedMessage>> {
861 Ok(Vec::new())
864 }
865
866 fn send_request(
867 &self,
868 dest_rank: usize,
869 message: &DistributedMessage,
870 ) -> Result<DistributedMessage> {
871 if dest_rank >= self.config.world_size {
873 return Ok(DistributedMessage::Error {
874 message: format!("Invalid destination rank: {dest_rank}"),
875 });
876 }
877
878 let connections = &self.connections;
880 if let Some(connection) = connections.get(&dest_rank) {
881 let serialized_message =
883 oxicode::serde::encode_to_vec(message, oxicode::config::standard()).map_err(
884 |e| TensorError::invalid_operation_simple(format!("Serialization error: {e}")),
885 )?;
886
887 let mut stream = connection;
889 let msg_len = serialized_message.len() as u32;
890 let len_bytes = msg_len.to_be_bytes();
891
892 if stream.write_all(&len_bytes).is_err() {
893 return Ok(DistributedMessage::Error {
894 message: format!("Failed to send to rank {dest_rank}"),
895 });
896 }
897
898 if stream.write_all(&serialized_message).is_err() {
899 return Ok(DistributedMessage::Error {
900 message: format!("Failed to send message to rank {dest_rank}"),
901 });
902 }
903
904 let mut response_len_bytes = [0u8; 4];
906 if stream.read_exact(&mut response_len_bytes).is_err() {
907 return Ok(DistributedMessage::Error {
908 message: format!("Failed to read response length from rank {dest_rank}"),
909 });
910 }
911
912 let response_len = u32::from_be_bytes(response_len_bytes) as usize;
913 let mut response_data = vec![0u8; response_len];
914
915 if stream.read_exact(&mut response_data).is_err() {
916 return Ok(DistributedMessage::Error {
917 message: format!("Failed to read response from rank {dest_rank}"),
918 });
919 }
920
921 match oxicode::serde::decode_owned_from_slice::<DistributedMessage, _>(
923 &response_data,
924 oxicode::config::standard(),
925 )
926 .map(|(v, _)| v)
927 {
928 Ok(response) => Ok(response),
929 Err(e) => Ok(DistributedMessage::Error {
930 message: format!("Deserialization error: {e}"),
931 }),
932 }
933 } else {
934 Ok(DistributedMessage::Error {
935 message: format!("No connection to rank {dest_rank}"),
936 })
937 }
938 }
939
940 fn shutdown(&mut self) -> Result<()> {
941 self.connections.clear();
943
944 if let Some(listener) = self.listener.take() {
945 drop(listener);
946 }
947
948 Ok(())
949 }
950}
951
952impl RdmaContext {
953 fn new(device_name: Option<&String>) -> Result<Self> {
954 Ok(Self {
955 device_name: device_name.cloned().unwrap_or_else(|| "mlx5_0".to_string()),
956 initialized: false,
957 memory_regions: HashMap::new(),
958 })
959 }
960
961 fn initialize(&mut self) -> Result<()> {
962 self.initialized = true;
970 Ok(())
971 }
972
973 fn cleanup(&mut self) -> Result<()> {
974 self.memory_regions.clear();
976 self.initialized = false;
977 Ok(())
978 }
979
980 fn register_memory_region(&mut self, key: String, size: usize) -> Result<()> {
981 let mr = RdmaMemoryRegion {
985 addr: 0, size,
987 };
988
989 self.memory_regions.insert(key, mr);
990 Ok(())
991 }
992}
993
994pub fn create_distributed_dataloader<T, D>(
996 dataset: D,
997 config: DistributedLoadingConfig,
998 dataloader_config: DataLoaderConfig,
999) -> Result<DataLoader<T, D, EnhancedDistributedSampler>>
1000where
1001 T: Clone
1002 + Default
1003 + scirs2_core::numeric::Zero
1004 + Send
1005 + Sync
1006 + 'static
1007 + bytemuck::Pod
1008 + bytemuck::Zeroable,
1009 D: Dataset<T> + Send + Sync + 'static,
1010{
1011 let sampler = EnhancedDistributedSampler::new(config.world_size, config.rank, config)?;
1012
1013 Ok(DataLoader::new(dataset, sampler, dataloader_config))
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019 use crate::TensorDataset;
1020
1021 #[test]
1022 fn test_distributed_loading_config() {
1023 let config = DistributedLoadingConfig::default();
1024 assert_eq!(config.world_size, 1);
1025 assert_eq!(config.rank, 0);
1026 assert!(!config.enable_rdma);
1027 }
1028
1029 #[test]
1030 fn test_enhanced_distributed_sampler_creation() {
1031 let config = DistributedLoadingConfig::default();
1032 let sampler = EnhancedDistributedSampler::new(2, 0, config);
1033 assert!(sampler.is_ok());
1034 }
1035
1036 #[test]
1037 fn test_communication_manager_creation() {
1038 let node_info = NodeInfo {
1039 rank: 0,
1040 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1041 device_capabilities: vec!["Cpu".to_string()],
1042 rdma_enabled: false,
1043 rdma_device: None,
1044 };
1045
1046 let config = DistributedLoadingConfig::default();
1047 let comm_manager = CommunicationManager::new(node_info, config);
1048 assert!(comm_manager.is_ok());
1049 }
1050
1051 #[test]
1052 fn test_index_ownership() {
1053 let config = DistributedLoadingConfig {
1054 world_size: 4,
1055 rank: 1,
1056 ..Default::default()
1057 };
1058
1059 let sampler =
1060 EnhancedDistributedSampler::new(4, 1, config).expect("test: operation should succeed");
1061
1062 let dataset_len = 100;
1064 assert!(sampler.is_local_index(25, dataset_len)); assert!(!sampler.is_local_index(5, dataset_len)); assert_eq!(sampler.get_index_owner(5, dataset_len), 0);
1067 assert_eq!(sampler.get_index_owner(75, dataset_len), 3);
1068 }
1069
1070 #[test]
1071 fn test_rdma_context_initialization() {
1072 let rdma_ctx = RdmaContext::new(Some(&"mlx5_0".to_string()));
1073 assert!(rdma_ctx.is_ok());
1074
1075 let mut ctx = rdma_ctx.expect("test: operation should succeed");
1076 assert!(ctx.initialize().is_ok());
1077 assert!(ctx.initialized);
1078 }
1079}