1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
6 element::Element,
7 ops::TransactionPrimitive,
8 tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11impl<B: Backend> BasicOps<B> for Bool {
12 type Elem = B::BoolElem;
13
14 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
15 if dtype != Self::Elem::dtype() {
16 panic!("Expected bool data type, got {dtype:?}");
17 }
18 B::bool_empty(shape, device)
19 }
20
21 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
22 if dtype != Self::Elem::dtype() {
23 panic!("Expected bool data type, got {dtype:?}");
24 }
25 B::bool_zeros(shape, device)
26 }
27 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
28 if dtype != Self::Elem::dtype() {
29 panic!("Expected bool data type, got {dtype:?}");
30 }
31 B::bool_ones(shape, device)
32 }
33
34 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
35 if dtype != Self::Elem::dtype() {
36 panic!("Expected bool data type, got {dtype:?}");
37 }
38 if fill_value.elem() {
39 B::bool_ones(shape, device)
40 } else {
41 B::bool_zeros(shape, device)
42 }
43 }
44
45 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
46 tr.register_bool(tensor);
47 }
48
49 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
50 B::bool_reshape(tensor, shape)
51 }
52
53 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
54 B::bool_transpose(tensor)
55 }
56
57 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
58 B::bool_swap_dims(tensor, dim1, dim2)
59 }
60
61 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
62 B::bool_slice(tensor, slices)
63 }
64
65 fn slice_assign(
66 tensor: Self::Primitive,
67 slices: &[Slice],
68 value: Self::Primitive,
69 ) -> Self::Primitive {
70 B::bool_slice_assign(tensor, slices, value)
71 }
72
73 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
74 B::bool_select(tensor, dim, indices)
75 }
76
77 fn select_assign(
78 tensor: Self::Primitive,
79 dim: usize,
80 indices: IntTensor<B>,
81 values: Self::Primitive,
82 update: IndexingUpdateOp,
83 ) -> Self::Primitive {
84 match update {
85 IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
86 }
87 }
88
89 fn mask_where(
90 tensor: Self::Primitive,
91 mask: B::BoolTensorPrimitive,
92 source: Self::Primitive,
93 ) -> Self::Primitive {
94 B::bool_mask_where(tensor, mask, source)
95 }
96
97 fn mask_fill(
98 tensor: Self::Primitive,
99 mask: B::BoolTensorPrimitive,
100 value: Scalar,
101 ) -> Self::Primitive {
102 B::bool_mask_fill(tensor, mask, value.elem())
105 }
106
107 fn gather(
108 dim: usize,
109 tensor: Self::Primitive,
110 indices: B::IntTensorPrimitive,
111 ) -> Self::Primitive {
112 B::bool_gather(dim, tensor, indices)
113 }
114
115 fn scatter(
116 dim: usize,
117 tensor: Self::Primitive,
118 indices: B::IntTensorPrimitive,
119 values: Self::Primitive,
120 update: IndexingUpdateOp,
121 ) -> Self::Primitive {
122 match update {
123 IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
124 }
125 }
126
127 fn device(tensor: &Self::Primitive) -> Device<B> {
128 B::bool_device(tensor)
129 }
130
131 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
132 B::bool_to_device(tensor, device)
133 }
134
135 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
136 B::bool_into_data(tensor).await
137 }
138
139 fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
140 B::bool_from_data(data.convert::<B::BoolElem>(), device)
141 }
142
143 fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
144 if dtype != B::BoolElem::dtype() {
146 panic!("Expected bool dtype, got {dtype:?}")
147 }
148 B::bool_from_data(data.convert_dtype(dtype), device)
149 }
150
151 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
152 B::bool_repeat_dim(tensor, dim, times)
153 }
154
155 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
156 B::bool_equal(lhs, rhs)
157 }
158
159 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
160 B::bool_not_equal(lhs, rhs)
161 }
162
163 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
164 B::bool_equal_elem(lhs, rhs.elem())
165 }
166
167 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
168 B::bool_not_equal_elem(lhs, rhs.elem())
169 }
170
171 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
172 B::bool_cat(vectors, dim)
173 }
174
175 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
176 B::bool_any(tensor)
177 }
178
179 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
180 B::bool_any_dim(tensor, dim)
181 }
182
183 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
184 B::bool_all(tensor)
185 }
186
187 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
188 B::bool_all_dim(tensor, dim)
189 }
190
191 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
192 B::bool_permute(tensor, axes)
193 }
194
195 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
196 B::bool_expand(tensor, shape)
197 }
198
199 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
200 B::bool_flip(tensor, axes)
201 }
202
203 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
204 B::bool_unfold(tensor, dim, size, step)
205 }
206}
207
208impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
209 type InnerKind = Bool;
210
211 fn inner(
212 tensor: <Self as TensorKind<B>>::Primitive,
213 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
214 B::bool_inner(tensor)
215 }
216
217 fn from_inner(
218 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
219 ) -> <Self as TensorKind<B>>::Primitive {
220 B::bool_from_inner(inner)
221 }
222}