1use std::io::{Read, Write};
12use std::net::{TcpListener, TcpStream};
13use std::sync::mpsc::{self, Receiver, Sender};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16
17use ferrotorch_core::FerrotorchResult;
18
19use crate::error::DistributedError;
20
21pub trait Backend: Send + Sync {
30 fn rank(&self) -> usize;
32
33 fn world_size(&self) -> usize;
35
36 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()>;
38
39 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()>;
42
43 fn recv_timeout(
49 &self,
50 dst: &mut [u8],
51 src_rank: usize,
52 timeout: Duration,
53 ) -> FerrotorchResult<()> {
54 let _ = timeout;
55 self.recv(dst, src_rank)
56 }
57
58 fn barrier(&self) -> FerrotorchResult<()>;
60}
61
62pub struct TcpBackend {
77 rank: usize,
78 world_size: usize,
79 connections: Vec<Option<Mutex<TcpStream>>>,
84}
85
86impl TcpBackend {
87 pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
93 if world_size < 2 {
94 return Err(DistributedError::InvalidWorldSize { world_size }.into());
95 }
96 if rank >= world_size {
97 return Err(DistributedError::InvalidRank { rank, world_size }.into());
98 }
99
100 let mut peer_streams: Vec<Option<TcpStream>> = (0..world_size).map(|_| None).collect();
102
103 if rank == 0 {
104 let listener = TcpListener::bind(master_addr).map_err(|e| DistributedError::Io {
105 message: format!("rank 0 bind {master_addr}: {e}"),
106 })?;
107
108 for _ in 1..world_size {
110 let (mut stream, _addr) = listener.accept().map_err(|e| DistributedError::Io {
111 message: format!("rank 0 accept: {e}"),
112 })?;
113 let mut rank_buf = [0u8; 8];
115 stream
116 .read_exact(&mut rank_buf)
117 .map_err(|e| DistributedError::Io {
118 message: format!("rank 0 read peer rank: {e}"),
119 })?;
120 let peer_rank = u64::from_le_bytes(rank_buf) as usize;
121 if peer_rank >= world_size || peer_rank == 0 {
122 return Err(DistributedError::InvalidRank {
123 rank: peer_rank,
124 world_size,
125 }
126 .into());
127 }
128 peer_streams[peer_rank] = Some(stream);
129 }
130 } else {
131 let mut stream = TcpStream::connect(master_addr).map_err(|e| DistributedError::Io {
133 message: format!("rank {rank} connect to {master_addr}: {e}"),
134 })?;
135 stream
136 .write_all(&(rank as u64).to_le_bytes())
137 .map_err(|e| DistributedError::Io {
138 message: format!("rank {rank} announce: {e}"),
139 })?;
140 peer_streams[0] = Some(stream);
141 }
142
143 let connections: Vec<Option<Mutex<TcpStream>>> = peer_streams
153 .into_iter()
154 .enumerate()
155 .map(|(i, opt)| {
156 if i == rank {
157 None
159 } else {
160 opt.map(Mutex::new)
164 }
165 })
166 .collect();
167
168 Ok(Self {
169 rank,
170 world_size,
171 connections,
172 })
173 }
174}
175
176impl Backend for TcpBackend {
177 fn rank(&self) -> usize {
178 self.rank
179 }
180
181 fn world_size(&self) -> usize {
182 self.world_size
183 }
184
185 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
186 if dst_rank == self.rank {
187 return Err(DistributedError::SelfSend { rank: self.rank }.into());
188 }
189 if dst_rank >= self.world_size {
190 return Err(DistributedError::InvalidRank {
191 rank: dst_rank,
192 world_size: self.world_size,
193 }
194 .into());
195 }
196
197 let conn = self.connections[dst_rank]
198 .as_ref()
199 .ok_or(DistributedError::NoConnection { rank: dst_rank })?;
200
201 let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
202 message: format!("send to rank {dst_rank}: {e}"),
203 })?;
204
205 let len_bytes = (data.len() as u64).to_le_bytes();
207 stream
208 .write_all(&len_bytes)
209 .map_err(|e| DistributedError::Io {
210 message: format!("send len to rank {dst_rank}: {e}"),
211 })?;
212 stream.write_all(data).map_err(|e| DistributedError::Io {
213 message: format!("send data to rank {dst_rank}: {e}"),
214 })?;
215 stream.flush().map_err(|e| DistributedError::Io {
216 message: format!("flush to rank {dst_rank}: {e}"),
217 })?;
218
219 Ok(())
220 }
221
222 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
223 if src_rank == self.rank {
224 return Err(DistributedError::SelfSend { rank: self.rank }.into());
225 }
226 if src_rank >= self.world_size {
227 return Err(DistributedError::InvalidRank {
228 rank: src_rank,
229 world_size: self.world_size,
230 }
231 .into());
232 }
233
234 let conn = self.connections[src_rank]
235 .as_ref()
236 .ok_or(DistributedError::NoConnection { rank: src_rank })?;
237
238 let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
239 message: format!("recv from rank {src_rank}: {e}"),
240 })?;
241
242 let mut len_bytes = [0u8; 8];
244 stream
245 .read_exact(&mut len_bytes)
246 .map_err(|e| DistributedError::Io {
247 message: format!("recv len from rank {src_rank}: {e}"),
248 })?;
249 let len = u64::from_le_bytes(len_bytes) as usize;
250
251 if len != dst.len() {
252 return Err(DistributedError::SizeMismatch {
253 expected: dst.len(),
254 got: len,
255 }
256 .into());
257 }
258
259 stream.read_exact(dst).map_err(|e| DistributedError::Io {
260 message: format!("recv data from rank {src_rank}: {e}"),
261 })?;
262
263 Ok(())
264 }
265
266 fn recv_timeout(
267 &self,
268 dst: &mut [u8],
269 src_rank: usize,
270 timeout: Duration,
271 ) -> FerrotorchResult<()> {
272 if src_rank == self.rank {
273 return Err(DistributedError::SelfSend { rank: self.rank }.into());
274 }
275 if src_rank >= self.world_size {
276 return Err(DistributedError::InvalidRank {
277 rank: src_rank,
278 world_size: self.world_size,
279 }
280 .into());
281 }
282
283 let conn = self.connections[src_rank]
284 .as_ref()
285 .ok_or(DistributedError::NoConnection { rank: src_rank })?;
286
287 let mut stream = conn.lock().map_err(|e| DistributedError::LockPoisoned {
288 message: format!("recv_timeout from rank {src_rank}: {e}"),
289 })?;
290
291 stream
293 .set_read_timeout(Some(timeout))
294 .map_err(|e| DistributedError::Io {
295 message: format!("set_read_timeout for rank {src_rank}: {e}"),
296 })?;
297
298 let mut len_bytes = [0u8; 8];
300 let result = (|| {
301 stream.read_exact(&mut len_bytes).map_err(|e| {
302 if e.kind() == std::io::ErrorKind::WouldBlock
303 || e.kind() == std::io::ErrorKind::TimedOut
304 {
305 DistributedError::Timeout {
306 seconds: timeout.as_secs(),
307 }
308 } else {
309 DistributedError::Io {
310 message: format!("recv_timeout len from rank {src_rank}: {e}"),
311 }
312 }
313 })?;
314 let len = u64::from_le_bytes(len_bytes) as usize;
315 if len != dst.len() {
316 return Err(DistributedError::SizeMismatch {
317 expected: dst.len(),
318 got: len,
319 });
320 }
321 stream.read_exact(dst).map_err(|e| {
322 if e.kind() == std::io::ErrorKind::WouldBlock
323 || e.kind() == std::io::ErrorKind::TimedOut
324 {
325 DistributedError::Timeout {
326 seconds: timeout.as_secs(),
327 }
328 } else {
329 DistributedError::Io {
330 message: format!("recv_timeout data from rank {src_rank}: {e}"),
331 }
332 }
333 })?;
334 Ok(())
335 })();
336
337 let _ = stream.set_read_timeout(None);
339
340 result.map_err(Into::into)
341 }
342
343 fn barrier(&self) -> FerrotorchResult<()> {
344 let tag = [0u8; 1];
347 if self.rank == 0 {
348 let mut buf = [0u8; 1];
349 for r in 1..self.world_size {
350 self.recv(&mut buf, r)?;
351 }
352 for r in 1..self.world_size {
353 self.send(&tag, r)?;
354 }
355 } else {
356 self.send(&tag, 0)?;
357 let mut buf = [0u8; 1];
358 self.recv(&mut buf, 0)?;
359 }
360 Ok(())
361 }
362}
363
364type ChannelMatrix = Arc<Vec<Vec<(Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>)>>>;
373
374pub struct SimulatedBackend {
380 rank: usize,
381 world_size: usize,
382 channels: ChannelMatrix,
384}
385
386impl SimulatedBackend {
387 pub fn create_group(world_size: usize) -> FerrotorchResult<Vec<Self>> {
391 if world_size == 0 {
392 return Err(DistributedError::InvalidWorldSize { world_size }.into());
393 }
394
395 type ChannelPair = (Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>);
397 let mut matrix: Vec<Vec<ChannelPair>> = Vec::new();
398
399 for _src in 0..world_size {
400 let mut row = Vec::new();
401 for _dst in 0..world_size {
402 let (tx, rx) = mpsc::channel();
403 row.push((Mutex::new(tx), Mutex::new(rx)));
404 }
405 matrix.push(row);
406 }
407
408 let shared = Arc::new(matrix);
409
410 let backends: Vec<Self> = (0..world_size)
411 .map(|rank| Self {
412 rank,
413 world_size,
414 channels: Arc::clone(&shared),
415 })
416 .collect();
417
418 Ok(backends)
419 }
420}
421
422impl Backend for SimulatedBackend {
423 fn rank(&self) -> usize {
424 self.rank
425 }
426
427 fn world_size(&self) -> usize {
428 self.world_size
429 }
430
431 fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
432 if dst_rank >= self.world_size {
433 return Err(DistributedError::InvalidRank {
434 rank: dst_rank,
435 world_size: self.world_size,
436 }
437 .into());
438 }
439
440 let tx = self.channels[self.rank][dst_rank].0.lock().map_err(|e| {
442 DistributedError::LockPoisoned {
443 message: format!("send channel lock rank {} -> {dst_rank}: {e}", self.rank),
444 }
445 })?;
446
447 tx.send(data.to_vec())
448 .map_err(|e| DistributedError::ChannelClosed {
449 message: format!("send rank {} -> {dst_rank}: {e}", self.rank),
450 })?;
451
452 Ok(())
453 }
454
455 fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
456 if src_rank >= self.world_size {
457 return Err(DistributedError::InvalidRank {
458 rank: src_rank,
459 world_size: self.world_size,
460 }
461 .into());
462 }
463
464 let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
466 DistributedError::LockPoisoned {
467 message: format!("recv channel lock rank {src_rank} -> {}: {e}", self.rank),
468 }
469 })?;
470
471 let data = rx.recv().map_err(|e| DistributedError::ChannelClosed {
472 message: format!("recv rank {src_rank} -> {}: {e}", self.rank),
473 })?;
474
475 if data.len() != dst.len() {
476 return Err(DistributedError::SizeMismatch {
477 expected: dst.len(),
478 got: data.len(),
479 }
480 .into());
481 }
482
483 dst.copy_from_slice(&data);
484 Ok(())
485 }
486
487 fn recv_timeout(
488 &self,
489 dst: &mut [u8],
490 src_rank: usize,
491 timeout: Duration,
492 ) -> FerrotorchResult<()> {
493 if src_rank >= self.world_size {
494 return Err(DistributedError::InvalidRank {
495 rank: src_rank,
496 world_size: self.world_size,
497 }
498 .into());
499 }
500
501 let rx = self.channels[src_rank][self.rank].1.lock().map_err(|e| {
502 DistributedError::LockPoisoned {
503 message: format!(
504 "recv_timeout channel lock rank {src_rank} -> {}: {e}",
505 self.rank
506 ),
507 }
508 })?;
509
510 let data = rx.recv_timeout(timeout).map_err(|e| match e {
511 mpsc::RecvTimeoutError::Timeout => DistributedError::Timeout {
512 seconds: timeout.as_secs(),
513 },
514 mpsc::RecvTimeoutError::Disconnected => DistributedError::ChannelClosed {
515 message: format!(
516 "recv_timeout rank {src_rank} -> {}: disconnected",
517 self.rank
518 ),
519 },
520 })?;
521
522 if data.len() != dst.len() {
523 return Err(DistributedError::SizeMismatch {
524 expected: dst.len(),
525 got: data.len(),
526 }
527 .into());
528 }
529
530 dst.copy_from_slice(&data);
531 Ok(())
532 }
533
534 fn barrier(&self) -> FerrotorchResult<()> {
535 let tag = [0u8; 1];
538 if self.rank == 0 {
539 let mut buf = [0u8; 1];
540 for r in 1..self.world_size {
541 self.recv(&mut buf, r)?;
542 }
543 for r in 1..self.world_size {
544 self.send(&tag, r)?;
545 }
546 } else {
547 self.send(&tag, 0)?;
548 let mut buf = [0u8; 1];
549 self.recv(&mut buf, 0)?;
550 }
551 Ok(())
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use std::thread;
559
560 #[test]
561 fn test_simulated_send_recv() {
562 let group = SimulatedBackend::create_group(2).unwrap();
563 let mut iter = group.into_iter();
564 let b0 = Arc::new(iter.next().unwrap());
565 let b1 = Arc::new(iter.next().unwrap());
566
567 let b0c = Arc::clone(&b0);
568 let sender = thread::spawn(move || {
569 b0c.send(&[1, 2, 3, 4], 1).unwrap();
570 });
571
572 let mut buf = [0u8; 4];
573 b1.recv(&mut buf, 0).unwrap();
574 sender.join().unwrap();
575
576 assert_eq!(buf, [1, 2, 3, 4]);
577 }
578
579 #[test]
580 fn test_simulated_barrier() {
581 let group = SimulatedBackend::create_group(4).unwrap();
582 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
583
584 let handles: Vec<_> = arcs
585 .into_iter()
586 .map(|b| {
587 thread::spawn(move || {
588 b.barrier().unwrap();
589 })
590 })
591 .collect();
592
593 for h in handles {
594 h.join().unwrap();
595 }
596 }
597
598 #[test]
599 fn test_simulated_rank_world_size() {
600 let group = SimulatedBackend::create_group(3).unwrap();
601 assert_eq!(group[0].rank(), 0);
602 assert_eq!(group[1].rank(), 1);
603 assert_eq!(group[2].rank(), 2);
604 assert_eq!(group[0].world_size(), 3);
605 }
606
607 #[test]
608 fn test_invalid_world_size() {
609 let result = SimulatedBackend::create_group(0);
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn test_send_to_invalid_rank() {
615 let group = SimulatedBackend::create_group(2).unwrap();
616 let result = group[0].send(&[1], 5);
617 assert!(result.is_err());
618 }
619}