1use crate::symmetric::{CollectiveError, Rank, SymmetricBuffer, SymmetricTransport};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum ReduceKind {
39 Sum,
40 Mean,
41 Max,
42 Min,
43}
44
45impl ReduceKind {
46 fn fold(self, acc: f32, x: f32) -> f32 {
47 match self {
48 Self::Sum => acc + x,
49 Self::Mean => acc + x, Self::Max => acc.max(x),
51 Self::Min => acc.min(x),
52 }
53 }
54 fn finalize(self, acc: f32, n: usize) -> f32 {
55 match self {
56 Self::Mean => acc / (n as f32),
57 _ => acc,
58 }
59 }
60 fn identity(self) -> f32 {
61 match self {
62 Self::Sum | Self::Mean => 0.0,
63 Self::Max => f32::NEG_INFINITY,
64 Self::Min => f32::INFINITY,
65 }
66 }
67}
68
69pub fn all_reduce<T: SymmetricTransport>(
80 transport: &T,
81 buf: SymmetricBuffer, local: &mut [f32],
83 op: ReduceKind,
84) -> Result<(), CollectiveError> {
85 let elems = buf.len / 4;
86 if local.len() != elems {
87 return Err(CollectiveError::LengthMismatch {
88 expected: elems,
89 got: local.len(),
90 });
91 }
92 let me = transport.this_rank();
93 let n = transport.num_ranks();
94
95 let our_buf = SymmetricBuffer {
97 rank: me,
98 offset: buf.offset,
99 len: buf.len,
100 };
101 let bytes = unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, buf.len) };
102 transport.put(our_buf, bytes)?;
103
104 transport.barrier()?;
106
107 let mut acc: Vec<f32> = vec![op.identity(); elems];
109 let mut scratch_bytes = vec![0u8; buf.len];
110 for r in 0..n {
111 let src = SymmetricBuffer {
112 rank: Rank(r),
113 offset: buf.offset,
114 len: buf.len,
115 };
116 transport.get(src, &mut scratch_bytes)?;
117 let scratch =
118 unsafe { std::slice::from_raw_parts(scratch_bytes.as_ptr() as *const f32, elems) };
119 for (i, &v) in scratch.iter().enumerate() {
120 acc[i] = op.fold(acc[i], v);
121 }
122 }
123 for v in acc.iter_mut() {
124 *v = op.finalize(*v, n as usize);
125 }
126 local.copy_from_slice(&acc);
127 Ok(())
128}
129
130pub fn all_gather<T: SymmetricTransport>(
137 transport: &T,
138 buf: SymmetricBuffer, local: &[f32],
140 output: &mut [f32],
141) -> Result<(), CollectiveError> {
142 let elems_per_rank = buf.len / 4;
143 let n = transport.num_ranks() as usize;
144 if local.len() != elems_per_rank {
145 return Err(CollectiveError::LengthMismatch {
146 expected: elems_per_rank,
147 got: local.len(),
148 });
149 }
150 if output.len() != n * elems_per_rank {
151 return Err(CollectiveError::LengthMismatch {
152 expected: n * elems_per_rank,
153 got: output.len(),
154 });
155 }
156
157 let me = transport.this_rank();
158 let our_buf = SymmetricBuffer {
159 rank: me,
160 offset: buf.offset,
161 len: buf.len,
162 };
163 let bytes = unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, buf.len) };
164 transport.put(our_buf, bytes)?;
165 transport.barrier()?;
166
167 let mut scratch_bytes = vec![0u8; buf.len];
168 for r in 0..n {
169 let src = SymmetricBuffer {
170 rank: Rank(r as u32),
171 offset: buf.offset,
172 len: buf.len,
173 };
174 transport.get(src, &mut scratch_bytes)?;
175 let chunk = unsafe {
176 std::slice::from_raw_parts(scratch_bytes.as_ptr() as *const f32, elems_per_rank)
177 };
178 let dst_start = r * elems_per_rank;
179 output[dst_start..dst_start + elems_per_rank].copy_from_slice(chunk);
180 }
181 Ok(())
182}
183
184pub fn reduce_scatter<T: SymmetricTransport>(
192 transport: &T,
193 buf: SymmetricBuffer,
194 local: &[f32],
195 output: &mut [f32],
196 op: ReduceKind,
197) -> Result<(), CollectiveError> {
198 let total = buf.len / 4;
199 let n = transport.num_ranks() as usize;
200 if !total.is_multiple_of(n) {
201 return Err(CollectiveError::TransportError {
202 reason: format!("reduce_scatter: total elements {total} not divisible by {n} ranks"),
203 });
204 }
205 let chunk = total / n;
206 if local.len() != total {
207 return Err(CollectiveError::LengthMismatch {
208 expected: total,
209 got: local.len(),
210 });
211 }
212 if output.len() != chunk {
213 return Err(CollectiveError::LengthMismatch {
214 expected: chunk,
215 got: output.len(),
216 });
217 }
218
219 let me = transport.this_rank().0 as usize;
222 let mut full = local.to_vec();
223 all_reduce(transport, buf, &mut full, op)?;
224 output.copy_from_slice(&full[me * chunk..(me + 1) * chunk]);
225 Ok(())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::symmetric::LocalTransport;
232
233 #[test]
237 fn all_reduce_sum_across_4_ranks() {
238 let n_ranks = 4u32;
239 let elems = 4usize;
240 let bytes = elems * 4;
241 let ts = LocalTransport::fan_out(n_ranks, bytes);
242 let _buf = SymmetricBuffer {
243 rank: Rank(0),
244 offset: 0,
245 len: bytes,
246 };
247
248 let mut state: Vec<Vec<f32>> = (0..n_ranks).map(|r| vec![(r + 1) as f32; elems]).collect();
250
251 for (r, t) in ts.iter().enumerate() {
257 let our_buf = SymmetricBuffer {
258 rank: Rank(r as u32),
259 offset: 0,
260 len: bytes,
261 };
262 let raw = unsafe { std::slice::from_raw_parts(state[r].as_ptr() as *const u8, bytes) };
263 t.put(our_buf, raw).unwrap();
264 }
265 for (r, t) in ts.iter().enumerate() {
270 let mut acc = vec![0f32; elems];
271 let mut scratch = vec![0u8; bytes];
272 for src_r in 0..n_ranks {
273 let src = SymmetricBuffer {
274 rank: Rank(src_r),
275 offset: 0,
276 len: bytes,
277 };
278 t.get(src, &mut scratch).unwrap();
279 let view =
280 unsafe { std::slice::from_raw_parts(scratch.as_ptr() as *const f32, elems) };
281 for (i, &v) in view.iter().enumerate() {
282 acc[i] += v;
283 }
284 }
285 state[r] = acc;
286 }
287
288 for (r, slot) in state.iter().enumerate() {
289 assert_eq!(slot, &vec![10.0; elems], "rank {r} after all-reduce");
290 }
291 }
292
293 #[test]
294 fn all_gather_concatenates_in_rank_order() {
295 let n_ranks = 3u32;
296 let chunk = 2usize;
297 let bytes = chunk * 4;
298 let ts = LocalTransport::fan_out(n_ranks, bytes);
299 let _buf = SymmetricBuffer {
300 rank: Rank(0),
301 offset: 0,
302 len: bytes,
303 };
304
305 let local: Vec<Vec<f32>> = (0..n_ranks)
307 .map(|r| {
308 let r = r as f32;
309 vec![10.0 * r, 10.0 * r + 1.0]
310 })
311 .collect();
312
313 for (r, t) in ts.iter().enumerate() {
315 let our_buf = SymmetricBuffer {
316 rank: Rank(r as u32),
317 offset: 0,
318 len: bytes,
319 };
320 let raw = unsafe { std::slice::from_raw_parts(local[r].as_ptr() as *const u8, bytes) };
321 t.put(our_buf, raw).unwrap();
322 }
323 for (r_idx, t) in ts.iter().enumerate() {
325 let mut output = vec![0f32; n_ranks as usize * chunk];
326 let mut scratch = vec![0u8; bytes];
327 for src_r in 0..n_ranks {
328 let src = SymmetricBuffer {
329 rank: Rank(src_r),
330 offset: 0,
331 len: bytes,
332 };
333 t.get(src, &mut scratch).unwrap();
334 let view =
335 unsafe { std::slice::from_raw_parts(scratch.as_ptr() as *const f32, chunk) };
336 let dst_start = src_r as usize * chunk;
337 output[dst_start..dst_start + chunk].copy_from_slice(view);
338 }
339 assert_eq!(
340 output,
341 vec![0.0, 1.0, 10.0, 11.0, 20.0, 21.0],
342 "rank {r_idx} after all-gather"
343 );
344 }
345 }
346
347 #[test]
348 fn reduce_kind_max_takes_pointwise_max() {
349 let mut acc = ReduceKind::Max.identity();
350 for v in [3.0, 1.0, 7.0, -2.0] {
351 acc = ReduceKind::Max.fold(acc, v);
352 }
353 assert_eq!(acc, 7.0);
354 }
355
356 #[test]
357 fn reduce_kind_mean_divides_at_finalize() {
358 let mut acc = ReduceKind::Mean.identity();
359 for v in [2.0, 4.0, 6.0, 8.0] {
360 acc = ReduceKind::Mean.fold(acc, v);
361 }
362 assert_eq!(ReduceKind::Mean.finalize(acc, 4), 5.0);
363 }
364}