1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 kernel::{self, AndOp, OrOp},
5};
6use burn_backend::{
7 ExecutionError, Slice,
8 ops::BoolTensorOps,
9 tensor::{BoolTensor, Device, FloatTensor, IntTensor},
10};
11use burn_backend::{Shape, TensorData, tensor::BoolElem};
12use cubecl::prelude::InputScalar;
13use std::ops::Range;
14
15use super::{expand, numeric, permute, unfold};
16
17impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
18where
19 R: CubeRuntime,
20 F: FloatElement,
21 I: IntElement,
22 BT: BoolElement,
23{
24 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
25 super::empty(shape, device, BT::dtype())
26 }
27
28 fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
29 numeric::zeros(device.clone(), shape, BT::dtype())
30 }
31
32 fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
33 numeric::ones(device.clone(), shape, BT::dtype())
34 }
35
36 async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
37 super::into_data(tensor).await
38 }
39
40 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
41 if data.dtype != BT::dtype() {
42 unimplemented!("Unsupported dtype for `bool_from_data`")
43 }
44 super::from_data(data, device)
45 }
46
47 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
48 kernel::bool_cast::<R, BT, I>(tensor)
49 }
50
51 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
52 tensor.device.clone()
53 }
54
55 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
56 super::to_device(tensor, device)
57 }
58
59 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
60 super::reshape(tensor, shape)
61 }
62
63 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
64 let all_steps_one = slices.iter().all(|info| info.step == 1);
66
67 if all_steps_one {
68 let simple_ranges: Vec<Range<usize>> = slices
70 .iter()
71 .enumerate()
72 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
73 .collect();
74
75 kernel::slice(tensor, &simple_ranges)
76 } else {
77 kernel::slice_with_steps(tensor, slices)
79 }
80 }
81
82 fn bool_slice_assign(
83 tensor: BoolTensor<Self>,
84 ranges: &[Slice],
85 value: BoolTensor<Self>,
86 ) -> BoolTensor<Self> {
87 kernel::slice_assign(tensor, ranges, value)
88 }
89
90 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
91 kernel::equal(lhs, rhs, BT::dtype())
92 }
93
94 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
95 kernel::equal_elem(
96 tensor,
97 InputScalar::new(BT::false_val(), BT::dtype()),
98 BT::dtype(),
99 )
100 }
101
102 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
103 kernel::launch_binop::<R, AndOp>(lhs, rhs)
104 }
105
106 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
107 kernel::launch_binop::<R, OrOp>(lhs, rhs)
108 }
109
110 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
111 kernel::bool_cast::<R, BT, F>(tensor)
112 }
113
114 fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
115 tensor.strides.swap(dim1, dim2);
116 tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
117
118 tensor
119 }
120
121 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
122 kernel::repeat_dim(tensor, dim, times)
123 }
124
125 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
126 permute(tensor, axes)
127 }
128
129 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
130 expand(tensor, shape)
131 }
132
133 fn bool_select(
134 tensor: BoolTensor<Self>,
135 dim: usize,
136 indices: IntTensor<Self>,
137 ) -> BoolTensor<Self> {
138 kernel::select(tensor, dim, indices)
139 }
140
141 fn bool_select_or(
142 tensor: BoolTensor<Self>,
143 dim: usize,
144 indices: IntTensor<Self>,
145 value: BoolTensor<Self>,
146 ) -> BoolTensor<Self> {
147 kernel::select_assign(tensor, dim, indices, value, true)
148 }
149
150 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
151 kernel::flip(tensor, axes, BT::dtype())
152 }
153
154 fn bool_unfold(
155 tensor: FloatTensor<Self>,
156 dim: usize,
157 size: usize,
158 step: usize,
159 ) -> FloatTensor<Self> {
160 unfold(tensor, dim, size, step)
161 }
162
163 fn bool_mask_where(
164 tensor: BoolTensor<Self>,
165 mask: BoolTensor<Self>,
166 value: BoolTensor<Self>,
167 ) -> BoolTensor<Self> {
168 kernel::mask_where_auto(tensor, mask, value, BT::dtype())
169 }
170
171 fn bool_mask_fill(
172 tensor: BoolTensor<Self>,
173 mask: BoolTensor<Self>,
174 value: BoolElem<Self>,
175 ) -> BoolTensor<Self> {
176 let dtype = tensor.dtype;
177 kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), dtype)
178 }
179
180 fn bool_gather(
181 dim: usize,
182 tensor: BoolTensor<Self>,
183 indices: IntTensor<Self>,
184 ) -> BoolTensor<Self> {
185 kernel::gather(dim, tensor, indices)
186 }
187
188 fn bool_scatter_or(
189 dim: usize,
190 tensor: BoolTensor<Self>,
191 indices: IntTensor<Self>,
192 value: BoolTensor<Self>,
193 ) -> BoolTensor<Self> {
194 kernel::scatter(dim, tensor, indices, value, true)
195 }
196
197 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
198 let dtype = lhs.dtype;
199 kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), dtype)
200 }
201}