1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
6 ops::TransactionPrimitive,
7 tensor::{
8 BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind,
9 TransactionOp,
10 },
11};
12
13impl<B: Backend> TransactionOp<B> for Bool {
14 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
15 tr.register_bool(tensor);
16 }
17}
18
19impl<B: Backend> BasicOps<B> for Bool {
20 type Elem = B::BoolElem;
21
22 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23 if !dtype.is_bool() {
24 panic!("Expected bool data type, got {dtype:?}");
25 }
26 B::bool_empty(shape, device, dtype.into())
27 }
28
29 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30 if !dtype.is_bool() {
31 panic!("Expected bool data type, got {dtype:?}");
32 }
33 B::bool_zeros(shape, device, dtype.into())
34 }
35 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
36 if !dtype.is_bool() {
37 panic!("Expected bool data type, got {dtype:?}");
38 }
39 B::bool_ones(shape, device, dtype.into())
40 }
41
42 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
43 if !dtype.is_bool() {
44 panic!("Expected bool data type, got {dtype:?}");
45 }
46 if fill_value.elem() {
47 B::bool_ones(shape, device, dtype.into())
48 } else {
49 B::bool_zeros(shape, device, dtype.into())
50 }
51 }
52
53 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
54 B::bool_reshape(tensor, shape)
55 }
56
57 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
58 B::bool_transpose(tensor)
59 }
60
61 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
62 B::bool_swap_dims(tensor, dim1, dim2)
63 }
64
65 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
66 B::bool_slice(tensor, slices)
67 }
68
69 fn slice_assign(
70 tensor: Self::Primitive,
71 slices: &[Slice],
72 value: Self::Primitive,
73 ) -> Self::Primitive {
74 B::bool_slice_assign(tensor, slices, value)
75 }
76
77 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
78 B::bool_select(tensor, dim, indices)
79 }
80
81 fn select_assign(
82 tensor: Self::Primitive,
83 dim: usize,
84 indices: IntTensor<B>,
85 values: Self::Primitive,
86 update: IndexingUpdateOp,
87 ) -> Self::Primitive {
88 match update {
89 IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
90 _ => unimplemented!(),
91 }
92 }
93
94 fn mask_where(
95 tensor: Self::Primitive,
96 mask: B::BoolTensorPrimitive,
97 source: Self::Primitive,
98 ) -> Self::Primitive {
99 B::bool_mask_where(tensor, mask, source)
100 }
101
102 fn mask_fill(
103 tensor: Self::Primitive,
104 mask: B::BoolTensorPrimitive,
105 value: Scalar,
106 ) -> Self::Primitive {
107 B::bool_mask_fill(tensor, mask, value)
108 }
109
110 fn gather(
111 dim: usize,
112 tensor: Self::Primitive,
113 indices: B::IntTensorPrimitive,
114 ) -> Self::Primitive {
115 B::bool_gather(dim, tensor, indices)
116 }
117
118 fn scatter(
119 dim: usize,
120 tensor: Self::Primitive,
121 indices: B::IntTensorPrimitive,
122 values: Self::Primitive,
123 update: IndexingUpdateOp,
124 ) -> Self::Primitive {
125 match update {
126 IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
127 _ => unimplemented!(),
128 }
129 }
130
131 fn scatter_nd(
132 _data: Self::Primitive,
133 _indices: IntTensor<B>,
134 _values: Self::Primitive,
135 _reduction: IndexingUpdateOp,
136 ) -> Self::Primitive {
137 panic!("scatter_nd is not supported for bool tensors")
138 }
139
140 fn gather_nd(_data: Self::Primitive, _indices: IntTensor<B>) -> Self::Primitive {
141 panic!("gather_nd is not supported for bool tensors")
142 }
143
144 fn device(tensor: &Self::Primitive) -> Device<B> {
145 B::bool_device(tensor)
146 }
147
148 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
149 B::bool_to_device(tensor, device)
150 }
151
152 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
153 B::bool_into_data(tensor).await
154 }
155
156 fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
157 B::bool_from_data(data.convert_dtype(dtype), device)
161 }
162
163 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
164 B::bool_repeat_dim(tensor, dim, times)
165 }
166
167 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
168 B::bool_equal(lhs, rhs)
169 }
170
171 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
172 B::bool_not_equal(lhs, rhs)
173 }
174
175 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
176 B::bool_equal_elem(lhs, rhs)
177 }
178
179 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
180 B::bool_not_equal_elem(lhs, rhs)
181 }
182
183 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
184 B::bool_cat(vectors, dim)
185 }
186
187 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
188 B::bool_any(tensor)
189 }
190
191 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
192 B::bool_any_dim(tensor, dim)
193 }
194
195 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
196 B::bool_all(tensor)
197 }
198
199 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
200 B::bool_all_dim(tensor, dim)
201 }
202
203 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
204 B::bool_permute(tensor, axes)
205 }
206
207 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
208 B::bool_expand(tensor, shape)
209 }
210
211 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
212 B::bool_flip(tensor, axes)
213 }
214
215 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
216 B::bool_unfold(tensor, dim, size, step)
217 }
218}
219
220impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
221 type InnerKind = Bool;
222
223 fn inner(
224 tensor: <Self as TensorKind<B>>::Primitive,
225 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
226 B::bool_inner(tensor)
227 }
228
229 fn from_inner(
230 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
231 ) -> <Self as TensorKind<B>>::Primitive {
232 B::bool_from_inner(inner)
233 }
234}