burn-candle 0.20.1

Candle backend for the Burn framework
Documentation
use burn_backend::{
    BackTrace, DType, ExecutionError, Shape, Slice, TensorData, TensorMetadata,
    ops::BoolTensorOps,
    tensor::{BoolElem, BoolTensor, Device, FloatTensor, IntTensor},
};

use crate::{
    Candle, CandleTensor,
    element::{CandleElement, FloatCandleElement, IntCandleElement},
};

use super::base::{expand, permute, unfold};

impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
        super::base::empty(shape, device, candle_core::DType::U8)
    }

    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
        super::base::zeros(shape, device, candle_core::DType::U8)
    }

    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
        super::base::ones(shape, device, candle_core::DType::U8)
    }

    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
        let x: Vec<u8> = tensor
            .tensor
            .flatten_all()
            .map_err(|err| ExecutionError::Generic {
                reason: format!("{err}"),
                backtrace: BackTrace::capture(),
            })?
            .to_vec1()
            .map_err(|err| ExecutionError::Generic {
                reason: format!("{err}"),
                backtrace: BackTrace::capture(),
            })?;

        let y = x.iter().map(|b| !matches!(b, 0)).collect();

        Ok(TensorData::new(y, tensor.shape()))
    }

    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
        match data.dtype {
            DType::U8 => super::base::from_data::<u8>(data, device),
            _ => unimplemented!("Unsupported dtype for `bool_from_data`"),
        }
    }

    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
        CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
    }

    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
        CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
    }

    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
        super::base::device(tensor)
    }

    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
        super::base::to_device(tensor, device)
    }

    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
        super::base::reshape(tensor, shape)
    }

    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
        super::base::slice_with_steps(tensor, slices)
    }

    fn bool_slice_assign(
        tensor: BoolTensor<Self>,
        slices: &[Slice],
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        super::base::slice_assign(tensor, slices, value)
    }

    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
        super::base::cat(tensors, dim)
    }

    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
        let (lhs_broadcast, rhs_broadcast) =
            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
        CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
    }

    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
        let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap());
        CandleTensor::new(tensor.tensor.eq(&x).unwrap())
    }

    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
        let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap();
        CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap())
    }

    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
        CandleTensor::new(
            lhs.tensor
                .add(&rhs.tensor)
                .unwrap()
                .clamp(0u32, 1u32)
                .unwrap(),
        )
    }

    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
        super::base::swap_dims(tensor, dim1, dim2)
    }

    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
        super::base::permute(tensor, axes)
    }

    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
        super::base::flip(tensor, axes)
    }

    fn bool_select(
        tensor: BoolTensor<Self>,
        dim: usize,
        indices: IntTensor<Self>,
    ) -> BoolTensor<Self> {
        CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
    }

    fn bool_select_or(
        tensor: BoolTensor<Self>,
        dim: usize,
        indices: IntTensor<Self>,
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        CandleTensor::new(
            tensor
                .tensor
                .index_add(&indices.tensor, &value.tensor, dim)
                .unwrap(),
        )
    }

    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
        expand(tensor, shape)
    }

    fn bool_unfold(
        tensor: BoolTensor<Self>,
        dim: usize,
        size: usize,
        step: usize,
    ) -> BoolTensor<Self> {
        unfold(tensor, dim, size, step)
    }

    fn bool_mask_where(
        tensor: BoolTensor<Self>,
        mask: BoolTensor<Self>,
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        super::base::mask_where_broadcasted(tensor, mask, value)
    }

    fn bool_mask_fill(
        tensor: BoolTensor<Self>,
        mask: BoolTensor<Self>,
        value: BoolElem<Self>,
    ) -> BoolTensor<Self> {
        CandleTensor::new(
            mask.tensor
                .where_cond(
                    &super::candle_utils::fill_like::<u8>(value, &tensor.tensor),
                    &tensor.tensor,
                )
                .unwrap(),
        )
    }

    fn bool_gather(
        dim: usize,
        tensor: BoolTensor<Self>,
        indices: IntTensor<Self>,
    ) -> BoolTensor<Self> {
        let tensor = tensor.tensor.contiguous().unwrap();
        let indices = indices.tensor.contiguous().unwrap();
        CandleTensor::new(tensor.gather(&indices, dim).unwrap())
    }

    fn bool_scatter_or(
        dim: usize,
        tensor: BoolTensor<Self>,
        indices: IntTensor<Self>,
        value: BoolTensor<Self>,
    ) -> BoolTensor<Self> {
        CandleTensor::new(
            tensor
                .tensor
                .scatter_add(&indices.tensor, &value.tensor, dim)
                .unwrap(),
        )
    }

    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
        CandleTensor::new(lhs.tensor.eq(rhs).unwrap())
    }
}