gllm_kernels/comm/
tcp.rs

1//! TCP communication for multi-node ring attention.
2
3use std::io::{Read, Write};
4use std::net::{TcpListener, TcpStream};
5use std::sync::Mutex;
6
7use burn::tensor::{Bytes, DType, TensorData};
8
9use super::traits::{CommError, CommResult, Communicator};
10
11const DTYPE_F64: u8 = 1;
12const DTYPE_F32: u8 = 2;
13const DTYPE_F16: u8 = 3;
14const DTYPE_BF16: u8 = 4;
15const DTYPE_I64: u8 = 5;
16const DTYPE_I32: u8 = 6;
17const DTYPE_I16: u8 = 7;
18const DTYPE_I8: u8 = 8;
19const DTYPE_U64: u8 = 9;
20const DTYPE_U32: u8 = 10;
21const DTYPE_U16: u8 = 11;
22const DTYPE_U8: u8 = 12;
23const DTYPE_BOOL: u8 = 13;
24const DTYPE_QFLOAT: u8 = 14;
25
26/// TCP communicator for multi-node ring communication.
27pub struct TcpComm {
28    rank: usize,
29    world_size: usize,
30    send_stream: Mutex<TcpStream>,
31    recv_stream: Mutex<TcpStream>,
32}
33
34impl TcpComm {
35    /// Connect to peers in a ring topology.
36    pub fn connect(addresses: Vec<String>, rank: usize) -> CommResult<Self> {
37        if addresses.is_empty() {
38            return Err(CommError::InvalidConfig(
39                "addresses must not be empty".to_string(),
40            ));
41        }
42
43        let world_size = addresses.len();
44        if rank >= world_size {
45            return Err(CommError::InvalidConfig(format!(
46                "rank {} >= world_size {}",
47                rank, world_size
48            )));
49        }
50
51        let bind_addr = &addresses[rank];
52        let listener = TcpListener::bind(bind_addr).map_err(|e| {
53            CommError::ConnectionFailed(format!("Failed to bind {}: {}", bind_addr, e))
54        })?;
55
56        let next_rank = (rank + 1) % world_size;
57        let next_addr = &addresses[next_rank];
58        let send_stream = TcpStream::connect(next_addr).map_err(|e| {
59            CommError::ConnectionFailed(format!("Failed to connect to {}: {}", next_addr, e))
60        })?;
61        let _ = send_stream.set_nodelay(true);
62
63        let (recv_stream, _) = listener.accept().map_err(|e| {
64            CommError::ConnectionFailed(format!("Failed to accept on {}: {}", bind_addr, e))
65        })?;
66        let _ = recv_stream.set_nodelay(true);
67
68        Ok(Self {
69            rank,
70            world_size,
71            send_stream: Mutex::new(send_stream),
72            recv_stream: Mutex::new(recv_stream),
73        })
74    }
75}
76
77impl Communicator for TcpComm {
78    fn rank(&self) -> usize {
79        self.rank
80    }
81
82    fn world_size(&self) -> usize {
83        self.world_size
84    }
85
86    fn send(&self, data: &TensorData) -> CommResult<()> {
87        let payload = serialize_tensor_data(data)?;
88        let len = payload.len() as u64;
89
90        let mut stream = self.send_stream.lock().map_err(|_| {
91            CommError::SendFailed("Failed to acquire send lock".to_string())
92        })?;
93
94        stream
95            .write_all(&len.to_le_bytes())
96            .map_err(|e| CommError::SendFailed(format!("Failed to send length: {}", e)))?;
97        stream
98            .write_all(&payload)
99            .map_err(|e| CommError::SendFailed(format!("Failed to send payload: {}", e)))?;
100        stream
101            .flush()
102            .map_err(|e| CommError::SendFailed(format!("Failed to flush: {}", e)))?;
103
104        Ok(())
105    }
106
107    fn recv(&self) -> CommResult<TensorData> {
108        let mut stream = self.recv_stream.lock().map_err(|_| {
109            CommError::RecvFailed("Failed to acquire recv lock".to_string())
110        })?;
111
112        let mut len_buf = [0u8; 8];
113        stream
114            .read_exact(&mut len_buf)
115            .map_err(|e| CommError::RecvFailed(format!("Failed to read length: {}", e)))?;
116        let len = u64::from_le_bytes(len_buf) as usize;
117
118        let mut buf = vec![0u8; len];
119        stream
120            .read_exact(&mut buf)
121            .map_err(|e| CommError::RecvFailed(format!("Failed to read payload: {}", e)))?;
122
123        deserialize_tensor_data(&buf)
124    }
125
126    fn barrier(&self) -> CommResult<()> {
127        if self.world_size <= 1 {
128            return Ok(());
129        }
130
131        let token = TensorData::new(Vec::<u8>::new(), [0]);
132
133        if self.rank == 0 {
134            self.send(&token)?;
135            let _ = self.recv()?;
136            self.send(&token)?;
137            let _ = self.recv()?;
138        } else {
139            let _ = self.recv()?;
140            self.send(&token)?;
141            let _ = self.recv()?;
142            self.send(&token)?;
143        }
144
145        Ok(())
146    }
147}
148
149fn serialize_tensor_data(data: &TensorData) -> CommResult<Vec<u8>> {
150    let mut buf = Vec::new();
151    encode_dtype(&mut buf, data.dtype)?;
152
153    write_u32(&mut buf, data.shape.len() as u32);
154    for dim in &data.shape {
155        write_u64(&mut buf, *dim as u64);
156    }
157
158    let bytes = data.as_bytes();
159    write_u64(&mut buf, bytes.len() as u64);
160    buf.extend_from_slice(bytes);
161
162    Ok(buf)
163}
164
165fn deserialize_tensor_data(buf: &[u8]) -> CommResult<TensorData> {
166    let mut pos = 0usize;
167    let dtype = decode_dtype(buf, &mut pos)?;
168
169    let shape_len = read_u32(buf, &mut pos)? as usize;
170    let mut shape = Vec::with_capacity(shape_len);
171    for _ in 0..shape_len {
172        shape.push(read_u64(buf, &mut pos)? as usize);
173    }
174
175    let bytes_len = read_u64(buf, &mut pos)? as usize;
176    let bytes = read_slice(buf, &mut pos, bytes_len)?.to_vec();
177
178    Ok(TensorData::from_bytes(Bytes::from_bytes_vec(bytes), shape, dtype))
179}
180
181fn encode_dtype(buf: &mut Vec<u8>, dtype: DType) -> CommResult<()> {
182    match dtype {
183        DType::F64 => buf.push(DTYPE_F64),
184        DType::F32 => buf.push(DTYPE_F32),
185        DType::F16 => buf.push(DTYPE_F16),
186        DType::BF16 => buf.push(DTYPE_BF16),
187        DType::I64 => buf.push(DTYPE_I64),
188        DType::I32 => buf.push(DTYPE_I32),
189        DType::I16 => buf.push(DTYPE_I16),
190        DType::I8 => buf.push(DTYPE_I8),
191        DType::U64 => buf.push(DTYPE_U64),
192        DType::U32 => buf.push(DTYPE_U32),
193        DType::U16 => buf.push(DTYPE_U16),
194        DType::U8 => buf.push(DTYPE_U8),
195        DType::Bool => buf.push(DTYPE_BOOL),
196        // Quantized types not supported in TCP comm yet
197        _ => {
198            return Err(CommError::Serialization(
199                "Quantized types not supported in TCP comm".to_string(),
200            ));
201        }
202    }
203
204    Ok(())
205}
206
207fn decode_dtype(buf: &[u8], pos: &mut usize) -> CommResult<DType> {
208    let tag = read_u8(buf, pos)?;
209    match tag {
210        DTYPE_F64 => Ok(DType::F64),
211        DTYPE_F32 => Ok(DType::F32),
212        DTYPE_F16 => Ok(DType::F16),
213        DTYPE_BF16 => Ok(DType::BF16),
214        DTYPE_I64 => Ok(DType::I64),
215        DTYPE_I32 => Ok(DType::I32),
216        DTYPE_I16 => Ok(DType::I16),
217        DTYPE_I8 => Ok(DType::I8),
218        DTYPE_U64 => Ok(DType::U64),
219        DTYPE_U32 => Ok(DType::U32),
220        DTYPE_U16 => Ok(DType::U16),
221        DTYPE_U8 => Ok(DType::U8),
222        DTYPE_BOOL => Ok(DType::Bool),
223        DTYPE_QFLOAT => Err(CommError::Serialization(
224            "Quantized types not supported in TCP comm".to_string(),
225        )),
226        _ => Err(CommError::Serialization(format!(
227            "Unknown dtype tag {}",
228            tag
229        ))),
230    }
231}
232
233fn write_u32(buf: &mut Vec<u8>, value: u32) {
234    buf.extend_from_slice(&value.to_le_bytes());
235}
236
237fn write_u64(buf: &mut Vec<u8>, value: u64) {
238    buf.extend_from_slice(&value.to_le_bytes());
239}
240
241fn read_u8(buf: &[u8], pos: &mut usize) -> CommResult<u8> {
242    let slice = read_slice(buf, pos, 1)?;
243    Ok(slice[0])
244}
245
246fn read_u32(buf: &[u8], pos: &mut usize) -> CommResult<u32> {
247    let slice = read_slice(buf, pos, 4)?;
248    Ok(u32::from_le_bytes(slice.try_into().unwrap()))
249}
250
251fn read_u64(buf: &[u8], pos: &mut usize) -> CommResult<u64> {
252    let slice = read_slice(buf, pos, 8)?;
253    Ok(u64::from_le_bytes(slice.try_into().unwrap()))
254}
255
256fn read_slice<'a>(buf: &'a [u8], pos: &mut usize, len: usize) -> CommResult<&'a [u8]> {
257    if *pos + len > buf.len() {
258        return Err(CommError::Serialization(
259            "Buffer too small for decode".to_string(),
260        ));
261    }
262    let slice = &buf[*pos..*pos + len];
263    *pos += len;
264    Ok(slice)
265}