1use std::io::{Read, Write};
4use std::net::{TcpListener, TcpStream};
5use std::sync::Mutex;
6
7use burn::tensor::{Bytes, DType, TensorData};
8
9use super::traits::{CommError, CommResult, Communicator};
10
11const DTYPE_F64: u8 = 1;
12const DTYPE_F32: u8 = 2;
13const DTYPE_F16: u8 = 3;
14const DTYPE_BF16: u8 = 4;
15const DTYPE_I64: u8 = 5;
16const DTYPE_I32: u8 = 6;
17const DTYPE_I16: u8 = 7;
18const DTYPE_I8: u8 = 8;
19const DTYPE_U64: u8 = 9;
20const DTYPE_U32: u8 = 10;
21const DTYPE_U16: u8 = 11;
22const DTYPE_U8: u8 = 12;
23const DTYPE_BOOL: u8 = 13;
24const DTYPE_QFLOAT: u8 = 14;
25
26pub struct TcpComm {
28 rank: usize,
29 world_size: usize,
30 send_stream: Mutex<TcpStream>,
31 recv_stream: Mutex<TcpStream>,
32}
33
34impl TcpComm {
35 pub fn connect(addresses: Vec<String>, rank: usize) -> CommResult<Self> {
37 if addresses.is_empty() {
38 return Err(CommError::InvalidConfig(
39 "addresses must not be empty".to_string(),
40 ));
41 }
42
43 let world_size = addresses.len();
44 if rank >= world_size {
45 return Err(CommError::InvalidConfig(format!(
46 "rank {} >= world_size {}",
47 rank, world_size
48 )));
49 }
50
51 let bind_addr = &addresses[rank];
52 let listener = TcpListener::bind(bind_addr).map_err(|e| {
53 CommError::ConnectionFailed(format!("Failed to bind {}: {}", bind_addr, e))
54 })?;
55
56 let next_rank = (rank + 1) % world_size;
57 let next_addr = &addresses[next_rank];
58 let send_stream = TcpStream::connect(next_addr).map_err(|e| {
59 CommError::ConnectionFailed(format!("Failed to connect to {}: {}", next_addr, e))
60 })?;
61 let _ = send_stream.set_nodelay(true);
62
63 let (recv_stream, _) = listener.accept().map_err(|e| {
64 CommError::ConnectionFailed(format!("Failed to accept on {}: {}", bind_addr, e))
65 })?;
66 let _ = recv_stream.set_nodelay(true);
67
68 Ok(Self {
69 rank,
70 world_size,
71 send_stream: Mutex::new(send_stream),
72 recv_stream: Mutex::new(recv_stream),
73 })
74 }
75}
76
77impl Communicator for TcpComm {
78 fn rank(&self) -> usize {
79 self.rank
80 }
81
82 fn world_size(&self) -> usize {
83 self.world_size
84 }
85
86 fn send(&self, data: &TensorData) -> CommResult<()> {
87 let payload = serialize_tensor_data(data)?;
88 let len = payload.len() as u64;
89
90 let mut stream = self.send_stream.lock().map_err(|_| {
91 CommError::SendFailed("Failed to acquire send lock".to_string())
92 })?;
93
94 stream
95 .write_all(&len.to_le_bytes())
96 .map_err(|e| CommError::SendFailed(format!("Failed to send length: {}", e)))?;
97 stream
98 .write_all(&payload)
99 .map_err(|e| CommError::SendFailed(format!("Failed to send payload: {}", e)))?;
100 stream
101 .flush()
102 .map_err(|e| CommError::SendFailed(format!("Failed to flush: {}", e)))?;
103
104 Ok(())
105 }
106
107 fn recv(&self) -> CommResult<TensorData> {
108 let mut stream = self.recv_stream.lock().map_err(|_| {
109 CommError::RecvFailed("Failed to acquire recv lock".to_string())
110 })?;
111
112 let mut len_buf = [0u8; 8];
113 stream
114 .read_exact(&mut len_buf)
115 .map_err(|e| CommError::RecvFailed(format!("Failed to read length: {}", e)))?;
116 let len = u64::from_le_bytes(len_buf) as usize;
117
118 let mut buf = vec![0u8; len];
119 stream
120 .read_exact(&mut buf)
121 .map_err(|e| CommError::RecvFailed(format!("Failed to read payload: {}", e)))?;
122
123 deserialize_tensor_data(&buf)
124 }
125
126 fn barrier(&self) -> CommResult<()> {
127 if self.world_size <= 1 {
128 return Ok(());
129 }
130
131 let token = TensorData::new(Vec::<u8>::new(), [0]);
132
133 if self.rank == 0 {
134 self.send(&token)?;
135 let _ = self.recv()?;
136 self.send(&token)?;
137 let _ = self.recv()?;
138 } else {
139 let _ = self.recv()?;
140 self.send(&token)?;
141 let _ = self.recv()?;
142 self.send(&token)?;
143 }
144
145 Ok(())
146 }
147}
148
149fn serialize_tensor_data(data: &TensorData) -> CommResult<Vec<u8>> {
150 let mut buf = Vec::new();
151 encode_dtype(&mut buf, data.dtype)?;
152
153 write_u32(&mut buf, data.shape.len() as u32);
154 for dim in &data.shape {
155 write_u64(&mut buf, *dim as u64);
156 }
157
158 let bytes = data.as_bytes();
159 write_u64(&mut buf, bytes.len() as u64);
160 buf.extend_from_slice(bytes);
161
162 Ok(buf)
163}
164
165fn deserialize_tensor_data(buf: &[u8]) -> CommResult<TensorData> {
166 let mut pos = 0usize;
167 let dtype = decode_dtype(buf, &mut pos)?;
168
169 let shape_len = read_u32(buf, &mut pos)? as usize;
170 let mut shape = Vec::with_capacity(shape_len);
171 for _ in 0..shape_len {
172 shape.push(read_u64(buf, &mut pos)? as usize);
173 }
174
175 let bytes_len = read_u64(buf, &mut pos)? as usize;
176 let bytes = read_slice(buf, &mut pos, bytes_len)?.to_vec();
177
178 Ok(TensorData::from_bytes(Bytes::from_bytes_vec(bytes), shape, dtype))
179}
180
181fn encode_dtype(buf: &mut Vec<u8>, dtype: DType) -> CommResult<()> {
182 match dtype {
183 DType::F64 => buf.push(DTYPE_F64),
184 DType::F32 => buf.push(DTYPE_F32),
185 DType::F16 => buf.push(DTYPE_F16),
186 DType::BF16 => buf.push(DTYPE_BF16),
187 DType::I64 => buf.push(DTYPE_I64),
188 DType::I32 => buf.push(DTYPE_I32),
189 DType::I16 => buf.push(DTYPE_I16),
190 DType::I8 => buf.push(DTYPE_I8),
191 DType::U64 => buf.push(DTYPE_U64),
192 DType::U32 => buf.push(DTYPE_U32),
193 DType::U16 => buf.push(DTYPE_U16),
194 DType::U8 => buf.push(DTYPE_U8),
195 DType::Bool => buf.push(DTYPE_BOOL),
196 _ => {
198 return Err(CommError::Serialization(
199 "Quantized types not supported in TCP comm".to_string(),
200 ));
201 }
202 }
203
204 Ok(())
205}
206
207fn decode_dtype(buf: &[u8], pos: &mut usize) -> CommResult<DType> {
208 let tag = read_u8(buf, pos)?;
209 match tag {
210 DTYPE_F64 => Ok(DType::F64),
211 DTYPE_F32 => Ok(DType::F32),
212 DTYPE_F16 => Ok(DType::F16),
213 DTYPE_BF16 => Ok(DType::BF16),
214 DTYPE_I64 => Ok(DType::I64),
215 DTYPE_I32 => Ok(DType::I32),
216 DTYPE_I16 => Ok(DType::I16),
217 DTYPE_I8 => Ok(DType::I8),
218 DTYPE_U64 => Ok(DType::U64),
219 DTYPE_U32 => Ok(DType::U32),
220 DTYPE_U16 => Ok(DType::U16),
221 DTYPE_U8 => Ok(DType::U8),
222 DTYPE_BOOL => Ok(DType::Bool),
223 DTYPE_QFLOAT => Err(CommError::Serialization(
224 "Quantized types not supported in TCP comm".to_string(),
225 )),
226 _ => Err(CommError::Serialization(format!(
227 "Unknown dtype tag {}",
228 tag
229 ))),
230 }
231}
232
233fn write_u32(buf: &mut Vec<u8>, value: u32) {
234 buf.extend_from_slice(&value.to_le_bytes());
235}
236
237fn write_u64(buf: &mut Vec<u8>, value: u64) {
238 buf.extend_from_slice(&value.to_le_bytes());
239}
240
241fn read_u8(buf: &[u8], pos: &mut usize) -> CommResult<u8> {
242 let slice = read_slice(buf, pos, 1)?;
243 Ok(slice[0])
244}
245
246fn read_u32(buf: &[u8], pos: &mut usize) -> CommResult<u32> {
247 let slice = read_slice(buf, pos, 4)?;
248 Ok(u32::from_le_bytes(slice.try_into().unwrap()))
249}
250
251fn read_u64(buf: &[u8], pos: &mut usize) -> CommResult<u64> {
252 let slice = read_slice(buf, pos, 8)?;
253 Ok(u64::from_le_bytes(slice.try_into().unwrap()))
254}
255
256fn read_slice<'a>(buf: &'a [u8], pos: &mut usize, len: usize) -> CommResult<&'a [u8]> {
257 if *pos + len > buf.len() {
258 return Err(CommError::Serialization(
259 "Buffer too small for decode".to_string(),
260 ));
261 }
262 let slice = &buf[*pos..*pos + len];
263 *pos += len;
264 Ok(slice)
265}