1use std::fmt::Display;
2
3use burn_communication::Address;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub struct CollectiveConfig {
11 pub(crate) num_devices: usize,
12 pub(crate) local_all_reduce_strategy: AllReduceStrategy,
13 pub(crate) local_reduce_strategy: ReduceStrategy,
14 pub(crate) local_broadcast_strategy: BroadcastStrategy,
15
16 pub(crate) num_nodes: Option<u32>,
18 pub(crate) global_address: Option<Address>,
19 pub(crate) node_address: Option<Address>,
20 pub(crate) data_service_port: Option<u16>,
21
22 pub(crate) global_all_reduce_strategy: Option<AllReduceStrategy>,
24 pub(crate) global_reduce_strategy: Option<ReduceStrategy>,
25 pub(crate) global_broadcast_strategy: Option<BroadcastStrategy>,
26}
27
28impl Default for CollectiveConfig {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl Display for CollectiveConfig {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 let num_devices = self.num_devices;
37 let local_all_reduce_strategy = self.local_all_reduce_strategy;
38 let local_reduce_strategy = self.local_reduce_strategy;
39 let local_broadcast_strategy = self.local_broadcast_strategy;
40 let num_nodes = self.num_nodes;
41 let global_address = &self.global_address;
42 let node_address = &self.node_address;
43 let data_service_port = self.data_service_port;
44 let global_all_reduce_strategy = self.global_all_reduce_strategy;
45 let global_reduce_strategy = self.global_reduce_strategy;
46 let global_broadcast_strategy = self.global_broadcast_strategy;
47
48 write!(
49 f,
50 r#"
51CollectiveConfig {{
52 num_devices: {num_devices:?},
53 local_all_reduce_strategy: {local_all_reduce_strategy:?},
54 local_reduce_strategy: {local_reduce_strategy:?},
55 local_broadcast_strategy: {local_broadcast_strategy:?},
56 num_nodes: {num_nodes:?},
57 global_address: {global_address:?},
58 node_address: {node_address:?},
59 data_service_port: {data_service_port:?},
60 global_all_reduce_strategy: {global_all_reduce_strategy:?},
61 global_reduce_strategy: {global_reduce_strategy:?},
62 global_broadcast_strategy: {global_broadcast_strategy:?},
63}}
64"#
65 )
66 }
67}
68
69impl CollectiveConfig {
70 fn new() -> Self {
71 Self {
72 num_devices: 1,
73 local_all_reduce_strategy: AllReduceStrategy::Tree(2),
74 local_reduce_strategy: ReduceStrategy::Tree(2),
75 local_broadcast_strategy: BroadcastStrategy::Tree(2),
76
77 num_nodes: None,
78 global_address: None,
79 node_address: None,
80 data_service_port: None,
81 global_all_reduce_strategy: Some(AllReduceStrategy::Ring),
82 global_reduce_strategy: Some(ReduceStrategy::Tree(2)),
83 global_broadcast_strategy: Some(BroadcastStrategy::Tree(2)),
84 }
85 }
86
87 pub fn with_num_devices(mut self, num: usize) -> Self {
89 self.num_devices = num;
90 self
91 }
92
93 pub fn with_local_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
103 self.local_all_reduce_strategy = strategy;
104 self
105 }
106
107 pub fn with_local_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
109 self.local_reduce_strategy = strategy;
110 self
111 }
112
113 pub fn with_local_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
115 self.local_broadcast_strategy = strategy;
116 self
117 }
118
119 pub fn with_num_nodes(mut self, n: u32) -> Self {
123 self.num_nodes = Some(n);
124 self
125 }
126
127 pub fn with_global_address(mut self, addr: Address) -> Self {
131 self.global_address = Some(addr);
132 self
133 }
134
135 pub fn with_node_address(mut self, addr: Address) -> Self {
139 self.node_address = Some(addr);
140 self
141 }
142
143 pub fn with_data_service_port(mut self, port: u16) -> Self {
148 self.data_service_port = Some(port);
149 self
150 }
151
152 pub fn with_global_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
157 self.global_all_reduce_strategy = Some(strategy);
158 self
159 }
160
161 pub fn with_global_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
166 self.global_reduce_strategy = Some(strategy);
167 self
168 }
169
170 pub fn with_global_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
175 self.global_broadcast_strategy = Some(strategy);
176 self
177 }
178
179 pub fn is_valid(&self) -> bool {
182 match (
183 self.num_nodes,
184 &self.global_address,
185 &self.node_address,
186 self.data_service_port,
187 ) {
188 (None, None, None, None) => true,
189 (Some(_), Some(_), Some(_), Some(_)) => true,
190 _ => false,
192 }
193 }
194
195 pub(crate) fn global_register_params(&self) -> Option<GlobalRegisterParams> {
200 match (
201 self.num_nodes,
202 &self.global_address,
203 &self.node_address,
204 self.data_service_port,
205 ) {
206 (None, None, None, None) => None,
208 (Some(num_nodes), Some(global_addr), Some(node_addr), Some(data_service_port)) => {
210 Some(GlobalRegisterParams {
211 num_nodes,
212 global_address: global_addr.clone(),
213 node_address: node_addr.clone(),
214 data_service_port,
215 })
216 }
217 _ => None,
219 }
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct GlobalRegisterParams {
228 pub global_address: Address,
230 pub node_address: Address,
232 pub data_service_port: u16,
235
236 pub num_nodes: u32,
238}
239
240#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
242pub struct SharedAllReduceParams {
243 pub op: ReduceOperation,
244 pub local_strategy: AllReduceStrategy,
245 pub global_strategy: Option<AllReduceStrategy>,
246}
247
248#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
250pub struct SharedReduceParams {}
251
252#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
254pub struct SharedBroadcastParams {
255 pub op: ReduceOperation,
256 pub local_strategy: BroadcastStrategy,
257 pub global_strategy: Option<BroadcastStrategy>,
258}
259
260#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
262pub enum ReduceOperation {
263 Sum,
264 Mean,
265}
266
267#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
269pub enum AllReduceStrategy {
270 Centralized,
273
274 Tree(u32),
279
280 Ring,
286}
287
288#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
290pub enum ReduceStrategy {
291 Centralized,
293
294 Tree(u32),
296}
297
298#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
300pub enum BroadcastStrategy {
301 Centralized,
303
304 Tree(u32),
306}
307
308#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
313pub struct PeerId(u32);
314
315impl Display for PeerId {
316 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317 write!(f, "PeerId({})", self.0)
318 }
319}
320
321impl From<u32> for PeerId {
322 fn from(value: u32) -> Self {
323 Self(value)
324 }
325}
326
327impl From<i32> for PeerId {
328 fn from(value: i32) -> Self {
329 Self(value as u32)
330 }
331}
332
333impl From<usize> for PeerId {
334 fn from(value: usize) -> Self {
335 Self(value as u32)
336 }
337}