burn_collective/api.rs
1use burn_tensor::backend::Backend;
2
3use crate::{
4 CollectiveConfig, PeerId, ReduceOperation, global::shared::GlobalCollectiveError,
5 local::server::get_collective_client,
6};
7
8/// Errors from collective operations
9#[allow(unused)]
10#[derive(Debug, Clone)]
11pub enum CollectiveError {
12 /// The [config](CollectiveConfig) was invalid.
13 /// Usually happens if only some global parameters have been defined
14 InvalidConfig,
15 /// Cannot un-register a node twice
16 MultipleUnregister,
17 /// Cannot register a node twice
18 MultipleRegister,
19 /// Trying to register a different way than is currently being done
20 RegisterParamsMismatch,
21 /// Trying to all-reduce tensors of different shapes: shape must match
22 AllReduceShapeMismatch,
23 /// Trying to all-reduce a different way than is currently being done: op must match
24 AllReduceOperationMismatch,
25 /// Trying to reduce tensors of different shapes: shape must match
26 ReduceShapeMismatch,
27 /// Trying to reduce a different way than is currently being done: op must match
28 ReduceOperationMismatch,
29 /// Trying to reduce with different roots
30 ReduceRootMismatch,
31 /// Trying to broadcast with different roots
32 BroadcastRootMismatch,
33 /// Trying to broadcast but no peer sent a tensor
34 BroadcastNoTensor,
35 /// Trying to broadcast but multiple peers sent a tensor
36 BroadcastMultipleTensors,
37 /// Local collective server couldn't respond
38 LocalServerMissing,
39 /// Another operation was called before Register
40 RegisterNotFirstOperation,
41 /// The global orchestrator had an error
42 Global(GlobalCollectiveError),
43
44 #[allow(unused)]
45 Other(String),
46}
47
48/// Registers a device. `num_devices` must be the same for every register,
49/// and `device_id` must be unique.
50///
51/// * `id` - The peer id of the caller
52///
53/// With auto-diff backends, make sure to use the inner backend.
54pub fn register<B: Backend>(
55 id: PeerId,
56 device: B::Device,
57 config: CollectiveConfig,
58) -> Result<(), CollectiveError> {
59 log::info!("Registering peer {id} with config: {config}");
60 let mut client = get_collective_client::<B>();
61 client.register(id, device, config)
62}
63
64/// Calls for an all-reduce operation with the given parameters, and returns the result.
65/// The `params` must be the same as the parameters passed by the other nodes.
66///
67/// * `id` - The peer id of the caller
68/// * `tensor` - The input tensor to reduce with the peers' tensors
69/// * `config` - Config of the collective operation, must be coherent with the other calls
70pub fn all_reduce<B: Backend>(
71 id: PeerId,
72 tensor: B::FloatTensorPrimitive,
73 op: ReduceOperation,
74) -> Result<B::FloatTensorPrimitive, CollectiveError> {
75 let client = get_collective_client::<B>();
76 client.all_reduce(id, tensor, op)
77}
78
79/// Broadcasts, or receives a broadcasted tensor.
80///
81/// * `id` - The peer id of the caller
82/// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
83/// the broadcasted tensor.
84///
85/// Returns the broadcasted tensor.
86pub fn broadcast<B: Backend>(
87 id: PeerId,
88 tensor: Option<B::FloatTensorPrimitive>,
89) -> Result<B::FloatTensorPrimitive, CollectiveError> {
90 let client = get_collective_client::<B>();
91 client.broadcast(id, tensor)
92}
93
94/// Reduces a tensor onto one device.
95///
96/// * `id` - The peer id of the caller
97/// * `tensor` - The tensor to send as input
98/// * `root` - The ID of the peer that will receive the result.
99///
100/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
101pub fn reduce<B: Backend>(
102 id: PeerId,
103 tensor: B::FloatTensorPrimitive,
104 op: ReduceOperation,
105 root: PeerId,
106) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {
107 let client = get_collective_client::<B>();
108 client.reduce(id, tensor, op, root)
109}
110
111/// Closes the collective session, unregistering the device
112pub fn finish_collective<B: Backend>(id: PeerId) -> Result<(), CollectiveError> {
113 let client = get_collective_client::<B>();
114 client.finish(id)
115}
116
117/// Resets the local collective server. All registered callers and ongoing operations are forgotten
118pub fn reset_collective<B: Backend>() {
119 let client = get_collective_client::<B>();
120 client.reset();
121}