use alloc::format;
use alloc::string::String;
use burn_backend::{
DType, Shape, TensorData,
backend::{Backend, DeviceId, DeviceOps, ExecutionError},
try_read_sync,
};
use burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr};
use burn_std::future::DynFut;
use crate::{
ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel,
RunnerClient,
};
macro_rules! impl_multi_backend_types {
($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => {
pub mod $module_name {
use super::*;
pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
#[allow(missing_docs)]
$DefaultBackend($DefaultBackend::Handle),
$(
#[allow(missing_docs)]
$OtherBackend($OtherBackend::Handle),
)+
}
#[derive(Clone, Debug)]
pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> {
#[allow(missing_docs)]
$DefaultBackend($DefaultBackend::Device),
$(
#[allow(missing_docs)]
$OtherBackend($OtherBackend::Device),
)+
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs,
$(
(Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs,
)+
_ => false,
}
}
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
fn default() -> Self {
Self::$DefaultBackend($DefaultBackend::Device::default())
}
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
fn from_id(_device_id: DeviceId) -> Self {
Default::default()
}
fn to_id(&self) -> DeviceId {
match self {
Self::$DefaultBackend(device) => device.id(),
$(
Self::$OtherBackend(device) => device.id(),
)+
}
}
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {}
#[derive(Clone)]
pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
#[allow(missing_docs)]
$DefaultBackend(Runner<$DefaultBackend>),
$(
#[allow(missing_docs)]
$OtherBackend(Runner<$OtherBackend>),
)+
}
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>
{
type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
fn register_op(&self, op: OperationIr) {
match self {
Self::$DefaultBackend(runner) => runner.register_op(op),
$(
Self::$OtherBackend(runner) => runner.register_op(op),
)+
}
}
fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {
match self {
Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor),
$(
Self::$OtherBackend(runner) => runner.read_tensor_async(tensor),
)+
}
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
match self {
Self::$DefaultBackend(runner) => {
let desc = runner.register_tensor_data_desc(data);
RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
}
$(
Self::$OtherBackend(runner) => {
let desc = runner.register_tensor_data_desc(data);
RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
}
)+
}
}
fn device(&self) -> Self::Device {
match self {
Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()),
$(
Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()),
)+
}
}
fn sync(&self) -> Result<(), ExecutionError> {
match self {
Self::$DefaultBackend(runner) => runner.sync(),
$(
Self::$OtherBackend(runner) => runner.sync(),
)+
}
}
fn seed(&self, seed: u64) {
match self {
Self::$DefaultBackend(runner) => runner.seed(seed),
$(
Self::$OtherBackend(runner) => runner.seed(seed),
)+
}
}
fn create_empty_handle(&self) -> TensorId {
match self {
Self::$DefaultBackend(runner) => runner.create_empty_handle(),
$(
Self::$OtherBackend(runner) => runner.create_empty_handle(),
)+
}
}
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
match self {
Self::$DefaultBackend(runner) => runner.dtype_usage(dtype),
$(
Self::$OtherBackend(runner) => runner.dtype_usage(dtype),
)+
}
}
}
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br>
where
Br: MultiBackendBridge<TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>,
{
type Device = Br::Device;
type Bridge = Br;
type FloatElem = $DefaultBackend::FloatElem;
type IntElem = $DefaultBackend::IntElem;
type BoolElem = $DefaultBackend::BoolElem;
type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>;
fn init_client(device: &Self::Device) -> Self::Client {
match device {
MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())),
$(
MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())),
)+
}
}
fn get_tensor_handle(
tensor: &TensorIr,
client: &Self::Client,
) -> <Self::Bridge as MultiBackendBridge>::TensorHandle {
match client {
MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)),
$(
MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)),
)+
}
}
fn register_tensor(
client: &Self::Client,
handle: <Self::Bridge as MultiBackendBridge>::TensorHandle,
shape: Shape,
dtype: DType,
) -> RouterTensor<Self::Client> {
match client {
MultiRunnerClient::$DefaultBackend(runner) => match handle {
Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
_ => unreachable!("Can't register tensor handle for another backend."),
},
$(
MultiRunnerClient::$OtherBackend(runner) => match handle {
Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
_ => unreachable!("Can't register tensor handle for another backend."),
},
)+
}
}
fn name(_device: &Self::Device) -> String {
let mut name = format!("{}", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default()));
$(
name.push_str(&format!(", {}", $OtherBackend::name(&<$OtherBackend::Device as Default>::default())));
)+
format!("direct<({})>", name)
}
}
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> {
type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>;
type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
fn change_backend_float(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
}
fn change_backend_int(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
}
fn change_backend_bool(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
}
}
}
};
}
macro_rules! bridge {
($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{
let tensor = $Backend::float_tensor(TensorHandle {
handle: $handle,
shape: $shape,
});
let tensor = $Backend::float_to_device(tensor, $device);
let handle = $Backend::float_tensor_handle(tensor);
Handle::$Backend(handle)
}};
($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{
let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape });
let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect(
"Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM."
);
let tensor = $BackendB::float_from_data(data, $device);
let handle = $BackendB::float_tensor_handle(tensor);
Handle::$BackendB(handle)
}};
}
macro_rules! multi_backend_match {
($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => {
multi_backend_match! (
@step
$shape,
($handle, $device);
{
(Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape),
$(
(Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape),
(Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape),
(Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape),
)+
};
$($OtherBackend),+
)
};
(@step
$shape:expr,
$pats:tt;
{ $($arms:tt)* };
$BackendA:ident,
$($OtherBackend:ident),+
) => {
multi_backend_match! (
@step
$shape,
$pats;
{
$($arms)*
$(
(Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape),
(Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape),
)*
};
$($OtherBackend),*
)
};
(@step
$shape:expr,
($handle:expr, $device:expr);
{ $($arms:tt)* };
$($BackendA:ident)?
) => {
match ($handle, $device) {
$($arms)*
}
};
}
impl_multi_backend_types!(duo, B1, B2);
impl_multi_backend_types!(trio, B1, B2, B3);
impl_multi_backend_types!(quad, B1, B2, B3, B4);
#[cfg(not(target_os = "windows"))] #[cfg(test)]
mod tests {
use burn_tensor::Tensor;
use super::*;
use crate::tests::TestBackend;
#[test]
fn should_support_dual_byte_bridge() {
let device1 = duo::MultiDevice::B1(Default::default());
let device2 = duo::MultiDevice::B2(Default::default());
let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
let tensor1_2 = tensor1.clone().to_device(&device2);
tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
let tensor2_1 = tensor2.clone().to_device(&device1);
tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
}
}