1use crate::error::{DistributedError, DistributedResult};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13use tracing::debug;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub enum AllReduceStrategy {
18 Ring,
21 Tree { arity: usize },
24 Centralized,
27 #[default]
29 Auto,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
34pub enum ReduceStrategy {
35 Tree { arity: usize },
37 Direct,
39 #[default]
41 Auto,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
46pub enum BroadcastStrategy {
47 Tree { arity: usize },
49 Direct,
51 #[default]
53 Auto,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CollectiveConfig {
59 pub num_devices: usize,
61 pub local_all_reduce: AllReduceStrategy,
63 pub local_reduce: ReduceStrategy,
65 pub local_broadcast: BroadcastStrategy,
67
68 pub num_nodes: Option<usize>,
71 pub global_all_reduce: Option<AllReduceStrategy>,
73 pub global_reduce: Option<ReduceStrategy>,
75 pub global_broadcast: Option<BroadcastStrategy>,
77
78 pub tree_threshold_bytes: usize,
81 pub tree_arity: usize,
83 pub timeout: Duration,
85}
86
87impl Default for CollectiveConfig {
88 fn default() -> Self {
89 Self {
90 num_devices: 1,
91 local_all_reduce: AllReduceStrategy::Auto,
92 local_reduce: ReduceStrategy::Auto,
93 local_broadcast: BroadcastStrategy::Auto,
94 num_nodes: None,
95 global_all_reduce: None,
96 global_reduce: None,
97 global_broadcast: None,
98 tree_threshold_bytes: 1024 * 1024, tree_arity: 2,
100 timeout: Duration::from_secs(60),
101 }
102 }
103}
104
105impl CollectiveConfig {
106 pub fn single_node(num_devices: usize) -> Self {
108 Self {
109 num_devices,
110 local_all_reduce: AllReduceStrategy::Ring,
111 ..Default::default()
112 }
113 }
114
115 pub fn multi_node(num_devices: usize, num_nodes: usize) -> Self {
117 Self {
118 num_devices,
119 num_nodes: Some(num_nodes),
120 local_all_reduce: AllReduceStrategy::Tree { arity: 2 },
121 global_all_reduce: Some(AllReduceStrategy::Ring),
122 global_reduce: Some(ReduceStrategy::Tree { arity: 2 }),
123 global_broadcast: Some(BroadcastStrategy::Tree { arity: 2 }),
124 ..Default::default()
125 }
126 }
127
128 pub fn validate(&self) -> DistributedResult<()> {
130 if self.num_devices == 0 {
131 return Err(DistributedError::Config("num_devices must be > 0".into()));
132 }
133
134 if let Some(n) = self.num_nodes {
135 if n == 0 {
136 return Err(DistributedError::Config("num_nodes must be > 0".into()));
137 }
138
139 if self.global_all_reduce.is_none()
141 || self.global_reduce.is_none()
142 || self.global_broadcast.is_none()
143 {
144 return Err(DistributedError::Config(
145 "All global strategies must be set for multi-node".into(),
146 ));
147 }
148 }
149
150 if self.tree_arity < 2 {
151 return Err(DistributedError::Config("tree_arity must be >= 2".into()));
152 }
153
154 Ok(())
155 }
156
157 pub fn select_all_reduce(&self, buffer_size: usize, world_size: usize) -> AllReduceStrategy {
159 match self.local_all_reduce {
160 AllReduceStrategy::Auto => {
161 if buffer_size < self.tree_threshold_bytes || world_size < 4 {
162 AllReduceStrategy::Tree {
163 arity: self.tree_arity,
164 }
165 } else {
166 AllReduceStrategy::Ring
167 }
168 }
169 other => other,
170 }
171 }
172}
173
174pub trait CollectiveOps: Send + Sync {
176 fn all_reduce(
178 &self,
179 buffer: &mut [f32],
180 strategy: AllReduceStrategy,
181 ) -> impl std::future::Future<Output = DistributedResult<()>> + Send;
182
183 fn reduce(
185 &self,
186 buffer: &mut [f32],
187 root: usize,
188 strategy: ReduceStrategy,
189 ) -> impl std::future::Future<Output = DistributedResult<()>> + Send;
190
191 fn broadcast(
193 &self,
194 buffer: &mut [f32],
195 root: usize,
196 strategy: BroadcastStrategy,
197 ) -> impl std::future::Future<Output = DistributedResult<()>> + Send;
198}
199
200pub mod ring {
202 use super::*;
203
204 pub async fn all_reduce<S, R>(
210 buffer: &mut [f32],
211 rank: usize,
212 world_size: usize,
213 send: &S,
214 recv: &R,
215 ) -> DistributedResult<()>
216 where
217 S: Fn(
218 &[u8],
219 )
220 -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
221 + Send
222 + Sync,
223 R: Fn(
224 &mut [u8],
225 )
226 -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
227 + Send
228 + Sync,
229 {
230 if world_size < 2 {
231 return Ok(());
232 }
233
234 let len = buffer.len();
235 let chunk_size = len / world_size;
236 let remainder = len % world_size;
237
238 let get_chunk_range = |idx: usize| -> (usize, usize) {
240 let start = idx * chunk_size + idx.min(remainder);
241 let end = start + chunk_size + if idx < remainder { 1 } else { 0 };
242 (start, end)
243 };
244
245 for step in 0..(world_size - 1) {
247 let send_idx = (rank + world_size - step) % world_size;
248 let recv_idx = (rank + world_size - step - 1) % world_size;
249
250 let (send_start, send_end) = get_chunk_range(send_idx);
251 let (recv_start, recv_end) = get_chunk_range(recv_idx);
252
253 let send_bytes: Vec<u8> = buffer[send_start..send_end]
255 .iter()
256 .flat_map(|f| f.to_le_bytes())
257 .collect();
258
259 let recv_len = (recv_end - recv_start) * 4;
260 let mut recv_bytes = vec![0u8; recv_len];
261
262 tokio::try_join!(send(&send_bytes), recv(&mut recv_bytes))?;
264
265 for (i, chunk) in recv_bytes.chunks_exact(4).enumerate() {
267 let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
268 buffer[recv_start + i] += val;
269 }
270 }
271
272 for step in 0..(world_size - 1) {
274 let send_idx = (rank + world_size - step) % world_size;
275 let recv_idx = (rank + world_size - step - 1) % world_size;
276
277 let (send_start, send_end) = get_chunk_range(send_idx);
278 let (recv_start, recv_end) = get_chunk_range(recv_idx);
279
280 let send_bytes: Vec<u8> = buffer[send_start..send_end]
282 .iter()
283 .flat_map(|f| f.to_le_bytes())
284 .collect();
285
286 let recv_len = (recv_end - recv_start) * 4;
287 let mut recv_bytes = vec![0u8; recv_len];
288
289 tokio::try_join!(send(&send_bytes), recv(&mut recv_bytes))?;
291
292 for (i, chunk) in recv_bytes.chunks_exact(4).enumerate() {
294 let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
295 buffer[recv_start + i] = val;
296 }
297 }
298
299 debug!("Ring all-reduce complete: {} elements", len);
300 Ok(())
301 }
302}
303
304pub mod tree {
306
307 #[derive(Debug, Clone, Copy)]
309 pub enum TreeRole {
310 Leaf,
312 Internal { num_children: usize },
314 Root { num_children: usize },
316 }
317
318 pub fn compute_role(rank: usize, world_size: usize, arity: usize) -> TreeRole {
320 if rank == 0 {
321 let num_children = arity.min(world_size - 1);
323 TreeRole::Root { num_children }
324 } else {
325 let first_child = rank * arity + 1;
327 if first_child < world_size {
328 let num_children = (world_size - first_child).min(arity);
329 TreeRole::Internal { num_children }
330 } else {
331 TreeRole::Leaf
332 }
333 }
334 }
335
336 pub fn parent_rank(rank: usize, _arity: usize) -> Option<usize> {
338 if rank == 0 {
339 None
340 } else {
341 Some((rank - 1) / _arity)
342 }
343 }
344
345 pub fn child_ranks(rank: usize, world_size: usize, arity: usize) -> Vec<usize> {
347 let first_child = rank * arity + 1;
348 (first_child..first_child + arity)
349 .filter(|&c| c < world_size)
350 .collect()
351 }
352}
353
354pub mod centralized {
356 use super::*;
357
358 #[allow(clippy::too_many_arguments)]
364 pub async fn all_reduce<S, R>(
365 buffer: &mut [f32],
366 _rank: usize,
367 world_size: usize,
368 is_root: bool,
369 send_to_root: &S,
370 recv_from_root: &R,
371 recv_from_peer: &R,
372 send_to_peer: &S,
373 ) -> DistributedResult<()>
374 where
375 S: Fn(
376 &[u8],
377 )
378 -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
379 + Send
380 + Sync,
381 R: Fn(
382 &mut [u8],
383 )
384 -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
385 + Send
386 + Sync,
387 {
388 if world_size < 2 {
389 return Ok(());
390 }
391
392 let len = buffer.len();
393 let byte_len = len * 4;
394
395 if is_root {
396 let mut recv_buf = vec![0u8; byte_len];
399
400 for _ in 1..world_size {
401 recv_from_peer(&mut recv_buf).await?;
402
403 for (i, chunk) in recv_buf.chunks_exact(4).enumerate() {
405 let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
406 buffer[i] += val;
407 }
408 }
409
410 let send_bytes: Vec<u8> = buffer.iter().flat_map(|f| f.to_le_bytes()).collect();
413
414 for _ in 1..world_size {
415 send_to_peer(&send_bytes).await?;
416 }
417 } else {
418 let send_bytes: Vec<u8> = buffer.iter().flat_map(|f| f.to_le_bytes()).collect();
421 send_to_root(&send_bytes).await?;
422
423 let mut recv_buf = vec![0u8; byte_len];
426 recv_from_root(&mut recv_buf).await?;
427
428 for (i, chunk) in recv_buf.chunks_exact(4).enumerate() {
430 buffer[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
431 }
432 }
433
434 debug!("Centralized all-reduce complete: {} elements", len);
435 Ok(())
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_config_validation() {
445 let mut config = CollectiveConfig::default();
446 assert!(config.validate().is_ok());
447
448 config.num_devices = 0;
449 assert!(config.validate().is_err());
450
451 config.num_devices = 1;
452 config.num_nodes = Some(2);
453 assert!(config.validate().is_err()); config.global_all_reduce = Some(AllReduceStrategy::Ring);
456 config.global_reduce = Some(ReduceStrategy::Tree { arity: 2 });
457 config.global_broadcast = Some(BroadcastStrategy::Tree { arity: 2 });
458 assert!(config.validate().is_ok());
459 }
460
461 #[test]
462 fn test_strategy_selection() {
463 let config = CollectiveConfig {
464 tree_threshold_bytes: 1024,
465 tree_arity: 2,
466 local_all_reduce: AllReduceStrategy::Auto,
467 ..Default::default()
468 };
469
470 let strategy = config.select_all_reduce(512, 4);
472 assert!(matches!(strategy, AllReduceStrategy::Tree { .. }));
473
474 let strategy = config.select_all_reduce(2048, 4);
476 assert!(matches!(strategy, AllReduceStrategy::Ring));
477
478 let strategy = config.select_all_reduce(2048, 2);
480 assert!(matches!(strategy, AllReduceStrategy::Tree { .. }));
481 }
482
483 #[test]
484 fn test_tree_roles() {
485 let world_size = 7;
493 let arity = 2;
494
495 assert!(matches!(
496 tree::compute_role(0, world_size, arity),
497 tree::TreeRole::Root { num_children: 2 }
498 ));
499 assert!(matches!(
500 tree::compute_role(1, world_size, arity),
501 tree::TreeRole::Internal { num_children: 2 }
502 ));
503 assert!(matches!(
504 tree::compute_role(3, world_size, arity),
505 tree::TreeRole::Leaf
506 ));
507
508 assert_eq!(tree::parent_rank(3, arity), Some(1));
509 assert_eq!(tree::parent_rank(1, arity), Some(0));
510 assert_eq!(tree::parent_rank(0, arity), None);
511
512 assert_eq!(tree::child_ranks(0, world_size, arity), vec![1, 2]);
513 assert_eq!(tree::child_ranks(1, world_size, arity), vec![3, 4]);
514 }
515}
516
517#[cfg(kani)]
518mod verification {
519 use super::*;
520
521 #[kani::proof]
522 #[kani::unwind(9)]
523 fn verify_tree_topology() {
524 let world_size: usize = kani::any();
525 let arity: usize = kani::any();
526
527 kani::assume(world_size > 0 && world_size <= 8);
531 kani::assume(arity >= 2 && arity <= 4);
532
533 for rank in 0..world_size {
534 let role = tree::compute_role(rank, world_size, arity);
535 let parent = tree::parent_rank(rank, arity);
536 let children = tree::child_ranks(rank, world_size, arity);
537
538 match role {
539 tree::TreeRole::Root { num_children } => {
540 assert!(rank == 0);
541 assert!(parent.is_none());
542 assert!(children.len() == num_children);
543 }
544 tree::TreeRole::Internal { num_children } => {
545 assert!(rank > 0);
546 assert!(parent.is_some());
547 assert!(children.len() == num_children);
548 assert!(num_children > 0);
549 }
550 tree::TreeRole::Leaf => {
551 assert!(rank > 0);
552 assert!(parent.is_some());
553 assert!(children.is_empty());
554 }
555 }
556
557 for &child in &children {
559 assert!(child < world_size);
560 assert!(child > rank);
561 assert!(tree::parent_rank(child, arity) == Some(rank));
562 }
563
564 if let Some(p) = parent {
565 assert!(p < rank);
566 let p_children = tree::child_ranks(p, world_size, arity);
567 assert!(p_children.contains(&rank));
568 }
569 }
570 }
571}