1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, ExecutionError, TensorData,
6 element::{Element, ElementConversion},
7 ops::TransactionPrimitive,
8 tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11impl<B: Backend> BasicOps<B> for Bool {
12 type Elem = B::BoolElem;
13
14 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
15 if dtype != Self::Elem::dtype() {
16 panic!("Expected bool data type, got {dtype:?}");
17 }
18 B::bool_empty(shape, device)
19 }
20
21 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
22 if dtype != Self::Elem::dtype() {
23 panic!("Expected bool data type, got {dtype:?}");
24 }
25 B::bool_zeros(shape, device)
26 }
27 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
28 if dtype != Self::Elem::dtype() {
29 panic!("Expected bool data type, got {dtype:?}");
30 }
31 B::bool_ones(shape, device)
32 }
33
34 fn full<E: ElementConversion>(
35 shape: Shape,
36 fill_value: E,
37 device: &Device<B>,
38 dtype: DType,
39 ) -> Self::Primitive {
40 if dtype != Self::Elem::dtype() {
41 panic!("Expected bool data type, got {dtype:?}");
42 }
43 if fill_value.elem() {
44 B::bool_ones(shape, device)
45 } else {
46 B::bool_zeros(shape, device)
47 }
48 }
49
50 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
51 tr.register_bool(tensor);
52 }
53
54 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
55 B::bool_reshape(tensor, shape)
56 }
57
58 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
59 B::bool_transpose(tensor)
60 }
61
62 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
63 B::bool_swap_dims(tensor, dim1, dim2)
64 }
65
66 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
67 B::bool_slice(tensor, slices)
68 }
69
70 fn slice_assign(
71 tensor: Self::Primitive,
72 slices: &[Slice],
73 value: Self::Primitive,
74 ) -> Self::Primitive {
75 B::bool_slice_assign(tensor, slices, value)
76 }
77
78 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
79 B::bool_select(tensor, dim, indices)
80 }
81
82 fn select_assign(
83 tensor: Self::Primitive,
84 dim: usize,
85 indices: IntTensor<B>,
86 values: Self::Primitive,
87 update: IndexingUpdateOp,
88 ) -> Self::Primitive {
89 match update {
90 IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
91 }
92 }
93
94 fn mask_where(
95 tensor: Self::Primitive,
96 mask: B::BoolTensorPrimitive,
97 source: Self::Primitive,
98 ) -> Self::Primitive {
99 B::bool_mask_where(tensor, mask, source)
100 }
101
102 fn mask_fill(
103 tensor: Self::Primitive,
104 mask: B::BoolTensorPrimitive,
105 value: Self::Elem,
106 ) -> Self::Primitive {
107 B::bool_mask_fill(tensor, mask, value)
108 }
109
110 fn gather(
111 dim: usize,
112 tensor: Self::Primitive,
113 indices: B::IntTensorPrimitive,
114 ) -> Self::Primitive {
115 B::bool_gather(dim, tensor, indices)
116 }
117
118 fn scatter(
119 dim: usize,
120 tensor: Self::Primitive,
121 indices: B::IntTensorPrimitive,
122 values: Self::Primitive,
123 update: IndexingUpdateOp,
124 ) -> Self::Primitive {
125 match update {
126 IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
127 }
128 }
129
130 fn device(tensor: &Self::Primitive) -> Device<B> {
131 B::bool_device(tensor)
132 }
133
134 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
135 B::bool_to_device(tensor, device)
136 }
137
138 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
139 B::bool_into_data(tensor).await
140 }
141
142 fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
143 B::bool_from_data(data.convert::<B::BoolElem>(), device)
144 }
145
146 fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
147 if dtype != B::BoolElem::dtype() {
149 panic!("Expected bool dtype, got {dtype:?}")
150 }
151 B::bool_from_data(data.convert_dtype(dtype), device)
152 }
153
154 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
155 B::bool_repeat_dim(tensor, dim, times)
156 }
157
158 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
159 B::bool_equal(lhs, rhs)
160 }
161
162 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
163 B::bool_not_equal(lhs, rhs)
164 }
165
166 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
167 B::bool_equal_elem(lhs, rhs)
168 }
169
170 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
171 B::bool_not_equal_elem(lhs, rhs)
172 }
173
174 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
175 B::bool_cat(vectors, dim)
176 }
177
178 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
179 B::bool_any(tensor)
180 }
181
182 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
183 B::bool_any_dim(tensor, dim)
184 }
185
186 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
187 B::bool_all(tensor)
188 }
189
190 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
191 B::bool_all_dim(tensor, dim)
192 }
193
194 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
195 B::bool_permute(tensor, axes)
196 }
197
198 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
199 B::bool_expand(tensor, shape)
200 }
201
202 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
203 B::bool_flip(tensor, axes)
204 }
205
206 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
207 B::bool_unfold(tensor, dim, size, step)
208 }
209}
210
211impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
212 type InnerKind = Bool;
213
214 fn inner(
215 tensor: <Self as TensorKind<B>>::Primitive,
216 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
217 B::bool_inner(tensor)
218 }
219
220 fn from_inner(
221 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
222 ) -> <Self as TensorKind<B>>::Primitive {
223 B::bool_from_inner(inner)
224 }
225}