burn_ndarray/ops/
bool_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_backend::tensor::BoolElem;
5use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
6use burn_backend::{
7    backend::ExecutionError,
8    ops::BoolTensorOps,
9    tensor::{BoolTensor, IntTensor},
10};
11use ndarray::IntoDimension;
12
13// Current crate
14use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
15use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
16use crate::{NdArrayDevice, SharedArray};
17
18// Workspace crates
19use burn_backend::{Shape, TensorData, backend::Backend};
20
21use super::{NdArrayBoolOps, NdArrayOps};
22
23impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
24    for NdArray<E, I, Q>
25where
26    NdArrayTensor: From<SharedArray<E>>,
27    NdArrayTensor: From<SharedArray<I>>,
28{
29    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
30        if !data.dtype.is_bool() {
31            unimplemented!("Unsupported dtype for `bool_from_data`")
32        }
33        NdArrayTensor::from_data(data)
34    }
35
36    async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
37        Ok(tensor.into_data())
38    }
39
40    fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
41        tensor
42    }
43
44    fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
45        NdArrayOps::reshape(tensor.bool(), shape).into()
46    }
47
48    fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
49        NdArrayOps::slice(tensor.bool(), slices).into()
50    }
51
52    fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {
53        // Use mapv directly instead of collecting to Vec and going through TensorData
54        let int_array: SharedArray<I> = tensor.bool().mapv(|b| b.elem()).into_shared();
55        int_array.into()
56    }
57
58    fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
59        NdArrayDevice::Cpu
60    }
61
62    fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
63        Self::bool_zeros(shape, _device)
64    }
65
66    fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
67        let values = vec![false; shape.num_elements()];
68        NdArrayTensor::from_data(TensorData::new(values, shape))
69    }
70
71    fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
72        let values = vec![true; shape.num_elements()];
73        NdArrayTensor::from_data(TensorData::new(values, shape))
74    }
75
76    fn bool_slice_assign(
77        tensor: NdArrayTensor,
78        slices: &[burn_backend::Slice],
79        value: NdArrayTensor,
80    ) -> NdArrayTensor {
81        NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
82    }
83
84    fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
85        NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
86    }
87
88    fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
89        NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
90    }
91
92    fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
93        tensor.bool().mapv(|a| !a).into_shared().into()
94    }
95
96    fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
97        NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
98    }
99
100    fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
101        NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
102    }
103
104    fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
105        let arr: SharedArray<E> = tensor.bool().mapv(|b| b.elem()).into_shared();
106        arr.into()
107    }
108
109    fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
110        NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
111    }
112
113    fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
114        tensor.bool().permuted_axes(axes.into_dimension()).into()
115    }
116
117    fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
118        NdArrayOps::expand(tensor.bool(), shape).into()
119    }
120
121    fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
122        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
123            let tensor_bool = tensor.bool();
124            let indices_vec: Vec<usize> = indices
125                .into_iter()
126                .map(|i| i.elem::<i64>() as usize)
127                .collect();
128
129            let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
130            selected.into_shared().into()
131        })
132    }
133
134    fn bool_select_or(
135        tensor: NdArrayTensor,
136        dim: usize,
137        indices: NdArrayTensor,
138        value: NdArrayTensor,
139    ) -> NdArrayTensor {
140        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
141            let mut output_array = tensor.bool().into_owned();
142            let value_bool = value.bool();
143
144            for (index_value, index) in indices.into_iter().enumerate() {
145                let index_usize = index.elem::<i64>() as usize;
146                let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
147                let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
148                // For boolean tensors, select_assign should use logical OR operation
149                view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
150            }
151            output_array.into_shared().into()
152        })
153    }
154
155    fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
156        NdArrayOps::flip(tensor.bool(), axes).into()
157    }
158
159    fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
160        NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
161    }
162
163    fn bool_mask_where(
164        tensor: BoolTensor<Self>,
165        mask: BoolTensor<Self>,
166        value: BoolTensor<Self>,
167    ) -> BoolTensor<Self> {
168        NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
169    }
170
171    fn bool_mask_fill(
172        tensor: BoolTensor<Self>,
173        mask: BoolTensor<Self>,
174        value: BoolElem<Self>,
175    ) -> BoolTensor<Self> {
176        NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
177    }
178
179    fn bool_gather(
180        dim: usize,
181        tensor: BoolTensor<Self>,
182        indices: IntTensor<Self>,
183    ) -> BoolTensor<Self> {
184        execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
185            dim,
186            tensor.bool(),
187            indices
188        ))
189    }
190
191    fn bool_scatter_or(
192        dim: usize,
193        tensor: BoolTensor<Self>,
194        indices: IntTensor<Self>,
195        value: BoolTensor<Self>,
196    ) -> BoolTensor<Self> {
197        execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
198            dim,
199            tensor.bool(),
200            indices,
201            value.bool()
202        ))
203    }
204
205    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
206        NdArrayBoolOps::equal_elem(lhs.bool(), rhs).into()
207    }
208
209    fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
210        // Use view() for zero-copy on borrowed storage with short-circuit evaluation
211        let result = NdArrayBoolOps::any_view(tensor.bool().view());
212        NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
213    }
214
215    fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
216        // Use view() for zero-copy on borrowed storage with short-circuit evaluation
217        let result = NdArrayBoolOps::all_view(tensor.bool().view());
218        NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
219    }
220}