Skip to main content

ferrotorch_distributed/
backend.rs

1//! Communication backends for distributed training.
2//!
3//! A [`Backend`] abstracts point-to-point messaging so that collective
4//! operations and DDP are transport-agnostic. Two implementations are
5//! provided:
6//!
7//! - [`TcpBackend`] — real multi-process backend over TCP sockets.
8//! - [`SimulatedBackend`] — in-process backend using channels, suitable
9//!   for unit tests without spawning multiple processes.
10
11use std::io::{Read, Write};
12use std::net::{TcpListener, TcpStream};
13use std::sync::mpsc::{self, Receiver, Sender};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16
17use ferrotorch_core::FerrotorchResult;
18
19use crate::error::DistributedError;
20
21// ---------------------------------------------------------------------------
22// Backend trait
23// ---------------------------------------------------------------------------
24
25/// Transport-agnostic communication backend.
26///
27/// Every rank in a distributed job holds a `Backend` that can send/receive
28/// raw byte buffers to/from any other rank, plus a collective barrier.
29pub trait Backend: Send + Sync {
30    /// This rank's index in the world (0-based).
31    fn rank(&self) -> usize;
32
33    /// Total number of ranks in the process group.
34    fn world_size(&self) -> usize;
35
36    /// Send `data` to `dst_rank`.
37    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()>;
38
39    /// Receive into `dst` from `src_rank`. The caller must allocate `dst`
40    /// with the correct length before calling.
41    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()>;
42
43    /// Receive into `dst` from `src_rank` with a timeout.
44    ///
45    /// Returns [`DistributedError::Timeout`] if the receive does not
46    /// complete within `timeout`. The default implementation delegates
47    /// to [`recv`](Self::recv) (no timeout).
48    fn recv_timeout(
49        &self,
50        dst: &mut [u8],
51        src_rank: usize,
52        timeout: Duration,
53    ) -> FerrotorchResult<()> {
54        let _ = timeout;
55        self.recv(dst, src_rank)
56    }
57
58    /// Block until every rank has reached this barrier.
59    fn barrier(&self) -> FerrotorchResult<()>;
60}
61
62// ---------------------------------------------------------------------------
63// TCP backend
64// ---------------------------------------------------------------------------
65
66/// Real multi-process backend over TCP sockets.
67///
68/// Uses a simple rendezvous protocol:
69/// 1. Rank 0 listens on `addr` and accepts `world_size - 1` connections.
70/// 2. Non-zero ranks connect to rank 0.
71/// 3. Rank 0 relays addressing information so all ranks establish
72///    pairwise connections.
73///
74/// Each connection is wrapped in a `Mutex` to allow concurrent
75/// send/recv on different pairs from different threads.
76pub struct TcpBackend {
77    rank: usize,
78    world_size: usize,
79    /// One TCP stream per peer, indexed by peer rank. `None` for the
80    /// self-slot (no self-loop) and for peers that are not directly
81    /// connected in the star topology (non-zero ranks only connect to
82    /// rank 0).
83    connections: Vec<Option<Mutex<TcpStream>>>,
84}
85
86impl TcpBackend {
87    /// Launch the TCP rendezvous and return a ready-to-use backend.
88    ///
89    /// * `rank` — this process's rank (0-based).
90    /// * `world_size` — total number of processes.
91    /// * `master_addr` — `host:port` where rank 0 listens.
92    pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
93        if world_size < 2 {
94            return Err(DistributedError::InvalidWorldSize { world_size }.into());
95        }
96        if rank >= world_size {
97            return Err(DistributedError::InvalidRank { rank, world_size }.into());
98        }
99
100        // Phase 1: rank 0 collects one connection per non-zero rank.
101        let mut peer_streams: Vec<Option<TcpStream>> = (0..world_size).map(|_| None).collect();
102
103        if rank == 0 {
104            let listener = TcpListener::bind(master_addr).map_err(|e| DistributedError::Io {
105                message: format!("rank 0 bind {master_addr}: {e}"),
106            })?;
107
108            // Accept connections from ranks 1..world_size-1.
109            for _ in 1..world_size {
110                let (mut stream, _addr) = listener.accept().map_err(|e| DistributedError::Io {
111                    message: format!("rank 0 accept: {e}"),
112                })?;
113                // First 8 bytes: the connecting rank as little-endian u64.
114                let mut rank_buf = [0u8; 8];
115                stream
116                    .read_exact(&mut rank_buf)
117                    .map_err(|e| DistributedError::Io {
118                        message: format!("rank 0 read peer rank: {e}"),
119                    })?;
120                let peer_rank = u64::from_le_bytes(rank_buf) as usize;
121                if peer_rank >= world_size || peer_rank == 0 {
122                    return Err(DistributedError::InvalidRank {
123                        rank: peer_rank,
124                        world_size,
125                    }
126                    .into());
127                }
128                peer_streams[peer_rank] = Some(stream);
129            }
130        } else {
131            // Non-zero rank: connect to rank 0 and announce our rank.
132            let mut stream = TcpStream::connect(master_addr).map_err(|e| DistributedError::Io {
133                message: format!("rank {rank} connect to {master_addr}: {e}"),
134            })?;
135            stream
136                .write_all(&(rank as u64).to_le_bytes())
137                .map_err(|e| DistributedError::Io {
138                    message: format!("rank {rank} announce: {e}"),
139                })?;
140            peer_streams[0] = Some(stream);
141        }
142
143        // Phase 2: rank 0 broadcasts peer addresses so every rank can
144        // form a full mesh. For simplicity in this MVP, we use a star
145        // topology where all traffic goes through rank 0's connections.
146        // A full mesh can be added later.
147
148        // Collect into the final connections vec. For the star topology,
149        // non-zero ranks only have a connection to rank 0, and rank 0 has
150        // connections to all others. The self-slot and unconnected peers
151        // are `None`.
152        let connections: Vec<Option<Mutex<TcpStream>>> = peer_streams
153            .into_iter()
154            .enumerate()
155            .map(|(i, opt)| {
156                if i == rank {
157                    // Self-slot: no self-loop needed.
158                    None
159                } else {
160                    // Some(stream) for connected peers, None for unconnected
161                    // peers (non-zero ranks only connect to rank 0 in star
162                    // topology).
163                    opt.map(Mutex::new)
164                }
165            })
166            .collect();
167
168        Ok(Self {
169            rank,
170            world_size,
171            connections,
172        })
173    }
174}
175
176impl Backend for TcpBackend {
177    fn rank(&self) -> usize {
178        self.rank
179    }
180
181    fn world_size(&self) -> usize {
182        self.world_size
183    }
184
185    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
186        if dst_rank == self.rank {
187            return Err(DistributedError::SelfSend { rank: self.rank }.into());
188        }
189        if dst_rank >= self.world_size {
190            return Err(DistributedError::InvalidRank {
191                rank: dst_rank,
192                world_size: self.world_size,
193            }
194            .into());
195        }
196
197        let conn = self.connections[dst_rank]
198            .as_ref()
199            .ok_or(DistributedError::NoConnection { rank: dst_rank })?;
200
201        let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
202            message: format!("send to rank {dst_rank}: {e}"),
203        })?;
204
205        // Length-prefixed protocol: send length (8 bytes LE) then payload.
206        let len_bytes = (data.len() as u64).to_le_bytes();
207        stream
208            .write_all(&len_bytes)
209            .map_err(|e| DistributedError::Io {
210                message: format!("send len to rank {dst_rank}: {e}"),
211            })?;
212        stream.write_all(data).map_err(|e| DistributedError::Io {
213            message: format!("send data to rank {dst_rank}: {e}"),
214        })?;
215        stream.flush().map_err(|e| DistributedError::Io {
216            message: format!("flush to rank {dst_rank}: {e}"),
217        })?;
218
219        Ok(())
220    }
221
222    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
223        if src_rank == self.rank {
224            return Err(DistributedError::SelfSend { rank: self.rank }.into());
225        }
226        if src_rank >= self.world_size {
227            return Err(DistributedError::InvalidRank {
228                rank: src_rank,
229                world_size: self.world_size,
230            }
231            .into());
232        }
233
234        let conn = self.connections[src_rank]
235            .as_ref()
236            .ok_or(DistributedError::NoConnection { rank: src_rank })?;
237
238        let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
239            message: format!("recv from rank {src_rank}: {e}"),
240        })?;
241
242        // Read length prefix.
243        let mut len_bytes = [0u8; 8];
244        stream
245            .read_exact(&mut len_bytes)
246            .map_err(|e| DistributedError::Io {
247                message: format!("recv len from rank {src_rank}: {e}"),
248            })?;
249        let len = u64::from_le_bytes(len_bytes) as usize;
250
251        if len != dst.len() {
252            return Err(DistributedError::SizeMismatch {
253                expected: dst.len(),
254                got: len,
255            }
256            .into());
257        }
258
259        stream.read_exact(dst).map_err(|e| DistributedError::Io {
260            message: format!("recv data from rank {src_rank}: {e}"),
261        })?;
262
263        Ok(())
264    }
265
266    fn recv_timeout(
267        &self,
268        dst: &mut [u8],
269        src_rank: usize,
270        timeout: Duration,
271    ) -> FerrotorchResult<()> {
272        if src_rank == self.rank {
273            return Err(DistributedError::SelfSend { rank: self.rank }.into());
274        }
275        if src_rank >= self.world_size {
276            return Err(DistributedError::InvalidRank {
277                rank: src_rank,
278                world_size: self.world_size,
279            }
280            .into());
281        }
282
283        let conn = self.connections[src_rank]
284            .as_ref()
285            .ok_or(DistributedError::NoConnection { rank: src_rank })?;
286
287        let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
288            message: format!("recv_timeout from rank {src_rank}: {e}"),
289        })?;
290
291        // Set the read timeout for this operation.
292        stream
293            .set_read_timeout(Some(timeout))
294            .map_err(|e| DistributedError::Io {
295                message: format!("set_read_timeout for rank {src_rank}: {e}"),
296            })?;
297
298        // Read length prefix.
299        let mut len_bytes = [0u8; 8];
300        let result = (|| {
301            stream.read_exact(&mut len_bytes).map_err(|e| {
302                if e.kind() == std::io::ErrorKind::WouldBlock
303                    || e.kind() == std::io::ErrorKind::TimedOut
304                {
305                    DistributedError::Timeout {
306                        seconds: timeout.as_secs(),
307                    }
308                } else {
309                    DistributedError::Io {
310                        message: format!("recv_timeout len from rank {src_rank}: {e}"),
311                    }
312                }
313            })?;
314            let len = u64::from_le_bytes(len_bytes) as usize;
315            if len != dst.len() {
316                return Err(DistributedError::SizeMismatch {
317                    expected: dst.len(),
318                    got: len,
319                });
320            }
321            stream.read_exact(dst).map_err(|e| {
322                if e.kind() == std::io::ErrorKind::WouldBlock
323                    || e.kind() == std::io::ErrorKind::TimedOut
324                {
325                    DistributedError::Timeout {
326                        seconds: timeout.as_secs(),
327                    }
328                } else {
329                    DistributedError::Io {
330                        message: format!("recv_timeout data from rank {src_rank}: {e}"),
331                    }
332                }
333            })?;
334            Ok(())
335        })();
336
337        // Restore blocking mode (no timeout) regardless of outcome.
338        let _ = stream.set_read_timeout(None);
339
340        result.map_err(Into::into)
341    }
342
343    fn barrier(&self) -> FerrotorchResult<()> {
344        // Simple barrier: all ranks send a byte to rank 0, rank 0 waits
345        // for all, then rank 0 sends a byte back to each.
346        let tag = [0u8; 1];
347        if self.rank == 0 {
348            let mut buf = [0u8; 1];
349            for r in 1..self.world_size {
350                self.recv(&mut buf, r)?;
351            }
352            for r in 1..self.world_size {
353                self.send(&tag, r)?;
354            }
355        } else {
356            self.send(&tag, 0)?;
357            let mut buf = [0u8; 1];
358            self.recv(&mut buf, 0)?;
359        }
360        Ok(())
361    }
362}
363
364// ---------------------------------------------------------------------------
365// Simulated backend (in-process, channel-based)
366// ---------------------------------------------------------------------------
367
368/// Shared channel state for all simulated ranks.
369///
370/// `channels[src][dst]` is the `(Sender, Receiver)` pair for messages
371/// from `src` to `dst`.
372type ChannelMatrix = Arc<Vec<Vec<(Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>)>>>;
373
374/// In-process backend using `std::sync::mpsc` channels.
375///
376/// Designed for testing collectives and DDP without spawning processes.
377/// Create all ranks via [`SimulatedBackend::create_group`], which returns
378/// one `SimulatedBackend` per rank.
379pub struct SimulatedBackend {
380    rank: usize,
381    world_size: usize,
382    /// `channels[src][dst]` — sender side is used by `src`, receiver by `dst`.
383    channels: ChannelMatrix,
384}
385
386impl SimulatedBackend {
387    /// Create a group of `world_size` simulated backends, one per rank.
388    ///
389    /// Returns a `Vec<SimulatedBackend>` where index `i` is rank `i`.
390    pub fn create_group(world_size: usize) -> FerrotorchResult<Vec<Self>> {
391        if world_size == 0 {
392            return Err(DistributedError::InvalidWorldSize { world_size }.into());
393        }
394
395        // Build the channel matrix: channels[src][dst].
396        type ChannelPair = (Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>);
397        let mut matrix: Vec<Vec<ChannelPair>> = Vec::new();
398
399        for _src in 0..world_size {
400            let mut row = Vec::new();
401            for _dst in 0..world_size {
402                let (tx, rx) = mpsc::channel();
403                row.push((Mutex::new(tx), Mutex::new(rx)));
404            }
405            matrix.push(row);
406        }
407
408        let shared = Arc::new(matrix);
409
410        let backends: Vec<Self> = (0..world_size)
411            .map(|rank| Self {
412                rank,
413                world_size,
414                channels: Arc::clone(&shared),
415            })
416            .collect();
417
418        Ok(backends)
419    }
420}
421
422impl Backend for SimulatedBackend {
423    fn rank(&self) -> usize {
424        self.rank
425    }
426
427    fn world_size(&self) -> usize {
428        self.world_size
429    }
430
431    fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
432        if dst_rank >= self.world_size {
433            return Err(DistributedError::InvalidRank {
434                rank: dst_rank,
435                world_size: self.world_size,
436            }
437            .into());
438        }
439
440        // channels[self.rank][dst_rank].sender
441        let tx = self.channels[self.rank][dst_rank].0.lock().map_err(|e| {
442            DistributedError::LockPoisoned {
443                message: format!("send channel lock rank {} -> {dst_rank}: {e}", self.rank),
444            }
445        })?;
446
447        tx.send(data.to_vec())
448            .map_err(|e| DistributedError::ChannelClosed {
449                message: format!("send rank {} -> {dst_rank}: {e}", self.rank),
450            })?;
451
452        Ok(())
453    }
454
455    fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
456        if src_rank >= self.world_size {
457            return Err(DistributedError::InvalidRank {
458                rank: src_rank,
459                world_size: self.world_size,
460            }
461            .into());
462        }
463
464        // channels[src_rank][self.rank].receiver
465        let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
466            DistributedError::LockPoisoned {
467                message: format!("recv channel lock rank {src_rank} -> {}: {e}", self.rank),
468            }
469        })?;
470
471        let data = rx.recv().map_err(|e| DistributedError::ChannelClosed {
472            message: format!("recv rank {src_rank} -> {}: {e}", self.rank),
473        })?;
474
475        if data.len() != dst.len() {
476            return Err(DistributedError::SizeMismatch {
477                expected: dst.len(),
478                got: data.len(),
479            }
480            .into());
481        }
482
483        dst.copy_from_slice(&data);
484        Ok(())
485    }
486
487    fn recv_timeout(
488        &self,
489        dst: &mut [u8],
490        src_rank: usize,
491        timeout: Duration,
492    ) -> FerrotorchResult<()> {
493        if src_rank >= self.world_size {
494            return Err(DistributedError::InvalidRank {
495                rank: src_rank,
496                world_size: self.world_size,
497            }
498            .into());
499        }
500
501        let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
502            DistributedError::LockPoisoned {
503                message: format!(
504                    "recv_timeout channel lock rank {src_rank} -> {}: {e}",
505                    self.rank
506                ),
507            }
508        })?;
509
510        let data = rx.recv_timeout(timeout).map_err(|e| match e {
511            mpsc::RecvTimeoutError::Timeout => DistributedError::Timeout {
512                seconds: timeout.as_secs(),
513            },
514            mpsc::RecvTimeoutError::Disconnected => DistributedError::ChannelClosed {
515                message: format!(
516                    "recv_timeout rank {src_rank} -> {}: disconnected",
517                    self.rank
518                ),
519            },
520        })?;
521
522        if data.len() != dst.len() {
523            return Err(DistributedError::SizeMismatch {
524                expected: dst.len(),
525                got: data.len(),
526            }
527            .into());
528        }
529
530        dst.copy_from_slice(&data);
531        Ok(())
532    }
533
534    fn barrier(&self) -> FerrotorchResult<()> {
535        // Same star-topology barrier as TcpBackend: gather at rank 0,
536        // then scatter acknowledgement.
537        let tag = [0u8; 1];
538        if self.rank == 0 {
539            let mut buf = [0u8; 1];
540            for r in 1..self.world_size {
541                self.recv(&mut buf, r)?;
542            }
543            for r in 1..self.world_size {
544                self.send(&tag, r)?;
545            }
546        } else {
547            self.send(&tag, 0)?;
548            let mut buf = [0u8; 1];
549            self.recv(&mut buf, 0)?;
550        }
551        Ok(())
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use std::thread;
559
560    #[test]
561    fn test_simulated_send_recv() {
562        let group = SimulatedBackend::create_group(2).unwrap();
563        let mut iter = group.into_iter();
564        let b0 = Arc::new(iter.next().unwrap());
565        let b1 = Arc::new(iter.next().unwrap());
566
567        let b0c = Arc::clone(&b0);
568        let sender = thread::spawn(move || {
569            b0c.send(&[1, 2, 3, 4], 1).unwrap();
570        });
571
572        let mut buf = [0u8; 4];
573        b1.recv(&mut buf, 0).unwrap();
574        sender.join().unwrap();
575
576        assert_eq!(buf, [1, 2, 3, 4]);
577    }
578
579    #[test]
580    fn test_simulated_barrier() {
581        let group = SimulatedBackend::create_group(4).unwrap();
582        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
583
584        let handles: Vec<_> = arcs
585            .into_iter()
586            .map(|b| {
587                thread::spawn(move || {
588                    b.barrier().unwrap();
589                })
590            })
591            .collect();
592
593        for h in handles {
594            h.join().unwrap();
595        }
596    }
597
598    #[test]
599    fn test_simulated_rank_world_size() {
600        let group = SimulatedBackend::create_group(3).unwrap();
601        assert_eq!(group[0].rank(), 0);
602        assert_eq!(group[1].rank(), 1);
603        assert_eq!(group[2].rank(), 2);
604        assert_eq!(group[0].world_size(), 3);
605    }
606
607    #[test]
608    fn test_invalid_world_size() {
609        let result = SimulatedBackend::create_group(0);
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn test_send_to_invalid_rank() {
615        let group = SimulatedBackend::create_group(2).unwrap();
616        let result = group[0].send(&[1], 5);
617        assert!(result.is_err());
618    }
619}