gllm_kernels/comm/
shared_memory.rs

1//! Shared memory communication for single-node ring attention.
2
3use std::thread;
4
5use burn::tensor::TensorData;
6use crossbeam_channel::{unbounded, Receiver, Sender};
7
8use super::traits::{CommError, CommResult, Communicator};
9
10/// A group of shared memory communicators for ring communication.
11pub struct SharedMemoryGroup {
12    comms: Vec<SharedMemoryComm>,
13}
14
15impl SharedMemoryGroup {
16    /// Create a new group of communicators with the specified world size.
17    pub fn new(world_size: usize) -> CommResult<Self> {
18        if world_size == 0 {
19            return Err(CommError::InvalidConfig(
20                "world_size must be > 0".to_string(),
21            ));
22        }
23
24        let (send_txs, recv_rxs) = build_ring_channels(world_size);
25        let (barrier_txs, barrier_rxs) = build_barrier_channels(world_size);
26
27        let mut comms = Vec::with_capacity(world_size);
28        for rank in 0..world_size {
29            let recv_idx = (rank + world_size - 1) % world_size;
30            comms.push(SharedMemoryComm {
31                rank,
32                world_size,
33                send_tx: send_txs[rank].clone(),
34                recv_rx: recv_rxs[recv_idx].clone(),
35                barrier_txs: barrier_txs.clone(),
36                barrier_rx: barrier_rxs[rank].clone(),
37            });
38        }
39
40        Ok(Self { comms })
41    }
42
43    /// Consume the group and return communicators for each rank.
44    pub fn into_comms(self) -> Vec<SharedMemoryComm> {
45        self.comms
46    }
47}
48
49/// Shared memory communicator for a single rank.
50pub struct SharedMemoryComm {
51    rank: usize,
52    world_size: usize,
53    send_tx: Sender<TensorData>,
54    recv_rx: Receiver<TensorData>,
55    barrier_txs: Vec<Sender<()>>,
56    barrier_rx: Receiver<()>,
57}
58
59impl Communicator for SharedMemoryComm {
60    fn rank(&self) -> usize {
61        self.rank
62    }
63
64    fn world_size(&self) -> usize {
65        self.world_size
66    }
67
68    fn send(&self, data: &TensorData) -> CommResult<()> {
69        self.send_tx
70            .send(data.clone())
71            .map_err(|_| CommError::Disconnected)
72    }
73
74    fn recv(&self) -> CommResult<TensorData> {
75        self.recv_rx.recv().map_err(|_| CommError::Disconnected)
76    }
77
78    fn send_recv(&self, send_data: &TensorData) -> CommResult<TensorData> {
79        let send_tx = self.send_tx.clone();
80        let data = send_data.clone();
81
82        let handle = thread::spawn(move || {
83            send_tx
84                .send(data)
85                .map_err(|_| CommError::Disconnected)
86        });
87
88        let recv_result = self.recv();
89        match handle.join() {
90            Ok(send_result) => send_result?,
91            Err(_) => {
92                return Err(CommError::SendFailed(
93                    "send thread panicked".to_string(),
94                ))
95            }
96        }
97
98        recv_result
99    }
100
101    fn barrier(&self) -> CommResult<()> {
102        if self.world_size <= 1 {
103            return Ok(());
104        }
105
106        for (idx, tx) in self.barrier_txs.iter().enumerate() {
107            if idx != self.rank {
108                tx.send(()).map_err(|_| CommError::Disconnected)?;
109            }
110        }
111
112        for _ in 0..(self.world_size - 1) {
113            self.barrier_rx.recv().map_err(|_| CommError::Disconnected)?;
114        }
115
116        Ok(())
117    }
118}
119
120fn build_ring_channels(
121    world_size: usize,
122) -> (Vec<Sender<TensorData>>, Vec<Receiver<TensorData>>) {
123    let mut send_txs = Vec::with_capacity(world_size);
124    let mut recv_rxs = Vec::with_capacity(world_size);
125
126    for _ in 0..world_size {
127        let (tx, rx) = unbounded();
128        send_txs.push(tx);
129        recv_rxs.push(rx);
130    }
131
132    (send_txs, recv_rxs)
133}
134
135fn build_barrier_channels(world_size: usize) -> (Vec<Sender<()>>, Vec<Receiver<()>>) {
136    let mut send_txs = Vec::with_capacity(world_size);
137    let mut recv_rxs = Vec::with_capacity(world_size);
138
139    for _ in 0..world_size {
140        let (tx, rx) = unbounded();
141        send_txs.push(tx);
142        recv_rxs.push(rx);
143    }
144
145    (send_txs, recv_rxs)
146}