Skip to main content

burn_ndarray/ops/
bool_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_backend::Scalar;
5use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
6use burn_backend::{
7    backend::ExecutionError,
8    ops::BoolTensorOps,
9    tensor::{BoolTensor, IntTensor},
10};
11use burn_std::{BoolDType, FloatDType, IntDType};
12use ndarray::IntoDimension;
13
14// Current crate
15use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
16use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
17use crate::{
18    NdArrayDevice, SharedArray, execute_with_float_out_dtype, execute_with_int_out_dtype, slice,
19};
20
21// Workspace crates
22use burn_backend::{Shape, TensorData, backend::Backend};
23
24use super::{NdArrayBoolOps, NdArrayOps};
25
26impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
27    for NdArray<E, I, Q>
28where
29    NdArrayTensor: From<SharedArray<E>>,
30    NdArrayTensor: From<SharedArray<I>>,
31{
32    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
33        if !data.dtype.is_bool() {
34            unimplemented!("Unsupported dtype for `bool_from_data`")
35        }
36        NdArrayTensor::from_data(data)
37    }
38
39    async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
40        Ok(tensor.into_data())
41    }
42
43    fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
44        tensor
45    }
46
47    fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
48        NdArrayOps::reshape(tensor.bool(), shape).into()
49    }
50
51    fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
52        slice!(tensor, slices)
53    }
54
55    fn bool_into_int(tensor: NdArrayTensor, out_dtype: IntDType) -> NdArrayTensor {
56        // Use mapv directly instead of collecting to Vec and going through TensorData
57        execute_with_int_out_dtype!(
58            out_dtype,
59            I,
60            tensor.bool().mapv(|b| b.elem::<I>()).into_shared().into()
61        )
62    }
63
64    fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
65        NdArrayDevice::Cpu
66    }
67
68    fn bool_empty(
69        shape: Shape,
70        _device: &<NdArray<E> as Backend>::Device,
71        dtype: BoolDType,
72    ) -> NdArrayTensor {
73        Self::bool_zeros(shape, _device, dtype)
74    }
75
76    fn bool_zeros(
77        shape: Shape,
78        _device: &<NdArray<E> as Backend>::Device,
79        _dtype: BoolDType,
80    ) -> NdArrayTensor {
81        let values = vec![false; shape.num_elements()];
82        NdArrayTensor::from_data(TensorData::new(values, shape))
83    }
84
85    fn bool_ones(
86        shape: Shape,
87        _device: &<NdArray<E> as Backend>::Device,
88        _dtype: BoolDType,
89    ) -> NdArrayTensor {
90        let values = vec![true; shape.num_elements()];
91        NdArrayTensor::from_data(TensorData::new(values, shape))
92    }
93
94    fn bool_slice_assign(
95        tensor: NdArrayTensor,
96        slices: &[burn_backend::Slice],
97        value: NdArrayTensor,
98    ) -> NdArrayTensor {
99        NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
100    }
101
102    fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
103        NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
104    }
105
106    fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
107        NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
108    }
109
110    fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
111        tensor.bool().mapv(|a| !a).into_shared().into()
112    }
113
114    fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
115        NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
116    }
117
118    fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
119        NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
120    }
121
122    fn bool_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor<Self> {
123        execute_with_float_out_dtype!(
124            out_dtype,
125            E,
126            tensor.bool().mapv(|b| b.elem::<E>()).into_shared().into()
127        )
128    }
129
130    fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
131        NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
132    }
133
134    fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
135        tensor.bool().permuted_axes(axes.into_dimension()).into()
136    }
137
138    fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
139        NdArrayOps::expand(tensor.bool(), shape).into()
140    }
141
142    fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
143        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
144            let tensor_bool = tensor.bool();
145            let indices_vec: Vec<usize> = indices
146                .into_iter()
147                .map(|i| i.elem::<i64>() as usize)
148                .collect();
149
150            let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
151            selected.into_shared().into()
152        })
153    }
154
155    fn bool_select_or(
156        tensor: NdArrayTensor,
157        dim: usize,
158        indices: NdArrayTensor,
159        value: NdArrayTensor,
160    ) -> NdArrayTensor {
161        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
162            let mut output_array = tensor.bool().into_owned();
163            let value_bool = value.bool();
164
165            for (index_value, index) in indices.into_iter().enumerate() {
166                let index_usize = index.elem::<i64>() as usize;
167                let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
168                let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
169                // For boolean tensors, select_assign should use logical OR operation
170                view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
171            }
172            output_array.into_shared().into()
173        })
174    }
175
176    fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
177        NdArrayOps::flip(tensor.bool(), axes).into()
178    }
179
180    fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
181        NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
182    }
183
184    fn bool_mask_where(
185        tensor: BoolTensor<Self>,
186        mask: BoolTensor<Self>,
187        value: BoolTensor<Self>,
188    ) -> BoolTensor<Self> {
189        NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
190    }
191
192    fn bool_mask_fill(
193        tensor: BoolTensor<Self>,
194        mask: BoolTensor<Self>,
195        value: Scalar,
196    ) -> BoolTensor<Self> {
197        NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
198    }
199
200    fn bool_gather(
201        dim: usize,
202        tensor: BoolTensor<Self>,
203        indices: IntTensor<Self>,
204    ) -> BoolTensor<Self> {
205        execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
206            dim,
207            tensor.bool(),
208            indices
209        ))
210    }
211
212    fn bool_scatter_or(
213        dim: usize,
214        tensor: BoolTensor<Self>,
215        indices: IntTensor<Self>,
216        value: BoolTensor<Self>,
217    ) -> BoolTensor<Self> {
218        execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
219            dim,
220            tensor.bool(),
221            indices,
222            value.bool()
223        ))
224    }
225
226    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
227        NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into()
228    }
229
230    fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
231        // Use view() for zero-copy on borrowed storage with short-circuit evaluation
232        let result = NdArrayBoolOps::any_view(tensor.bool().view());
233        NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
234    }
235
236    fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
237        // Use view() for zero-copy on borrowed storage with short-circuit evaluation
238        let result = NdArrayBoolOps::all_view(tensor.bool().view());
239        NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
240    }
241}