pmetal_distributed/
ring.rs1use crate::{
2 DistributedBackend, ReduceOp,
3 config::DistributedConfig,
4 error::DistributedError,
5 transport::{TcpTransport, TransportReceiver, TransportSender},
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use std::sync::atomic::{AtomicU64, Ordering};
10use tokio::sync::Mutex;
11use zerocopy::{FromBytes, IntoBytes};
12
13pub struct RingBackend {
14 rank: usize,
15 world_size: usize,
16 sender: Mutex<TransportSender>,
17 receiver: Mutex<TransportReceiver>,
18 barrier_counter: AtomicU64,
22}
23
24impl RingBackend {
25 pub async fn new(config: DistributedConfig) -> Result<Self> {
26 config.validate()?;
27 let (sender, receiver) = TcpTransport::connect(&config).await?;
28 Ok(Self {
29 rank: config.rank,
30 world_size: config.nodes.len(),
31 sender: Mutex::new(sender),
32 receiver: Mutex::new(receiver),
33 barrier_counter: AtomicU64::new(0),
34 })
35 }
36}
37
38#[async_trait]
39impl DistributedBackend for RingBackend {
40 fn rank(&self) -> usize {
41 self.rank
42 }
43
44 fn world_size(&self) -> usize {
45 self.world_size
46 }
47
48 async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()> {
49 if !buffer.len().is_multiple_of(4) {
51 return Err(DistributedError::Protocol(format!(
52 "Buffer length {} is not a multiple of 4 (f32 size)",
53 buffer.len()
54 ))
55 .into());
56 }
57
58 if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
59 return Err(DistributedError::Protocol(
60 "Buffer is not properly aligned for f32 operations".to_string(),
61 )
62 .into());
63 }
64
65 let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
66 .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
67 let len = floats.len();
68
69 let chunk_size = len / self.world_size;
70 let remainder = len % self.world_size;
71
72 let get_chunk_range = |idx: usize| -> (usize, usize) {
73 let start = idx * chunk_size + idx.min(remainder);
74 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
75 (start, end)
76 };
77
78 let mut sender = self.sender.lock().await;
80 let mut receiver = self.receiver.lock().await;
81
82 let mut send_chunk_idx = self.rank;
84 let mut recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
85
86 let max_chunk_size = chunk_size + 1;
88 let mut recv_buf = vec![0u8; max_chunk_size * 4];
89
90 for _ in 0..self.world_size - 1 {
91 let (s_start, s_end) = get_chunk_range(send_chunk_idx);
92 let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
93
94 let send_buf = floats[s_start..s_end].as_bytes().to_vec();
104
105 let recv_bytes_len = (r_end - r_start) * 4;
106 let recv_slice = &mut recv_buf[..recv_bytes_len];
107
108 tokio::time::timeout(std::time::Duration::from_secs(30), async {
111 let send_fut = sender.send(&send_buf);
112 let recv_fut = receiver.recv(recv_slice);
113 tokio::try_join!(send_fut, recv_fut)
114 })
115 .await
116 .map_err(|_| {
117 anyhow::anyhow!(
118 "Ring all-reduce scatter-reduce timed out after 30s — peer may have crashed"
119 )
120 })??;
121
122 let recv_floats =
124 <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
125
126 for i in 0..recv_floats.len() {
127 floats[r_start + i] += recv_floats[i];
128 }
129
130 send_chunk_idx = recv_chunk_idx;
131 recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
132 }
133
134 send_chunk_idx = self.rank;
139 recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
140
141 for _ in 0..self.world_size - 1 {
142 let (s_start, s_end) = get_chunk_range(send_chunk_idx);
143 let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
144
145 let send_buf = floats[s_start..s_end].as_bytes().to_vec();
146
147 let recv_bytes_len = (r_end - r_start) * 4;
148 let recv_slice = &mut recv_buf[..recv_bytes_len];
149
150 tokio::time::timeout(std::time::Duration::from_secs(30), async {
152 let send_fut = sender.send(&send_buf);
153 let recv_fut = receiver.recv(recv_slice);
154 tokio::try_join!(send_fut, recv_fut)
155 })
156 .await
157 .map_err(|_| {
158 anyhow::anyhow!(
159 "Ring all-reduce all-gather timed out after 30s — peer may have crashed"
160 )
161 })??;
162
163 let recv_floats =
165 <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
166 floats[r_start..r_end].copy_from_slice(recv_floats);
167
168 send_chunk_idx = recv_chunk_idx;
169 recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
170 }
171
172 if op == ReduceOp::Mean {
174 let divisor = self.world_size as f32;
175 for f in floats.iter_mut() {
176 *f /= divisor;
177 }
178 }
179
180 Ok(())
181 }
182
183 async fn barrier(&self) -> Result<()> {
197 let world_size = self.world_size;
198 if world_size < 2 {
199 return Ok(());
200 }
201
202 let seq = self.barrier_counter.fetch_add(1, Ordering::SeqCst);
204
205 let mut sender = self.sender.lock().await;
206 let mut receiver = self.receiver.lock().await;
207
208 let token: [u8; 8] = seq.to_le_bytes();
210
211 for _ in 0..world_size - 1 {
214 let mut recv_buf = [0u8; 8];
215 tokio::time::timeout(std::time::Duration::from_secs(30), async {
216 tokio::try_join!(sender.send(&token), receiver.recv(&mut recv_buf))
217 })
218 .await
219 .map_err(|_| {
220 anyhow::anyhow!("Barrier phase-1 timed out after 30s — peer may have crashed")
221 })??;
222
223 let recv_seq = u64::from_le_bytes(recv_buf);
225 if recv_seq != seq {
226 return Err(DistributedError::Protocol(format!(
227 "Barrier sequence mismatch: expected {seq}, got {recv_seq}"
228 ))
229 .into());
230 }
231 }
232
233 let ack_seq = seq.wrapping_add(u64::MAX / 2); let ack_token: [u8; 8] = ack_seq.to_le_bytes();
236
237 for _ in 0..world_size - 1 {
238 let mut recv_buf = [0u8; 8];
239 tokio::time::timeout(std::time::Duration::from_secs(30), async {
240 tokio::try_join!(sender.send(&ack_token), receiver.recv(&mut recv_buf))
241 })
242 .await
243 .map_err(|_| {
244 anyhow::anyhow!("Barrier phase-2 timed out after 30s — peer may have crashed")
245 })??;
246 }
247
248 Ok(())
249 }
250}
251
252#[cfg(kani)]
253mod verification {
254 use super::*;
255
256 #[kani::proof]
257 #[kani::unwind(17)] fn verify_get_chunk_range() {
259 let len: usize = kani::any();
260 let world_size: usize = kani::any();
261
262 kani::assume(world_size > 0 && world_size <= 16);
264 kani::assume(len >= world_size && len < 1024);
265
266 let chunk_size = len / world_size;
267 let remainder = len % world_size;
268
269 let get_chunk_range = |idx: usize| -> (usize, usize) {
270 let start = idx * chunk_size + idx.min(remainder);
271 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
272 (start, end)
273 };
274
275 let mut total_elements = 0;
276 let mut last_end = 0;
277
278 for i in 0..world_size {
279 let (start, end) = get_chunk_range(i);
280
281 assert!(start <= end);
283 assert!(start == last_end);
285 assert!(end <= len);
287
288 total_elements += end - start;
289 last_end = end;
290 }
291
292 assert!(total_elements == len);
294 assert!(last_end == len);
295 }
296}