1use alloc::vec;
3use alloc::vec::Vec;
4use burn_backend::tensor::BoolElem;
5use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
6use burn_backend::{
7 backend::ExecutionError,
8 ops::BoolTensorOps,
9 tensor::{BoolTensor, IntTensor},
10};
11use ndarray::IntoDimension;
12
13use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
15use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
16use crate::{NdArrayDevice, SharedArray};
17
18use burn_backend::{Shape, TensorData, backend::Backend};
20
21use super::{NdArrayBoolOps, NdArrayOps};
22
23impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
24 for NdArray<E, I, Q>
25where
26 NdArrayTensor: From<SharedArray<E>>,
27 NdArrayTensor: From<SharedArray<I>>,
28{
29 fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
30 if !data.dtype.is_bool() {
31 unimplemented!("Unsupported dtype for `bool_from_data`")
32 }
33 NdArrayTensor::from_data(data)
34 }
35
36 async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
37 Ok(tensor.into_data())
38 }
39
40 fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
41 tensor
42 }
43
44 fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
45 NdArrayOps::reshape(tensor.bool(), shape).into()
46 }
47
48 fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
49 NdArrayOps::slice(tensor.bool(), slices).into()
50 }
51
52 fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {
53 let int_array: SharedArray<I> = tensor.bool().mapv(|b| b.elem()).into_shared();
55 int_array.into()
56 }
57
58 fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
59 NdArrayDevice::Cpu
60 }
61
62 fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
63 Self::bool_zeros(shape, _device)
64 }
65
66 fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
67 let values = vec![false; shape.num_elements()];
68 NdArrayTensor::from_data(TensorData::new(values, shape))
69 }
70
71 fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
72 let values = vec![true; shape.num_elements()];
73 NdArrayTensor::from_data(TensorData::new(values, shape))
74 }
75
76 fn bool_slice_assign(
77 tensor: NdArrayTensor,
78 slices: &[burn_backend::Slice],
79 value: NdArrayTensor,
80 ) -> NdArrayTensor {
81 NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
82 }
83
84 fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
85 NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
86 }
87
88 fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
89 NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
90 }
91
92 fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
93 tensor.bool().mapv(|a| !a).into_shared().into()
94 }
95
96 fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
97 NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
98 }
99
100 fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
101 NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
102 }
103
104 fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
105 let arr: SharedArray<E> = tensor.bool().mapv(|b| b.elem()).into_shared();
106 arr.into()
107 }
108
109 fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
110 NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
111 }
112
113 fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
114 tensor.bool().permuted_axes(axes.into_dimension()).into()
115 }
116
117 fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
118 NdArrayOps::expand(tensor.bool(), shape).into()
119 }
120
121 fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
122 execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
123 let tensor_bool = tensor.bool();
124 let indices_vec: Vec<usize> = indices
125 .into_iter()
126 .map(|i| i.elem::<i64>() as usize)
127 .collect();
128
129 let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
130 selected.into_shared().into()
131 })
132 }
133
134 fn bool_select_or(
135 tensor: NdArrayTensor,
136 dim: usize,
137 indices: NdArrayTensor,
138 value: NdArrayTensor,
139 ) -> NdArrayTensor {
140 execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
141 let mut output_array = tensor.bool().into_owned();
142 let value_bool = value.bool();
143
144 for (index_value, index) in indices.into_iter().enumerate() {
145 let index_usize = index.elem::<i64>() as usize;
146 let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
147 let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
148 view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
150 }
151 output_array.into_shared().into()
152 })
153 }
154
155 fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
156 NdArrayOps::flip(tensor.bool(), axes).into()
157 }
158
159 fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
160 NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
161 }
162
163 fn bool_mask_where(
164 tensor: BoolTensor<Self>,
165 mask: BoolTensor<Self>,
166 value: BoolTensor<Self>,
167 ) -> BoolTensor<Self> {
168 NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
169 }
170
171 fn bool_mask_fill(
172 tensor: BoolTensor<Self>,
173 mask: BoolTensor<Self>,
174 value: BoolElem<Self>,
175 ) -> BoolTensor<Self> {
176 NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
177 }
178
179 fn bool_gather(
180 dim: usize,
181 tensor: BoolTensor<Self>,
182 indices: IntTensor<Self>,
183 ) -> BoolTensor<Self> {
184 execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
185 dim,
186 tensor.bool(),
187 indices
188 ))
189 }
190
191 fn bool_scatter_or(
192 dim: usize,
193 tensor: BoolTensor<Self>,
194 indices: IntTensor<Self>,
195 value: BoolTensor<Self>,
196 ) -> BoolTensor<Self> {
197 execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
198 dim,
199 tensor.bool(),
200 indices,
201 value.bool()
202 ))
203 }
204
205 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
206 NdArrayBoolOps::equal_elem(lhs.bool(), rhs).into()
207 }
208
209 fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
210 let result = NdArrayBoolOps::any_view(tensor.bool().view());
212 NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
213 }
214
215 fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
216 let result = NdArrayBoolOps::all_view(tensor.bool().view());
218 NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
219 }
220}