burn-backend 0.21.0

Core backend interfaces and data structures for executing tensor operations in Burn.
Documentation
use std::{sync::mpsc::Sender, thread::spawn};

use crate::tensor::Device;

use crate::distributed::{
    DistributedBackend, DistributedConfig, DistributedParams, TensorRef,
    server::DistributedSyncServer,
};

pub(crate) enum ActionMessage<B: DistributedBackend> {
    Message(DistributedSyncMessage<B>),
    Close(),
}

pub(crate) enum DistributedSyncMessage<B: DistributedBackend> {
    RegisterSyncParameters(Vec<DistributedParams>),
    TensorSync((TensorRef<B>, DistributedParams)),
    #[allow(clippy::type_complexity)]
    CollectiveSync((Device<B>, oneshot::Sender<Box<dyn FnOnce() + Send>>)),
}

#[derive(Clone)]
pub struct DistributedSyncClient<B: DistributedBackend> {
    sender: Sender<ActionMessage<B>>,
}

impl<B: DistributedBackend> DistributedSyncClient<B> {
    pub(crate) fn new(num_devices: usize, config: DistributedConfig) -> Self {
        let (tx, rx) = std::sync::mpsc::channel();

        let mut server = DistributedSyncServer::new(num_devices, config);
        spawn(move || {
            while let ActionMessage::Message(msg) =
                rx.recv().expect("Gradient sync server disconnected.")
            {
                server.process_message(msg)
            }
        });
        Self { sender: tx }
    }

    pub fn register_sync_parameters(&self, sharded_params: Vec<DistributedParams>) {
        self.sender
            .send(ActionMessage::Message(
                DistributedSyncMessage::RegisterSyncParameters(sharded_params),
            ))
            .unwrap();
    }

    pub fn submit_gradient_sync(&self, tensor: TensorRef<B>, params: DistributedParams) {
        self.sender
            .send(ActionMessage::Message(DistributedSyncMessage::TensorSync(
                (tensor, params),
            )))
            .unwrap();
    }

    pub fn submit_sync_collective(&self, device: Device<B>) {
        let (tx, rx) = oneshot::channel();

        self.sender
            .send(ActionMessage::Message(
                DistributedSyncMessage::CollectiveSync((device.clone(), tx)),
            ))
            .unwrap();

        let sync = rx.recv().expect("Can receive callback");

        sync();
    }

    pub(crate) fn close(&self) {
        self.sender.send(ActionMessage::Close()).unwrap();
    }
}