burn_cubecl/ops/
bool_ops.rs1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 kernel::{self, AndOp, OrOp},
5};
6use burn_tensor::ops::{BoolTensor, BoolTensorOps, Device, FloatTensor, IntTensor};
7use burn_tensor::{Shape, TensorData};
8use std::ops::Range;
9
10use super::{expand, permute};
11
12impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
13where
14 R: CubeRuntime,
15 F: FloatElement,
16 I: IntElement,
17 BT: BoolElement,
18{
19 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
20 super::empty::<R, BT>(shape, device)
21 }
22
23 async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
24 super::into_data::<R, BT>(tensor).await
25 }
26
27 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
28 if data.dtype != BT::dtype() {
29 unimplemented!("Unsupported dtype for `bool_from_data`")
30 }
31 super::from_data::<R>(data, device)
32 }
33
34 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
35 kernel::bool_cast::<R, BT, I>(tensor)
36 }
37
38 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
39 tensor.device.clone()
40 }
41
42 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
43 super::to_device(tensor, device)
44 }
45
46 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
47 super::reshape(tensor, shape)
48 }
49
50 fn bool_slice(tensor: BoolTensor<Self>, ranges: &[Range<usize>]) -> BoolTensor<Self> {
51 kernel::slice::<R, BT>(tensor, ranges)
52 }
53
54 fn bool_slice_assign(
55 tensor: BoolTensor<Self>,
56 ranges: &[Range<usize>],
57 value: BoolTensor<Self>,
58 ) -> BoolTensor<Self> {
59 kernel::slice_assign::<R, BT>(tensor, ranges, value)
60 }
61
62 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
63 kernel::equal::<R, BT, BT>(lhs, rhs)
64 }
65
66 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
67 kernel::equal_elem::<R, BT, BT>(tensor, BT::false_val())
68 }
69
70 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
71 kernel::launch_binop::<R, BT, AndOp>(lhs, rhs)
72 }
73
74 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
75 kernel::launch_binop::<R, BT, OrOp>(lhs, rhs)
76 }
77
78 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
79 kernel::bool_cast::<R, BT, F>(tensor)
80 }
81
82 fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
83 tensor.strides.swap(dim1, dim2);
84 tensor.shape.dims.swap(dim1, dim2);
85
86 tensor
87 }
88
89 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
90 kernel::repeat_dim::<R, BT>(tensor, dim, times)
91 }
92
93 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
94 permute(tensor, axes)
95 }
96
97 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
98 expand(tensor, shape)
99 }
100
101 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
102 kernel::flip::<R, BT, BT>(tensor, axes)
103 }
104}