Skip to main content

burn_router/ops/
bool_tensor.rs

1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3use burn_std::{BoolDType, FloatDType, IntDType};
4
5use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
6use burn_backend::ops::BoolTensorOps;
7use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor};
8use burn_backend::{Scalar, Shape, Slice, TensorData};
9use burn_ir::{
10    BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
11    GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
12    PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr,
13    SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
14};
15
16impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
17    fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
18        let client = get_client::<R>(device);
19        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
20
21        client
22            .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
23            .output()
24    }
25
26    fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
27        let client = get_client::<R>(device);
28        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
29
30        client
31            .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
32            .output()
33    }
34
35    fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
36        let client = get_client::<R>(device);
37        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
38
39        client
40            .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
41            .output()
42    }
43
44    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
45        tensor.into_data().await
46    }
47
48    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
49        let client = get_client::<R>(device);
50        let out = client.register_tensor_data(data);
51        let desc = InitOperationIr {
52            out: out.to_ir_out(),
53        };
54
55        // Call register op when output is already initialized
56        client.register_op(OperationIr::Init(desc));
57
58        out
59    }
60
61    fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
62        let client = tensor.client.clone();
63        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
64            client.create_empty_handle()
65        });
66
67        client
68            .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
69            .output()
70    }
71
72    fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
73        let client = tensor.client.clone();
74        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
75            client.create_empty_handle()
76        });
77
78        client
79            .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
80            .output()
81    }
82
83    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
84        tensor.client.device()
85    }
86
87    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
88        if &tensor.client.device() == device {
89            return tensor;
90        }
91        R::change_client_backend(tensor, device)
92    }
93
94    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
95        let client = tensor.client.clone();
96        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
97
98        client
99            .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
100            .output()
101    }
102
103    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
104        let client = tensor.client.clone();
105        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
106            client.create_empty_handle()
107        });
108
109        client
110            .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
111            .output()
112    }
113
114    fn bool_slice_assign(
115        tensor: BoolTensor<Self>,
116        slices: &[burn_backend::Slice],
117        value: BoolTensor<Self>,
118    ) -> BoolTensor<Self> {
119        let client = tensor.client.clone();
120        let desc =
121            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
122                client.create_empty_handle()
123            });
124
125        client
126            .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
127            .output()
128    }
129
130    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
131        let client = lhs.client.clone();
132        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
133            client.create_empty_handle()
134        });
135
136        client
137            .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
138            .output()
139    }
140
141    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
142        let client = tensor.client.clone();
143        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
144
145        client
146            .register(OperationIr::Bool(BoolOperationIr::Not(desc)))
147            .output()
148    }
149
150    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
151        let client = lhs.client.clone();
152        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
153            client.create_empty_handle()
154        });
155
156        client
157            .register(OperationIr::Bool(BoolOperationIr::And(desc)))
158            .output()
159    }
160
161    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
162        let client = lhs.client.clone();
163        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
164            client.create_empty_handle()
165        });
166
167        client
168            .register(OperationIr::Bool(BoolOperationIr::Or(desc)))
169            .output()
170    }
171
172    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
173        let client = tensor.client.clone();
174        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
175            client.create_empty_handle()
176        });
177
178        client
179            .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
180            .output()
181    }
182
183    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
184        let client = tensor.client.clone();
185        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
186            client.create_empty_handle()
187        });
188
189        client
190            .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
191            .output()
192    }
193
194    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
195        let client = tensor.client.clone();
196        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
197            client.create_empty_handle()
198        });
199
200        client
201            .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
202            .output()
203    }
204
205    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
206        let client = tensor.client.clone();
207        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
208
209        client
210            .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
211            .output()
212    }
213
214    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
215        let client = tensors.first().unwrap().client.clone();
216        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
217        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
218
219        client
220            .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
221            .output()
222    }
223
224    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
225        let client = tensor.client.clone();
226        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
227            client.create_empty_handle()
228        });
229
230        client
231            .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
232            .output()
233    }
234
235    fn bool_unfold(
236        tensor: BoolTensor<Self>,
237        dim: usize,
238        size: usize,
239        step: usize,
240    ) -> BoolTensor<Self> {
241        let client = tensor.client.clone();
242        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
243            client.create_empty_handle()
244        });
245
246        client
247            .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
248            .output()
249    }
250
251    fn bool_mask_where(
252        tensor: BoolTensor<Self>,
253        mask: BoolTensor<Self>,
254        value: BoolTensor<Self>,
255    ) -> BoolTensor<Self> {
256        let client = tensor.client.clone();
257        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
258            client.create_empty_handle()
259        });
260
261        client
262            .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
263            .output()
264    }
265
266    fn bool_mask_fill(
267        tensor: BoolTensor<Self>,
268        mask: BoolTensor<Self>,
269        value: Scalar,
270    ) -> BoolTensor<Self> {
271        let client = tensor.client.clone();
272        let value = value.into();
273        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
274            client.create_empty_handle()
275        });
276
277        client
278            .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
279            .output()
280    }
281
282    fn bool_gather(
283        dim: usize,
284        tensor: BoolTensor<Self>,
285        indices: IntTensor<Self>,
286    ) -> BoolTensor<Self> {
287        let client = tensor.client.clone();
288        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
289            client.create_empty_handle()
290        });
291
292        client
293            .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
294            .output()
295    }
296
297    fn bool_scatter_or(
298        dim: usize,
299        tensor: BoolTensor<Self>,
300        indices: IntTensor<Self>,
301        value: BoolTensor<Self>,
302    ) -> BoolTensor<Self> {
303        let client = tensor.client.clone();
304        let desc = ScatterOpIr::create(
305            tensor.into_ir(),
306            dim,
307            indices.into_ir(),
308            value.into_ir(),
309            IndexingUpdateOp::Add,
310            || client.create_empty_handle(),
311        );
312
313        client
314            .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
315            .output()
316    }
317
318    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
319        let dtype = lhs.dtype;
320        let client = lhs.client.clone();
321        let rhs = rhs.into();
322        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, dtype, || {
323            client.create_empty_handle()
324        });
325
326        client
327            .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
328            .output()
329    }
330
331    fn bool_select(
332        tensor: BoolTensor<Self>,
333        dim: usize,
334        indices: IntTensor<Self>,
335    ) -> BoolTensor<Self> {
336        let client = tensor.client.clone();
337        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
338            client.create_empty_handle()
339        });
340
341        client
342            .register(OperationIr::BaseBool(BaseOperationIr::Select(desc)))
343            .output()
344    }
345
346    fn bool_select_or(
347        tensor: BoolTensor<Self>,
348        dim: usize,
349        indices: IntTensor<Self>,
350        value: BoolTensor<Self>,
351    ) -> BoolTensor<Self> {
352        let client = tensor.client.clone();
353        let desc = SelectAssignOpIr::create(
354            tensor.into_ir(),
355            dim,
356            indices.into_ir(),
357            value.into_ir(),
358            IndexingUpdateOp::Add,
359            || client.create_empty_handle(),
360        );
361
362        client
363            .register(OperationIr::BaseBool(BaseOperationIr::SelectAssign(desc)))
364            .output()
365    }
366}