pmetal_distributed/
ring.rs1use crate::{
2 DistributedBackend,
3 config::DistributedConfig,
4 error::DistributedError,
5 transport::{TcpTransport, TransportReceiver, TransportSender},
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use tokio::sync::Mutex;
10use zerocopy::{FromBytes, IntoBytes};
11
12pub struct RingBackend {
13 rank: usize,
14 world_size: usize,
15 sender: Mutex<TransportSender>,
16 receiver: Mutex<TransportReceiver>,
17}
18
19impl RingBackend {
20 pub async fn new(config: DistributedConfig) -> Result<Self> {
21 config.validate()?;
22 let (sender, receiver) = TcpTransport::connect(&config).await?;
23 Ok(Self {
24 rank: config.rank,
25 world_size: config.nodes.len(),
26 sender: Mutex::new(sender),
27 receiver: Mutex::new(receiver),
28 })
29 }
30}
31
32#[async_trait]
33impl DistributedBackend for RingBackend {
34 fn rank(&self) -> usize {
35 self.rank
36 }
37
38 fn world_size(&self) -> usize {
39 self.world_size
40 }
41
42 async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
43 if !buffer.len().is_multiple_of(4) {
45 return Err(DistributedError::Protocol(format!(
46 "Buffer length {} is not a multiple of 4 (f32 size)",
47 buffer.len()
48 ))
49 .into());
50 }
51
52 if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
53 return Err(DistributedError::Protocol(
54 "Buffer is not properly aligned for f32 operations".to_string(),
55 )
56 .into());
57 }
58
59 let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
60 .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
61 let len = floats.len();
62
63 let chunk_size = len / self.world_size;
64 let remainder = len % self.world_size;
65
66 let get_chunk_range = |idx: usize| -> (usize, usize) {
67 let start = idx * chunk_size + idx.min(remainder);
68 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
69 (start, end)
70 };
71
72 let mut sender = self.sender.lock().await;
74 let mut receiver = self.receiver.lock().await;
75
76 let mut send_chunk_idx = self.rank;
78 let mut recv_chunk_idx = (self.rank + self.world_size - 1) % self.world_size;
79
80 let max_chunk_size = chunk_size + 1;
82 let mut recv_buf = vec![0u8; max_chunk_size * 4];
83
84 for _ in 0..self.world_size - 1 {
85 let (s_start, s_end) = get_chunk_range(send_chunk_idx);
86 let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
87
88 let send_buf = floats[s_start..s_end].as_bytes().to_vec();
98
99 let recv_bytes_len = (r_end - r_start) * 4;
100 let recv_slice = &mut recv_buf[..recv_bytes_len];
101
102 let send_fut = sender.send(&send_buf);
104 let recv_fut = receiver.recv(recv_slice);
105
106 tokio::try_join!(send_fut, recv_fut)?;
107
108 let recv_floats =
110 <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
111
112 for i in 0..recv_floats.len() {
113 floats[r_start + i] += recv_floats[i];
114 }
115
116 send_chunk_idx = recv_chunk_idx;
117 recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
118 }
119
120 send_chunk_idx = (self.rank + 1) % self.world_size;
122 recv_chunk_idx = self.rank;
123
124 for _ in 0..self.world_size - 1 {
125 let (s_start, s_end) = get_chunk_range(send_chunk_idx);
126 let (r_start, r_end) = get_chunk_range(recv_chunk_idx);
127
128 let send_buf = floats[s_start..s_end].as_bytes().to_vec();
129
130 let recv_bytes_len = (r_end - r_start) * 4;
131 let recv_slice = &mut recv_buf[..recv_bytes_len];
132
133 let send_fut = sender.send(&send_buf);
134 let recv_fut = receiver.recv(recv_slice);
135
136 tokio::try_join!(send_fut, recv_fut)?;
137
138 let recv_floats =
140 <[f32]>::ref_from_bytes(recv_slice).expect("recv buffer aligned for f32");
141 floats[r_start..r_end].copy_from_slice(recv_floats);
142
143 send_chunk_idx = recv_chunk_idx;
144 recv_chunk_idx = (recv_chunk_idx + self.world_size - 1) % self.world_size;
145 }
146
147 Ok(())
148 }
149
150 async fn barrier(&self) -> Result<()> {
151 let mut buf = [0u8; 4]; self.all_reduce(&mut buf).await
153 }
154}
155
156#[cfg(kani)]
157mod verification {
158 use super::*;
159
160 #[kani::proof]
161 #[kani::unwind(17)] fn verify_get_chunk_range() {
163 let len: usize = kani::any();
164 let world_size: usize = kani::any();
165
166 kani::assume(world_size > 0 && world_size <= 16);
168 kani::assume(len >= world_size && len < 1024);
169
170 let chunk_size = len / world_size;
171 let remainder = len % world_size;
172
173 let get_chunk_range = |idx: usize| -> (usize, usize) {
174 let start = idx * chunk_size + idx.min(remainder);
175 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
176 (start, end)
177 };
178
179 let mut total_elements = 0;
180 let mut last_end = 0;
181
182 for i in 0..world_size {
183 let (start, end) = get_chunk_range(i);
184
185 assert!(start <= end);
187 assert!(start == last_end);
189 assert!(end <= len);
191
192 total_elements += end - start;
193 last_end = end;
194 }
195
196 assert!(total_elements == len);
198 assert!(last_end == len);
199 }
200}