burn-router 0.21.0

Multi-backend router decorator for the Burn framework
Documentation
use alloc::vec::Vec;
use burn_backend::backend::ExecutionError;
use burn_std::{BoolDType, FloatDType, IntDType};

use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
use burn_backend::ops::BoolTensorOps;
use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor};
use burn_backend::{Scalar, Shape, Slice, TensorData};
use burn_ir::{
    BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
    GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
    PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr,
    SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
};

impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
    fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
        let client = get_client::<R>(device);
        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
            .output()
    }

    fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
        let client = get_client::<R>(device);
        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
            .output()
    }

    fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
        let client = get_client::<R>(device);
        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
            .output()
    }

    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
        tensor.into_data().await
    }

    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
        let client = get_client::<R>(device);
        let out = client.register_tensor_data(data);
        let desc = InitOperationIr {
            out: out.to_ir_out(),
        };

        // Call register op when output is already initialized
        client.register_op(OperationIr::Init(desc));

        out
    }

    fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
        let client = tensor.client.clone();
        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
            .output()
    }

    fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
        let client = tensor.client.clone();
        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
            .output()
    }

    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
        tensor.client.device()
    }

    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
        if &tensor.client.device() == device {
            return tensor;
        }
        R::change_client_backend(tensor, device)
    }

    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
            .output()
    }

    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
            .output()
    }

    fn bool_slice_assign(
        tensor: BoolTensor<Self>,
        slices: &[burn_backend::Slice],
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc =
            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
                client.create_empty_handle()
            });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
            .output()
    }

    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
        let client = lhs.client.clone();
        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
            .output()
    }

    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());

        client
            .register(OperationIr::Bool(BoolOperationIr::Not(desc)))
            .output()
    }

    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
        let client = lhs.client.clone();
        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::Bool(BoolOperationIr::And(desc)))
            .output()
    }

    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
        let client = lhs.client.clone();
        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::Bool(BoolOperationIr::Or(desc)))
            .output()
    }

    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
            .output()
    }

    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
            .output()
    }

    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
            .output()
    }

    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
            .output()
    }

    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
        let client = tensors.first().unwrap().client.clone();
        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
            .output()
    }

    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
            .output()
    }

    fn bool_unfold(
        tensor: BoolTensor<Self>,
        dim: usize,
        size: usize,
        step: usize,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
            .output()
    }

    fn bool_mask_where(
        tensor: BoolTensor<Self>,
        mask: BoolTensor<Self>,
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
            .output()
    }

    fn bool_mask_fill(
        tensor: BoolTensor<Self>,
        mask: BoolTensor<Self>,
        value: Scalar,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let value = value.into();
        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
            .output()
    }

    fn bool_gather(
        dim: usize,
        tensor: BoolTensor<Self>,
        indices: IntTensor<Self>,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
            .output()
    }

    fn bool_scatter_or(
        dim: usize,
        tensor: BoolTensor<Self>,
        indices: IntTensor<Self>,
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = ScatterOpIr::create(
            tensor.into_ir(),
            dim,
            indices.into_ir(),
            value.into_ir(),
            IndexingUpdateOp::Add,
            || client.create_empty_handle(),
        );

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
            .output()
    }

    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
        let dtype = lhs.dtype;
        let client = lhs.client.clone();
        let rhs = rhs.into();
        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, dtype, || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
            .output()
    }

    fn bool_select(
        tensor: BoolTensor<Self>,
        dim: usize,
        indices: IntTensor<Self>,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
            client.create_empty_handle()
        });

        client
            .register(OperationIr::BaseBool(BaseOperationIr::Select(desc)))
            .output()
    }

    fn bool_select_or(
        tensor: BoolTensor<Self>,
        dim: usize,
        indices: IntTensor<Self>,
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        let client = tensor.client.clone();
        let desc = SelectAssignOpIr::create(
            tensor.into_ir(),
            dim,
            indices.into_ir(),
            value.into_ir(),
            IndexingUpdateOp::Add,
            || client.create_empty_handle(),
        );

        client
            .register(OperationIr::BaseBool(BaseOperationIr::SelectAssign(desc)))
            .output()
    }
}