1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3
4use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
5use burn_backend::ops::BoolTensorOps;
6use burn_backend::tensor::{
7 BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
8};
9use burn_backend::{Element, Shape, Slice, TensorData};
10use burn_ir::{
11 BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
12 GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
13 PermuteOpIr, RepeatDimOpIr, ScalarIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr,
14 SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
15};
16
17impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
18 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
19 let client = get_client::<R>(device);
20 let desc =
21 CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
22
23 client
24 .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
25 .output()
26 }
27
28 fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
29 let client = get_client::<R>(device);
30 let desc =
31 CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
32
33 client
34 .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
35 .output()
36 }
37
38 fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
39 let client = get_client::<R>(device);
40 let desc =
41 CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
42
43 client
44 .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
45 .output()
46 }
47
48 async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
49 tensor.into_data().await
50 }
51
52 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
53 let client = get_client::<R>(device);
54 let out = client.register_tensor_data(data);
55 let desc = InitOperationIr {
56 out: out.to_ir_out(),
57 };
58
59 client.register_op(OperationIr::Init(desc));
61
62 out
63 }
64
65 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
66 let client = tensor.client.clone();
67 let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {
68 client.create_empty_handle()
69 });
70
71 client
72 .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
73 .output()
74 }
75
76 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
77 let client = tensor.client.clone();
78 let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {
79 client.create_empty_handle()
80 });
81
82 client
83 .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
84 .output()
85 }
86
87 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
88 tensor.client.device()
89 }
90
91 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
92 if &tensor.client.device() == device {
93 return tensor;
94 }
95 R::change_client_backend(tensor, device)
96 }
97
98 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
99 let client = tensor.client.clone();
100 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
101
102 client
103 .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
104 .output()
105 }
106
107 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
108 let client = tensor.client.clone();
109 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
110 client.create_empty_handle()
111 });
112
113 client
114 .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
115 .output()
116 }
117
118 fn bool_slice_assign(
119 tensor: BoolTensor<Self>,
120 slices: &[burn_backend::Slice],
121 value: BoolTensor<Self>,
122 ) -> BoolTensor<Self> {
123 let client = tensor.client.clone();
124 let desc =
125 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
126 client.create_empty_handle()
127 });
128
129 client
130 .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
131 .output()
132 }
133
134 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
135 let client = lhs.client.clone();
136 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
137 client.create_empty_handle()
138 });
139
140 client
141 .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
142 .output()
143 }
144
145 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
146 let client = tensor.client.clone();
147 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
148
149 client
150 .register(OperationIr::Bool(BoolOperationIr::Not(desc)))
151 .output()
152 }
153
154 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
155 let client = lhs.client.clone();
156 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
157 client.create_empty_handle()
158 });
159
160 client
161 .register(OperationIr::Bool(BoolOperationIr::And(desc)))
162 .output()
163 }
164
165 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
166 let client = lhs.client.clone();
167 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
168 client.create_empty_handle()
169 });
170
171 client
172 .register(OperationIr::Bool(BoolOperationIr::Or(desc)))
173 .output()
174 }
175
176 fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
177 let client = tensor.client.clone();
178 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
179 client.create_empty_handle()
180 });
181
182 client
183 .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
184 .output()
185 }
186
187 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
188 let client = tensor.client.clone();
189 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
190 client.create_empty_handle()
191 });
192
193 client
194 .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
195 .output()
196 }
197
198 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
199 let client = tensor.client.clone();
200 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
201 client.create_empty_handle()
202 });
203
204 client
205 .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
206 .output()
207 }
208
209 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
210 let client = tensor.client.clone();
211 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
212
213 client
214 .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
215 .output()
216 }
217
218 fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
219 let client = tensors.first().unwrap().client.clone();
220 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
221 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
222
223 client
224 .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
225 .output()
226 }
227
228 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
229 let client = tensor.client.clone();
230 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
231 client.create_empty_handle()
232 });
233
234 client
235 .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
236 .output()
237 }
238
239 fn bool_unfold(
240 tensor: BoolTensor<Self>,
241 dim: usize,
242 size: usize,
243 step: usize,
244 ) -> BoolTensor<Self> {
245 let client = tensor.client.clone();
246 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
247 client.create_empty_handle()
248 });
249
250 client
251 .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
252 .output()
253 }
254
255 fn bool_mask_where(
256 tensor: BoolTensor<Self>,
257 mask: BoolTensor<Self>,
258 value: BoolTensor<Self>,
259 ) -> BoolTensor<Self> {
260 let client = tensor.client.clone();
261 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
262 client.create_empty_handle()
263 });
264
265 client
266 .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
267 .output()
268 }
269
270 fn bool_mask_fill(
271 tensor: BoolTensor<Self>,
272 mask: BoolTensor<Self>,
273 value: burn_backend::tensor::BoolElem<Self>,
274 ) -> BoolTensor<Self> {
275 let client = tensor.client.clone();
276 let value = ScalarIr::with_dtype(value, &tensor.dtype);
277 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
278 client.create_empty_handle()
279 });
280
281 client
282 .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
283 .output()
284 }
285
286 fn bool_gather(
287 dim: usize,
288 tensor: BoolTensor<Self>,
289 indices: IntTensor<Self>,
290 ) -> BoolTensor<Self> {
291 let client = tensor.client.clone();
292 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
293 client.create_empty_handle()
294 });
295
296 client
297 .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
298 .output()
299 }
300
301 fn bool_scatter_or(
302 dim: usize,
303 tensor: BoolTensor<Self>,
304 indices: IntTensor<Self>,
305 value: BoolTensor<Self>,
306 ) -> BoolTensor<Self> {
307 let client = tensor.client.clone();
308 let desc = ScatterOpIr::create(
309 tensor.into_ir(),
310 dim,
311 indices.into_ir(),
312 value.into_ir(),
313 IndexingUpdateOp::Add,
314 || client.create_empty_handle(),
315 );
316
317 client
318 .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
319 .output()
320 }
321
322 fn bool_equal_elem(
323 lhs: BoolTensor<Self>,
324 rhs: burn_backend::tensor::BoolElem<Self>,
325 ) -> BoolTensor<Self> {
326 let client = lhs.client.clone();
327 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
328 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
329 client.create_empty_handle()
330 });
331
332 client
333 .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
334 .output()
335 }
336}