Skip to main content

burn_collective/
config.rs

1use std::fmt::Display;
2
3use burn_communication::Address;
4use serde::{Deserialize, Serialize};
5
6/// Parameter struct for setting up and getting parameters for collective operations.
7/// Used in most collective api calls.
8/// This config is per-node. It is passed to [reduce](crate::register).
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub struct CollectiveConfig {
11    pub(crate) num_devices: usize,
12    pub(crate) local_all_reduce_strategy: AllReduceStrategy,
13    pub(crate) local_reduce_strategy: ReduceStrategy,
14    pub(crate) local_broadcast_strategy: BroadcastStrategy,
15
16    // Global parameters (all are optional, but if one is defined they should all be)
17    pub(crate) num_nodes: Option<u32>,
18    pub(crate) global_address: Option<Address>,
19    pub(crate) node_address: Option<Address>,
20    pub(crate) data_service_port: Option<u16>,
21
22    // These strategies may be defined when no other global params are defined
23    pub(crate) global_all_reduce_strategy: Option<AllReduceStrategy>,
24    pub(crate) global_reduce_strategy: Option<ReduceStrategy>,
25    pub(crate) global_broadcast_strategy: Option<BroadcastStrategy>,
26}
27
28impl Default for CollectiveConfig {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl Display for CollectiveConfig {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        let num_devices = self.num_devices;
37        let local_all_reduce_strategy = self.local_all_reduce_strategy;
38        let local_reduce_strategy = self.local_reduce_strategy;
39        let local_broadcast_strategy = self.local_broadcast_strategy;
40        let num_nodes = self.num_nodes;
41        let global_address = &self.global_address;
42        let node_address = &self.node_address;
43        let data_service_port = self.data_service_port;
44        let global_all_reduce_strategy = self.global_all_reduce_strategy;
45        let global_reduce_strategy = self.global_reduce_strategy;
46        let global_broadcast_strategy = self.global_broadcast_strategy;
47
48        write!(
49            f,
50            r#"
51CollectiveConfig {{
52    num_devices: {num_devices:?},
53    local_all_reduce_strategy: {local_all_reduce_strategy:?},
54    local_reduce_strategy: {local_reduce_strategy:?},
55    local_broadcast_strategy: {local_broadcast_strategy:?},
56    num_nodes: {num_nodes:?},
57    global_address: {global_address:?},
58    node_address: {node_address:?},
59    data_service_port: {data_service_port:?},
60    global_all_reduce_strategy: {global_all_reduce_strategy:?},
61    global_reduce_strategy: {global_reduce_strategy:?},
62    global_broadcast_strategy: {global_broadcast_strategy:?},
63}}
64"#
65        )
66    }
67}
68
69impl CollectiveConfig {
70    fn new() -> Self {
71        Self {
72            num_devices: 1,
73            local_all_reduce_strategy: AllReduceStrategy::Tree(2),
74            local_reduce_strategy: ReduceStrategy::Tree(2),
75            local_broadcast_strategy: BroadcastStrategy::Tree(2),
76
77            num_nodes: None,
78            global_address: None,
79            node_address: None,
80            data_service_port: None,
81            global_all_reduce_strategy: Some(AllReduceStrategy::Ring),
82            global_reduce_strategy: Some(ReduceStrategy::Tree(2)),
83            global_broadcast_strategy: Some(BroadcastStrategy::Tree(2)),
84        }
85    }
86
87    /// Selects the number of devices (local peers) on the current node
88    pub fn with_num_devices(mut self, num: usize) -> Self {
89        self.num_devices = num;
90        self
91    }
92
93    /// Selects an all-reduce strategy to use on the local level.
94    ///
95    /// In multi-node contexts, use of the Ring strategy in the local level may be less
96    /// advantageous. With multiple nodes, the global all-reduce step is enabled, and its result
97    /// is redistributed to all devices.
98    /// The Ring strategy inherently distributes the result, which in this context would not be
99    /// necessary.
100    ///
101    /// It is recommended to use a tree strategy locally, and a ring strategy globally.
102    pub fn with_local_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
103        self.local_all_reduce_strategy = strategy;
104        self
105    }
106
107    /// Selects a reduce strategy to use on the local level.
108    pub fn with_local_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
109        self.local_reduce_strategy = strategy;
110        self
111    }
112
113    /// Selects a broadcast strategy to use on the local level.
114    pub fn with_local_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
115        self.local_broadcast_strategy = strategy;
116        self
117    }
118
119    /// Set the number of nodes in the collective
120    ///
121    /// This parameter is a global parameter and should only be set in multi-node contexts
122    pub fn with_num_nodes(mut self, n: u32) -> Self {
123        self.num_nodes = Some(n);
124        self
125    }
126
127    /// Set the network address of the Global Collective Orchestrator
128    ///  
129    /// This parameter is a global parameter and should only be set in multi-node contexts
130    pub fn with_global_address(mut self, addr: Address) -> Self {
131        self.global_address = Some(addr);
132        self
133    }
134
135    /// Define the address for this node
136    ///
137    /// This parameter is a global parameter and should only be set in multi-node contexts
138    pub fn with_node_address(mut self, addr: Address) -> Self {
139        self.node_address = Some(addr);
140        self
141    }
142
143    /// Selects the network port on which to expose the tensor data service
144    /// used for peer-to-peer tensor downloading.
145    ///
146    /// This parameter is a global parameter and should only be set in multi-node contexts
147    pub fn with_data_service_port(mut self, port: u16) -> Self {
148        self.data_service_port = Some(port);
149        self
150    }
151
152    /// Selects an all-reduce strategy to use on the global level.
153    ///
154    /// This parameter is a global parameter and should only be set in multi-node contexts.
155    /// See [the local strategy](Self::with_local_all_reduce_strategy)
156    pub fn with_global_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
157        self.global_all_reduce_strategy = Some(strategy);
158        self
159    }
160
161    /// Selects an reduce strategy to use on the global level.
162    ///
163    /// This parameter is a global parameter and should only be set in multi-node contexts.
164    /// See [the local strategy](Self::with_local_reduce_strategy)
165    pub fn with_global_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
166        self.global_reduce_strategy = Some(strategy);
167        self
168    }
169
170    /// Selects an broadcst strategy to use on the global level.
171    ///
172    /// This parameter is a global parameter and should only be set in multi-node contexts.
173    /// See [the local strategy](Self::with_local_broadcast_strategy)
174    pub fn with_global_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
175        self.global_broadcast_strategy = Some(strategy);
176        self
177    }
178
179    /// Returns whether the config is valid. If only some required global-level parameters are
180    /// defined and others are not, the config is invalid.  
181    pub fn is_valid(&self) -> bool {
182        match (
183            self.num_nodes,
184            &self.global_address,
185            &self.node_address,
186            self.data_service_port,
187        ) {
188            (None, None, None, None) => true,
189            (Some(_), Some(_), Some(_), Some(_)) => true,
190            // Global parameters have only been partially defined!
191            _ => false,
192        }
193    }
194
195    /// Return the global parameters for registering in a multi-node context.
196    ///
197    /// If only some global parameters are defined, returns None. Use [is_valid](Self::is_valid) to check for
198    /// validity in this case.
199    pub(crate) fn global_register_params(&self) -> Option<GlobalRegisterParams> {
200        match (
201            self.num_nodes,
202            &self.global_address,
203            &self.node_address,
204            self.data_service_port,
205        ) {
206            // Only local collective
207            (None, None, None, None) => None,
208            // Local + global collective
209            (Some(num_nodes), Some(global_addr), Some(node_addr), Some(data_service_port)) => {
210                Some(GlobalRegisterParams {
211                    num_nodes,
212                    global_address: global_addr.clone(),
213                    node_address: node_addr.clone(),
214                    data_service_port,
215                })
216            }
217            // Config is invalid!
218            _ => None,
219        }
220    }
221}
222
223/// Helper struct for parameters in a multi-node register operation. Either they are all defined,
224/// or all not defined. Passed to the global client for registering on the global level and
225/// opening the p2p tensor service.
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct GlobalRegisterParams {
228    /// The address for the connection to the global orchestrator.
229    pub global_address: Address,
230    /// The address for the connection to this node.
231    pub node_address: Address,
232    /// The port on which to open the tensor data service for peer-to-peer tensor transfers with
233    /// other nodes. Should match the port given in the node url.
234    pub data_service_port: u16,
235
236    /// The number of nodes globally. Should be the same between different nodes
237    pub num_nodes: u32,
238}
239
240/// Parameters for an all-reduce that should be the same between all devices
241#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
242pub struct SharedAllReduceParams {
243    pub op: ReduceOperation,
244    pub local_strategy: AllReduceStrategy,
245    pub global_strategy: Option<AllReduceStrategy>,
246}
247
248/// Parameters for a reduce that should be the same between all devices
249#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
250pub struct SharedReduceParams {}
251
252/// Parameters for a broadcast that should be the same between all devices
253#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
254pub struct SharedBroadcastParams {
255    pub op: ReduceOperation,
256    pub local_strategy: BroadcastStrategy,
257    pub global_strategy: Option<BroadcastStrategy>,
258}
259
260/// Reduce can be done different ways
261#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
262pub enum ReduceOperation {
263    Sum,
264    Mean,
265}
266
267/// All reduce can be implemented with different algorithms, which all have the same result.
268#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
269pub enum AllReduceStrategy {
270    /// One device is the "central". The other devices, "peripherals", send their tensors to the
271    /// central. The central does the reduction, and sends the result back to each peripheral.  
272    Centralized,
273
274    /// Devices are organized in a tree structure (with a given arity). Each node reduces its
275    /// children's tensors with its own, and sends the result to its parent. Leaf nodes will
276    /// simply send their tensors to their parents.
277    /// When the root node calculates the result, it is propagated down the tree.
278    Tree(u32),
279
280    /// Devices are organized in a ring. The tensors are split into N slices, where N is the
281    /// number of devices participating. The slices are progressively sent around the ring until
282    /// every device has one fully reduced slice of the tensor. Then, the resulting slices are sent
283    /// around until every device has the full result.
284    /// See `ring.rs` for details.
285    Ring,
286}
287
288/// Reduce can be implemented with different algorithms, which all have the same result.
289#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
290pub enum ReduceStrategy {
291    /// See [all-reduce](AllReduceStrategy::Centralized)
292    Centralized,
293
294    /// See [all-reduce](AllReduceStrategy::Tree)
295    Tree(u32),
296}
297
298/// Broadcast can be implemented with different algorithms, which all have the same result.
299#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
300pub enum BroadcastStrategy {
301    /// See [all-reduce](AllReduceStrategy::Centralized)
302    Centralized,
303
304    /// See [all-reduce](AllReduceStrategy::Tree)
305    Tree(u32),
306}
307
308/// A unique identifier for a peer in the context of collective operations.
309/// They must be unique, even in multi-node contexts.
310///
311/// This is like the rank in NCCL
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
313pub struct PeerId(u32);
314
315impl Display for PeerId {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        write!(f, "PeerId({})", self.0)
318    }
319}
320
321impl From<u32> for PeerId {
322    fn from(value: u32) -> Self {
323        Self(value)
324    }
325}
326
327impl From<i32> for PeerId {
328    fn from(value: i32) -> Self {
329        Self(value as u32)
330    }
331}
332
333impl From<usize> for PeerId {
334    fn from(value: usize) -> Self {
335        Self(value as u32)
336    }
337}