gllm_kernels/comm/
shared_memory.rs1use std::thread;
4
5use burn::tensor::TensorData;
6use crossbeam_channel::{unbounded, Receiver, Sender};
7
8use super::traits::{CommError, CommResult, Communicator};
9
10pub struct SharedMemoryGroup {
12 comms: Vec<SharedMemoryComm>,
13}
14
15impl SharedMemoryGroup {
16 pub fn new(world_size: usize) -> CommResult<Self> {
18 if world_size == 0 {
19 return Err(CommError::InvalidConfig(
20 "world_size must be > 0".to_string(),
21 ));
22 }
23
24 let (send_txs, recv_rxs) = build_ring_channels(world_size);
25 let (barrier_txs, barrier_rxs) = build_barrier_channels(world_size);
26
27 let mut comms = Vec::with_capacity(world_size);
28 for rank in 0..world_size {
29 let recv_idx = (rank + world_size - 1) % world_size;
30 comms.push(SharedMemoryComm {
31 rank,
32 world_size,
33 send_tx: send_txs[rank].clone(),
34 recv_rx: recv_rxs[recv_idx].clone(),
35 barrier_txs: barrier_txs.clone(),
36 barrier_rx: barrier_rxs[rank].clone(),
37 });
38 }
39
40 Ok(Self { comms })
41 }
42
43 pub fn into_comms(self) -> Vec<SharedMemoryComm> {
45 self.comms
46 }
47}
48
49pub struct SharedMemoryComm {
51 rank: usize,
52 world_size: usize,
53 send_tx: Sender<TensorData>,
54 recv_rx: Receiver<TensorData>,
55 barrier_txs: Vec<Sender<()>>,
56 barrier_rx: Receiver<()>,
57}
58
59impl Communicator for SharedMemoryComm {
60 fn rank(&self) -> usize {
61 self.rank
62 }
63
64 fn world_size(&self) -> usize {
65 self.world_size
66 }
67
68 fn send(&self, data: &TensorData) -> CommResult<()> {
69 self.send_tx
70 .send(data.clone())
71 .map_err(|_| CommError::Disconnected)
72 }
73
74 fn recv(&self) -> CommResult<TensorData> {
75 self.recv_rx.recv().map_err(|_| CommError::Disconnected)
76 }
77
78 fn send_recv(&self, send_data: &TensorData) -> CommResult<TensorData> {
79 let send_tx = self.send_tx.clone();
80 let data = send_data.clone();
81
82 let handle = thread::spawn(move || {
83 send_tx
84 .send(data)
85 .map_err(|_| CommError::Disconnected)
86 });
87
88 let recv_result = self.recv();
89 match handle.join() {
90 Ok(send_result) => send_result?,
91 Err(_) => {
92 return Err(CommError::SendFailed(
93 "send thread panicked".to_string(),
94 ))
95 }
96 }
97
98 recv_result
99 }
100
101 fn barrier(&self) -> CommResult<()> {
102 if self.world_size <= 1 {
103 return Ok(());
104 }
105
106 for (idx, tx) in self.barrier_txs.iter().enumerate() {
107 if idx != self.rank {
108 tx.send(()).map_err(|_| CommError::Disconnected)?;
109 }
110 }
111
112 for _ in 0..(self.world_size - 1) {
113 self.barrier_rx.recv().map_err(|_| CommError::Disconnected)?;
114 }
115
116 Ok(())
117 }
118}
119
120fn build_ring_channels(
121 world_size: usize,
122) -> (Vec<Sender<TensorData>>, Vec<Receiver<TensorData>>) {
123 let mut send_txs = Vec::with_capacity(world_size);
124 let mut recv_rxs = Vec::with_capacity(world_size);
125
126 for _ in 0..world_size {
127 let (tx, rx) = unbounded();
128 send_txs.push(tx);
129 recv_rxs.push(rx);
130 }
131
132 (send_txs, recv_rxs)
133}
134
135fn build_barrier_channels(world_size: usize) -> (Vec<Sender<()>>, Vec<Receiver<()>>) {
136 let mut send_txs = Vec::with_capacity(world_size);
137 let mut recv_rxs = Vec::with_capacity(world_size);
138
139 for _ in 0..world_size {
140 let (tx, rx) = unbounded();
141 send_txs.push(tx);
142 recv_rxs.push(rx);
143 }
144
145 (send_txs, recv_rxs)
146}