burn_ndarray/ops/
bool_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensorOps};
5use burn_tensor::{ElementConversion, TensorMetadata};
6use core::ops::Range;
7use ndarray::IntoDimension;
8
9// Current crate
10use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
11use crate::{NdArray, tensor::NdArrayTensor};
12use crate::{NdArrayDevice, new_tensor_float};
13
14// Workspace crates
15use burn_tensor::{Shape, TensorData, backend::Backend};
16
17use super::{NdArrayBoolOps, NdArrayOps};
18
19impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
20    for NdArray<E, I, Q>
21{
22    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
23        if !data.dtype.is_bool() {
24            unimplemented!("Unsupported dtype for `bool_from_data`")
25        }
26        NdArrayTensor::from_data(data)
27    }
28
29    async fn bool_into_data(tensor: NdArrayTensor<bool>) -> TensorData {
30        NdArrayOps::into_data(tensor)
31    }
32
33    fn bool_to_device(tensor: NdArrayTensor<bool>, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
34        tensor
35    }
36
37    fn bool_reshape(tensor: NdArrayTensor<bool>, shape: Shape) -> NdArrayTensor<bool> {
38        NdArrayOps::reshape(tensor, shape)
39    }
40
41    fn bool_slice(tensor: NdArrayTensor<bool>, ranges: &[Range<usize>]) -> NdArrayTensor<bool> {
42        NdArrayOps::slice(tensor, ranges)
43    }
44
45    fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<I> {
46        let shape = tensor.shape();
47        let values = tensor.array.into_iter().collect();
48        NdArray::<E, I>::int_from_data(
49            TensorData::new(values, shape).convert::<I>(),
50            &NdArrayDevice::Cpu,
51        )
52    }
53
54    fn bool_device(_tensor: &NdArrayTensor<bool>) -> <NdArray<E> as Backend>::Device {
55        NdArrayDevice::Cpu
56    }
57
58    fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<bool> {
59        let values = vec![false; shape.num_elements()];
60        NdArrayTensor::from_data(TensorData::new(values, shape))
61    }
62
63    fn bool_slice_assign(
64        tensor: NdArrayTensor<bool>,
65        ranges: &[Range<usize>],
66        value: NdArrayTensor<bool>,
67    ) -> NdArrayTensor<bool> {
68        NdArrayOps::slice_assign(tensor, ranges, value)
69    }
70
71    fn bool_cat(tensors: Vec<NdArrayTensor<bool>>, dim: usize) -> NdArrayTensor<bool> {
72        NdArrayOps::cat(tensors, dim)
73    }
74
75    fn bool_equal(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
76        NdArrayBoolOps::equal(lhs, rhs)
77    }
78
79    fn bool_not(tensor: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
80        let array = tensor.array.mapv(|a| !a).into_shared();
81        NdArrayTensor { array }
82    }
83
84    fn bool_and(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
85        NdArrayBoolOps::and(lhs, rhs)
86    }
87
88    fn bool_or(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
89        NdArrayBoolOps::or(lhs, rhs)
90    }
91
92    fn bool_into_float(tensor: NdArrayTensor<bool>) -> FloatTensor<Self> {
93        new_tensor_float!(NdArrayTensor {
94            array: tensor.array.mapv(|a| (a as i32).elem()).into_shared(),
95        })
96    }
97
98    fn bool_swap_dims(
99        tensor: NdArrayTensor<bool>,
100        dim1: usize,
101        dim2: usize,
102    ) -> NdArrayTensor<bool> {
103        NdArrayOps::swap_dims(tensor, dim1, dim2)
104    }
105
106    fn bool_permute(tensor: NdArrayTensor<bool>, axes: &[usize]) -> NdArrayTensor<bool> {
107        let array = tensor.array.permuted_axes(axes.into_dimension());
108        NdArrayTensor { array }
109    }
110
111    fn bool_expand(tensor: NdArrayTensor<bool>, shape: Shape) -> NdArrayTensor<bool> {
112        NdArrayOps::expand(tensor, shape)
113    }
114
115    fn bool_flip(tensor: NdArrayTensor<bool>, axes: &[usize]) -> NdArrayTensor<bool> {
116        NdArrayOps::flip(tensor, axes)
117    }
118}