1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 kernel::{self, AndOp, OrOp},
5};
6use burn_tensor::ops::{BoolTensor, BoolTensorOps, Device, FloatTensor, IntTensor};
7use burn_tensor::{Shape, TensorData};
8use std::ops::Range;
9
10use super::{expand, numeric, permute, unfold};
11
12impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
13where
14 R: CubeRuntime,
15 F: FloatElement,
16 I: IntElement,
17 BT: BoolElement,
18{
19 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
20 super::empty::<R, BT>(shape, device)
21 }
22
23 fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
24 numeric::zeros::<R, BT>(shape, device)
25 }
26
27 fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
28 numeric::ones::<R, BT>(shape, device)
29 }
30
31 async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
32 super::into_data::<R, BT>(tensor).await
33 }
34
35 fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
36 if data.dtype != BT::dtype() {
37 unimplemented!("Unsupported dtype for `bool_from_data`")
38 }
39 super::from_data::<R>(data, device)
40 }
41
42 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
43 kernel::bool_cast::<R, BT, I>(tensor)
44 }
45
46 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
47 tensor.device.clone()
48 }
49
50 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
51 super::to_device(tensor, device)
52 }
53
54 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
55 super::reshape(tensor, shape)
56 }
57
58 fn bool_slice(tensor: BoolTensor<Self>, slices: &[burn_tensor::Slice]) -> BoolTensor<Self> {
59 let all_steps_one = slices.iter().all(|info| info.step == 1);
61
62 if all_steps_one {
63 let simple_ranges: Vec<Range<usize>> = slices
65 .iter()
66 .enumerate()
67 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
68 .collect();
69
70 kernel::slice::<R, BT>(tensor, &simple_ranges)
71 } else {
72 kernel::slice_with_steps::<R, BT>(tensor, slices)
74 }
75 }
76
77 fn bool_slice_assign(
78 tensor: BoolTensor<Self>,
79 ranges: &[burn_tensor::Slice],
80 value: BoolTensor<Self>,
81 ) -> BoolTensor<Self> {
82 kernel::slice_assign::<R, BT>(tensor, ranges, value)
83 }
84
85 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
86 kernel::equal::<R, BT, BT>(lhs, rhs)
87 }
88
89 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
90 kernel::equal_elem::<R, BT, BT>(tensor, BT::false_val())
91 }
92
93 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
94 kernel::launch_binop::<R, BT, AndOp>(lhs, rhs)
95 }
96
97 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
98 kernel::launch_binop::<R, BT, OrOp>(lhs, rhs)
99 }
100
101 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
102 kernel::bool_cast::<R, BT, F>(tensor)
103 }
104
105 fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
106 tensor.strides.swap(dim1, dim2);
107 tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
108
109 tensor
110 }
111
112 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
113 kernel::repeat_dim::<R, BT>(tensor, dim, times)
114 }
115
116 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
117 permute(tensor, axes)
118 }
119
120 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
121 expand(tensor, shape)
122 }
123
124 fn bool_select(
125 tensor: BoolTensor<Self>,
126 dim: usize,
127 indices: IntTensor<Self>,
128 ) -> BoolTensor<Self> {
129 kernel::select::<R, BT, I>(tensor, dim, indices)
130 }
131
132 fn bool_select_assign(
133 tensor: BoolTensor<Self>,
134 dim: usize,
135 indices: IntTensor<Self>,
136 value: BoolTensor<Self>,
137 ) -> BoolTensor<Self> {
138 kernel::select_assign::<R, BT, I>(tensor, dim, indices, value, true)
139 }
140
141 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
142 kernel::flip::<R, BT, BT>(tensor, axes)
143 }
144
145 fn bool_unfold(
146 tensor: FloatTensor<Self>,
147 dim: usize,
148 size: usize,
149 step: usize,
150 ) -> FloatTensor<Self> {
151 unfold(tensor, dim, size, step)
152 }
153}