burn_cubecl/ops/
bool_tensor.rs

1use crate::{
2    CubeBackend, CubeRuntime, FloatElement, IntElement,
3    element::BoolElement,
4    kernel::{self, AndOp, OrOp},
5};
6use burn_backend::{
7    ExecutionError, Slice,
8    ops::BoolTensorOps,
9    tensor::{BoolTensor, Device, FloatTensor, IntTensor},
10};
11use burn_backend::{Shape, TensorData, tensor::BoolElem};
12use cubecl::prelude::InputScalar;
13use std::ops::Range;
14
15use super::{expand, numeric, permute, unfold};
16
17impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
18where
19    R: CubeRuntime,
20    F: FloatElement,
21    I: IntElement,
22    BT: BoolElement,
23{
24    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
25        super::empty(shape, device, BT::dtype())
26    }
27
28    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
29        numeric::zeros(device.clone(), shape, BT::dtype())
30    }
31
32    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
33        numeric::ones(device.clone(), shape, BT::dtype())
34    }
35
36    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
37        super::into_data(tensor).await
38    }
39
40    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
41        if data.dtype != BT::dtype() {
42            unimplemented!("Unsupported dtype for `bool_from_data`")
43        }
44        super::from_data(data, device)
45    }
46
47    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
48        kernel::bool_cast::<R, BT, I>(tensor)
49    }
50
51    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
52        tensor.device.clone()
53    }
54
55    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
56        super::to_device(tensor, device)
57    }
58
59    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
60        super::reshape(tensor, shape)
61    }
62
63    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
64        // Check if all steps are 1
65        let all_steps_one = slices.iter().all(|info| info.step == 1);
66
67        if all_steps_one {
68            // Use optimized slice for step=1
69            let simple_ranges: Vec<Range<usize>> = slices
70                .iter()
71                .enumerate()
72                .map(|(i, slice)| slice.to_range(tensor.shape[i]))
73                .collect();
74
75            kernel::slice(tensor, &simple_ranges)
76        } else {
77            // Use slice with steps kernel
78            kernel::slice_with_steps(tensor, slices)
79        }
80    }
81
82    fn bool_slice_assign(
83        tensor: BoolTensor<Self>,
84        ranges: &[Slice],
85        value: BoolTensor<Self>,
86    ) -> BoolTensor<Self> {
87        kernel::slice_assign(tensor, ranges, value)
88    }
89
90    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
91        kernel::equal(lhs, rhs, BT::dtype())
92    }
93
94    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
95        kernel::equal_elem(
96            tensor,
97            InputScalar::new(BT::false_val(), BT::dtype()),
98            BT::dtype(),
99        )
100    }
101
102    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
103        kernel::launch_binop::<R, AndOp>(lhs, rhs)
104    }
105
106    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
107        kernel::launch_binop::<R, OrOp>(lhs, rhs)
108    }
109
110    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
111        kernel::bool_cast::<R, BT, F>(tensor)
112    }
113
114    fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
115        tensor.strides.swap(dim1, dim2);
116        tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
117
118        tensor
119    }
120
121    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
122        kernel::repeat_dim(tensor, dim, times)
123    }
124
125    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
126        permute(tensor, axes)
127    }
128
129    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
130        expand(tensor, shape)
131    }
132
133    fn bool_select(
134        tensor: BoolTensor<Self>,
135        dim: usize,
136        indices: IntTensor<Self>,
137    ) -> BoolTensor<Self> {
138        kernel::select(tensor, dim, indices)
139    }
140
141    fn bool_select_or(
142        tensor: BoolTensor<Self>,
143        dim: usize,
144        indices: IntTensor<Self>,
145        value: BoolTensor<Self>,
146    ) -> BoolTensor<Self> {
147        kernel::select_assign(tensor, dim, indices, value, true)
148    }
149
150    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
151        kernel::flip(tensor, axes, BT::dtype())
152    }
153
154    fn bool_unfold(
155        tensor: FloatTensor<Self>,
156        dim: usize,
157        size: usize,
158        step: usize,
159    ) -> FloatTensor<Self> {
160        unfold(tensor, dim, size, step)
161    }
162
163    fn bool_mask_where(
164        tensor: BoolTensor<Self>,
165        mask: BoolTensor<Self>,
166        value: BoolTensor<Self>,
167    ) -> BoolTensor<Self> {
168        kernel::mask_where_auto(tensor, mask, value, BT::dtype())
169    }
170
171    fn bool_mask_fill(
172        tensor: BoolTensor<Self>,
173        mask: BoolTensor<Self>,
174        value: BoolElem<Self>,
175    ) -> BoolTensor<Self> {
176        let dtype = tensor.dtype;
177        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), dtype)
178    }
179
180    fn bool_gather(
181        dim: usize,
182        tensor: BoolTensor<Self>,
183        indices: IntTensor<Self>,
184    ) -> BoolTensor<Self> {
185        kernel::gather(dim, tensor, indices)
186    }
187
188    fn bool_scatter_or(
189        dim: usize,
190        tensor: BoolTensor<Self>,
191        indices: IntTensor<Self>,
192        value: BoolTensor<Self>,
193    ) -> BoolTensor<Self> {
194        kernel::scatter(dim, tensor, indices, value, true)
195    }
196
197    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
198        let dtype = lhs.dtype;
199        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), dtype)
200    }
201}