Skip to main content

ferrotorch_distributed/
backend.rs

1//! Communication backends for distributed training.
2//!
3//! A [`Backend`] abstracts point-to-point messaging so that collective
4//! operations and DDP are transport-agnostic. Two implementations are
5//! provided:
6//!
7//! - [`TcpBackend`] — real multi-process backend over TCP sockets.
8//! - [`SimulatedBackend`] — in-process backend using channels, suitable
9//!   for unit tests without spawning multiple processes.
10//!
11//! ## REQ status (per `.design/ferrotorch-distributed/backend.md`)
12//!
13//! | REQ | Status | Evidence |
14//! |---|---|---|
15//! | REQ-1 (Backend trait) | SHIPPED | `pub trait Backend: Send + Sync` in `backend.rs`; consumers `collective.rs`, `p2p.rs`, `ddp.rs`, `fsdp.rs` all take `&dyn Backend`. |
16//! | REQ-2 (TcpBackend rendezvous) | SHIPPED | `pub struct TcpBackend` + `pub fn new` in `backend.rs`; consumer `hybrid_backend.rs` invokes `TcpBackend::new(rank, world_size, addr)`. |
17//! | REQ-3 (TcpBackend wire protocol) | SHIPPED | `impl Backend for TcpBackend` in `backend.rs` with length-prefix protocol; consumer `hybrid_backend.rs` delegates all P2P methods to inner TcpBackend. |
18//! | REQ-4 (SimulatedBackend channel matrix) | SHIPPED | `pub struct SimulatedBackend` + `create_group` in `backend.rs`; consumer crate-root re-export at `lib.rs`. |
19//! | REQ-5 (SimulatedBackend Backend impl) | SHIPPED | `impl Backend for SimulatedBackend` in `backend.rs`; consumer re-export at `lib.rs`. |
20//! | REQ-6 (SubBackend struct) | SHIPPED | `pub struct SubBackend` + `pub fn new` in `backend.rs`; consumer re-export at `lib.rs`. |
21//! | REQ-7 (SubBackend Backend impl) | SHIPPED | `impl Backend for SubBackend` in `backend.rs`; consumer re-export at `lib.rs` reached via `ferrotorch/src/lib.rs`. |
22//! | REQ-8 (barrier across all three backends) | SHIPPED | barrier methods in `backend.rs` for TcpBackend / SimulatedBackend / SubBackend; consumer `pub fn barrier` in `collective.rs`. |
23
24use std::io::{Read, Write};
25use std::net::{TcpListener, TcpStream};
26use std::sync::mpsc::{self, Receiver, Sender};
27use std::sync::{Arc, Mutex};
28use std::time::Duration;
29
30use ferrotorch_core::FerrotorchResult;
31
32use crate::error::DistributedError;
33
34// ---------------------------------------------------------------------------
35// Backend trait
36// ---------------------------------------------------------------------------
37
38/// Transport-agnostic communication backend.
39///
40/// Every rank in a distributed job holds a `Backend` that can send/receive
41/// raw byte buffers to/from any other rank, plus a collective barrier.
42pub trait Backend: Send + Sync {
43    /// This rank's index in the world (0-based).
44    fn rank(&self) -> usize;
45
46    /// Total number of ranks in the process group.
47    fn world_size(&self) -> usize;
48
49    /// Send `data` to `dst_rank`.
50    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()>;
51
52    /// Receive into `dst` from `src_rank`. The caller must allocate `dst`
53    /// with the correct length before calling.
54    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()>;
55
56    /// Receive into `dst` from `src_rank` with a timeout.
57    ///
58    /// Returns [`DistributedError::Timeout`] if the receive does not
59    /// complete within `timeout`. The default implementation delegates
60    /// to [`recv`](Self::recv) (no timeout).
61    fn recv_timeout(
62        &self,
63        dst: &mut [u8],
64        src_rank: usize,
65        timeout: Duration,
66    ) -> FerrotorchResult<()> {
67        let _ = timeout;
68        self.recv(dst, src_rank)
69    }
70
71    /// Block until every rank has reached this barrier.
72    fn barrier(&self) -> FerrotorchResult<()>;
73
74    /// Downcast hook for the NCCL fast path.
75    ///
76    /// Returns `Some(&NcclBackend)` if this backend *is* an
77    /// [`NcclBackend`](crate::nccl_backend::NcclBackend); `None` otherwise.
78    /// The default implementation returns `None`.
79    ///
80    /// [`gpu_collective::gpu_allreduce`](crate::gpu_collective::gpu_allreduce)
81    /// and [`gpu_collective::gpu_broadcast`](crate::gpu_collective::gpu_broadcast)
82    /// query this method to decide between the NCCL GPU-native fast path
83    /// and the host round-trip fallback. Only compiled under the `nccl`
84    /// feature gate (which gates the existence of [`NcclBackend`]).
85    #[cfg(feature = "nccl")]
86    fn as_nccl_backend(&self) -> Option<&crate::nccl_backend::NcclBackend> {
87        None
88    }
89}
90
91// ---------------------------------------------------------------------------
92// TCP backend
93// ---------------------------------------------------------------------------
94
95/// Real multi-process backend over TCP sockets.
96///
97/// Uses a simple rendezvous protocol:
98/// 1. Rank 0 listens on `addr` and accepts `world_size - 1` connections.
99/// 2. Non-zero ranks connect to rank 0.
100/// 3. Rank 0 relays addressing information so all ranks establish
101///    pairwise connections.
102///
103/// Each connection is wrapped in a `Mutex` to allow concurrent
104/// send/recv on different pairs from different threads.
105pub struct TcpBackend {
106    rank: usize,
107    world_size: usize,
108    /// One TCP stream per peer, indexed by peer rank. `None` for the
109    /// self-slot (no self-loop) and for peers that are not directly
110    /// connected in the star topology (non-zero ranks only connect to
111    /// rank 0).
112    connections: Vec<Option<Mutex<TcpStream>>>,
113}
114
115impl TcpBackend {
116    /// Launch the TCP rendezvous and return a ready-to-use backend.
117    ///
118    /// * `rank` — this process's rank (0-based).
119    /// * `world_size` — total number of processes.
120    /// * `master_addr` — `host:port` where rank 0 listens.
121    pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
122        if world_size < 2 {
123            return Err(DistributedError::InvalidWorldSize { world_size }.into());
124        }
125        if rank >= world_size {
126            return Err(DistributedError::InvalidRank { rank, world_size }.into());
127        }
128
129        // Phase 1: rank 0 collects one connection per non-zero rank.
130        let mut peer_streams: Vec<Option<TcpStream>> = (0..world_size).map(|_| None).collect();
131
132        if rank == 0 {
133            let listener = TcpListener::bind(master_addr).map_err(|e| DistributedError::Io {
134                message: format!("rank 0 bind {master_addr}: {e}"),
135            })?;
136
137            // Accept connections from ranks 1..world_size-1.
138            for _ in 1..world_size {
139                let (mut stream, _addr) = listener.accept().map_err(|e| DistributedError::Io {
140                    message: format!("rank 0 accept: {e}"),
141                })?;
142                // First 8 bytes: the connecting rank as little-endian u64.
143                let mut rank_buf = [0u8; 8];
144                stream
145                    .read_exact(&mut rank_buf)
146                    .map_err(|e| DistributedError::Io {
147                        message: format!("rank 0 read peer rank: {e}"),
148                    })?;
149                let peer_rank = u64::from_le_bytes(rank_buf) as usize;
150                if peer_rank >= world_size || peer_rank == 0 {
151                    return Err(DistributedError::InvalidRank {
152                        rank: peer_rank,
153                        world_size,
154                    }
155                    .into());
156                }
157                peer_streams[peer_rank] = Some(stream);
158            }
159        } else {
160            // Non-zero rank: connect to rank 0 and announce our rank.
161            let mut stream = TcpStream::connect(master_addr).map_err(|e| DistributedError::Io {
162                message: format!("rank {rank} connect to {master_addr}: {e}"),
163            })?;
164            stream
165                .write_all(&(rank as u64).to_le_bytes())
166                .map_err(|e| DistributedError::Io {
167                    message: format!("rank {rank} announce: {e}"),
168                })?;
169            peer_streams[0] = Some(stream);
170        }
171
172        // Phase 2: rank 0 broadcasts peer addresses so every rank can
173        // form a full mesh. For simplicity in this MVP, we use a star
174        // topology where all traffic goes through rank 0's connections.
175        // A full mesh can be added later.
176
177        // Collect into the final connections vec. For the star topology,
178        // non-zero ranks only have a connection to rank 0, and rank 0 has
179        // connections to all others. The self-slot and unconnected peers
180        // are `None`.
181        let connections: Vec<Option<Mutex<TcpStream>>> = peer_streams
182            .into_iter()
183            .enumerate()
184            .map(|(i, opt)| {
185                if i == rank {
186                    // Self-slot: no self-loop needed.
187                    None
188                } else {
189                    // Some(stream) for connected peers, None for unconnected
190                    // peers (non-zero ranks only connect to rank 0 in star
191                    // topology).
192                    opt.map(Mutex::new)
193                }
194            })
195            .collect();
196
197        Ok(Self {
198            rank,
199            world_size,
200            connections,
201        })
202    }
203}
204
205impl Backend for TcpBackend {
206    fn rank(&self) -> usize {
207        self.rank
208    }
209
210    fn world_size(&self) -> usize {
211        self.world_size
212    }
213
214    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
215        if dst_rank == self.rank {
216            return Err(DistributedError::SelfSend { rank: self.rank }.into());
217        }
218        if dst_rank >= self.world_size {
219            return Err(DistributedError::InvalidRank {
220                rank: dst_rank,
221                world_size: self.world_size,
222            }
223            .into());
224        }
225
226        let conn = self.connections[dst_rank]
227            .as_ref()
228            .ok_or(DistributedError::NoConnection { rank: dst_rank })?;
229
230        let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
231            message: format!("send to rank {dst_rank}: {e}"),
232        })?;
233
234        // Length-prefixed protocol: send length (8 bytes LE) then payload.
235        let len_bytes = (data.len() as u64).to_le_bytes();
236        stream
237            .write_all(&len_bytes)
238            .map_err(|e| DistributedError::Io {
239                message: format!("send len to rank {dst_rank}: {e}"),
240            })?;
241        stream.write_all(data).map_err(|e| DistributedError::Io {
242            message: format!("send data to rank {dst_rank}: {e}"),
243        })?;
244        stream.flush().map_err(|e| DistributedError::Io {
245            message: format!("flush to rank {dst_rank}: {e}"),
246        })?;
247
248        Ok(())
249    }
250
251    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
252        if src_rank == self.rank {
253            return Err(DistributedError::SelfSend { rank: self.rank }.into());
254        }
255        if src_rank >= self.world_size {
256            return Err(DistributedError::InvalidRank {
257                rank: src_rank,
258                world_size: self.world_size,
259            }
260            .into());
261        }
262
263        let conn = self.connections[src_rank]
264            .as_ref()
265            .ok_or(DistributedError::NoConnection { rank: src_rank })?;
266
267        let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
268            message: format!("recv from rank {src_rank}: {e}"),
269        })?;
270
271        // Read length prefix.
272        let mut len_bytes = [0u8; 8];
273        stream
274            .read_exact(&mut len_bytes)
275            .map_err(|e| DistributedError::Io {
276                message: format!("recv len from rank {src_rank}: {e}"),
277            })?;
278        let len = u64::from_le_bytes(len_bytes) as usize;
279
280        if len != dst.len() {
281            return Err(DistributedError::SizeMismatch {
282                expected: dst.len(),
283                got: len,
284            }
285            .into());
286        }
287
288        stream.read_exact(dst).map_err(|e| DistributedError::Io {
289            message: format!("recv data from rank {src_rank}: {e}"),
290        })?;
291
292        Ok(())
293    }
294
295    fn recv_timeout(
296        &self,
297        dst: &mut [u8],
298        src_rank: usize,
299        timeout: Duration,
300    ) -> FerrotorchResult<()> {
301        if src_rank == self.rank {
302            return Err(DistributedError::SelfSend { rank: self.rank }.into());
303        }
304        if src_rank >= self.world_size {
305            return Err(DistributedError::InvalidRank {
306                rank: src_rank,
307                world_size: self.world_size,
308            }
309            .into());
310        }
311
312        let conn = self.connections[src_rank]
313            .as_ref()
314            .ok_or(DistributedError::NoConnection { rank: src_rank })?;
315
316        let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
317            message: format!("recv_timeout from rank {src_rank}: {e}"),
318        })?;
319
320        // Set the read timeout for this operation.
321        stream
322            .set_read_timeout(Some(timeout))
323            .map_err(|e| DistributedError::Io {
324                message: format!("set_read_timeout for rank {src_rank}: {e}"),
325            })?;
326
327        // Read length prefix.
328        let mut len_bytes = [0u8; 8];
329        let result = (|| {
330            stream.read_exact(&mut len_bytes).map_err(|e| {
331                if e.kind() == std::io::ErrorKind::WouldBlock
332                    || e.kind() == std::io::ErrorKind::TimedOut
333                {
334                    DistributedError::Timeout {
335                        seconds: timeout.as_secs(),
336                    }
337                } else {
338                    DistributedError::Io {
339                        message: format!("recv_timeout len from rank {src_rank}: {e}"),
340                    }
341                }
342            })?;
343            let len = u64::from_le_bytes(len_bytes) as usize;
344            if len != dst.len() {
345                return Err(DistributedError::SizeMismatch {
346                    expected: dst.len(),
347                    got: len,
348                });
349            }
350            stream.read_exact(dst).map_err(|e| {
351                if e.kind() == std::io::ErrorKind::WouldBlock
352                    || e.kind() == std::io::ErrorKind::TimedOut
353                {
354                    DistributedError::Timeout {
355                        seconds: timeout.as_secs(),
356                    }
357                } else {
358                    DistributedError::Io {
359                        message: format!("recv_timeout data from rank {src_rank}: {e}"),
360                    }
361                }
362            })?;
363            Ok(())
364        })();
365
366        // Restore blocking mode (no timeout) regardless of outcome.
367        let _ = stream.set_read_timeout(None);
368
369        result.map_err(Into::into)
370    }
371
372    fn barrier(&self) -> FerrotorchResult<()> {
373        // Simple barrier: all ranks send a byte to rank 0, rank 0 waits
374        // for all, then rank 0 sends a byte back to each.
375        let tag = [0u8; 1];
376        if self.rank == 0 {
377            let mut buf = [0u8; 1];
378            for r in 1..self.world_size {
379                self.recv(&mut buf, r)?;
380            }
381            for r in 1..self.world_size {
382                self.send(&tag, r)?;
383            }
384        } else {
385            self.send(&tag, 0)?;
386            let mut buf = [0u8; 1];
387            self.recv(&mut buf, 0)?;
388        }
389        Ok(())
390    }
391}
392
393// ---------------------------------------------------------------------------
394// Simulated backend (in-process, channel-based)
395// ---------------------------------------------------------------------------
396
397/// Shared channel state for all simulated ranks.
398///
399/// `channels[src][dst]` is the `(Sender, Receiver)` pair for messages
400/// from `src` to `dst`.
401type ChannelMatrix = Arc<Vec<Vec<(Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>)>>>;
402
403/// In-process backend using `std::sync::mpsc` channels.
404///
405/// Designed for testing collectives and DDP without spawning processes.
406/// Create all ranks via [`SimulatedBackend::create_group`], which returns
407/// one `SimulatedBackend` per rank.
408pub struct SimulatedBackend {
409    rank: usize,
410    world_size: usize,
411    /// `channels[src][dst]` — sender side is used by `src`, receiver by `dst`.
412    channels: ChannelMatrix,
413}
414
415impl SimulatedBackend {
416    /// Create a group of `world_size` simulated backends, one per rank.
417    ///
418    /// Returns a `Vec<SimulatedBackend>` where index `i` is rank `i`.
419    pub fn create_group(world_size: usize) -> FerrotorchResult<Vec<Self>> {
420        if world_size == 0 {
421            return Err(DistributedError::InvalidWorldSize { world_size }.into());
422        }
423
424        // Build the channel matrix: channels[src][dst].
425        type ChannelPair = (Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>);
426        let mut matrix: Vec<Vec<ChannelPair>> = Vec::new();
427
428        for _src in 0..world_size {
429            let mut row = Vec::new();
430            for _dst in 0..world_size {
431                let (tx, rx) = mpsc::channel();
432                row.push((Mutex::new(tx), Mutex::new(rx)));
433            }
434            matrix.push(row);
435        }
436
437        let shared = Arc::new(matrix);
438
439        let backends: Vec<Self> = (0..world_size)
440            .map(|rank| Self {
441                rank,
442                world_size,
443                channels: Arc::clone(&shared),
444            })
445            .collect();
446
447        Ok(backends)
448    }
449}
450
451impl Backend for SimulatedBackend {
452    fn rank(&self) -> usize {
453        self.rank
454    }
455
456    fn world_size(&self) -> usize {
457        self.world_size
458    }
459
460    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
461        if dst_rank >= self.world_size {
462            return Err(DistributedError::InvalidRank {
463                rank: dst_rank,
464                world_size: self.world_size,
465            }
466            .into());
467        }
468
469        // channels[self.rank][dst_rank].sender
470        let tx = self.channels[self.rank][dst_rank].0.lock().map_err(|e| {
471            DistributedError::LockPoisoned {
472                message: format!("send channel lock rank {} -> {dst_rank}: {e}", self.rank),
473            }
474        })?;
475
476        tx.send(data.to_vec())
477            .map_err(|e| DistributedError::ChannelClosed {
478                message: format!("send rank {} -> {dst_rank}: {e}", self.rank),
479            })?;
480
481        Ok(())
482    }
483
484    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
485        if src_rank >= self.world_size {
486            return Err(DistributedError::InvalidRank {
487                rank: src_rank,
488                world_size: self.world_size,
489            }
490            .into());
491        }
492
493        // channels[src_rank][self.rank].receiver
494        let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
495            DistributedError::LockPoisoned {
496                message: format!("recv channel lock rank {src_rank} -> {}: {e}", self.rank),
497            }
498        })?;
499
500        let data = rx.recv().map_err(|e| DistributedError::ChannelClosed {
501            message: format!("recv rank {src_rank} -> {}: {e}", self.rank),
502        })?;
503
504        if data.len() != dst.len() {
505            return Err(DistributedError::SizeMismatch {
506                expected: dst.len(),
507                got: data.len(),
508            }
509            .into());
510        }
511
512        dst.copy_from_slice(&data);
513        Ok(())
514    }
515
516    fn recv_timeout(
517        &self,
518        dst: &mut [u8],
519        src_rank: usize,
520        timeout: Duration,
521    ) -> FerrotorchResult<()> {
522        if src_rank >= self.world_size {
523            return Err(DistributedError::InvalidRank {
524                rank: src_rank,
525                world_size: self.world_size,
526            }
527            .into());
528        }
529
530        let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
531            DistributedError::LockPoisoned {
532                message: format!(
533                    "recv_timeout channel lock rank {src_rank} -> {}: {e}",
534                    self.rank
535                ),
536            }
537        })?;
538
539        let data = rx.recv_timeout(timeout).map_err(|e| match e {
540            mpsc::RecvTimeoutError::Timeout => DistributedError::Timeout {
541                seconds: timeout.as_secs(),
542            },
543            mpsc::RecvTimeoutError::Disconnected => DistributedError::ChannelClosed {
544                message: format!(
545                    "recv_timeout rank {src_rank} -> {}: disconnected",
546                    self.rank
547                ),
548            },
549        })?;
550
551        if data.len() != dst.len() {
552            return Err(DistributedError::SizeMismatch {
553                expected: dst.len(),
554                got: data.len(),
555            }
556            .into());
557        }
558
559        dst.copy_from_slice(&data);
560        Ok(())
561    }
562
563    fn barrier(&self) -> FerrotorchResult<()> {
564        // Same star-topology barrier as TcpBackend: gather at rank 0,
565        // then scatter acknowledgement.
566        let tag = [0u8; 1];
567        if self.rank == 0 {
568            let mut buf = [0u8; 1];
569            for r in 1..self.world_size {
570                self.recv(&mut buf, r)?;
571            }
572            for r in 1..self.world_size {
573                self.send(&tag, r)?;
574            }
575        } else {
576            self.send(&tag, 0)?;
577            let mut buf = [0u8; 1];
578            self.recv(&mut buf, 0)?;
579        }
580        Ok(())
581    }
582}
583
584// ---------------------------------------------------------------------------
585// SubBackend — subgroup adapter
586// ---------------------------------------------------------------------------
587
588/// A backend view that restricts communication to a subset of global ranks.
589///
590/// `SubBackend` wraps a parent [`Backend`] and a list of member global ranks
591/// to form a logical sub-process-group. It translates every `rank()`,
592/// `world_size()`, `send()`, and `recv()` call into the equivalent parent-
593/// backend operation using the subgroup→global rank mapping.
594///
595/// Because `SubBackend` implements the [`Backend`] trait, the existing
596/// [`allreduce`](crate::collective::allreduce),
597/// [`all_gather`](crate::collective::all_gather), and
598/// [`reduce_scatter`](crate::collective::reduce_scatter) collective functions
599/// work on a subgroup without any changes: they only see the `SubBackend`
600/// through the trait, so "rank 0" in the collective means "the first member
601/// of the subgroup" and `world_size` means the subgroup size.
602///
603/// Used by FSDP's [`HybridShard`](crate::fsdp::ShardingStrategy::HybridShard)
604/// strategy to form an intra-node (sharding) subgroup and an inter-node
605/// (replication) subgroup.
606///
607/// CL-327.
608pub struct SubBackend {
609    parent: Arc<dyn Backend>,
610    /// Global ranks that are members of this subgroup, sorted ascending.
611    members: Vec<usize>,
612    /// This process's index within `members` (the local rank).
613    local_rank: usize,
614}
615
616impl SubBackend {
617    /// Create a subgroup view from a parent backend and a list of member
618    /// global ranks.
619    ///
620    /// The caller's rank (read from `parent.rank()`) must be in `members`.
621    /// `members` is sorted and deduplicated before being stored.
622    ///
623    /// # Errors
624    ///
625    /// - [`DistributedError::InvalidRank`] if the parent's rank is not in
626    ///   `members`, or if any member is ≥ parent `world_size`.
627    /// - [`DistributedError::InvalidWorldSize`] if `members` is empty.
628    pub fn new(parent: Arc<dyn Backend>, members: Vec<usize>) -> FerrotorchResult<Self> {
629        if members.is_empty() {
630            return Err(DistributedError::InvalidWorldSize { world_size: 0 }.into());
631        }
632
633        let parent_world = parent.world_size();
634        for &m in &members {
635            if m >= parent_world {
636                return Err(DistributedError::InvalidRank {
637                    rank: m,
638                    world_size: parent_world,
639                }
640                .into());
641            }
642        }
643
644        // Sort and dedup so local rank ordering is deterministic.
645        let mut sorted_members = members;
646        sorted_members.sort_unstable();
647        sorted_members.dedup();
648
649        let parent_rank = parent.rank();
650        let local_rank = sorted_members
651            .iter()
652            .position(|&r| r == parent_rank)
653            .ok_or(DistributedError::InvalidRank {
654                rank: parent_rank,
655                world_size: sorted_members.len(),
656            })?;
657
658        Ok(Self {
659            parent,
660            members: sorted_members,
661            local_rank,
662        })
663    }
664
665    /// Return the global ranks that make up this subgroup, in ascending
666    /// order. The index of this rank's entry is its local rank.
667    pub fn members(&self) -> &[usize] {
668        &self.members
669    }
670
671    /// Map a local (subgroup-relative) rank to its global rank.
672    ///
673    /// # Panics
674    ///
675    /// Panics if `local` is out of bounds.
676    pub fn to_global(&self, local: usize) -> usize {
677        self.members[local]
678    }
679
680    /// Map a global rank to its local subgroup rank, or `None` if the
681    /// global rank is not a member of this subgroup.
682    pub fn to_local(&self, global: usize) -> Option<usize> {
683        self.members.iter().position(|&r| r == global)
684    }
685
686    /// The parent backend this subgroup was derived from.
687    pub fn parent(&self) -> &Arc<dyn Backend> {
688        &self.parent
689    }
690}
691
692impl Backend for SubBackend {
693    fn rank(&self) -> usize {
694        self.local_rank
695    }
696
697    fn world_size(&self) -> usize {
698        self.members.len()
699    }
700
701    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
702        if dst_rank >= self.members.len() {
703            return Err(DistributedError::InvalidRank {
704                rank: dst_rank,
705                world_size: self.members.len(),
706            }
707            .into());
708        }
709        self.parent.send(data, self.members[dst_rank])
710    }
711
712    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
713        if src_rank >= self.members.len() {
714            return Err(DistributedError::InvalidRank {
715                rank: src_rank,
716                world_size: self.members.len(),
717            }
718            .into());
719        }
720        self.parent.recv(dst, self.members[src_rank])
721    }
722
723    fn recv_timeout(
724        &self,
725        dst: &mut [u8],
726        src_rank: usize,
727        timeout: Duration,
728    ) -> FerrotorchResult<()> {
729        if src_rank >= self.members.len() {
730            return Err(DistributedError::InvalidRank {
731                rank: src_rank,
732                world_size: self.members.len(),
733            }
734            .into());
735        }
736        self.parent
737            .recv_timeout(dst, self.members[src_rank], timeout)
738    }
739
740    fn barrier(&self) -> FerrotorchResult<()> {
741        // Gather-scatter barrier within the subgroup: local rank 0 waits
742        // for a byte from every other member, then sends one back. We
743        // use the parent backend's send/recv via the local→global rank
744        // map, so this doesn't conflict with a simultaneous parent-level
745        // barrier as long as the subgroups are non-overlapping at rank 0.
746        let tag = [0u8; 1];
747        let size = self.members.len();
748        if size <= 1 {
749            return Ok(());
750        }
751        if self.local_rank == 0 {
752            let mut buf = [0u8; 1];
753            for r in 1..size {
754                self.recv(&mut buf, r)?;
755            }
756            for r in 1..size {
757                self.send(&tag, r)?;
758            }
759        } else {
760            self.send(&tag, 0)?;
761            let mut buf = [0u8; 1];
762            self.recv(&mut buf, 0)?;
763        }
764        Ok(())
765    }
766}
767
768#[cfg(test)]
769mod tests {
770    use super::*;
771    use std::thread;
772
773    #[test]
774    fn test_simulated_send_recv() {
775        let group = SimulatedBackend::create_group(2).unwrap();
776        let mut iter = group.into_iter();
777        let b0 = Arc::new(iter.next().unwrap());
778        let b1 = Arc::new(iter.next().unwrap());
779
780        let b0c = Arc::clone(&b0);
781        let sender = thread::spawn(move || {
782            b0c.send(&[1, 2, 3, 4], 1).unwrap();
783        });
784
785        let mut buf = [0u8; 4];
786        b1.recv(&mut buf, 0).unwrap();
787        sender.join().unwrap();
788
789        assert_eq!(buf, [1, 2, 3, 4]);
790    }
791
792    #[test]
793    fn test_simulated_barrier() {
794        let group = SimulatedBackend::create_group(4).unwrap();
795        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
796
797        let handles: Vec<_> = arcs
798            .into_iter()
799            .map(|b| {
800                thread::spawn(move || {
801                    b.barrier().unwrap();
802                })
803            })
804            .collect();
805
806        for h in handles {
807            h.join().unwrap();
808        }
809    }
810
811    #[test]
812    fn test_simulated_rank_world_size() {
813        let group = SimulatedBackend::create_group(3).unwrap();
814        assert_eq!(group[0].rank(), 0);
815        assert_eq!(group[1].rank(), 1);
816        assert_eq!(group[2].rank(), 2);
817        assert_eq!(group[0].world_size(), 3);
818    }
819
820    #[test]
821    fn test_invalid_world_size() {
822        let result = SimulatedBackend::create_group(0);
823        assert!(result.is_err());
824    }
825
826    #[test]
827    fn test_send_to_invalid_rank() {
828        let group = SimulatedBackend::create_group(2).unwrap();
829        let result = group[0].send(&[1], 5);
830        assert!(result.is_err());
831    }
832
833    // -----------------------------------------------------------------------
834    // SubBackend tests. CL-327
835    // -----------------------------------------------------------------------
836
837    #[test]
838    fn test_subbackend_local_rank_and_world_size() {
839        // 4 parent ranks, subgroup is {1, 3}. Rank 3 should be local 1.
840        let group = SimulatedBackend::create_group(4).unwrap();
841        let arcs: Vec<Arc<dyn Backend>> = group
842            .into_iter()
843            .map(|b| Arc::new(b) as Arc<dyn Backend>)
844            .collect();
845
846        let sub_for_rank1 = SubBackend::new(Arc::clone(&arcs[1]), vec![1, 3]).unwrap();
847        let sub_for_rank3 = SubBackend::new(Arc::clone(&arcs[3]), vec![1, 3]).unwrap();
848
849        assert_eq!(sub_for_rank1.rank(), 0);
850        assert_eq!(sub_for_rank1.world_size(), 2);
851        assert_eq!(sub_for_rank3.rank(), 1);
852        assert_eq!(sub_for_rank3.world_size(), 2);
853    }
854
855    #[test]
856    fn test_subbackend_global_rank_not_in_members_is_error() {
857        let group = SimulatedBackend::create_group(4).unwrap();
858        let arcs: Vec<Arc<dyn Backend>> = group
859            .into_iter()
860            .map(|b| Arc::new(b) as Arc<dyn Backend>)
861            .collect();
862
863        // Rank 0 tries to create a subgroup that doesn't include itself.
864        let result = SubBackend::new(Arc::clone(&arcs[0]), vec![1, 3]);
865        assert!(result.is_err());
866    }
867
868    #[test]
869    fn test_subbackend_empty_members_is_error() {
870        let group = SimulatedBackend::create_group(2).unwrap();
871        let arc: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
872        let result = SubBackend::new(arc, vec![]);
873        assert!(result.is_err());
874    }
875
876    #[test]
877    fn test_subbackend_send_recv_routes_through_parent() {
878        // 4 parent ranks, subgroup {0, 2}. Local rank 0 sends to local rank 1
879        // (which is global 2).
880        let group = SimulatedBackend::create_group(4).unwrap();
881        let arcs: Vec<Arc<dyn Backend>> = group
882            .into_iter()
883            .map(|b| Arc::new(b) as Arc<dyn Backend>)
884            .collect();
885
886        let sub0 = SubBackend::new(Arc::clone(&arcs[0]), vec![0, 2]).unwrap();
887        let sub2 = SubBackend::new(Arc::clone(&arcs[2]), vec![0, 2]).unwrap();
888
889        let sender = thread::spawn(move || {
890            sub0.send(&[9, 8, 7], 1).unwrap();
891        });
892
893        let mut buf = [0u8; 3];
894        sub2.recv(&mut buf, 0).unwrap();
895        sender.join().unwrap();
896
897        assert_eq!(buf, [9, 8, 7]);
898    }
899
900    #[test]
901    fn test_subbackend_barrier() {
902        // 6 parent ranks, subgroup {0, 2, 4} (even ranks). Barrier runs
903        // only across the subgroup.
904        let group = SimulatedBackend::create_group(6).unwrap();
905        let arcs: Vec<Arc<dyn Backend>> = group
906            .into_iter()
907            .map(|b| Arc::new(b) as Arc<dyn Backend>)
908            .collect();
909
910        let members = vec![0usize, 2, 4];
911
912        let handles: Vec<_> = [0usize, 2, 4]
913            .into_iter()
914            .map(|global_rank| {
915                let parent = Arc::clone(&arcs[global_rank]);
916                let ms = members.clone();
917                thread::spawn(move || {
918                    let sub = SubBackend::new(parent, ms).unwrap();
919                    sub.barrier().unwrap();
920                })
921            })
922            .collect();
923
924        for h in handles {
925            h.join().unwrap();
926        }
927    }
928
929    #[test]
930    fn test_subbackend_to_global_to_local() {
931        let group = SimulatedBackend::create_group(4).unwrap();
932        let arc: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
933        let sub = SubBackend::new(arc, vec![0, 2]).unwrap();
934
935        // Members are sorted, so to_global(0)=0, to_global(1)=2.
936        assert_eq!(sub.to_global(0), 0);
937        assert_eq!(sub.to_global(1), 2);
938        assert_eq!(sub.to_local(0), Some(0));
939        assert_eq!(sub.to_local(2), Some(1));
940        assert_eq!(sub.to_local(1), None);
941        assert_eq!(sub.to_local(3), None);
942    }
943
944    #[test]
945    fn test_subbackend_sorts_and_dedups_members() {
946        let group = SimulatedBackend::create_group(4).unwrap();
947        let arcs: Vec<Arc<dyn Backend>> = group
948            .into_iter()
949            .map(|b| Arc::new(b) as Arc<dyn Backend>)
950            .collect();
951
952        // Pass unsorted + duplicated members; expect sorted + deduped.
953        let sub = SubBackend::new(Arc::clone(&arcs[2]), vec![3, 2, 0, 2, 3]).unwrap();
954        assert_eq!(sub.members(), &[0, 2, 3]);
955        assert_eq!(sub.rank(), 1); // rank 2 maps to local index 1 in sorted order.
956        assert_eq!(sub.world_size(), 3);
957    }
958}