burn_ndarray/ops/
bool_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_tensor::backend::ExecutionError;
5use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensorOps};
6use burn_tensor::{ElementConversion, TensorMetadata};
7use ndarray::IntoDimension;
8
9// Current crate
10use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
11use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
12use crate::{NdArrayDevice, SharedArray};
13
14// Workspace crates
15use burn_tensor::{Shape, TensorData, backend::Backend};
16
17use super::{NdArrayBoolOps, NdArrayOps};
18
19impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
20    for NdArray<E, I, Q>
21where
22    NdArrayTensor: From<SharedArray<E>>,
23    NdArrayTensor: From<SharedArray<I>>,
24{
25    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
26        if !data.dtype.is_bool() {
27            unimplemented!("Unsupported dtype for `bool_from_data`")
28        }
29        NdArrayTensor::from_data(data)
30    }
31
32    async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
33        Ok(tensor.into_data())
34    }
35
36    fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
37        tensor
38    }
39
40    fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
41        NdArrayOps::reshape(tensor.bool(), shape).into()
42    }
43
44    fn bool_slice(tensor: NdArrayTensor, slices: &[burn_tensor::Slice]) -> NdArrayTensor {
45        NdArrayOps::slice(tensor.bool(), slices).into()
46    }
47
48    fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {
49        let shape = tensor.shape();
50        let values = tensor.bool().into_iter().collect();
51        NdArray::<E, I>::int_from_data(
52            TensorData::new(values, shape).convert::<I>(),
53            &NdArrayDevice::Cpu,
54        )
55    }
56
57    fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
58        NdArrayDevice::Cpu
59    }
60
61    fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
62        Self::bool_zeros(shape, _device)
63    }
64
65    fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
66        let values = vec![false; shape.num_elements()];
67        NdArrayTensor::from_data(TensorData::new(values, shape))
68    }
69
70    fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
71        let values = vec![true; shape.num_elements()];
72        NdArrayTensor::from_data(TensorData::new(values, shape))
73    }
74
75    fn bool_slice_assign(
76        tensor: NdArrayTensor,
77        slices: &[burn_tensor::Slice],
78        value: NdArrayTensor,
79    ) -> NdArrayTensor {
80        NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
81    }
82
83    fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
84        NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
85    }
86
87    fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
88        NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
89    }
90
91    fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
92        tensor.bool().mapv(|a| !a).into_shared().into()
93    }
94
95    fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
96        NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
97    }
98
99    fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
100        NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
101    }
102
103    fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
104        let arr: SharedArray<E> = tensor.bool().mapv(|a| (a as i32).elem()).into_shared();
105        arr.into()
106    }
107
108    fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
109        NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
110    }
111
112    fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
113        tensor.bool().permuted_axes(axes.into_dimension()).into()
114    }
115
116    fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
117        NdArrayOps::expand(tensor.bool(), shape).into()
118    }
119
120    fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
121        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
122            let tensor_bool = tensor.bool();
123            let indices_vec: Vec<usize> = indices
124                .into_iter()
125                .map(|i| i.elem::<i64>() as usize)
126                .collect();
127
128            let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
129            selected.into_shared().into()
130        })
131    }
132
133    fn bool_select_assign(
134        tensor: NdArrayTensor,
135        dim: usize,
136        indices: NdArrayTensor,
137        value: NdArrayTensor,
138    ) -> NdArrayTensor {
139        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
140            let mut output_array = tensor.bool().into_owned();
141            let value_bool = value.bool();
142
143            for (index_value, index) in indices.into_iter().enumerate() {
144                let index_usize = index.elem::<i64>() as usize;
145                let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
146                let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
147                // For boolean tensors, select_assign should use logical OR operation
148                view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
149            }
150            output_array.into_shared().into()
151        })
152    }
153
154    fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
155        NdArrayOps::flip(tensor.bool(), axes).into()
156    }
157
158    fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
159        NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
160    }
161}