1use burn_backend::{
2 ExecutionError, Scalar, TensorData,
3 ops::BoolTensorOps,
4 tensor::{BoolTensor, FloatTensor, IntTensor},
5};
6use burn_std::{Shape, Slice};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl BoolTensorOps<Self> for Dispatch {
12 fn bool_empty(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
13 creation_op!(Bool, device, |device| B::bool_empty(shape, device))
14 }
15
16 fn bool_zeros(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
17 creation_op!(Bool, device, |device| B::bool_zeros(shape, device))
18 }
19
20 fn bool_ones(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
21 creation_op!(Bool, device, |device| B::bool_ones(shape, device))
22 }
23
24 async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
25 unary_op!(tensor, bool, |tensor| B::bool_into_data(tensor).await)
26 }
27
28 fn bool_from_data(data: TensorData, device: &DispatchDevice) -> BoolTensor<Self> {
29 creation_op!(Bool, device, |device| B::bool_from_data(data, device))
30 }
31
32 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
33 unary_op!(tensor, bool, |tensor| B::bool_into_int(tensor) => Int)
34 }
35
36 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
37 unary_op!(tensor, bool, |tensor| B::bool_into_float(tensor) => Float)
38 }
39
40 fn bool_device(tensor: &BoolTensor<Self>) -> DispatchDevice {
41 tensor.device()
42 }
43
44 fn bool_to_device(tensor: BoolTensor<Self>, device: &DispatchDevice) -> BoolTensor<Self> {
45 to_device!(
46 Bool,
47 bool,
48 tensor,
49 device,
50 bool_to_device,
51 |inner, device| {
52 let data =
53 burn_backend::read_sync(B1::bool_into_data(inner)).expect("Should read data");
54 B2::bool_from_data(data, device)
55 }
56 )
57 }
58
59 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
60 unary_op!(tensor, bool, |tensor| B::bool_reshape(tensor, shape) => Bool)
61 }
62
63 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
64 unary_op!(tensor, bool, |tensor| B::bool_slice(tensor, slices) => Bool)
65 }
66
67 fn bool_slice_assign(
68 tensor: BoolTensor<Self>,
69 slices: &[Slice],
70 value: BoolTensor<Self>,
71 ) -> BoolTensor<Self> {
72 binary_op!((tensor, bool), (value, bool), |tensor, value| B::bool_slice_assign(tensor, slices, value) => Bool)
73 }
74
75 fn bool_mask_where(
76 tensor: BoolTensor<Self>,
77 mask: BoolTensor<Self>,
78 value: BoolTensor<Self>,
79 ) -> BoolTensor<Self> {
80 multi_op!(
81 inputs[(tensor, bool), (mask, bool), (value, bool)], => Bool,
82 B::bool_mask_where(tensor, mask, value)
83 )
84 }
85
86 fn bool_mask_fill(
87 tensor: BoolTensor<Self>,
88 mask: BoolTensor<Self>,
89 value: Scalar,
90 ) -> BoolTensor<Self> {
91 binary_op!((tensor, bool), (mask, bool), |tensor, mask| B::bool_mask_fill(tensor, mask, value) => Bool)
92 }
93
94 fn bool_gather(
95 dim: usize,
96 tensor: BoolTensor<Self>,
97 indices: IntTensor<Self>,
98 ) -> BoolTensor<Self> {
99 binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_gather(dim, tensor, indices) => Bool)
100 }
101
102 fn bool_scatter_or(
103 dim: usize,
104 tensor: BoolTensor<Self>,
105 indices: IntTensor<Self>,
106 value: BoolTensor<Self>,
107 ) -> BoolTensor<Self> {
108 multi_op!(
109 inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
110 B::bool_scatter_or(dim, tensor, indices, value)
111 )
112 }
113
114 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
115 binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_equal(lhs, rhs) => Bool)
116 }
117
118 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
119 unary_op!(lhs, bool, |lhs| B::bool_equal_elem(lhs, rhs) => Bool)
120 }
121
122 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
123 unary_op!(tensor, bool, |tensor| B::bool_not(tensor) => Bool)
124 }
125
126 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
127 binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_and(lhs, rhs) => Bool)
128 }
129
130 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
131 binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_or(lhs, rhs) => Bool)
132 }
133
134 fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
135 unary_op!(tensor, bool, |tensor| B::bool_swap_dims(tensor, dim1, dim2) => Bool)
136 }
137
138 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
139 unary_op!(tensor, bool, |tensor| B::bool_permute(tensor, axes) => Bool)
140 }
141
142 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
143 unary_op!(tensor, bool, |tensor| B::bool_flip(tensor, axes) => Bool)
144 }
145
146 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
147 unary_op!(tensor, bool, |tensor| B::bool_expand(tensor, shape) => Bool)
148 }
149
150 fn bool_unfold(
151 tensor: BoolTensor<Self>,
152 dim: usize,
153 size: usize,
154 step: usize,
155 ) -> BoolTensor<Self> {
156 unary_op!(tensor, bool, |tensor| B::bool_unfold(tensor, dim, size, step) => Bool)
157 }
158
159 fn bool_select(
160 tensor: BoolTensor<Self>,
161 dim: usize,
162 indices: IntTensor<Self>,
163 ) -> BoolTensor<Self> {
164 binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_select(tensor, dim, indices) => Bool)
165 }
166
167 fn bool_select_or(
168 tensor: BoolTensor<Self>,
169 dim: usize,
170 indices: IntTensor<Self>,
171 value: BoolTensor<Self>,
172 ) -> BoolTensor<Self> {
173 multi_op!(
174 inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
175 B::bool_select_or(tensor, dim, indices, value)
176 )
177 }
178
179 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
180 unary_op!(tensor, bool, |tensor| B::bool_repeat_dim(tensor, dim, times) => Bool)
181 }
182
183 fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
184 vec_op!(tensors, bool, |tensors| B::bool_cat(tensors, dim) => Bool)
185 }
186
187 fn bool_not_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
188 binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_not_equal(lhs, rhs) => Bool)
189 }
190
191 fn bool_not_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
192 unary_op!(lhs, bool, |lhs| B::bool_not_equal_elem(lhs, rhs) => Bool)
193 }
194
195 fn bool_xor(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
196 binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_xor(lhs, rhs) => Bool)
197 }
198
199 fn bool_transpose(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
200 unary_op!(tensor, bool, |tensor| B::bool_transpose(tensor) => Bool)
201 }
202
203 fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
204 unary_op!(tensor, bool, |tensor| B::bool_any(tensor) => Bool)
205 }
206
207 fn bool_any_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
208 unary_op!(tensor, bool, |tensor| B::bool_any_dim(tensor, dim) => Bool)
209 }
210
211 fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
212 unary_op!(tensor, bool, |tensor| B::bool_all(tensor) => Bool)
213 }
214
215 fn bool_all_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
216 unary_op!(tensor, bool, |tensor| B::bool_all_dim(tensor, dim) => Bool)
217 }
218
219 async fn bool_argwhere(tensor: BoolTensor<Self>) -> IntTensor<Self> {
220 unary_op!(tensor, bool, |tensor| B::bool_argwhere(tensor).await => Int)
221 }
222}