1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 kernel::{self, AndOp, OrOp},
5};
6use burn_backend::{
7 ExecutionError, Slice,
8 ops::BoolTensorOps,
9 tensor::{BoolTensor, Device, FloatTensor, IntTensor},
10};
11use burn_backend::{Scalar, Shape, TensorData};
12use burn_std::{BoolDType, BoolStore, DType, FloatDType, IntDType};
13use cubecl::prelude::InputScalar;
14use std::ops::Range;
15
16use super::{expand, numeric, permute, unfold};
17
18impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
19where
20 R: CubeRuntime,
21 F: FloatElement,
22 I: IntElement,
23 BT: BoolElement,
24{
25 fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
26 super::empty(shape, device, dtype.into())
27 }
28
29 fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
30 numeric::zeros(device.clone(), shape, dtype.into())
31 }
32
33 fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
34 numeric::ones(device.clone(), shape, dtype.into())
35 }
36
37 async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
38 super::into_data(tensor).await
39 }
40
41 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
42 if !matches!(
43 data.dtype,
44 DType::Bool(BoolStore::U8) | DType::Bool(BoolStore::U32)
45 ) {
46 unimplemented!("Unsupported dtype for `bool_from_data` {:?}", data.dtype);
47 }
48 super::from_data(data, device)
49 }
50
51 fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
52 kernel::bool_cast(tensor, out_dtype.into())
53 }
54
55 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
56 tensor.device.clone()
57 }
58
59 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
60 super::to_device(tensor, device)
61 }
62
63 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
64 super::reshape(tensor, shape)
65 }
66
67 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
68 let all_steps_one = slices.iter().all(|info| info.step == 1);
70
71 if all_steps_one {
72 let simple_ranges: Vec<Range<usize>> = slices
74 .iter()
75 .enumerate()
76 .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
77 .collect();
78
79 kernel::slice(tensor, &simple_ranges)
80 } else {
81 kernel::slice_with_steps(tensor, slices)
83 }
84 }
85
86 fn bool_slice_assign(
87 tensor: BoolTensor<Self>,
88 ranges: &[Slice],
89 value: BoolTensor<Self>,
90 ) -> BoolTensor<Self> {
91 kernel::slice_assign(tensor, ranges, value)
92 }
93
94 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
95 let dtype = lhs.dtype;
96 kernel::equal(lhs, rhs, dtype)
97 }
98
99 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
100 let dtype = tensor.dtype;
101 let scalar = match dtype {
102 DType::Bool(BoolStore::U32) => InputScalar::new(u32::false_val(), dtype),
103 DType::Bool(BoolStore::U8) => InputScalar::new(u8::false_val(), dtype),
104 other => unimplemented!("Unsupported dtype for `bool_from_data` {other:?}"),
105 };
106 kernel::equal_elem(tensor, scalar, dtype)
107 }
108
109 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
110 kernel::launch_binop::<R, AndOp>(lhs, rhs)
111 }
112
113 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
114 kernel::launch_binop::<R, OrOp>(lhs, rhs)
115 }
116
117 fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
118 kernel::bool_cast(tensor, out_dtype.into())
119 }
120
121 fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
122 tensor.meta.swap(dim1, dim2);
123
124 tensor
125 }
126
127 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
128 kernel::repeat_dim(tensor, dim, times)
129 }
130
131 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
132 permute(tensor, axes)
133 }
134
135 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
136 expand(tensor, shape)
137 }
138
139 fn bool_select(
140 tensor: BoolTensor<Self>,
141 dim: usize,
142 indices: IntTensor<Self>,
143 ) -> BoolTensor<Self> {
144 kernel::select(tensor, dim, indices)
145 }
146
147 fn bool_select_or(
148 tensor: BoolTensor<Self>,
149 dim: usize,
150 indices: IntTensor<Self>,
151 value: BoolTensor<Self>,
152 ) -> BoolTensor<Self> {
153 kernel::select_assign(tensor, dim, indices, value, true)
154 }
155
156 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
157 let dtype = tensor.dtype;
158 kernel::flip(tensor, axes, dtype)
159 }
160
161 fn bool_unfold(
162 tensor: FloatTensor<Self>,
163 dim: usize,
164 size: usize,
165 step: usize,
166 ) -> FloatTensor<Self> {
167 unfold(tensor, dim, size, step)
168 }
169
170 fn bool_mask_where(
171 tensor: BoolTensor<Self>,
172 mask: BoolTensor<Self>,
173 value: BoolTensor<Self>,
174 ) -> BoolTensor<Self> {
175 let dtype = tensor.dtype;
176 kernel::mask_where_auto(tensor, mask, value, dtype)
177 }
178
179 fn bool_mask_fill(
180 tensor: BoolTensor<Self>,
181 mask: BoolTensor<Self>,
182 value: Scalar,
183 ) -> BoolTensor<Self> {
184 let dtype = tensor.dtype;
185 kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), dtype)
186 }
187
188 fn bool_gather(
189 dim: usize,
190 tensor: BoolTensor<Self>,
191 indices: IntTensor<Self>,
192 ) -> BoolTensor<Self> {
193 kernel::gather(dim, tensor, indices)
194 }
195
196 fn bool_scatter_or(
197 dim: usize,
198 tensor: BoolTensor<Self>,
199 indices: IntTensor<Self>,
200 value: BoolTensor<Self>,
201 ) -> BoolTensor<Self> {
202 kernel::scatter(dim, tensor, indices, value, true)
203 }
204
205 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
206 let dtype = lhs.dtype;
207 kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), dtype)
208 }
209}