Skip to main content

burn_router/ops/
bool_tensor.rs

1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3
4use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
5use burn_backend::ops::BoolTensorOps;
6use burn_backend::tensor::{
7    BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
8};
9use burn_backend::{Element, Shape, Slice, TensorData};
10use burn_ir::{
11    BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
12    GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
13    PermuteOpIr, RepeatDimOpIr, ScalarIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr,
14    SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
15};
16
17impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
18    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
19        let client = get_client::<R>(device);
20        let desc =
21            CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
22
23        client
24            .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
25            .output()
26    }
27
28    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
29        let client = get_client::<R>(device);
30        let desc =
31            CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
32
33        client
34            .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
35            .output()
36    }
37
38    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
39        let client = get_client::<R>(device);
40        let desc =
41            CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
42
43        client
44            .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
45            .output()
46    }
47
48    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
49        tensor.into_data().await
50    }
51
52    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
53        let client = get_client::<R>(device);
54        let out = client.register_tensor_data(data);
55        let desc = InitOperationIr {
56            out: out.to_ir_out(),
57        };
58
59        // Call register op when output is already initialized
60        client.register_op(OperationIr::Init(desc));
61
62        out
63    }
64
65    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
66        let client = tensor.client.clone();
67        let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {
68            client.create_empty_handle()
69        });
70
71        client
72            .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
73            .output()
74    }
75
76    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
77        let client = tensor.client.clone();
78        let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {
79            client.create_empty_handle()
80        });
81
82        client
83            .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
84            .output()
85    }
86
87    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
88        tensor.client.device()
89    }
90
91    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
92        if &tensor.client.device() == device {
93            return tensor;
94        }
95        R::change_client_backend(tensor, device)
96    }
97
98    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
99        let client = tensor.client.clone();
100        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
101
102        client
103            .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
104            .output()
105    }
106
107    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
108        let client = tensor.client.clone();
109        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
110            client.create_empty_handle()
111        });
112
113        client
114            .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
115            .output()
116    }
117
118    fn bool_slice_assign(
119        tensor: BoolTensor<Self>,
120        slices: &[burn_backend::Slice],
121        value: BoolTensor<Self>,
122    ) -> BoolTensor<Self> {
123        let client = tensor.client.clone();
124        let desc =
125            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
126                client.create_empty_handle()
127            });
128
129        client
130            .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
131            .output()
132    }
133
134    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
135        let client = lhs.client.clone();
136        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
137            client.create_empty_handle()
138        });
139
140        client
141            .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
142            .output()
143    }
144
145    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
146        let client = tensor.client.clone();
147        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
148
149        client
150            .register(OperationIr::Bool(BoolOperationIr::Not(desc)))
151            .output()
152    }
153
154    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
155        let client = lhs.client.clone();
156        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
157            client.create_empty_handle()
158        });
159
160        client
161            .register(OperationIr::Bool(BoolOperationIr::And(desc)))
162            .output()
163    }
164
165    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
166        let client = lhs.client.clone();
167        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
168            client.create_empty_handle()
169        });
170
171        client
172            .register(OperationIr::Bool(BoolOperationIr::Or(desc)))
173            .output()
174    }
175
176    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
177        let client = tensor.client.clone();
178        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
179            client.create_empty_handle()
180        });
181
182        client
183            .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
184            .output()
185    }
186
187    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
188        let client = tensor.client.clone();
189        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
190            client.create_empty_handle()
191        });
192
193        client
194            .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
195            .output()
196    }
197
198    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
199        let client = tensor.client.clone();
200        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
201            client.create_empty_handle()
202        });
203
204        client
205            .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
206            .output()
207    }
208
209    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
210        let client = tensor.client.clone();
211        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
212
213        client
214            .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
215            .output()
216    }
217
218    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
219        let client = tensors.first().unwrap().client.clone();
220        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
221        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
222
223        client
224            .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
225            .output()
226    }
227
228    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
229        let client = tensor.client.clone();
230        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
231            client.create_empty_handle()
232        });
233
234        client
235            .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
236            .output()
237    }
238
239    fn bool_unfold(
240        tensor: BoolTensor<Self>,
241        dim: usize,
242        size: usize,
243        step: usize,
244    ) -> BoolTensor<Self> {
245        let client = tensor.client.clone();
246        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
247            client.create_empty_handle()
248        });
249
250        client
251            .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
252            .output()
253    }
254
255    fn bool_mask_where(
256        tensor: BoolTensor<Self>,
257        mask: BoolTensor<Self>,
258        value: BoolTensor<Self>,
259    ) -> BoolTensor<Self> {
260        let client = tensor.client.clone();
261        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
262            client.create_empty_handle()
263        });
264
265        client
266            .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
267            .output()
268    }
269
270    fn bool_mask_fill(
271        tensor: BoolTensor<Self>,
272        mask: BoolTensor<Self>,
273        value: burn_backend::tensor::BoolElem<Self>,
274    ) -> BoolTensor<Self> {
275        let client = tensor.client.clone();
276        let value = ScalarIr::with_dtype(value, &tensor.dtype);
277        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
278            client.create_empty_handle()
279        });
280
281        client
282            .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
283            .output()
284    }
285
286    fn bool_gather(
287        dim: usize,
288        tensor: BoolTensor<Self>,
289        indices: IntTensor<Self>,
290    ) -> BoolTensor<Self> {
291        let client = tensor.client.clone();
292        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
293            client.create_empty_handle()
294        });
295
296        client
297            .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
298            .output()
299    }
300
301    fn bool_scatter_or(
302        dim: usize,
303        tensor: BoolTensor<Self>,
304        indices: IntTensor<Self>,
305        value: BoolTensor<Self>,
306    ) -> BoolTensor<Self> {
307        let client = tensor.client.clone();
308        let desc = ScatterOpIr::create(
309            tensor.into_ir(),
310            dim,
311            indices.into_ir(),
312            value.into_ir(),
313            IndexingUpdateOp::Add,
314            || client.create_empty_handle(),
315        );
316
317        client
318            .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
319            .output()
320    }
321
322    fn bool_equal_elem(
323        lhs: BoolTensor<Self>,
324        rhs: burn_backend::tensor::BoolElem<Self>,
325    ) -> BoolTensor<Self> {
326        let client = lhs.client.clone();
327        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
328        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
329            client.create_empty_handle()
330        });
331
332        client
333            .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
334            .output()
335    }
336}