burn_ndarray/ops/
bool_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensorOps};
5use burn_tensor::{ElementConversion, TensorMetadata};
6use core::ops::Range;
7use ndarray::{IntoDimension, Zip};
8
9// Current crate
10use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
11use crate::{new_tensor_float, NdArrayDevice};
12use crate::{tensor::NdArrayTensor, NdArray};
13
14// Workspace crates
15use burn_tensor::{backend::Backend, Shape, TensorData};
16
17use super::NdArrayOps;
18
19impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
20    for NdArray<E, I, Q>
21{
22    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
23        NdArrayTensor::from_data(data)
24    }
25
26    async fn bool_into_data(tensor: NdArrayTensor<bool>) -> TensorData {
27        let shape = tensor.shape();
28        let values = tensor.array.into_iter().collect();
29        TensorData::new(values, shape)
30    }
31
32    fn bool_to_device(tensor: NdArrayTensor<bool>, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
33        tensor
34    }
35
36    fn bool_reshape(tensor: NdArrayTensor<bool>, shape: Shape) -> NdArrayTensor<bool> {
37        NdArrayOps::reshape(tensor, shape)
38    }
39
40    fn bool_slice(tensor: NdArrayTensor<bool>, ranges: &[Range<usize>]) -> NdArrayTensor<bool> {
41        NdArrayOps::slice(tensor, ranges)
42    }
43
44    fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<I> {
45        let shape = tensor.shape();
46        let values = tensor.array.into_iter().collect();
47        NdArray::<E, I>::int_from_data(
48            TensorData::new(values, shape).convert::<I>(),
49            &NdArrayDevice::Cpu,
50        )
51    }
52
53    fn bool_device(_tensor: &NdArrayTensor<bool>) -> <NdArray<E> as Backend>::Device {
54        NdArrayDevice::Cpu
55    }
56
57    fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<bool> {
58        let values = vec![false; shape.num_elements()];
59        NdArrayTensor::from_data(TensorData::new(values, shape))
60    }
61
62    fn bool_slice_assign(
63        tensor: NdArrayTensor<bool>,
64        ranges: &[Range<usize>],
65        value: NdArrayTensor<bool>,
66    ) -> NdArrayTensor<bool> {
67        NdArrayOps::slice_assign(tensor, ranges, value)
68    }
69
70    fn bool_cat(tensors: Vec<NdArrayTensor<bool>>, dim: usize) -> NdArrayTensor<bool> {
71        NdArrayOps::cat(tensors, dim)
72    }
73
74    fn bool_equal(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
75        let output = Zip::from(&lhs.array)
76            .and(&rhs.array)
77            .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
78            .into_shared();
79        NdArrayTensor::new(output)
80    }
81
82    fn bool_not(tensor: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
83        let array = tensor.array.mapv(|a| !a).into_shared();
84        NdArrayTensor { array }
85    }
86
87    fn bool_into_float(tensor: NdArrayTensor<bool>) -> FloatTensor<Self> {
88        new_tensor_float!(NdArrayTensor {
89            array: tensor.array.mapv(|a| (a as i32).elem()).into_shared(),
90        })
91    }
92
93    fn bool_swap_dims(
94        tensor: NdArrayTensor<bool>,
95        dim1: usize,
96        dim2: usize,
97    ) -> NdArrayTensor<bool> {
98        NdArrayOps::swap_dims(tensor, dim1, dim2)
99    }
100
101    fn bool_permute(tensor: NdArrayTensor<bool>, axes: &[usize]) -> NdArrayTensor<bool> {
102        let array = tensor.array.permuted_axes(axes.into_dimension());
103        NdArrayTensor { array }
104    }
105
106    fn bool_expand(tensor: NdArrayTensor<bool>, shape: Shape) -> NdArrayTensor<bool> {
107        NdArrayOps::expand(tensor, shape)
108    }
109
110    fn bool_flip(tensor: NdArrayTensor<bool>, axes: &[usize]) -> NdArrayTensor<bool> {
111        NdArrayOps::flip(tensor, axes)
112    }
113}