Skip to main content

burn_dispatch/ops/
bool_tensor.rs

1use burn_backend::{
2    ExecutionError, Scalar, TensorData,
3    ops::BoolTensorOps,
4    tensor::{BoolTensor, FloatTensor, IntTensor},
5};
6use burn_std::{Shape, Slice};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl BoolTensorOps<Self> for Dispatch {
12    fn bool_empty(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
13        creation_op!(Bool, device, |device| B::bool_empty(shape, device))
14    }
15
16    fn bool_zeros(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
17        creation_op!(Bool, device, |device| B::bool_zeros(shape, device))
18    }
19
20    fn bool_ones(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
21        creation_op!(Bool, device, |device| B::bool_ones(shape, device))
22    }
23
24    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
25        unary_op!(tensor, bool, |tensor| B::bool_into_data(tensor).await)
26    }
27
28    fn bool_from_data(data: TensorData, device: &DispatchDevice) -> BoolTensor<Self> {
29        creation_op!(Bool, device, |device| B::bool_from_data(data, device))
30    }
31
32    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
33        unary_op!(tensor, bool, |tensor| B::bool_into_int(tensor) => Int)
34    }
35
36    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
37        unary_op!(tensor, bool, |tensor| B::bool_into_float(tensor) => Float)
38    }
39
40    fn bool_device(tensor: &BoolTensor<Self>) -> DispatchDevice {
41        tensor.device()
42    }
43
44    fn bool_to_device(tensor: BoolTensor<Self>, device: &DispatchDevice) -> BoolTensor<Self> {
45        to_device!(
46            Bool,
47            bool,
48            tensor,
49            device,
50            bool_to_device,
51            |inner, device| {
52                let data =
53                    burn_backend::read_sync(B1::bool_into_data(inner)).expect("Should read data");
54                B2::bool_from_data(data, device)
55            }
56        )
57    }
58
59    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
60        unary_op!(tensor, bool, |tensor| B::bool_reshape(tensor, shape) => Bool)
61    }
62
63    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
64        unary_op!(tensor, bool, |tensor| B::bool_slice(tensor, slices) => Bool)
65    }
66
67    fn bool_slice_assign(
68        tensor: BoolTensor<Self>,
69        slices: &[Slice],
70        value: BoolTensor<Self>,
71    ) -> BoolTensor<Self> {
72        binary_op!((tensor, bool), (value, bool), |tensor, value| B::bool_slice_assign(tensor, slices, value) => Bool)
73    }
74
75    fn bool_mask_where(
76        tensor: BoolTensor<Self>,
77        mask: BoolTensor<Self>,
78        value: BoolTensor<Self>,
79    ) -> BoolTensor<Self> {
80        multi_op!(
81            inputs[(tensor, bool), (mask, bool), (value, bool)], => Bool,
82            B::bool_mask_where(tensor, mask, value)
83        )
84    }
85
86    fn bool_mask_fill(
87        tensor: BoolTensor<Self>,
88        mask: BoolTensor<Self>,
89        value: Scalar,
90    ) -> BoolTensor<Self> {
91        binary_op!((tensor, bool), (mask, bool), |tensor, mask| B::bool_mask_fill(tensor, mask, value) => Bool)
92    }
93
94    fn bool_gather(
95        dim: usize,
96        tensor: BoolTensor<Self>,
97        indices: IntTensor<Self>,
98    ) -> BoolTensor<Self> {
99        binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_gather(dim, tensor, indices) => Bool)
100    }
101
102    fn bool_scatter_or(
103        dim: usize,
104        tensor: BoolTensor<Self>,
105        indices: IntTensor<Self>,
106        value: BoolTensor<Self>,
107    ) -> BoolTensor<Self> {
108        multi_op!(
109            inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
110            B::bool_scatter_or(dim, tensor, indices, value)
111        )
112    }
113
114    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
115        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_equal(lhs, rhs) => Bool)
116    }
117
118    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
119        unary_op!(lhs, bool, |lhs| B::bool_equal_elem(lhs, rhs) => Bool)
120    }
121
122    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
123        unary_op!(tensor, bool, |tensor| B::bool_not(tensor) => Bool)
124    }
125
126    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
127        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_and(lhs, rhs) => Bool)
128    }
129
130    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
131        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_or(lhs, rhs) => Bool)
132    }
133
134    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
135        unary_op!(tensor, bool, |tensor| B::bool_swap_dims(tensor, dim1, dim2) => Bool)
136    }
137
138    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
139        unary_op!(tensor, bool, |tensor| B::bool_permute(tensor, axes) => Bool)
140    }
141
142    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
143        unary_op!(tensor, bool, |tensor| B::bool_flip(tensor, axes) => Bool)
144    }
145
146    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
147        unary_op!(tensor, bool, |tensor| B::bool_expand(tensor, shape) => Bool)
148    }
149
150    fn bool_unfold(
151        tensor: BoolTensor<Self>,
152        dim: usize,
153        size: usize,
154        step: usize,
155    ) -> BoolTensor<Self> {
156        unary_op!(tensor, bool, |tensor| B::bool_unfold(tensor, dim, size, step) => Bool)
157    }
158
159    fn bool_select(
160        tensor: BoolTensor<Self>,
161        dim: usize,
162        indices: IntTensor<Self>,
163    ) -> BoolTensor<Self> {
164        binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_select(tensor, dim, indices) => Bool)
165    }
166
167    fn bool_select_or(
168        tensor: BoolTensor<Self>,
169        dim: usize,
170        indices: IntTensor<Self>,
171        value: BoolTensor<Self>,
172    ) -> BoolTensor<Self> {
173        multi_op!(
174            inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
175            B::bool_select_or(tensor, dim, indices, value)
176        )
177    }
178
179    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
180        unary_op!(tensor, bool, |tensor| B::bool_repeat_dim(tensor, dim, times) => Bool)
181    }
182
183    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
184        vec_op!(tensors, bool, |tensors| B::bool_cat(tensors, dim) => Bool)
185    }
186
187    fn bool_not_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
188        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_not_equal(lhs, rhs) => Bool)
189    }
190
191    fn bool_not_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
192        unary_op!(lhs, bool, |lhs| B::bool_not_equal_elem(lhs, rhs) => Bool)
193    }
194
195    fn bool_xor(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
196        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_xor(lhs, rhs) => Bool)
197    }
198
199    fn bool_transpose(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
200        unary_op!(tensor, bool, |tensor| B::bool_transpose(tensor) => Bool)
201    }
202
203    fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
204        unary_op!(tensor, bool, |tensor| B::bool_any(tensor) => Bool)
205    }
206
207    fn bool_any_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
208        unary_op!(tensor, bool, |tensor| B::bool_any_dim(tensor, dim) => Bool)
209    }
210
211    fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
212        unary_op!(tensor, bool, |tensor| B::bool_all(tensor) => Bool)
213    }
214
215    fn bool_all_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
216        unary_op!(tensor, bool, |tensor| B::bool_all_dim(tensor, dim) => Bool)
217    }
218
219    async fn bool_argwhere(tensor: BoolTensor<Self>) -> IntTensor<Self> {
220        unary_op!(tensor, bool, |tensor| B::bool_argwhere(tensor).await => Int)
221    }
222}