1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
6 ops::TransactionPrimitive,
7 tensor::{
8 BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind,
9 TransactionOp,
10 },
11};
12
13impl<B: Backend> TransactionOp<B> for Bool {
14 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
15 tr.register_bool(tensor);
16 }
17}
18
19impl<B: Backend> BasicOps<B> for Bool {
20 type Elem = B::BoolElem;
21
22 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23 if !dtype.is_bool() {
24 panic!("Expected bool data type, got {dtype:?}");
25 }
26 B::bool_empty(shape, device, dtype.into())
27 }
28
29 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30 if !dtype.is_bool() {
31 panic!("Expected bool data type, got {dtype:?}");
32 }
33 B::bool_zeros(shape, device, dtype.into())
34 }
35 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
36 if !dtype.is_bool() {
37 panic!("Expected bool data type, got {dtype:?}");
38 }
39 B::bool_ones(shape, device, dtype.into())
40 }
41
42 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
43 if !dtype.is_bool() {
44 panic!("Expected bool data type, got {dtype:?}");
45 }
46 if fill_value.elem() {
47 B::bool_ones(shape, device, dtype.into())
48 } else {
49 B::bool_zeros(shape, device, dtype.into())
50 }
51 }
52
53 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
54 B::bool_reshape(tensor, shape)
55 }
56
57 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
58 B::bool_transpose(tensor)
59 }
60
61 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
62 B::bool_swap_dims(tensor, dim1, dim2)
63 }
64
65 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
66 B::bool_slice(tensor, slices)
67 }
68
69 fn slice_assign(
70 tensor: Self::Primitive,
71 slices: &[Slice],
72 value: Self::Primitive,
73 ) -> Self::Primitive {
74 B::bool_slice_assign(tensor, slices, value)
75 }
76
77 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
78 B::bool_select(tensor, dim, indices)
79 }
80
81 fn select_assign(
82 tensor: Self::Primitive,
83 dim: usize,
84 indices: IntTensor<B>,
85 values: Self::Primitive,
86 update: IndexingUpdateOp,
87 ) -> Self::Primitive {
88 match update {
89 IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
90 }
91 }
92
93 fn mask_where(
94 tensor: Self::Primitive,
95 mask: B::BoolTensorPrimitive,
96 source: Self::Primitive,
97 ) -> Self::Primitive {
98 B::bool_mask_where(tensor, mask, source)
99 }
100
101 fn mask_fill(
102 tensor: Self::Primitive,
103 mask: B::BoolTensorPrimitive,
104 value: Scalar,
105 ) -> Self::Primitive {
106 B::bool_mask_fill(tensor, mask, value)
107 }
108
109 fn gather(
110 dim: usize,
111 tensor: Self::Primitive,
112 indices: B::IntTensorPrimitive,
113 ) -> Self::Primitive {
114 B::bool_gather(dim, tensor, indices)
115 }
116
117 fn scatter(
118 dim: usize,
119 tensor: Self::Primitive,
120 indices: B::IntTensorPrimitive,
121 values: Self::Primitive,
122 update: IndexingUpdateOp,
123 ) -> Self::Primitive {
124 match update {
125 IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
126 }
127 }
128
129 fn device(tensor: &Self::Primitive) -> Device<B> {
130 B::bool_device(tensor)
131 }
132
133 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
134 B::bool_to_device(tensor, device)
135 }
136
137 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
138 B::bool_into_data(tensor).await
139 }
140
141 fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
142 B::bool_from_data(data.convert_dtype(dtype), device)
146 }
147
148 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
149 B::bool_repeat_dim(tensor, dim, times)
150 }
151
152 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
153 B::bool_equal(lhs, rhs)
154 }
155
156 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
157 B::bool_not_equal(lhs, rhs)
158 }
159
160 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
161 B::bool_equal_elem(lhs, rhs)
162 }
163
164 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
165 B::bool_not_equal_elem(lhs, rhs)
166 }
167
168 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
169 B::bool_cat(vectors, dim)
170 }
171
172 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
173 B::bool_any(tensor)
174 }
175
176 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
177 B::bool_any_dim(tensor, dim)
178 }
179
180 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
181 B::bool_all(tensor)
182 }
183
184 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
185 B::bool_all_dim(tensor, dim)
186 }
187
188 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
189 B::bool_permute(tensor, axes)
190 }
191
192 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
193 B::bool_expand(tensor, shape)
194 }
195
196 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
197 B::bool_flip(tensor, axes)
198 }
199
200 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
201 B::bool_unfold(tensor, dim, size, step)
202 }
203}
204
205impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
206 type InnerKind = Bool;
207
208 fn inner(
209 tensor: <Self as TensorKind<B>>::Primitive,
210 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
211 B::bool_inner(tensor)
212 }
213
214 fn from_inner(
215 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
216 ) -> <Self as TensorKind<B>>::Primitive {
217 B::bool_from_inner(inner)
218 }
219}