Skip to main content

pmetal_distributed/
collective.rs

1//! Configurable collective operations with pluggable strategies.
2//!
3//! Provides multiple all-reduce, reduce, and broadcast algorithms:
4//! - Ring: Bandwidth-optimal for large tensors
5//! - Tree: Latency-optimal for small tensors
6//! - Centralized: Simple, works for small collectives
7//!
8//! Based on Burn's collective framework for  operations.
9
10use crate::error::{DistributedError, DistributedResult};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13use tracing::debug;
14
15/// All-reduce algorithm strategy.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub enum AllReduceStrategy {
18    /// Ring all-reduce: O(n) latency, O(1) bandwidth per node.
19    /// Best for large tensors on high-bandwidth networks.
20    Ring,
21    /// Tree all-reduce: O(log n) latency, O(log n) bandwidth per node.
22    /// Best for small tensors or latency-sensitive operations.
23    Tree { arity: usize },
24    /// Centralized all-reduce: O(n) latency, O(n) bandwidth on root.
25    /// Simple, works for small collectives.
26    Centralized,
27    /// Automatic selection based on tensor size and cluster topology.
28    #[default]
29    Auto,
30}
31
32/// Reduce algorithm strategy.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
34pub enum ReduceStrategy {
35    /// Tree reduce with configurable arity.
36    Tree { arity: usize },
37    /// Direct reduce to root.
38    Direct,
39    /// Automatic selection.
40    #[default]
41    Auto,
42}
43
44/// Broadcast algorithm strategy.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
46pub enum BroadcastStrategy {
47    /// Tree broadcast with configurable arity.
48    Tree { arity: usize },
49    /// Direct broadcast from root.
50    Direct,
51    /// Automatic selection.
52    #[default]
53    Auto,
54}
55
56/// Configuration for collective operations.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CollectiveConfig {
59    /// Number of local devices (GPUs).
60    pub num_devices: usize,
61    /// Local all-reduce strategy.
62    pub local_all_reduce: AllReduceStrategy,
63    /// Local reduce strategy.
64    pub local_reduce: ReduceStrategy,
65    /// Local broadcast strategy.
66    pub local_broadcast: BroadcastStrategy,
67
68    // Global (multi-node) settings
69    /// Number of nodes (None = single node).
70    pub num_nodes: Option<usize>,
71    /// Global all-reduce strategy.
72    pub global_all_reduce: Option<AllReduceStrategy>,
73    /// Global reduce strategy.
74    pub global_reduce: Option<ReduceStrategy>,
75    /// Global broadcast strategy.
76    pub global_broadcast: Option<BroadcastStrategy>,
77
78    // Tuning parameters
79    /// Threshold (bytes) below which tree is preferred over ring.
80    pub tree_threshold_bytes: usize,
81    /// Tree arity (branching factor).
82    pub tree_arity: usize,
83    /// Timeout for collective operations.
84    pub timeout: Duration,
85}
86
87impl Default for CollectiveConfig {
88    fn default() -> Self {
89        Self {
90            num_devices: 1,
91            local_all_reduce: AllReduceStrategy::Auto,
92            local_reduce: ReduceStrategy::Auto,
93            local_broadcast: BroadcastStrategy::Auto,
94            num_nodes: None,
95            global_all_reduce: None,
96            global_reduce: None,
97            global_broadcast: None,
98            tree_threshold_bytes: 1024 * 1024, // 1 MB
99            tree_arity: 2,
100            timeout: Duration::from_secs(60),
101        }
102    }
103}
104
105impl CollectiveConfig {
106    /// Create a config for a single node with multiple devices.
107    pub fn single_node(num_devices: usize) -> Self {
108        Self {
109            num_devices,
110            local_all_reduce: AllReduceStrategy::Ring,
111            ..Default::default()
112        }
113    }
114
115    /// Create a config for multi-node training.
116    pub fn multi_node(num_devices: usize, num_nodes: usize) -> Self {
117        Self {
118            num_devices,
119            num_nodes: Some(num_nodes),
120            local_all_reduce: AllReduceStrategy::Tree { arity: 2 },
121            global_all_reduce: Some(AllReduceStrategy::Ring),
122            global_reduce: Some(ReduceStrategy::Tree { arity: 2 }),
123            global_broadcast: Some(BroadcastStrategy::Tree { arity: 2 }),
124            ..Default::default()
125        }
126    }
127
128    /// Validate the configuration.
129    pub fn validate(&self) -> DistributedResult<()> {
130        if self.num_devices == 0 {
131            return Err(DistributedError::Config("num_devices must be > 0".into()));
132        }
133
134        if let Some(n) = self.num_nodes {
135            if n == 0 {
136                return Err(DistributedError::Config("num_nodes must be > 0".into()));
137            }
138
139            // All global settings must be set together
140            if self.global_all_reduce.is_none()
141                || self.global_reduce.is_none()
142                || self.global_broadcast.is_none()
143            {
144                return Err(DistributedError::Config(
145                    "All global strategies must be set for multi-node".into(),
146                ));
147            }
148        }
149
150        if self.tree_arity < 2 {
151            return Err(DistributedError::Config("tree_arity must be >= 2".into()));
152        }
153
154        Ok(())
155    }
156
157    /// Select the best all-reduce strategy for a given buffer size.
158    pub fn select_all_reduce(&self, buffer_size: usize, world_size: usize) -> AllReduceStrategy {
159        match self.local_all_reduce {
160            AllReduceStrategy::Auto => {
161                if buffer_size < self.tree_threshold_bytes || world_size < 4 {
162                    AllReduceStrategy::Tree {
163                        arity: self.tree_arity,
164                    }
165                } else {
166                    AllReduceStrategy::Ring
167                }
168            }
169            other => other,
170        }
171    }
172}
173
174/// Trait for collective operation implementations.
175pub trait CollectiveOps: Send + Sync {
176    /// Perform all-reduce with the configured strategy.
177    fn all_reduce(
178        &self,
179        buffer: &mut [f32],
180        strategy: AllReduceStrategy,
181    ) -> impl std::future::Future<Output = DistributedResult<()>> + Send;
182
183    /// Perform reduce to root with the configured strategy.
184    fn reduce(
185        &self,
186        buffer: &mut [f32],
187        root: usize,
188        strategy: ReduceStrategy,
189    ) -> impl std::future::Future<Output = DistributedResult<()>> + Send;
190
191    /// Perform broadcast from root with the configured strategy.
192    fn broadcast(
193        &self,
194        buffer: &mut [f32],
195        root: usize,
196        strategy: BroadcastStrategy,
197    ) -> impl std::future::Future<Output = DistributedResult<()>> + Send;
198}
199
200/// Ring all-reduce implementation.
201pub mod ring {
202    use super::*;
203
204    /// Perform ring all-reduce (scatter-reduce + all-gather).
205    ///
206    /// This is bandwidth-optimal for large tensors:
207    /// - Total data transferred per node: 2 * (n-1) / n * buffer_size
208    /// - Number of steps: 2 * (n - 1)
209    pub async fn all_reduce<S, R>(
210        buffer: &mut [f32],
211        rank: usize,
212        world_size: usize,
213        send: &S,
214        recv: &R,
215    ) -> DistributedResult<()>
216    where
217        S: Fn(
218                &[u8],
219            )
220                -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
221            + Send
222            + Sync,
223        R: Fn(
224                &mut [u8],
225            )
226                -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
227            + Send
228            + Sync,
229    {
230        if world_size < 2 {
231            return Ok(());
232        }
233
234        let len = buffer.len();
235        let chunk_size = len / world_size;
236        let remainder = len % world_size;
237
238        // Helper to get chunk range
239        let get_chunk_range = |idx: usize| -> (usize, usize) {
240            let start = idx * chunk_size + idx.min(remainder);
241            let end = start + chunk_size + if idx < remainder { 1 } else { 0 };
242            (start, end)
243        };
244
245        // === SCATTER-REDUCE PHASE ===
246        for step in 0..(world_size - 1) {
247            let send_idx = (rank + world_size - step) % world_size;
248            let recv_idx = (rank + world_size - step - 1) % world_size;
249
250            let (send_start, send_end) = get_chunk_range(send_idx);
251            let (recv_start, recv_end) = get_chunk_range(recv_idx);
252
253            // Prepare send buffer
254            let send_bytes: Vec<u8> = buffer[send_start..send_end]
255                .iter()
256                .flat_map(|f| f.to_le_bytes())
257                .collect();
258
259            let recv_len = (recv_end - recv_start) * 4;
260            let mut recv_bytes = vec![0u8; recv_len];
261
262            // Send and receive concurrently
263            tokio::try_join!(send(&send_bytes), recv(&mut recv_bytes))?;
264
265            // Reduce received data
266            for (i, chunk) in recv_bytes.chunks_exact(4).enumerate() {
267                let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
268                buffer[recv_start + i] += val;
269            }
270        }
271
272        // === ALL-GATHER PHASE ===
273        for step in 0..(world_size - 1) {
274            let send_idx = (rank + world_size - step) % world_size;
275            let recv_idx = (rank + world_size - step - 1) % world_size;
276
277            let (send_start, send_end) = get_chunk_range(send_idx);
278            let (recv_start, recv_end) = get_chunk_range(recv_idx);
279
280            // Prepare send buffer
281            let send_bytes: Vec<u8> = buffer[send_start..send_end]
282                .iter()
283                .flat_map(|f| f.to_le_bytes())
284                .collect();
285
286            let recv_len = (recv_end - recv_start) * 4;
287            let mut recv_bytes = vec![0u8; recv_len];
288
289            // Send and receive concurrently
290            tokio::try_join!(send(&send_bytes), recv(&mut recv_bytes))?;
291
292            // Copy received data
293            for (i, chunk) in recv_bytes.chunks_exact(4).enumerate() {
294                let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
295                buffer[recv_start + i] = val;
296            }
297        }
298
299        debug!("Ring all-reduce complete: {} elements", len);
300        Ok(())
301    }
302}
303
304/// Tree all-reduce implementation.
305pub mod tree {
306
307    /// Tree node role for a given phase.
308    #[derive(Debug, Clone, Copy)]
309    pub enum TreeRole {
310        /// Leaf node in this phase.
311        Leaf,
312        /// Internal node with children.
313        Internal { num_children: usize },
314        /// Root node.
315        Root { num_children: usize },
316    }
317
318    /// Compute tree role for a node in a k-ary tree.
319    pub fn compute_role(rank: usize, world_size: usize, arity: usize) -> TreeRole {
320        if rank == 0 {
321            // Root
322            let num_children = arity.min(world_size - 1);
323            TreeRole::Root { num_children }
324        } else {
325            // Check if this node has children
326            let first_child = rank * arity + 1;
327            if first_child < world_size {
328                let num_children = (world_size - first_child).min(arity);
329                TreeRole::Internal { num_children }
330            } else {
331                TreeRole::Leaf
332            }
333        }
334    }
335
336    /// Get parent rank in a k-ary tree.
337    pub fn parent_rank(rank: usize, _arity: usize) -> Option<usize> {
338        if rank == 0 {
339            None
340        } else {
341            Some((rank - 1) / _arity)
342        }
343    }
344
345    /// Get child ranks in a k-ary tree.
346    pub fn child_ranks(rank: usize, world_size: usize, arity: usize) -> Vec<usize> {
347        let first_child = rank * arity + 1;
348        (first_child..first_child + arity)
349            .filter(|&c| c < world_size)
350            .collect()
351    }
352}
353
354/// Centralized all-reduce implementation.
355pub mod centralized {
356    use super::*;
357
358    /// Perform centralized all-reduce (reduce to root + broadcast).
359    ///
360    /// Simple but not bandwidth-optimal:
361    /// - Root receives from all, reduces, broadcasts to all
362    /// - O(n) messages, O(n) bandwidth on root
363    #[allow(clippy::too_many_arguments)]
364    pub async fn all_reduce<S, R>(
365        buffer: &mut [f32],
366        _rank: usize,
367        world_size: usize,
368        is_root: bool,
369        send_to_root: &S,
370        recv_from_root: &R,
371        recv_from_peer: &R,
372        send_to_peer: &S,
373    ) -> DistributedResult<()>
374    where
375        S: Fn(
376                &[u8],
377            )
378                -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
379            + Send
380            + Sync,
381        R: Fn(
382                &mut [u8],
383            )
384                -> std::pin::Pin<Box<dyn std::future::Future<Output = DistributedResult<()>> + Send>>
385            + Send
386            + Sync,
387    {
388        if world_size < 2 {
389            return Ok(());
390        }
391
392        let len = buffer.len();
393        let byte_len = len * 4;
394
395        if is_root {
396            // === REDUCE PHASE ===
397            // Receive from all peers and accumulate
398            let mut recv_buf = vec![0u8; byte_len];
399
400            for _ in 1..world_size {
401                recv_from_peer(&mut recv_buf).await?;
402
403                // Accumulate
404                for (i, chunk) in recv_buf.chunks_exact(4).enumerate() {
405                    let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
406                    buffer[i] += val;
407                }
408            }
409
410            // === BROADCAST PHASE ===
411            // Send result to all peers
412            let send_bytes: Vec<u8> = buffer.iter().flat_map(|f| f.to_le_bytes()).collect();
413
414            for _ in 1..world_size {
415                send_to_peer(&send_bytes).await?;
416            }
417        } else {
418            // === REDUCE PHASE ===
419            // Send to root
420            let send_bytes: Vec<u8> = buffer.iter().flat_map(|f| f.to_le_bytes()).collect();
421            send_to_root(&send_bytes).await?;
422
423            // === BROADCAST PHASE ===
424            // Receive from root
425            let mut recv_buf = vec![0u8; byte_len];
426            recv_from_root(&mut recv_buf).await?;
427
428            // Copy result
429            for (i, chunk) in recv_buf.chunks_exact(4).enumerate() {
430                buffer[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
431            }
432        }
433
434        debug!("Centralized all-reduce complete: {} elements", len);
435        Ok(())
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_config_validation() {
445        let mut config = CollectiveConfig::default();
446        assert!(config.validate().is_ok());
447
448        config.num_devices = 0;
449        assert!(config.validate().is_err());
450
451        config.num_devices = 1;
452        config.num_nodes = Some(2);
453        assert!(config.validate().is_err()); // Missing global strategies
454
455        config.global_all_reduce = Some(AllReduceStrategy::Ring);
456        config.global_reduce = Some(ReduceStrategy::Tree { arity: 2 });
457        config.global_broadcast = Some(BroadcastStrategy::Tree { arity: 2 });
458        assert!(config.validate().is_ok());
459    }
460
461    #[test]
462    fn test_strategy_selection() {
463        let config = CollectiveConfig {
464            tree_threshold_bytes: 1024,
465            tree_arity: 2,
466            local_all_reduce: AllReduceStrategy::Auto,
467            ..Default::default()
468        };
469
470        // Small buffer -> tree
471        let strategy = config.select_all_reduce(512, 4);
472        assert!(matches!(strategy, AllReduceStrategy::Tree { .. }));
473
474        // Large buffer -> ring
475        let strategy = config.select_all_reduce(2048, 4);
476        assert!(matches!(strategy, AllReduceStrategy::Ring));
477
478        // Small world size -> tree
479        let strategy = config.select_all_reduce(2048, 2);
480        assert!(matches!(strategy, AllReduceStrategy::Tree { .. }));
481    }
482
483    #[test]
484    fn test_tree_roles() {
485        // 2-ary tree with 7 nodes
486        //       0
487        //      / \
488        //     1   2
489        //    / \ / \
490        //   3  4 5  6
491
492        let world_size = 7;
493        let arity = 2;
494
495        assert!(matches!(
496            tree::compute_role(0, world_size, arity),
497            tree::TreeRole::Root { num_children: 2 }
498        ));
499        assert!(matches!(
500            tree::compute_role(1, world_size, arity),
501            tree::TreeRole::Internal { num_children: 2 }
502        ));
503        assert!(matches!(
504            tree::compute_role(3, world_size, arity),
505            tree::TreeRole::Leaf
506        ));
507
508        assert_eq!(tree::parent_rank(3, arity), Some(1));
509        assert_eq!(tree::parent_rank(1, arity), Some(0));
510        assert_eq!(tree::parent_rank(0, arity), None);
511
512        assert_eq!(tree::child_ranks(0, world_size, arity), vec![1, 2]);
513        assert_eq!(tree::child_ranks(1, world_size, arity), vec![3, 4]);
514    }
515}
516
517#[cfg(kani)]
518mod verification {
519    use super::*;
520
521    #[kani::proof]
522    #[kani::unwind(9)]
523    fn verify_tree_topology() {
524        let world_size: usize = kani::any();
525        let arity: usize = kani::any();
526
527        // Reduced bounds for tractable verification — Vec heap allocations
528        // and nested iterator loops in child_ranks/contains make larger
529        // bounds prohibitively expensive for CBMC.
530        kani::assume(world_size > 0 && world_size <= 8);
531        kani::assume(arity >= 2 && arity <= 4);
532
533        for rank in 0..world_size {
534            let role = tree::compute_role(rank, world_size, arity);
535            let parent = tree::parent_rank(rank, arity);
536            let children = tree::child_ranks(rank, world_size, arity);
537
538            match role {
539                tree::TreeRole::Root { num_children } => {
540                    assert!(rank == 0);
541                    assert!(parent.is_none());
542                    assert!(children.len() == num_children);
543                }
544                tree::TreeRole::Internal { num_children } => {
545                    assert!(rank > 0);
546                    assert!(parent.is_some());
547                    assert!(children.len() == num_children);
548                    assert!(num_children > 0);
549                }
550                tree::TreeRole::Leaf => {
551                    assert!(rank > 0);
552                    assert!(parent.is_some());
553                    assert!(children.is_empty());
554                }
555            }
556
557            // Verify parent-child consistency
558            for &child in &children {
559                assert!(child < world_size);
560                assert!(child > rank);
561                assert!(tree::parent_rank(child, arity) == Some(rank));
562            }
563
564            if let Some(p) = parent {
565                assert!(p < rank);
566                let p_children = tree::child_ranks(p, world_size, arity);
567                assert!(p_children.contains(&rank));
568            }
569        }
570    }
571}