1use alloc::vec;
3use alloc::vec::Vec;
4use burn_backend::Scalar;
5use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
6use burn_backend::{
7 backend::ExecutionError,
8 ops::BoolTensorOps,
9 tensor::{BoolTensor, IntTensor},
10};
11use burn_std::{BoolDType, FloatDType, IntDType};
12use ndarray::IntoDimension;
13
14use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
16use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
17use crate::{
18 NdArrayDevice, SharedArray, execute_with_float_out_dtype, execute_with_int_out_dtype, slice,
19};
20
21use burn_backend::{Shape, TensorData, backend::Backend};
23
24use super::{NdArrayBoolOps, NdArrayOps};
25
26impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
27 for NdArray<E, I, Q>
28where
29 NdArrayTensor: From<SharedArray<E>>,
30 NdArrayTensor: From<SharedArray<I>>,
31{
32 fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
33 if !data.dtype.is_bool() {
34 unimplemented!("Unsupported dtype for `bool_from_data`")
35 }
36 NdArrayTensor::from_data(data)
37 }
38
39 async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
40 Ok(tensor.into_data())
41 }
42
43 fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
44 tensor
45 }
46
47 fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
48 NdArrayOps::reshape(tensor.bool(), shape).into()
49 }
50
51 fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
52 slice!(tensor, slices)
53 }
54
55 fn bool_into_int(tensor: NdArrayTensor, out_dtype: IntDType) -> NdArrayTensor {
56 execute_with_int_out_dtype!(
58 out_dtype,
59 I,
60 tensor.bool().mapv(|b| b.elem::<I>()).into_shared().into()
61 )
62 }
63
64 fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
65 NdArrayDevice::Cpu
66 }
67
68 fn bool_empty(
69 shape: Shape,
70 _device: &<NdArray<E> as Backend>::Device,
71 dtype: BoolDType,
72 ) -> NdArrayTensor {
73 Self::bool_zeros(shape, _device, dtype)
74 }
75
76 fn bool_zeros(
77 shape: Shape,
78 _device: &<NdArray<E> as Backend>::Device,
79 _dtype: BoolDType,
80 ) -> NdArrayTensor {
81 let values = vec![false; shape.num_elements()];
82 NdArrayTensor::from_data(TensorData::new(values, shape))
83 }
84
85 fn bool_ones(
86 shape: Shape,
87 _device: &<NdArray<E> as Backend>::Device,
88 _dtype: BoolDType,
89 ) -> NdArrayTensor {
90 let values = vec![true; shape.num_elements()];
91 NdArrayTensor::from_data(TensorData::new(values, shape))
92 }
93
94 fn bool_slice_assign(
95 tensor: NdArrayTensor,
96 slices: &[burn_backend::Slice],
97 value: NdArrayTensor,
98 ) -> NdArrayTensor {
99 NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
100 }
101
102 fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
103 NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
104 }
105
106 fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
107 NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
108 }
109
110 fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
111 tensor.bool().mapv(|a| !a).into_shared().into()
112 }
113
114 fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
115 NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
116 }
117
118 fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
119 NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
120 }
121
122 fn bool_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor<Self> {
123 execute_with_float_out_dtype!(
124 out_dtype,
125 E,
126 tensor.bool().mapv(|b| b.elem::<E>()).into_shared().into()
127 )
128 }
129
130 fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
131 NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
132 }
133
134 fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
135 tensor.bool().permuted_axes(axes.into_dimension()).into()
136 }
137
138 fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
139 NdArrayOps::expand(tensor.bool(), shape).into()
140 }
141
142 fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
143 execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
144 let tensor_bool = tensor.bool();
145 let indices_vec: Vec<usize> = indices
146 .into_iter()
147 .map(|i| i.elem::<i64>() as usize)
148 .collect();
149
150 let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
151 selected.into_shared().into()
152 })
153 }
154
155 fn bool_select_or(
156 tensor: NdArrayTensor,
157 dim: usize,
158 indices: NdArrayTensor,
159 value: NdArrayTensor,
160 ) -> NdArrayTensor {
161 execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
162 let mut output_array = tensor.bool().into_owned();
163 let value_bool = value.bool();
164
165 for (index_value, index) in indices.into_iter().enumerate() {
166 let index_usize = index.elem::<i64>() as usize;
167 let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
168 let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
169 view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
171 }
172 output_array.into_shared().into()
173 })
174 }
175
176 fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
177 NdArrayOps::flip(tensor.bool(), axes).into()
178 }
179
180 fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
181 NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
182 }
183
184 fn bool_mask_where(
185 tensor: BoolTensor<Self>,
186 mask: BoolTensor<Self>,
187 value: BoolTensor<Self>,
188 ) -> BoolTensor<Self> {
189 NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
190 }
191
192 fn bool_mask_fill(
193 tensor: BoolTensor<Self>,
194 mask: BoolTensor<Self>,
195 value: Scalar,
196 ) -> BoolTensor<Self> {
197 NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
198 }
199
200 fn bool_gather(
201 dim: usize,
202 tensor: BoolTensor<Self>,
203 indices: IntTensor<Self>,
204 ) -> BoolTensor<Self> {
205 execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
206 dim,
207 tensor.bool(),
208 indices
209 ))
210 }
211
212 fn bool_scatter_or(
213 dim: usize,
214 tensor: BoolTensor<Self>,
215 indices: IntTensor<Self>,
216 value: BoolTensor<Self>,
217 ) -> BoolTensor<Self> {
218 execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
219 dim,
220 tensor.bool(),
221 indices,
222 value.bool()
223 ))
224 }
225
226 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
227 NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into()
228 }
229
230 fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
231 let result = NdArrayBoolOps::any_view(tensor.bool().view());
233 NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
234 }
235
236 fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
237 let result = NdArrayBoolOps::all_view(tensor.bool().view());
239 NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
240 }
241}