Skip to main content

pmetal_distributed/
ring.rs

1use crate::{
2    DistributedBackend,
3    config::DistributedConfig,
4    error::DistributedError,
5    transport::{TcpTransport, TransportReceiver, TransportSender},
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use tokio::sync::Mutex;
10use zerocopy::{FromBytes, IntoBytes};
11
12pub struct RingBackend {
13    rank: usize,
14    world_size: usize,
15    sender: Mutex<TransportSender>,
16    receiver: Mutex<TransportReceiver>,
17}
18
19impl RingBackend {
20    pub async fn new(config: DistributedConfig) -> Result<Self> {
21        config.validate()?;
22        let (sender, receiver) = TcpTransport::connect(&config).await?;
23        Ok(Self {
24            rank: config.rank,
25            world_size: config.nodes.len(),
26            sender: Mutex::new(sender),
27            receiver: Mutex::new(receiver),
28        })
29    }
30}
31
32#[async_trait]
33impl DistributedBackend for RingBackend {
34    fn rank(&self) -> usize {
35        self.rank
36    }
37
38    fn world_size(&self) -> usize {
39        self.world_size
40    }
41
42    async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
43        // Validate buffer alignment and size for f32 operations
44        if !buffer.len().is_multiple_of(4) {
45            return Err(DistributedError::Protocol(format!(
46                "Buffer length {} is not a multiple of 4 (f32 size)",
47                buffer.len()
48            ))
49            .into());
50        }
51
52        if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
53            return Err(DistributedError::Protocol(
54                "Buffer is not properly aligned for f32 operations".to_string(),
55            )
56            .into());
57        }
58
59        let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
60            .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
61        let len = floats.len();
62
63        let chunk_size = len / self.world_size;
64        let remainder = len % self.world_size;
65
66        let get_chunk_range = |idx: usize| -> (usize, usize) {
67            let start = idx * chunk_size + idx.min(remainder);
68            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
69            (start, end)
70        };
71
72        // Lock both transport halves
73        let mut sender = self.sender.lock().await;
74        let mut receiver = self.receiver.lock().await;
75
76        // 1. Scatter-Reduce
77        let mut send_chunk_idx = self.rank;
78        let mut recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
79
80        // Recv buffer
81        let max_chunk_size = chunk_size + 1;
82        let mut recv_buf = vec![0u8; max_chunk_size * 4];
83
84        for _ in 0..self.world_size - 1 {
85            let (s_start, s_end) = get_chunk_range(send_chunk_idx);
86            let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
87
88            // Prepare send data
89            // We need to copy to a temp buffer because floats is borrowed by recv logic?
90            // Actually, we can just send. send takes &[u8].
91            // BUT: We need to use &mut sender and &mut receiver concurrently.
92            // Since we locked them, we own them. We can pass them to concurrent futures.
93            // But MutexGuard is not Send if we hold it across await? No, it is.
94            // Wait, we can't borrow `sender` and `receiver` into `try_join` if they are MutexGuards?
95            // Yes we can, they are distinct.
96
97            let send_buf = floats[s_start..s_end].as_bytes().to_vec();
98
99            let recv_bytes_len = (r_end - r_start) * 4;
100            let recv_slice = &mut recv_buf[..recv_bytes_len];
101
102            // Concurrent Send and Recv
103            let send_fut = sender.send(&send_buf);
104            let recv_fut = receiver.recv(recv_slice);
105
106            tokio::try_join!(send_fut, recv_fut)?;
107
108            // Reduce
109            let recv_floats =
110                <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
111
112            for i in 0..recv_floats.len() {
113                floats[r_start + i] += recv_floats[i];
114            }
115
116            send_chunk_idx = recv_chunk_idx;
117            recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
118        }
119
120        // 2. All-Gather
121        send_chunk_idx = (self.rank + 1) % self.world_size;
122        recv_chunk_idx = self.rank;
123
124        for _ in 0..self.world_size - 1 {
125            let (s_start, s_end) = get_chunk_range(send_chunk_idx);
126            let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
127
128            let send_buf = floats[s_start..s_end].as_bytes().to_vec();
129
130            let recv_bytes_len = (r_end - r_start) * 4;
131            let recv_slice = &mut recv_buf[..recv_bytes_len];
132
133            let send_fut = sender.send(&send_buf);
134            let recv_fut = receiver.recv(recv_slice);
135
136            tokio::try_join!(send_fut, recv_fut)?;
137
138            // Copy (Gather)
139            let recv_floats =
140                <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
141            floats[r_start..r_end].copy_from_slice(recv_floats);
142
143            send_chunk_idx = recv_chunk_idx;
144            recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
145        }
146
147        Ok(())
148    }
149
150    async fn barrier(&self) -> Result<()> {
151        let mut buf = [0u8; 4]; // Minimum 1 float
152        self.all_reduce(&mut buf).await
153    }
154}
155
156#[cfg(kani)]
157mod verification {
158    use super::*;
159
160    #[kani::proof]
161    #[kani::unwind(17)] // Sufficient for world_size up to 16 for testing
162    fn verify_get_chunk_range() {
163        let len: usize = kani::any();
164        let world_size: usize = kani::any();
165
166        // Preconditions
167        kani::assume(world_size > 0 && world_size <= 16);
168        kani::assume(len >= world_size && len < 1024);
169
170        let chunk_size = len / world_size;
171        let remainder = len % world_size;
172
173        let get_chunk_range = |idx: usize| -> (usize, usize) {
174            let start = idx * chunk_size + idx.min(remainder);
175            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
176            (start, end)
177        };
178
179        let mut total_elements = 0;
180        let mut last_end = 0;
181
182        for i in 0..world_size {
183            let (start, end) = get_chunk_range(i);
184
185            // Chunks must be valid ranges
186            assert!(start <= end);
187            // Chunks must be contiguous
188            assert!(start == last_end);
189            // Chunks must be within bounds
190            assert!(end <= len);
191
192            total_elements += end - start;
193            last_end = end;
194        }
195
196        // Total elements must match original length
197        assert!(total_elements == len);
198        assert!(last_end == len);
199    }
200}