burn_cubecl/ops/
bool_ops.rs

1use crate::{
2    CubeBackend, CubeRuntime, FloatElement, IntElement,
3    element::BoolElement,
4    kernel::{self, AndOp, OrOp},
5};
6use burn_tensor::ops::{BoolTensor, BoolTensorOps, Device, FloatTensor, IntTensor};
7use burn_tensor::{Shape, TensorData};
8use std::ops::Range;
9
10use super::{expand, numeric, permute, unfold};
11
12impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
13where
14    R: CubeRuntime,
15    F: FloatElement,
16    I: IntElement,
17    BT: BoolElement,
18{
19    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
20        super::empty::<R, BT>(shape, device)
21    }
22
23    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
24        numeric::zeros::<R, BT>(shape, device)
25    }
26
27    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
28        numeric::ones::<R, BT>(shape, device)
29    }
30
31    async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
32        super::into_data::<R, BT>(tensor).await
33    }
34
35    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
36        if data.dtype != BT::dtype() {
37            unimplemented!("Unsupported dtype for `bool_from_data`")
38        }
39        super::from_data::<R>(data, device)
40    }
41
42    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
43        kernel::bool_cast::<R, BT, I>(tensor)
44    }
45
46    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
47        tensor.device.clone()
48    }
49
50    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
51        super::to_device(tensor, device)
52    }
53
54    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
55        super::reshape(tensor, shape)
56    }
57
58    fn bool_slice(tensor: BoolTensor<Self>, slices: &[burn_tensor::Slice]) -> BoolTensor<Self> {
59        // Check if all steps are 1
60        let all_steps_one = slices.iter().all(|info| info.step == 1);
61
62        if all_steps_one {
63            // Use optimized slice for step=1
64            let simple_ranges: Vec<Range<usize>> = slices
65                .iter()
66                .enumerate()
67                .map(|(i, slice)| slice.to_range(tensor.shape[i]))
68                .collect();
69
70            kernel::slice::<R, BT>(tensor, &simple_ranges)
71        } else {
72            // Use slice with steps kernel
73            kernel::slice_with_steps::<R, BT>(tensor, slices)
74        }
75    }
76
77    fn bool_slice_assign(
78        tensor: BoolTensor<Self>,
79        ranges: &[burn_tensor::Slice],
80        value: BoolTensor<Self>,
81    ) -> BoolTensor<Self> {
82        kernel::slice_assign::<R, BT>(tensor, ranges, value)
83    }
84
85    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
86        kernel::equal::<R, BT, BT>(lhs, rhs)
87    }
88
89    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
90        kernel::equal_elem::<R, BT, BT>(tensor, BT::false_val())
91    }
92
93    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
94        kernel::launch_binop::<R, BT, AndOp>(lhs, rhs)
95    }
96
97    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
98        kernel::launch_binop::<R, BT, OrOp>(lhs, rhs)
99    }
100
101    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
102        kernel::bool_cast::<R, BT, F>(tensor)
103    }
104
105    fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
106        tensor.strides.swap(dim1, dim2);
107        tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
108
109        tensor
110    }
111
112    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
113        kernel::repeat_dim::<R, BT>(tensor, dim, times)
114    }
115
116    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
117        permute(tensor, axes)
118    }
119
120    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
121        expand(tensor, shape)
122    }
123
124    fn bool_select(
125        tensor: BoolTensor<Self>,
126        dim: usize,
127        indices: IntTensor<Self>,
128    ) -> BoolTensor<Self> {
129        kernel::select::<R, BT, I>(tensor, dim, indices)
130    }
131
132    fn bool_select_assign(
133        tensor: BoolTensor<Self>,
134        dim: usize,
135        indices: IntTensor<Self>,
136        value: BoolTensor<Self>,
137    ) -> BoolTensor<Self> {
138        kernel::select_assign::<R, BT, I>(tensor, dim, indices, value, true)
139    }
140
141    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
142        kernel::flip::<R, BT, BT>(tensor, axes)
143    }
144
145    fn bool_unfold(
146        tensor: FloatTensor<Self>,
147        dim: usize,
148        size: usize,
149        step: usize,
150    ) -> FloatTensor<Self> {
151        unfold(tensor, dim, size, step)
152    }
153}