1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3use burn_std::{BoolDType, FloatDType, IntDType};
4
5use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
6use burn_backend::ops::BoolTensorOps;
7use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor};
8use burn_backend::{Scalar, Shape, Slice, TensorData};
9use burn_ir::{
10 BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
11 GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
12 PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr,
13 SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
14};
15
16impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
17 fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
18 let client = get_client::<R>(device);
19 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
20
21 client
22 .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
23 .output()
24 }
25
26 fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
27 let client = get_client::<R>(device);
28 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
29
30 client
31 .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
32 .output()
33 }
34
35 fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
36 let client = get_client::<R>(device);
37 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
38
39 client
40 .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
41 .output()
42 }
43
44 async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
45 tensor.into_data().await
46 }
47
48 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
49 let client = get_client::<R>(device);
50 let out = client.register_tensor_data(data);
51 let desc = InitOperationIr {
52 out: out.to_ir_out(),
53 };
54
55 client.register_op(OperationIr::Init(desc));
57
58 out
59 }
60
61 fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
62 let client = tensor.client.clone();
63 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
64 client.create_empty_handle()
65 });
66
67 client
68 .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
69 .output()
70 }
71
72 fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
73 let client = tensor.client.clone();
74 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
75 client.create_empty_handle()
76 });
77
78 client
79 .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
80 .output()
81 }
82
83 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
84 tensor.client.device()
85 }
86
87 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
88 if &tensor.client.device() == device {
89 return tensor;
90 }
91 R::change_client_backend(tensor, device)
92 }
93
94 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
95 let client = tensor.client.clone();
96 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
97
98 client
99 .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
100 .output()
101 }
102
103 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
104 let client = tensor.client.clone();
105 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
106 client.create_empty_handle()
107 });
108
109 client
110 .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
111 .output()
112 }
113
114 fn bool_slice_assign(
115 tensor: BoolTensor<Self>,
116 slices: &[burn_backend::Slice],
117 value: BoolTensor<Self>,
118 ) -> BoolTensor<Self> {
119 let client = tensor.client.clone();
120 let desc =
121 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
122 client.create_empty_handle()
123 });
124
125 client
126 .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
127 .output()
128 }
129
130 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
131 let client = lhs.client.clone();
132 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
133 client.create_empty_handle()
134 });
135
136 client
137 .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
138 .output()
139 }
140
141 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
142 let client = tensor.client.clone();
143 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
144
145 client
146 .register(OperationIr::Bool(BoolOperationIr::Not(desc)))
147 .output()
148 }
149
150 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
151 let client = lhs.client.clone();
152 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
153 client.create_empty_handle()
154 });
155
156 client
157 .register(OperationIr::Bool(BoolOperationIr::And(desc)))
158 .output()
159 }
160
161 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
162 let client = lhs.client.clone();
163 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
164 client.create_empty_handle()
165 });
166
167 client
168 .register(OperationIr::Bool(BoolOperationIr::Or(desc)))
169 .output()
170 }
171
172 fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
173 let client = tensor.client.clone();
174 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
175 client.create_empty_handle()
176 });
177
178 client
179 .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
180 .output()
181 }
182
183 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
184 let client = tensor.client.clone();
185 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
186 client.create_empty_handle()
187 });
188
189 client
190 .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
191 .output()
192 }
193
194 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
195 let client = tensor.client.clone();
196 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
197 client.create_empty_handle()
198 });
199
200 client
201 .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
202 .output()
203 }
204
205 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
206 let client = tensor.client.clone();
207 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
208
209 client
210 .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
211 .output()
212 }
213
214 fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
215 let client = tensors.first().unwrap().client.clone();
216 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
217 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
218
219 client
220 .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
221 .output()
222 }
223
224 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
225 let client = tensor.client.clone();
226 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
227 client.create_empty_handle()
228 });
229
230 client
231 .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
232 .output()
233 }
234
235 fn bool_unfold(
236 tensor: BoolTensor<Self>,
237 dim: usize,
238 size: usize,
239 step: usize,
240 ) -> BoolTensor<Self> {
241 let client = tensor.client.clone();
242 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
243 client.create_empty_handle()
244 });
245
246 client
247 .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
248 .output()
249 }
250
251 fn bool_mask_where(
252 tensor: BoolTensor<Self>,
253 mask: BoolTensor<Self>,
254 value: BoolTensor<Self>,
255 ) -> BoolTensor<Self> {
256 let client = tensor.client.clone();
257 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
258 client.create_empty_handle()
259 });
260
261 client
262 .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
263 .output()
264 }
265
266 fn bool_mask_fill(
267 tensor: BoolTensor<Self>,
268 mask: BoolTensor<Self>,
269 value: Scalar,
270 ) -> BoolTensor<Self> {
271 let client = tensor.client.clone();
272 let value = value.into();
273 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
274 client.create_empty_handle()
275 });
276
277 client
278 .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
279 .output()
280 }
281
282 fn bool_gather(
283 dim: usize,
284 tensor: BoolTensor<Self>,
285 indices: IntTensor<Self>,
286 ) -> BoolTensor<Self> {
287 let client = tensor.client.clone();
288 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
289 client.create_empty_handle()
290 });
291
292 client
293 .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
294 .output()
295 }
296
297 fn bool_scatter_or(
298 dim: usize,
299 tensor: BoolTensor<Self>,
300 indices: IntTensor<Self>,
301 value: BoolTensor<Self>,
302 ) -> BoolTensor<Self> {
303 let client = tensor.client.clone();
304 let desc = ScatterOpIr::create(
305 tensor.into_ir(),
306 dim,
307 indices.into_ir(),
308 value.into_ir(),
309 IndexingUpdateOp::Add,
310 || client.create_empty_handle(),
311 );
312
313 client
314 .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
315 .output()
316 }
317
318 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
319 let dtype = lhs.dtype;
320 let client = lhs.client.clone();
321 let rhs = rhs.into();
322 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, dtype, || {
323 client.create_empty_handle()
324 });
325
326 client
327 .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
328 .output()
329 }
330
331 fn bool_select(
332 tensor: BoolTensor<Self>,
333 dim: usize,
334 indices: IntTensor<Self>,
335 ) -> BoolTensor<Self> {
336 let client = tensor.client.clone();
337 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
338 client.create_empty_handle()
339 });
340
341 client
342 .register(OperationIr::BaseBool(BaseOperationIr::Select(desc)))
343 .output()
344 }
345
346 fn bool_select_or(
347 tensor: BoolTensor<Self>,
348 dim: usize,
349 indices: IntTensor<Self>,
350 value: BoolTensor<Self>,
351 ) -> BoolTensor<Self> {
352 let client = tensor.client.clone();
353 let desc = SelectAssignOpIr::create(
354 tensor.into_ir(),
355 dim,
356 indices.into_ir(),
357 value.into_ir(),
358 IndexingUpdateOp::Add,
359 || client.create_empty_handle(),
360 );
361
362 client
363 .register(OperationIr::BaseBool(BaseOperationIr::SelectAssign(desc)))
364 .output()
365 }
366}