1use alloc::vec;
3use alloc::vec::Vec;
4use burn_backend::Scalar;
5use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
6use burn_backend::{
7 backend::ExecutionError,
8 ops::BoolTensorOps,
9 tensor::{BoolTensor, IntTensor},
10};
11use burn_std::{BoolDType, FloatDType, IntDType};
12use ndarray::IntoDimension;
13
14use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
16use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
17use crate::{
18 NdArrayDevice, SharedArray, execute_with_float_out_dtype, execute_with_int_out_dtype, slice,
19};
20
21use burn_backend::{Shape, TensorData};
23
24use super::{NdArrayBoolOps, NdArrayOps};
25
26impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
27 for NdArray<E, I, Q>
28where
29 NdArrayTensor: From<SharedArray<E>>,
30 NdArrayTensor: From<SharedArray<I>>,
31{
32 fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
33 if !data.dtype.is_bool() {
34 unimplemented!("Unsupported dtype for `bool_from_data`")
35 }
36 NdArrayTensor::from_data(data)
37 }
38
39 async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
40 Ok(tensor.into_data())
41 }
42
43 fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
44 tensor
45 }
46
47 fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
48 NdArrayOps::reshape(tensor.bool(), shape).into()
49 }
50
51 fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
52 slice!(tensor, slices)
53 }
54
55 fn bool_into_int(tensor: NdArrayTensor, out_dtype: IntDType) -> NdArrayTensor {
56 execute_with_int_out_dtype!(
58 out_dtype,
59 I,
60 tensor.bool().mapv(|b| b.elem::<I>()).into_shared().into()
61 )
62 }
63
64 fn bool_device(_tensor: &NdArrayTensor) -> NdArrayDevice {
65 NdArrayDevice::Cpu
66 }
67
68 fn bool_empty(shape: Shape, _device: &NdArrayDevice, dtype: BoolDType) -> NdArrayTensor {
69 Self::bool_zeros(shape, _device, dtype)
70 }
71
72 fn bool_zeros(shape: Shape, _device: &NdArrayDevice, _dtype: BoolDType) -> NdArrayTensor {
73 let values = vec![false; shape.num_elements()];
74 NdArrayTensor::from_data(TensorData::new(values, shape))
75 }
76
77 fn bool_ones(shape: Shape, _device: &NdArrayDevice, _dtype: BoolDType) -> NdArrayTensor {
78 let values = vec![true; shape.num_elements()];
79 NdArrayTensor::from_data(TensorData::new(values, shape))
80 }
81
82 fn bool_slice_assign(
83 tensor: NdArrayTensor,
84 slices: &[burn_backend::Slice],
85 value: NdArrayTensor,
86 ) -> NdArrayTensor {
87 NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
88 }
89
90 fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
91 NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
92 }
93
94 fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
95 NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
96 }
97
98 fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
99 tensor.bool().mapv(|a| !a).into_shared().into()
100 }
101
102 fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
103 NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
104 }
105
106 fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
107 NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
108 }
109
110 fn bool_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor<Self> {
111 execute_with_float_out_dtype!(
112 out_dtype,
113 E,
114 tensor.bool().mapv(|b| b.elem::<E>()).into_shared().into()
115 )
116 }
117
118 fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
119 NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
120 }
121
122 fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
123 tensor.bool().permuted_axes(axes.into_dimension()).into()
124 }
125
126 fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
127 NdArrayOps::expand(tensor.bool(), shape).into()
128 }
129
130 fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
131 execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
132 let tensor_bool = tensor.bool();
133 let indices_vec: Vec<usize> = indices
134 .into_iter()
135 .map(|i| i.elem::<i64>() as usize)
136 .collect();
137
138 let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
139 selected.into_shared().into()
140 })
141 }
142
143 fn bool_select_or(
144 tensor: NdArrayTensor,
145 dim: usize,
146 indices: NdArrayTensor,
147 value: NdArrayTensor,
148 ) -> NdArrayTensor {
149 execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
150 let mut output_array = tensor.bool().into_owned();
151 let value_bool = value.bool();
152
153 for (index_value, index) in indices.into_iter().enumerate() {
154 let index_usize = index.elem::<i64>() as usize;
155 let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
156 let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
157 view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
159 }
160 output_array.into_shared().into()
161 })
162 }
163
164 fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
165 NdArrayOps::flip(tensor.bool(), axes).into()
166 }
167
168 fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
169 NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
170 }
171
172 fn bool_mask_where(
173 tensor: BoolTensor<Self>,
174 mask: BoolTensor<Self>,
175 value: BoolTensor<Self>,
176 ) -> BoolTensor<Self> {
177 NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
178 }
179
180 fn bool_mask_fill(
181 tensor: BoolTensor<Self>,
182 mask: BoolTensor<Self>,
183 value: Scalar,
184 ) -> BoolTensor<Self> {
185 NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
186 }
187
188 fn bool_gather(
189 dim: usize,
190 tensor: BoolTensor<Self>,
191 indices: IntTensor<Self>,
192 ) -> BoolTensor<Self> {
193 execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
194 dim,
195 tensor.bool(),
196 indices
197 ))
198 }
199
200 fn bool_scatter_or(
201 dim: usize,
202 tensor: BoolTensor<Self>,
203 indices: IntTensor<Self>,
204 value: BoolTensor<Self>,
205 ) -> BoolTensor<Self> {
206 execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
207 dim,
208 tensor.bool(),
209 indices,
210 value.bool()
211 ))
212 }
213
214 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
215 NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into()
216 }
217
218 fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
219 let result = NdArrayBoolOps::any_view(tensor.bool().view());
221 NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
222 }
223
224 fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
225 let result = NdArrayBoolOps::all_view(tensor.bool().view());
227 NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
228 }
229}