Skip to main content

burn_backend/tensor/ops/
bool.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
6    element::Element,
7    ops::TransactionPrimitive,
8    tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11impl<B: Backend> BasicOps<B> for Bool {
12    type Elem = B::BoolElem;
13
14    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
15        if dtype != Self::Elem::dtype() {
16            panic!("Expected bool data type, got {dtype:?}");
17        }
18        B::bool_empty(shape, device)
19    }
20
21    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
22        if dtype != Self::Elem::dtype() {
23            panic!("Expected bool data type, got {dtype:?}");
24        }
25        B::bool_zeros(shape, device)
26    }
27    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
28        if dtype != Self::Elem::dtype() {
29            panic!("Expected bool data type, got {dtype:?}");
30        }
31        B::bool_ones(shape, device)
32    }
33
34    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
35        if dtype != Self::Elem::dtype() {
36            panic!("Expected bool data type, got {dtype:?}");
37        }
38        if fill_value.elem() {
39            B::bool_ones(shape, device)
40        } else {
41            B::bool_zeros(shape, device)
42        }
43    }
44
45    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
46        tr.register_bool(tensor);
47    }
48
49    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
50        B::bool_reshape(tensor, shape)
51    }
52
53    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
54        B::bool_transpose(tensor)
55    }
56
57    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
58        B::bool_swap_dims(tensor, dim1, dim2)
59    }
60
61    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
62        B::bool_slice(tensor, slices)
63    }
64
65    fn slice_assign(
66        tensor: Self::Primitive,
67        slices: &[Slice],
68        value: Self::Primitive,
69    ) -> Self::Primitive {
70        B::bool_slice_assign(tensor, slices, value)
71    }
72
73    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
74        B::bool_select(tensor, dim, indices)
75    }
76
77    fn select_assign(
78        tensor: Self::Primitive,
79        dim: usize,
80        indices: IntTensor<B>,
81        values: Self::Primitive,
82        update: IndexingUpdateOp,
83    ) -> Self::Primitive {
84        match update {
85            IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
86        }
87    }
88
89    fn mask_where(
90        tensor: Self::Primitive,
91        mask: B::BoolTensorPrimitive,
92        source: Self::Primitive,
93    ) -> Self::Primitive {
94        B::bool_mask_where(tensor, mask, source)
95    }
96
97    fn mask_fill(
98        tensor: Self::Primitive,
99        mask: B::BoolTensorPrimitive,
100        value: Scalar,
101    ) -> Self::Primitive {
102        // NOTE: we currently only support one one element type for bool, so the contract reflects that
103        // via `BoolElem<B>` instead of `Scalar` in the trait methods.
104        B::bool_mask_fill(tensor, mask, value.elem())
105    }
106
107    fn gather(
108        dim: usize,
109        tensor: Self::Primitive,
110        indices: B::IntTensorPrimitive,
111    ) -> Self::Primitive {
112        B::bool_gather(dim, tensor, indices)
113    }
114
115    fn scatter(
116        dim: usize,
117        tensor: Self::Primitive,
118        indices: B::IntTensorPrimitive,
119        values: Self::Primitive,
120        update: IndexingUpdateOp,
121    ) -> Self::Primitive {
122        match update {
123            IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
124        }
125    }
126
127    fn device(tensor: &Self::Primitive) -> Device<B> {
128        B::bool_device(tensor)
129    }
130
131    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
132        B::bool_to_device(tensor, device)
133    }
134
135    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
136        B::bool_into_data(tensor).await
137    }
138
139    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
140        B::bool_from_data(data.convert::<B::BoolElem>(), device)
141    }
142
143    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
144        // Backends only use one bool representation dtype
145        if dtype != B::BoolElem::dtype() {
146            panic!("Expected bool dtype, got {dtype:?}")
147        }
148        B::bool_from_data(data.convert_dtype(dtype), device)
149    }
150
151    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
152        B::bool_repeat_dim(tensor, dim, times)
153    }
154
155    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
156        B::bool_equal(lhs, rhs)
157    }
158
159    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
160        B::bool_not_equal(lhs, rhs)
161    }
162
163    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
164        B::bool_equal_elem(lhs, rhs.elem())
165    }
166
167    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
168        B::bool_not_equal_elem(lhs, rhs.elem())
169    }
170
171    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
172        B::bool_cat(vectors, dim)
173    }
174
175    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
176        B::bool_any(tensor)
177    }
178
179    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
180        B::bool_any_dim(tensor, dim)
181    }
182
183    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
184        B::bool_all(tensor)
185    }
186
187    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
188        B::bool_all_dim(tensor, dim)
189    }
190
191    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
192        B::bool_permute(tensor, axes)
193    }
194
195    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
196        B::bool_expand(tensor, shape)
197    }
198
199    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
200        B::bool_flip(tensor, axes)
201    }
202
203    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
204        B::bool_unfold(tensor, dim, size, step)
205    }
206}
207
208impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
209    type InnerKind = Bool;
210
211    fn inner(
212        tensor: <Self as TensorKind<B>>::Primitive,
213    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
214        B::bool_inner(tensor)
215    }
216
217    fn from_inner(
218        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
219    ) -> <Self as TensorKind<B>>::Primitive {
220        B::bool_from_inner(inner)
221    }
222}