Skip to main content

burn_collective/
api.rs

1use burn_tensor::backend::Backend;
2
3use crate::{
4    CollectiveConfig, PeerId, ReduceOperation, global::shared::GlobalCollectiveError,
5    local::server::get_collective_client,
6};
7
8/// Errors from collective operations
9#[allow(unused)]
10#[derive(Debug, Clone)]
11pub enum CollectiveError {
12    /// The [config](CollectiveConfig) was invalid.
13    /// Usually happens if only some global parameters have been defined
14    InvalidConfig,
15    /// Cannot un-register a node twice
16    MultipleUnregister,
17    /// Cannot register a node twice
18    MultipleRegister,
19    /// Trying to register a different way than is currently being done
20    RegisterParamsMismatch,
21    /// Trying to all-reduce tensors of different shapes: shape must match
22    AllReduceShapeMismatch,
23    /// Trying to all-reduce a different way than is currently being done: op must match
24    AllReduceOperationMismatch,
25    /// Trying to reduce tensors of different shapes: shape must match
26    ReduceShapeMismatch,
27    /// Trying to reduce a different way than is currently being done: op must match
28    ReduceOperationMismatch,
29    /// Trying to reduce with different roots
30    ReduceRootMismatch,
31    /// Trying to broadcast with different roots
32    BroadcastRootMismatch,
33    /// Trying to broadcast but no peer sent a tensor
34    BroadcastNoTensor,
35    /// Trying to broadcast but multiple peers sent a tensor
36    BroadcastMultipleTensors,
37    /// Local collective server couldn't respond
38    LocalServerMissing,
39    /// Another operation was called before Register
40    RegisterNotFirstOperation,
41    /// The global orchestrator had an error
42    Global(GlobalCollectiveError),
43
44    #[allow(unused)]
45    Other(String),
46}
47
48/// Registers a device. `num_devices` must be the same for every register,
49/// and `device_id` must be unique.
50///
51/// * `id` - The peer id of the caller
52///
53/// With auto-diff backends, make sure to use the inner backend.
54pub fn register<B: Backend>(
55    id: PeerId,
56    device: B::Device,
57    config: CollectiveConfig,
58) -> Result<(), CollectiveError> {
59    log::info!("Registering peer {id} with config: {config}");
60    let mut client = get_collective_client::<B>();
61    client.register(id, device, config)
62}
63
64/// Calls for an all-reduce operation with the given parameters, and returns the result.
65/// The `params` must be the same as the parameters passed by the other nodes.
66///
67/// * `id` - The peer id of the caller
68/// * `tensor` - The input tensor to reduce with the peers' tensors
69/// * `config` - Config of the collective operation, must be coherent with the other calls
70pub fn all_reduce<B: Backend>(
71    id: PeerId,
72    tensor: B::FloatTensorPrimitive,
73    op: ReduceOperation,
74) -> Result<B::FloatTensorPrimitive, CollectiveError> {
75    let client = get_collective_client::<B>();
76    client.all_reduce(id, tensor, op)
77}
78
79/// Broadcasts, or receives a broadcasted tensor.
80///
81/// * `id` - The peer id of the caller
82/// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
83///   the broadcasted tensor.
84///
85/// Returns the broadcasted tensor.
86pub fn broadcast<B: Backend>(
87    id: PeerId,
88    tensor: Option<B::FloatTensorPrimitive>,
89) -> Result<B::FloatTensorPrimitive, CollectiveError> {
90    let client = get_collective_client::<B>();
91    client.broadcast(id, tensor)
92}
93
94/// Reduces a tensor onto one device.
95///
96/// * `id` - The peer id of the caller
97/// * `tensor` - The tensor to send as input
98/// * `root` - The ID of the peer that will receive the result.
99///
100/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
101pub fn reduce<B: Backend>(
102    id: PeerId,
103    tensor: B::FloatTensorPrimitive,
104    op: ReduceOperation,
105    root: PeerId,
106) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {
107    let client = get_collective_client::<B>();
108    client.reduce(id, tensor, op, root)
109}
110
111/// Closes the collective session, unregistering the device
112pub fn finish_collective<B: Backend>(id: PeerId) -> Result<(), CollectiveError> {
113    let client = get_collective_client::<B>();
114    client.finish(id)
115}
116
117/// Resets the local collective server. All registered callers and ongoing operations are forgotten
118pub fn reset_collective<B: Backend>() {
119    let client = get_collective_client::<B>();
120    client.reset();
121}