Skip to main content

pmetal_distributed/
ring.rs

1use crate::{
2    DistributedBackend, ReduceOp,
3    config::DistributedConfig,
4    error::DistributedError,
5    transport::{TcpTransport, TransportReceiver, TransportSender},
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use std::sync::atomic::{AtomicU64, Ordering};
10use tokio::sync::Mutex;
11use zerocopy::{FromBytes, IntoBytes};
12
13pub struct RingBackend {
14    rank: usize,
15    world_size: usize,
16    sender: Mutex<TransportSender>,
17    receiver: Mutex<TransportReceiver>,
18    /// Monotonically increasing counter used to assign unique sequence numbers
19    /// to barrier rounds, preventing stale tokens from a previous barrier from
20    /// being mistaken for tokens from the current one.
21    barrier_counter: AtomicU64,
22}
23
24impl RingBackend {
25    pub async fn new(config: DistributedConfig) -> Result<Self> {
26        config.validate()?;
27        let (sender, receiver) = TcpTransport::connect(&config).await?;
28        Ok(Self {
29            rank: config.rank,
30            world_size: config.nodes.len(),
31            sender: Mutex::new(sender),
32            receiver: Mutex::new(receiver),
33            barrier_counter: AtomicU64::new(0),
34        })
35    }
36}
37
38#[async_trait]
39impl DistributedBackend for RingBackend {
40    fn rank(&self) -> usize {
41        self.rank
42    }
43
44    fn world_size(&self) -> usize {
45        self.world_size
46    }
47
48    async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()> {
49        // Validate buffer alignment and size for f32 operations
50        if !buffer.len().is_multiple_of(4) {
51            return Err(DistributedError::Protocol(format!(
52                "Buffer length {} is not a multiple of 4 (f32 size)",
53                buffer.len()
54            ))
55            .into());
56        }
57
58        if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
59            return Err(DistributedError::Protocol(
60                "Buffer is not properly aligned for f32 operations".to_string(),
61            )
62            .into());
63        }
64
65        let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
66            .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
67        let len = floats.len();
68
69        let chunk_size = len / self.world_size;
70        let remainder = len % self.world_size;
71
72        let get_chunk_range = |idx: usize| -> (usize, usize) {
73            let start = idx * chunk_size + idx.min(remainder);
74            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
75            (start, end)
76        };
77
78        // Lock both transport halves
79        let mut sender = self.sender.lock().await;
80        let mut receiver = self.receiver.lock().await;
81
82        // 1. Scatter-Reduce
83        let mut send_chunk_idx = self.rank;
84        let mut recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
85
86        // Recv buffer
87        let max_chunk_size = chunk_size + 1;
88        let mut recv_buf = vec![0u8; max_chunk_size * 4];
89
90        for _ in 0..self.world_size - 1 {
91            let (s_start, s_end) = get_chunk_range(send_chunk_idx);
92            let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
93
94            // Prepare send data
95            // We need to copy to a temp buffer because floats is borrowed by recv logic?
96            // Actually, we can just send. send takes &[u8].
97            // BUT: We need to use &mut sender and &mut receiver concurrently.
98            // Since we locked them, we own them. We can pass them to concurrent futures.
99            // But MutexGuard is not Send if we hold it across await? No, it is.
100            // Wait, we can't borrow `sender` and `receiver` into `try_join` if they are MutexGuards?
101            // Yes we can, they are distinct.
102
103            let send_buf = floats[s_start..s_end].as_bytes().to_vec();
104
105            let recv_bytes_len = (r_end - r_start) * 4;
106            let recv_slice = &mut recv_buf[..recv_bytes_len];
107
108            // Concurrent Send and Recv with timeout to prevent deadlock
109            // if a peer crashes mid-transfer.
110            tokio::time::timeout(std::time::Duration::from_secs(30), async {
111                let send_fut = sender.send(&send_buf);
112                let recv_fut = receiver.recv(recv_slice);
113                tokio::try_join!(send_fut, recv_fut)
114            })
115            .await
116            .map_err(|_| {
117                anyhow::anyhow!(
118                    "Ring all-reduce scatter-reduce timed out after 30s — peer may have crashed"
119                )
120            })??;
121
122            // Reduce
123            let recv_floats =
124                <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
125
126            for i in 0..recv_floats.len() {
127                floats[r_start + i] += recv_floats[i];
128            }
129
130            send_chunk_idx = recv_chunk_idx;
131            recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
132        }
133
134        // 2. All-Gather — each node sends its own fully-reduced chunk rightward.
135        // After scatter-reduce, node r holds fully-reduced chunk r.
136        // Step 0: send chunk r (own), receive chunk (r-1) from left neighbor.
137        // Step k: send what was received in step k-1.
138        send_chunk_idx = self.rank;
139        recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
140
141        for _ in 0..self.world_size - 1 {
142            let (s_start, s_end) = get_chunk_range(send_chunk_idx);
143            let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
144
145            let send_buf = floats[s_start..s_end].as_bytes().to_vec();
146
147            let recv_bytes_len = (r_end - r_start) * 4;
148            let recv_slice = &mut recv_buf[..recv_bytes_len];
149
150            // Timeout mirrors the scatter-reduce phase.
151            tokio::time::timeout(std::time::Duration::from_secs(30), async {
152                let send_fut = sender.send(&send_buf);
153                let recv_fut = receiver.recv(recv_slice);
154                tokio::try_join!(send_fut, recv_fut)
155            })
156            .await
157            .map_err(|_| {
158                anyhow::anyhow!(
159                    "Ring all-reduce all-gather timed out after 30s — peer may have crashed"
160                )
161            })??;
162
163            // Copy (Gather)
164            let recv_floats =
165                <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
166            floats[r_start..r_end].copy_from_slice(recv_floats);
167
168            send_chunk_idx = recv_chunk_idx;
169            recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
170        }
171
172        // Apply mean reduction: divide by world_size after the ring has summed.
173        if op == ReduceOp::Mean {
174            let divisor = self.world_size as f32;
175            for f in floats.iter_mut() {
176                *f /= divisor;
177            }
178        }
179
180        Ok(())
181    }
182
183    /// Two-phase barrier using a monotonic sequence number.
184    ///
185    /// Phase 1 (propagate): send the barrier token with a unique sequence
186    /// number around the ring; each node forwards it after receiving from its
187    /// predecessor.  When the token reaches the initiator after `world_size - 1`
188    /// hops, every node has observed it.
189    ///
190    /// Phase 2 (acknowledge): send the sequence number back around the ring
191    /// in the same direction to signal completion.  When a node receives the
192    /// acknowledgement it knows all nodes have finished Phase 1 and may proceed.
193    ///
194    /// The monotonic counter prevents tokens from a crashed/slow previous round
195    /// from being mistaken for tokens from the current round.
196    async fn barrier(&self) -> Result<()> {
197        let world_size = self.world_size;
198        if world_size < 2 {
199            return Ok(());
200        }
201
202        // Allocate a fresh sequence number for this barrier invocation.
203        let seq = self.barrier_counter.fetch_add(1, Ordering::SeqCst);
204
205        let mut sender = self.sender.lock().await;
206        let mut receiver = self.receiver.lock().await;
207
208        // Each barrier token is 8 bytes: the little-endian u64 sequence number.
209        let token: [u8; 8] = seq.to_le_bytes();
210
211        // Phase 1: propagate the token around the ring.
212        // Each node forwards after receiving (all world_size - 1 hops).
213        for _ in 0..world_size - 1 {
214            let mut recv_buf = [0u8; 8];
215            tokio::time::timeout(std::time::Duration::from_secs(30), async {
216                tokio::try_join!(sender.send(&token), receiver.recv(&mut recv_buf))
217            })
218            .await
219            .map_err(|_| {
220                anyhow::anyhow!("Barrier phase-1 timed out after 30s — peer may have crashed")
221            })??;
222
223            // Verify the token sequence number to detect stale messages.
224            let recv_seq = u64::from_le_bytes(recv_buf);
225            if recv_seq != seq {
226                return Err(DistributedError::Protocol(format!(
227                    "Barrier sequence mismatch: expected {seq}, got {recv_seq}"
228                ))
229                .into());
230            }
231        }
232
233        // Phase 2: acknowledge completion.
234        let ack_seq = seq.wrapping_add(u64::MAX / 2); // distinct from seq
235        let ack_token: [u8; 8] = ack_seq.to_le_bytes();
236
237        for _ in 0..world_size - 1 {
238            let mut recv_buf = [0u8; 8];
239            tokio::time::timeout(std::time::Duration::from_secs(30), async {
240                tokio::try_join!(sender.send(&ack_token), receiver.recv(&mut recv_buf))
241            })
242            .await
243            .map_err(|_| {
244                anyhow::anyhow!("Barrier phase-2 timed out after 30s — peer may have crashed")
245            })??;
246        }
247
248        Ok(())
249    }
250}
251
252#[cfg(kani)]
253mod verification {
254    use super::*;
255
256    #[kani::proof]
257    #[kani::unwind(17)] // Sufficient for world_size up to 16 for testing
258    fn verify_get_chunk_range() {
259        let len: usize = kani::any();
260        let world_size: usize = kani::any();
261
262        // Preconditions
263        kani::assume(world_size > 0 && world_size <= 16);
264        kani::assume(len >= world_size && len < 1024);
265
266        let chunk_size = len / world_size;
267        let remainder = len % world_size;
268
269        let get_chunk_range = |idx: usize| -> (usize, usize) {
270            let start = idx * chunk_size + idx.min(remainder);
271            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
272            (start, end)
273        };
274
275        let mut total_elements = 0;
276        let mut last_end = 0;
277
278        for i in 0..world_size {
279            let (start, end) = get_chunk_range(i);
280
281            // Chunks must be valid ranges
282            assert!(start <= end);
283            // Chunks must be contiguous
284            assert!(start == last_end);
285            // Chunks must be within bounds
286            assert!(end <= len);
287
288            total_elements += end - start;
289            last_end = end;
290        }
291
292        // Total elements must match original length
293        assert!(total_elements == len);
294        assert!(last_end == len);
295    }
296}