1use std::io::{Read, Write};
25use std::net::{TcpListener, TcpStream};
26use std::sync::mpsc::{self, Receiver, Sender};
27use std::sync::{Arc, Mutex};
28use std::time::Duration;
29
30use ferrotorch_core::FerrotorchResult;
31
32use crate::error::DistributedError;
33
34pub trait Backend: Send + Sync {
43 fn rank(&self) -> usize;
45
46 fn world_size(&self) -> usize;
48
49 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()>;
51
52 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()>;
55
56 fn recv_timeout(
62 &self,
63 dst: &mut [u8],
64 src_rank: usize,
65 timeout: Duration,
66 ) -> FerrotorchResult<()> {
67 let _ = timeout;
68 self.recv(dst, src_rank)
69 }
70
71 fn barrier(&self) -> FerrotorchResult<()>;
73
74 #[cfg(feature = "nccl")]
86 fn as_nccl_backend(&self) -> Option<&crate::nccl_backend::NcclBackend> {
87 None
88 }
89}
90
91pub struct TcpBackend {
106 rank: usize,
107 world_size: usize,
108 connections: Vec<Option<Mutex<TcpStream>>>,
113}
114
115impl TcpBackend {
116 pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
122 if world_size < 2 {
123 return Err(DistributedError::InvalidWorldSize { world_size }.into());
124 }
125 if rank >= world_size {
126 return Err(DistributedError::InvalidRank { rank, world_size }.into());
127 }
128
129 let mut peer_streams: Vec<Option<TcpStream>> = (0..world_size).map(|_| None).collect();
131
132 if rank == 0 {
133 let listener = TcpListener::bind(master_addr).map_err(|e| DistributedError::Io {
134 message: format!("rank 0 bind {master_addr}: {e}"),
135 })?;
136
137 for _ in 1..world_size {
139 let (mut stream, _addr) = listener.accept().map_err(|e| DistributedError::Io {
140 message: format!("rank 0 accept: {e}"),
141 })?;
142 let mut rank_buf = [0u8; 8];
144 stream
145 .read_exact(&mut rank_buf)
146 .map_err(|e| DistributedError::Io {
147 message: format!("rank 0 read peer rank: {e}"),
148 })?;
149 let peer_rank = u64::from_le_bytes(rank_buf) as usize;
150 if peer_rank >= world_size || peer_rank == 0 {
151 return Err(DistributedError::InvalidRank {
152 rank: peer_rank,
153 world_size,
154 }
155 .into());
156 }
157 peer_streams[peer_rank] = Some(stream);
158 }
159 } else {
160 let mut stream = TcpStream::connect(master_addr).map_err(|e| DistributedError::Io {
162 message: format!("rank {rank} connect to {master_addr}: {e}"),
163 })?;
164 stream
165 .write_all(&(rank as u64).to_le_bytes())
166 .map_err(|e| DistributedError::Io {
167 message: format!("rank {rank} announce: {e}"),
168 })?;
169 peer_streams[0] = Some(stream);
170 }
171
172 let connections: Vec<Option<Mutex<TcpStream>>> = peer_streams
182 .into_iter()
183 .enumerate()
184 .map(|(i, opt)| {
185 if i == rank {
186 None
188 } else {
189 opt.map(Mutex::new)
193 }
194 })
195 .collect();
196
197 Ok(Self {
198 rank,
199 world_size,
200 connections,
201 })
202 }
203}
204
205impl Backend for TcpBackend {
206 fn rank(&self) -> usize {
207 self.rank
208 }
209
210 fn world_size(&self) -> usize {
211 self.world_size
212 }
213
214 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
215 if dst_rank == self.rank {
216 return Err(DistributedError::SelfSend { rank: self.rank }.into());
217 }
218 if dst_rank >= self.world_size {
219 return Err(DistributedError::InvalidRank {
220 rank: dst_rank,
221 world_size: self.world_size,
222 }
223 .into());
224 }
225
226 let conn = self.connections[dst_rank]
227 .as_ref()
228 .ok_or(DistributedError::NoConnection { rank: dst_rank })?;
229
230 let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
231 message: format!("send to rank {dst_rank}: {e}"),
232 })?;
233
234 let len_bytes = (data.len() as u64).to_le_bytes();
236 stream
237 .write_all(&len_bytes)
238 .map_err(|e| DistributedError::Io {
239 message: format!("send len to rank {dst_rank}: {e}"),
240 })?;
241 stream.write_all(data).map_err(|e| DistributedError::Io {
242 message: format!("send data to rank {dst_rank}: {e}"),
243 })?;
244 stream.flush().map_err(|e| DistributedError::Io {
245 message: format!("flush to rank {dst_rank}: {e}"),
246 })?;
247
248 Ok(())
249 }
250
251 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
252 if src_rank == self.rank {
253 return Err(DistributedError::SelfSend { rank: self.rank }.into());
254 }
255 if src_rank >= self.world_size {
256 return Err(DistributedError::InvalidRank {
257 rank: src_rank,
258 world_size: self.world_size,
259 }
260 .into());
261 }
262
263 let conn = self.connections[src_rank]
264 .as_ref()
265 .ok_or(DistributedError::NoConnection { rank: src_rank })?;
266
267 let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
268 message: format!("recv from rank {src_rank}: {e}"),
269 })?;
270
271 let mut len_bytes = [0u8; 8];
273 stream
274 .read_exact(&mut len_bytes)
275 .map_err(|e| DistributedError::Io {
276 message: format!("recv len from rank {src_rank}: {e}"),
277 })?;
278 let len = u64::from_le_bytes(len_bytes) as usize;
279
280 if len != dst.len() {
281 return Err(DistributedError::SizeMismatch {
282 expected: dst.len(),
283 got: len,
284 }
285 .into());
286 }
287
288 stream.read_exact(dst).map_err(|e| DistributedError::Io {
289 message: format!("recv data from rank {src_rank}: {e}"),
290 })?;
291
292 Ok(())
293 }
294
295 fn recv_timeout(
296 &self,
297 dst: &mut [u8],
298 src_rank: usize,
299 timeout: Duration,
300 ) -> FerrotorchResult<()> {
301 if src_rank == self.rank {
302 return Err(DistributedError::SelfSend { rank: self.rank }.into());
303 }
304 if src_rank >= self.world_size {
305 return Err(DistributedError::InvalidRank {
306 rank: src_rank,
307 world_size: self.world_size,
308 }
309 .into());
310 }
311
312 let conn = self.connections[src_rank]
313 .as_ref()
314 .ok_or(DistributedError::NoConnection { rank: src_rank })?;
315
316 let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
317 message: format!("recv_timeout from rank {src_rank}: {e}"),
318 })?;
319
320 stream
322 .set_read_timeout(Some(timeout))
323 .map_err(|e| DistributedError::Io {
324 message: format!("set_read_timeout for rank {src_rank}: {e}"),
325 })?;
326
327 let mut len_bytes = [0u8; 8];
329 let result = (|| {
330 stream.read_exact(&mut len_bytes).map_err(|e| {
331 if e.kind() == std::io::ErrorKind::WouldBlock
332 || e.kind() == std::io::ErrorKind::TimedOut
333 {
334 DistributedError::Timeout {
335 seconds: timeout.as_secs(),
336 }
337 } else {
338 DistributedError::Io {
339 message: format!("recv_timeout len from rank {src_rank}: {e}"),
340 }
341 }
342 })?;
343 let len = u64::from_le_bytes(len_bytes) as usize;
344 if len != dst.len() {
345 return Err(DistributedError::SizeMismatch {
346 expected: dst.len(),
347 got: len,
348 });
349 }
350 stream.read_exact(dst).map_err(|e| {
351 if e.kind() == std::io::ErrorKind::WouldBlock
352 || e.kind() == std::io::ErrorKind::TimedOut
353 {
354 DistributedError::Timeout {
355 seconds: timeout.as_secs(),
356 }
357 } else {
358 DistributedError::Io {
359 message: format!("recv_timeout data from rank {src_rank}: {e}"),
360 }
361 }
362 })?;
363 Ok(())
364 })();
365
366 let _ = stream.set_read_timeout(None);
368
369 result.map_err(Into::into)
370 }
371
372 fn barrier(&self) -> FerrotorchResult<()> {
373 let tag = [0u8; 1];
376 if self.rank == 0 {
377 let mut buf = [0u8; 1];
378 for r in 1..self.world_size {
379 self.recv(&mut buf, r)?;
380 }
381 for r in 1..self.world_size {
382 self.send(&tag, r)?;
383 }
384 } else {
385 self.send(&tag, 0)?;
386 let mut buf = [0u8; 1];
387 self.recv(&mut buf, 0)?;
388 }
389 Ok(())
390 }
391}
392
393type ChannelMatrix = Arc<Vec<Vec<(Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>)>>>;
402
403pub struct SimulatedBackend {
409 rank: usize,
410 world_size: usize,
411 channels: ChannelMatrix,
413}
414
415impl SimulatedBackend {
416 pub fn create_group(world_size: usize) -> FerrotorchResult<Vec<Self>> {
420 if world_size == 0 {
421 return Err(DistributedError::InvalidWorldSize { world_size }.into());
422 }
423
424 type ChannelPair = (Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>);
426 let mut matrix: Vec<Vec<ChannelPair>> = Vec::new();
427
428 for _src in 0..world_size {
429 let mut row = Vec::new();
430 for _dst in 0..world_size {
431 let (tx, rx) = mpsc::channel();
432 row.push((Mutex::new(tx), Mutex::new(rx)));
433 }
434 matrix.push(row);
435 }
436
437 let shared = Arc::new(matrix);
438
439 let backends: Vec<Self> = (0..world_size)
440 .map(|rank| Self {
441 rank,
442 world_size,
443 channels: Arc::clone(&shared),
444 })
445 .collect();
446
447 Ok(backends)
448 }
449}
450
451impl Backend for SimulatedBackend {
452 fn rank(&self) -> usize {
453 self.rank
454 }
455
456 fn world_size(&self) -> usize {
457 self.world_size
458 }
459
460 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
461 if dst_rank >= self.world_size {
462 return Err(DistributedError::InvalidRank {
463 rank: dst_rank,
464 world_size: self.world_size,
465 }
466 .into());
467 }
468
469 let tx = self.channels[self.rank][dst_rank].0.lock().map_err(|e| {
471 DistributedError::LockPoisoned {
472 message: format!("send channel lock rank {} -> {dst_rank}: {e}", self.rank),
473 }
474 })?;
475
476 tx.send(data.to_vec())
477 .map_err(|e| DistributedError::ChannelClosed {
478 message: format!("send rank {} -> {dst_rank}: {e}", self.rank),
479 })?;
480
481 Ok(())
482 }
483
484 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
485 if src_rank >= self.world_size {
486 return Err(DistributedError::InvalidRank {
487 rank: src_rank,
488 world_size: self.world_size,
489 }
490 .into());
491 }
492
493 let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
495 DistributedError::LockPoisoned {
496 message: format!("recv channel lock rank {src_rank} -> {}: {e}", self.rank),
497 }
498 })?;
499
500 let data = rx.recv().map_err(|e| DistributedError::ChannelClosed {
501 message: format!("recv rank {src_rank} -> {}: {e}", self.rank),
502 })?;
503
504 if data.len() != dst.len() {
505 return Err(DistributedError::SizeMismatch {
506 expected: dst.len(),
507 got: data.len(),
508 }
509 .into());
510 }
511
512 dst.copy_from_slice(&data);
513 Ok(())
514 }
515
516 fn recv_timeout(
517 &self,
518 dst: &mut [u8],
519 src_rank: usize,
520 timeout: Duration,
521 ) -> FerrotorchResult<()> {
522 if src_rank >= self.world_size {
523 return Err(DistributedError::InvalidRank {
524 rank: src_rank,
525 world_size: self.world_size,
526 }
527 .into());
528 }
529
530 let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
531 DistributedError::LockPoisoned {
532 message: format!(
533 "recv_timeout channel lock rank {src_rank} -> {}: {e}",
534 self.rank
535 ),
536 }
537 })?;
538
539 let data = rx.recv_timeout(timeout).map_err(|e| match e {
540 mpsc::RecvTimeoutError::Timeout => DistributedError::Timeout {
541 seconds: timeout.as_secs(),
542 },
543 mpsc::RecvTimeoutError::Disconnected => DistributedError::ChannelClosed {
544 message: format!(
545 "recv_timeout rank {src_rank} -> {}: disconnected",
546 self.rank
547 ),
548 },
549 })?;
550
551 if data.len() != dst.len() {
552 return Err(DistributedError::SizeMismatch {
553 expected: dst.len(),
554 got: data.len(),
555 }
556 .into());
557 }
558
559 dst.copy_from_slice(&data);
560 Ok(())
561 }
562
563 fn barrier(&self) -> FerrotorchResult<()> {
564 let tag = [0u8; 1];
567 if self.rank == 0 {
568 let mut buf = [0u8; 1];
569 for r in 1..self.world_size {
570 self.recv(&mut buf, r)?;
571 }
572 for r in 1..self.world_size {
573 self.send(&tag, r)?;
574 }
575 } else {
576 self.send(&tag, 0)?;
577 let mut buf = [0u8; 1];
578 self.recv(&mut buf, 0)?;
579 }
580 Ok(())
581 }
582}
583
584pub struct SubBackend {
609 parent: Arc<dyn Backend>,
610 members: Vec<usize>,
612 local_rank: usize,
614}
615
616impl SubBackend {
617 pub fn new(parent: Arc<dyn Backend>, members: Vec<usize>) -> FerrotorchResult<Self> {
629 if members.is_empty() {
630 return Err(DistributedError::InvalidWorldSize { world_size: 0 }.into());
631 }
632
633 let parent_world = parent.world_size();
634 for &m in &members {
635 if m >= parent_world {
636 return Err(DistributedError::InvalidRank {
637 rank: m,
638 world_size: parent_world,
639 }
640 .into());
641 }
642 }
643
644 let mut sorted_members = members;
646 sorted_members.sort_unstable();
647 sorted_members.dedup();
648
649 let parent_rank = parent.rank();
650 let local_rank = sorted_members
651 .iter()
652 .position(|&r| r == parent_rank)
653 .ok_or(DistributedError::InvalidRank {
654 rank: parent_rank,
655 world_size: sorted_members.len(),
656 })?;
657
658 Ok(Self {
659 parent,
660 members: sorted_members,
661 local_rank,
662 })
663 }
664
665 pub fn members(&self) -> &[usize] {
668 &self.members
669 }
670
671 pub fn to_global(&self, local: usize) -> usize {
677 self.members[local]
678 }
679
680 pub fn to_local(&self, global: usize) -> Option<usize> {
683 self.members.iter().position(|&r| r == global)
684 }
685
686 pub fn parent(&self) -> &Arc<dyn Backend> {
688 &self.parent
689 }
690}
691
692impl Backend for SubBackend {
693 fn rank(&self) -> usize {
694 self.local_rank
695 }
696
697 fn world_size(&self) -> usize {
698 self.members.len()
699 }
700
701 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
702 if dst_rank >= self.members.len() {
703 return Err(DistributedError::InvalidRank {
704 rank: dst_rank,
705 world_size: self.members.len(),
706 }
707 .into());
708 }
709 self.parent.send(data, self.members[dst_rank])
710 }
711
712 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
713 if src_rank >= self.members.len() {
714 return Err(DistributedError::InvalidRank {
715 rank: src_rank,
716 world_size: self.members.len(),
717 }
718 .into());
719 }
720 self.parent.recv(dst, self.members[src_rank])
721 }
722
723 fn recv_timeout(
724 &self,
725 dst: &mut [u8],
726 src_rank: usize,
727 timeout: Duration,
728 ) -> FerrotorchResult<()> {
729 if src_rank >= self.members.len() {
730 return Err(DistributedError::InvalidRank {
731 rank: src_rank,
732 world_size: self.members.len(),
733 }
734 .into());
735 }
736 self.parent
737 .recv_timeout(dst, self.members[src_rank], timeout)
738 }
739
740 fn barrier(&self) -> FerrotorchResult<()> {
741 let tag = [0u8; 1];
747 let size = self.members.len();
748 if size <= 1 {
749 return Ok(());
750 }
751 if self.local_rank == 0 {
752 let mut buf = [0u8; 1];
753 for r in 1..size {
754 self.recv(&mut buf, r)?;
755 }
756 for r in 1..size {
757 self.send(&tag, r)?;
758 }
759 } else {
760 self.send(&tag, 0)?;
761 let mut buf = [0u8; 1];
762 self.recv(&mut buf, 0)?;
763 }
764 Ok(())
765 }
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771 use std::thread;
772
773 #[test]
774 fn test_simulated_send_recv() {
775 let group = SimulatedBackend::create_group(2).unwrap();
776 let mut iter = group.into_iter();
777 let b0 = Arc::new(iter.next().unwrap());
778 let b1 = Arc::new(iter.next().unwrap());
779
780 let b0c = Arc::clone(&b0);
781 let sender = thread::spawn(move || {
782 b0c.send(&[1, 2, 3, 4], 1).unwrap();
783 });
784
785 let mut buf = [0u8; 4];
786 b1.recv(&mut buf, 0).unwrap();
787 sender.join().unwrap();
788
789 assert_eq!(buf, [1, 2, 3, 4]);
790 }
791
792 #[test]
793 fn test_simulated_barrier() {
794 let group = SimulatedBackend::create_group(4).unwrap();
795 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
796
797 let handles: Vec<_> = arcs
798 .into_iter()
799 .map(|b| {
800 thread::spawn(move || {
801 b.barrier().unwrap();
802 })
803 })
804 .collect();
805
806 for h in handles {
807 h.join().unwrap();
808 }
809 }
810
811 #[test]
812 fn test_simulated_rank_world_size() {
813 let group = SimulatedBackend::create_group(3).unwrap();
814 assert_eq!(group[0].rank(), 0);
815 assert_eq!(group[1].rank(), 1);
816 assert_eq!(group[2].rank(), 2);
817 assert_eq!(group[0].world_size(), 3);
818 }
819
820 #[test]
821 fn test_invalid_world_size() {
822 let result = SimulatedBackend::create_group(0);
823 assert!(result.is_err());
824 }
825
826 #[test]
827 fn test_send_to_invalid_rank() {
828 let group = SimulatedBackend::create_group(2).unwrap();
829 let result = group[0].send(&[1], 5);
830 assert!(result.is_err());
831 }
832
833 #[test]
838 fn test_subbackend_local_rank_and_world_size() {
839 let group = SimulatedBackend::create_group(4).unwrap();
841 let arcs: Vec<Arc<dyn Backend>> = group
842 .into_iter()
843 .map(|b| Arc::new(b) as Arc<dyn Backend>)
844 .collect();
845
846 let sub_for_rank1 = SubBackend::new(Arc::clone(&arcs[1]), vec![1, 3]).unwrap();
847 let sub_for_rank3 = SubBackend::new(Arc::clone(&arcs[3]), vec![1, 3]).unwrap();
848
849 assert_eq!(sub_for_rank1.rank(), 0);
850 assert_eq!(sub_for_rank1.world_size(), 2);
851 assert_eq!(sub_for_rank3.rank(), 1);
852 assert_eq!(sub_for_rank3.world_size(), 2);
853 }
854
855 #[test]
856 fn test_subbackend_global_rank_not_in_members_is_error() {
857 let group = SimulatedBackend::create_group(4).unwrap();
858 let arcs: Vec<Arc<dyn Backend>> = group
859 .into_iter()
860 .map(|b| Arc::new(b) as Arc<dyn Backend>)
861 .collect();
862
863 let result = SubBackend::new(Arc::clone(&arcs[0]), vec![1, 3]);
865 assert!(result.is_err());
866 }
867
868 #[test]
869 fn test_subbackend_empty_members_is_error() {
870 let group = SimulatedBackend::create_group(2).unwrap();
871 let arc: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
872 let result = SubBackend::new(arc, vec![]);
873 assert!(result.is_err());
874 }
875
876 #[test]
877 fn test_subbackend_send_recv_routes_through_parent() {
878 let group = SimulatedBackend::create_group(4).unwrap();
881 let arcs: Vec<Arc<dyn Backend>> = group
882 .into_iter()
883 .map(|b| Arc::new(b) as Arc<dyn Backend>)
884 .collect();
885
886 let sub0 = SubBackend::new(Arc::clone(&arcs[0]), vec![0, 2]).unwrap();
887 let sub2 = SubBackend::new(Arc::clone(&arcs[2]), vec![0, 2]).unwrap();
888
889 let sender = thread::spawn(move || {
890 sub0.send(&[9, 8, 7], 1).unwrap();
891 });
892
893 let mut buf = [0u8; 3];
894 sub2.recv(&mut buf, 0).unwrap();
895 sender.join().unwrap();
896
897 assert_eq!(buf, [9, 8, 7]);
898 }
899
900 #[test]
901 fn test_subbackend_barrier() {
902 let group = SimulatedBackend::create_group(6).unwrap();
905 let arcs: Vec<Arc<dyn Backend>> = group
906 .into_iter()
907 .map(|b| Arc::new(b) as Arc<dyn Backend>)
908 .collect();
909
910 let members = vec![0usize, 2, 4];
911
912 let handles: Vec<_> = [0usize, 2, 4]
913 .into_iter()
914 .map(|global_rank| {
915 let parent = Arc::clone(&arcs[global_rank]);
916 let ms = members.clone();
917 thread::spawn(move || {
918 let sub = SubBackend::new(parent, ms).unwrap();
919 sub.barrier().unwrap();
920 })
921 })
922 .collect();
923
924 for h in handles {
925 h.join().unwrap();
926 }
927 }
928
929 #[test]
930 fn test_subbackend_to_global_to_local() {
931 let group = SimulatedBackend::create_group(4).unwrap();
932 let arc: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
933 let sub = SubBackend::new(arc, vec![0, 2]).unwrap();
934
935 assert_eq!(sub.to_global(0), 0);
937 assert_eq!(sub.to_global(1), 2);
938 assert_eq!(sub.to_local(0), Some(0));
939 assert_eq!(sub.to_local(2), Some(1));
940 assert_eq!(sub.to_local(1), None);
941 assert_eq!(sub.to_local(3), None);
942 }
943
944 #[test]
945 fn test_subbackend_sorts_and_dedups_members() {
946 let group = SimulatedBackend::create_group(4).unwrap();
947 let arcs: Vec<Arc<dyn Backend>> = group
948 .into_iter()
949 .map(|b| Arc::new(b) as Arc<dyn Backend>)
950 .collect();
951
952 let sub = SubBackend::new(Arc::clone(&arcs[2]), vec![3, 2, 0, 2, 3]).unwrap();
954 assert_eq!(sub.members(), &[0, 2, 3]);
955 assert_eq!(sub.rank(), 1); assert_eq!(sub.world_size(), 3);
957 }
958}