pmetal_distributed/
ring.rs1use crate::{
2 DistributedBackend,
3 config::DistributedConfig,
4 error::DistributedError,
5 transport::{TcpTransport, TransportReceiver, TransportSender},
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use bytemuck::cast_slice_mut;
10use tokio::sync::Mutex;
11
12pub struct RingBackend {
13 rank: usize,
14 world_size: usize,
15 sender: Mutex<TransportSender>,
16 receiver: Mutex<TransportReceiver>,
17}
18
19impl RingBackend {
20 pub async fn new(config: DistributedConfig) -> Result<Self> {
21 config.validate()?;
22 let (sender, receiver) = TcpTransport::connect(&config).await?;
23 Ok(Self {
24 rank: config.rank,
25 world_size: config.nodes.len(),
26 sender: Mutex::new(sender),
27 receiver: Mutex::new(receiver),
28 })
29 }
30}
31
32#[async_trait]
33impl DistributedBackend for RingBackend {
34 fn rank(&self) -> usize {
35 self.rank
36 }
37
38 fn world_size(&self) -> usize {
39 self.world_size
40 }
41
42 async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
43 if !buffer.len().is_multiple_of(4) {
45 return Err(DistributedError::Protocol(format!(
46 "Buffer length {} is not a multiple of 4 (f32 size)",
47 buffer.len()
48 ))
49 .into());
50 }
51
52 if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
53 return Err(DistributedError::Protocol(
54 "Buffer is not properly aligned for f32 operations".to_string(),
55 )
56 .into());
57 }
58
59 let floats: &mut [f32] = cast_slice_mut(buffer);
61 let len = floats.len();
62
63 let chunk_size = len / self.world_size;
64 let remainder = len % self.world_size;
65
66 let get_chunk_range = |idx: usize| -> (usize, usize) {
67 let start = idx * chunk_size + idx.min(remainder);
68 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
69 (start, end)
70 };
71
72 let mut sender = self.sender.lock().await;
74 let mut receiver = self.receiver.lock().await;
75
76 let mut send_chunk_idx = self.rank;
78 let mut recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
79
80 let max_chunk_size = chunk_size + 1;
82 let mut recv_buf = vec![0u8; max_chunk_size * 4];
83
84 for _ in 0..self.world_size - 1 {
85 let (s_start, s_end) = get_chunk_range(send_chunk_idx);
86 let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
87
88 let send_bytes_len = (s_end - s_start) * 4;
98 let mut send_buf = vec![0u8; send_bytes_len];
103 unsafe {
104 std::ptr::copy_nonoverlapping(
105 floats[s_start..s_end].as_ptr() as *const u8,
106 send_buf.as_mut_ptr(),
107 send_bytes_len,
108 );
109 }
110
111 let recv_bytes_len = (r_end - r_start) * 4;
112 let recv_slice = &mut recv_buf[..recv_bytes_len];
113
114 let send_fut = sender.send(&send_buf);
116 let recv_fut = receiver.recv(recv_slice);
117
118 tokio::try_join!(send_fut, recv_fut)?;
119
120 let recv_floats = unsafe {
122 std::slice::from_raw_parts(recv_slice.as_ptr() as *const f32, r_end - r_start)
123 };
124
125 for i in 0..recv_floats.len() {
126 floats[r_start + i] += recv_floats[i];
127 }
128
129 send_chunk_idx = recv_chunk_idx;
130 recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
131 }
132
133 send_chunk_idx = (self.rank + 1) % self.world_size;
135 recv_chunk_idx = self.rank;
136
137 for _ in 0..self.world_size - 1 {
138 let (s_start, s_end) = get_chunk_range(send_chunk_idx);
139 let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
140
141 let send_bytes_len = (s_end - s_start) * 4;
142 let mut send_buf = vec![0u8; send_bytes_len];
143 unsafe {
144 std::ptr::copy_nonoverlapping(
145 floats[s_start..s_end].as_ptr() as *const u8,
146 send_buf.as_mut_ptr(),
147 send_bytes_len,
148 );
149 }
150
151 let recv_bytes_len = (r_end - r_start) * 4;
152 let recv_slice = &mut recv_buf[..recv_bytes_len];
153
154 let send_fut = sender.send(&send_buf);
155 let recv_fut = receiver.recv(recv_slice);
156
157 tokio::try_join!(send_fut, recv_fut)?;
158
159 unsafe {
161 std::ptr::copy_nonoverlapping(
162 recv_slice.as_ptr(),
163 floats[r_start..r_end].as_mut_ptr() as *mut u8,
164 recv_bytes_len,
165 );
166 }
167
168 send_chunk_idx = recv_chunk_idx;
169 recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
170 }
171
172 Ok(())
173 }
174
175 async fn barrier(&self) -> Result<()> {
176 let mut buf = [0u8; 4]; self.all_reduce(&mut buf).await
178 }
179}