Skip to main content

pmetal_distributed/
ring.rs

1use crate::{
2    DistributedBackend,
3    config::DistributedConfig,
4    error::DistributedError,
5    transport::{TcpTransport, TransportReceiver, TransportSender},
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use bytemuck::cast_slice_mut;
10use tokio::sync::Mutex;
11
12pub struct RingBackend {
13    rank: usize,
14    world_size: usize,
15    sender: Mutex<TransportSender>,
16    receiver: Mutex<TransportReceiver>,
17}
18
19impl RingBackend {
20    pub async fn new(config: DistributedConfig) -> Result<Self> {
21        config.validate()?;
22        let (sender, receiver) = TcpTransport::connect(&config).await?;
23        Ok(Self {
24            rank: config.rank,
25            world_size: config.nodes.len(),
26            sender: Mutex::new(sender),
27            receiver: Mutex::new(receiver),
28        })
29    }
30}
31
32#[async_trait]
33impl DistributedBackend for RingBackend {
34    fn rank(&self) -> usize {
35        self.rank
36    }
37
38    fn world_size(&self) -> usize {
39        self.world_size
40    }
41
42    async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
43        // Validate buffer alignment and size for f32 operations
44        if !buffer.len().is_multiple_of(4) {
45            return Err(DistributedError::Protocol(format!(
46                "Buffer length {} is not a multiple of 4 (f32 size)",
47                buffer.len()
48            ))
49            .into());
50        }
51
52        if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
53            return Err(DistributedError::Protocol(
54                "Buffer is not properly aligned for f32 operations".to_string(),
55            )
56            .into());
57        }
58
59        // Safe cast using bytemuck (validates alignment at compile time for Pod types)
60        let floats: &mut [f32] = cast_slice_mut(buffer);
61        let len = floats.len();
62
63        let chunk_size = len / self.world_size;
64        let remainder = len % self.world_size;
65
66        let get_chunk_range = |idx: usize| -> (usize, usize) {
67            let start = idx * chunk_size + idx.min(remainder);
68            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
69            (start, end)
70        };
71
72        // Lock both transport halves
73        let mut sender = self.sender.lock().await;
74        let mut receiver = self.receiver.lock().await;
75
76        // 1. Scatter-Reduce
77        let mut send_chunk_idx = self.rank;
78        let mut recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
79
80        // Recv buffer
81        let max_chunk_size = chunk_size + 1;
82        let mut recv_buf = vec![0u8; max_chunk_size * 4];
83
84        for _ in 0..self.world_size - 1 {
85            let (s_start, s_end) = get_chunk_range(send_chunk_idx);
86            let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
87
88            // Prepare send data
89            // We need to copy to a temp buffer because floats is borrowed by recv logic?
90            // Actually, we can just send. send takes &[u8].
91            // BUT: We need to use &mut sender and &mut receiver concurrently.
92            // Since we locked them, we own them. We can pass them to concurrent futures.
93            // But MutexGuard is not Send if we hold it across await? No, it is.
94            // Wait, we can't borrow `sender` and `receiver` into `try_join` if they are MutexGuards?
95            // Yes we can, they are distinct.
96
97            let send_bytes_len = (s_end - s_start) * 4;
98            // Create a temporary send buffer to avoid borrow checker issues with `floats`
99            // (One future reads floats, the other writes floats).
100            // Rust borrow checker will complain if we access `floats` in join! blocks if one is mutable.
101
102            let mut send_buf = vec![0u8; send_bytes_len];
103            unsafe {
104                std::ptr::copy_nonoverlapping(
105                    floats[s_start..s_end].as_ptr() as *const u8,
106                    send_buf.as_mut_ptr(),
107                    send_bytes_len,
108                );
109            }
110
111            let recv_bytes_len = (r_end - r_start) * 4;
112            let recv_slice = &mut recv_buf[..recv_bytes_len];
113
114            // Concurrent Send and Recv
115            let send_fut = sender.send(&send_buf);
116            let recv_fut = receiver.recv(recv_slice);
117
118            tokio::try_join!(send_fut, recv_fut)?;
119
120            // Reduce
121            let recv_floats = unsafe {
122                std::slice::from_raw_parts(recv_slice.as_ptr() as *const f32, r_end - r_start)
123            };
124
125            for i in 0..recv_floats.len() {
126                floats[r_start + i] += recv_floats[i];
127            }
128
129            send_chunk_idx = recv_chunk_idx;
130            recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
131        }
132
133        // 2. All-Gather
134        send_chunk_idx = (self.rank + 1) % self.world_size;
135        recv_chunk_idx = self.rank;
136
137        for _ in 0..self.world_size - 1 {
138            let (s_start, s_end) = get_chunk_range(send_chunk_idx);
139            let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
140
141            let send_bytes_len = (s_end - s_start) * 4;
142            let mut send_buf = vec![0u8; send_bytes_len];
143            unsafe {
144                std::ptr::copy_nonoverlapping(
145                    floats[s_start..s_end].as_ptr() as *const u8,
146                    send_buf.as_mut_ptr(),
147                    send_bytes_len,
148                );
149            }
150
151            let recv_bytes_len = (r_end - r_start) * 4;
152            let recv_slice = &mut recv_buf[..recv_bytes_len];
153
154            let send_fut = sender.send(&send_buf);
155            let recv_fut = receiver.recv(recv_slice);
156
157            tokio::try_join!(send_fut, recv_fut)?;
158
159            // Copy (Gather)
160            unsafe {
161                std::ptr::copy_nonoverlapping(
162                    recv_slice.as_ptr(),
163                    floats[r_start..r_end].as_mut_ptr() as *mut u8,
164                    recv_bytes_len,
165                );
166            }
167
168            send_chunk_idx = recv_chunk_idx;
169            recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
170        }
171
172        Ok(())
173    }
174
175    async fn barrier(&self) -> Result<()> {
176        let mut buf = [0u8; 4]; // Minimum 1 float
177        self.all_reduce(&mut buf).await
178    }
179}