Skip to main content

burn_cubecl/ops/
bool_tensor.rs

1use crate::{
2    CubeBackend, CubeRuntime, FloatElement, IntElement,
3    element::BoolElement,
4    kernel::{self, AndOp, OrOp},
5};
6use burn_backend::{
7    ExecutionError, Slice,
8    ops::BoolTensorOps,
9    tensor::{BoolTensor, Device, FloatTensor, IntTensor},
10};
11use burn_backend::{Scalar, Shape, TensorData};
12use burn_std::{BoolDType, BoolStore, DType, FloatDType, IntDType};
13use cubecl::prelude::InputScalar;
14use std::ops::Range;
15
16use super::{expand, numeric, permute, unfold};
17
18impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
19where
20    R: CubeRuntime,
21    F: FloatElement,
22    I: IntElement,
23    BT: BoolElement,
24{
25    fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
26        super::empty(shape, device, dtype.into())
27    }
28
29    fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
30        numeric::zeros(device.clone(), shape, dtype.into())
31    }
32
33    fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
34        numeric::ones(device.clone(), shape, dtype.into())
35    }
36
37    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
38        super::into_data(tensor).await
39    }
40
41    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
42        if !matches!(
43            data.dtype,
44            DType::Bool(BoolStore::U8) | DType::Bool(BoolStore::U32)
45        ) {
46            unimplemented!("Unsupported dtype for `bool_from_data` {:?}", data.dtype);
47        }
48        super::from_data(data, device)
49    }
50
51    fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
52        kernel::bool_cast(tensor, out_dtype.into())
53    }
54
55    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
56        tensor.device.clone()
57    }
58
59    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
60        super::to_device(tensor, device)
61    }
62
63    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
64        super::reshape(tensor, shape)
65    }
66
67    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
68        // Check if all steps are 1
69        let all_steps_one = slices.iter().all(|info| info.step == 1);
70
71        if all_steps_one {
72            // Use optimized slice for step=1
73            let simple_ranges: Vec<Range<usize>> = slices
74                .iter()
75                .enumerate()
76                .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
77                .collect();
78
79            kernel::slice(tensor, &simple_ranges)
80        } else {
81            // Use slice with steps kernel
82            kernel::slice_with_steps(tensor, slices)
83        }
84    }
85
86    fn bool_slice_assign(
87        tensor: BoolTensor<Self>,
88        ranges: &[Slice],
89        value: BoolTensor<Self>,
90    ) -> BoolTensor<Self> {
91        kernel::slice_assign(tensor, ranges, value)
92    }
93
94    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
95        let dtype = lhs.dtype;
96        kernel::equal(lhs, rhs, dtype)
97    }
98
99    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
100        let dtype = tensor.dtype;
101        let scalar = match dtype {
102            DType::Bool(BoolStore::U32) => InputScalar::new(u32::false_val(), dtype),
103            DType::Bool(BoolStore::U8) => InputScalar::new(u8::false_val(), dtype),
104            other => unimplemented!("Unsupported dtype for `bool_from_data` {other:?}"),
105        };
106        kernel::equal_elem(tensor, scalar, dtype)
107    }
108
109    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
110        kernel::launch_binop::<R, AndOp>(lhs, rhs)
111    }
112
113    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
114        kernel::launch_binop::<R, OrOp>(lhs, rhs)
115    }
116
117    fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
118        kernel::bool_cast(tensor, out_dtype.into())
119    }
120
121    fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
122        tensor.meta.swap(dim1, dim2);
123
124        tensor
125    }
126
127    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
128        kernel::repeat_dim(tensor, dim, times)
129    }
130
131    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
132        permute(tensor, axes)
133    }
134
135    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
136        expand(tensor, shape)
137    }
138
139    fn bool_select(
140        tensor: BoolTensor<Self>,
141        dim: usize,
142        indices: IntTensor<Self>,
143    ) -> BoolTensor<Self> {
144        kernel::select(tensor, dim, indices)
145    }
146
147    fn bool_select_or(
148        tensor: BoolTensor<Self>,
149        dim: usize,
150        indices: IntTensor<Self>,
151        value: BoolTensor<Self>,
152    ) -> BoolTensor<Self> {
153        kernel::select_assign(tensor, dim, indices, value, true)
154    }
155
156    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
157        let dtype = tensor.dtype;
158        kernel::flip(tensor, axes, dtype)
159    }
160
161    fn bool_unfold(
162        tensor: FloatTensor<Self>,
163        dim: usize,
164        size: usize,
165        step: usize,
166    ) -> FloatTensor<Self> {
167        unfold(tensor, dim, size, step)
168    }
169
170    fn bool_mask_where(
171        tensor: BoolTensor<Self>,
172        mask: BoolTensor<Self>,
173        value: BoolTensor<Self>,
174    ) -> BoolTensor<Self> {
175        let dtype = tensor.dtype;
176        kernel::mask_where_auto(tensor, mask, value, dtype)
177    }
178
179    fn bool_mask_fill(
180        tensor: BoolTensor<Self>,
181        mask: BoolTensor<Self>,
182        value: Scalar,
183    ) -> BoolTensor<Self> {
184        let dtype = tensor.dtype;
185        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), dtype)
186    }
187
188    fn bool_gather(
189        dim: usize,
190        tensor: BoolTensor<Self>,
191        indices: IntTensor<Self>,
192    ) -> BoolTensor<Self> {
193        kernel::gather(dim, tensor, indices)
194    }
195
196    fn bool_scatter_or(
197        dim: usize,
198        tensor: BoolTensor<Self>,
199        indices: IntTensor<Self>,
200        value: BoolTensor<Self>,
201    ) -> BoolTensor<Self> {
202        kernel::scatter(dim, tensor, indices, value, true)
203    }
204
205    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
206        let dtype = lhs.dtype;
207        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), dtype)
208    }
209}