Skip to main content

burn_backend/tensor/ops/
bool.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
6    ops::TransactionPrimitive,
7    tensor::{
8        BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind,
9        TransactionOp,
10    },
11};
12
13impl<B: Backend> TransactionOp<B> for Bool {
14    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
15        tr.register_bool(tensor);
16    }
17}
18
19impl<B: Backend> BasicOps<B> for Bool {
20    type Elem = B::BoolElem;
21
22    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23        if !dtype.is_bool() {
24            panic!("Expected bool data type, got {dtype:?}");
25        }
26        B::bool_empty(shape, device, dtype.into())
27    }
28
29    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30        if !dtype.is_bool() {
31            panic!("Expected bool data type, got {dtype:?}");
32        }
33        B::bool_zeros(shape, device, dtype.into())
34    }
35    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
36        if !dtype.is_bool() {
37            panic!("Expected bool data type, got {dtype:?}");
38        }
39        B::bool_ones(shape, device, dtype.into())
40    }
41
42    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
43        if !dtype.is_bool() {
44            panic!("Expected bool data type, got {dtype:?}");
45        }
46        if fill_value.elem() {
47            B::bool_ones(shape, device, dtype.into())
48        } else {
49            B::bool_zeros(shape, device, dtype.into())
50        }
51    }
52
53    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
54        B::bool_reshape(tensor, shape)
55    }
56
57    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
58        B::bool_transpose(tensor)
59    }
60
61    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
62        B::bool_swap_dims(tensor, dim1, dim2)
63    }
64
65    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
66        B::bool_slice(tensor, slices)
67    }
68
69    fn slice_assign(
70        tensor: Self::Primitive,
71        slices: &[Slice],
72        value: Self::Primitive,
73    ) -> Self::Primitive {
74        B::bool_slice_assign(tensor, slices, value)
75    }
76
77    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
78        B::bool_select(tensor, dim, indices)
79    }
80
81    fn select_assign(
82        tensor: Self::Primitive,
83        dim: usize,
84        indices: IntTensor<B>,
85        values: Self::Primitive,
86        update: IndexingUpdateOp,
87    ) -> Self::Primitive {
88        match update {
89            IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
90        }
91    }
92
93    fn mask_where(
94        tensor: Self::Primitive,
95        mask: B::BoolTensorPrimitive,
96        source: Self::Primitive,
97    ) -> Self::Primitive {
98        B::bool_mask_where(tensor, mask, source)
99    }
100
101    fn mask_fill(
102        tensor: Self::Primitive,
103        mask: B::BoolTensorPrimitive,
104        value: Scalar,
105    ) -> Self::Primitive {
106        B::bool_mask_fill(tensor, mask, value)
107    }
108
109    fn gather(
110        dim: usize,
111        tensor: Self::Primitive,
112        indices: B::IntTensorPrimitive,
113    ) -> Self::Primitive {
114        B::bool_gather(dim, tensor, indices)
115    }
116
117    fn scatter(
118        dim: usize,
119        tensor: Self::Primitive,
120        indices: B::IntTensorPrimitive,
121        values: Self::Primitive,
122        update: IndexingUpdateOp,
123    ) -> Self::Primitive {
124        match update {
125            IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
126        }
127    }
128
129    fn device(tensor: &Self::Primitive) -> Device<B> {
130        B::bool_device(tensor)
131    }
132
133    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
134        B::bool_to_device(tensor, device)
135    }
136
137    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
138        B::bool_into_data(tensor).await
139    }
140
141    fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
142        // Bool tensors have exactly one representation per backend, so the
143        // requested dtype should have been resolved to the default bool dtype with the
144        // tensor creation options.
145        B::bool_from_data(data.convert_dtype(dtype), device)
146    }
147
148    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
149        B::bool_repeat_dim(tensor, dim, times)
150    }
151
152    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
153        B::bool_equal(lhs, rhs)
154    }
155
156    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
157        B::bool_not_equal(lhs, rhs)
158    }
159
160    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
161        B::bool_equal_elem(lhs, rhs)
162    }
163
164    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
165        B::bool_not_equal_elem(lhs, rhs)
166    }
167
168    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
169        B::bool_cat(vectors, dim)
170    }
171
172    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
173        B::bool_any(tensor)
174    }
175
176    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
177        B::bool_any_dim(tensor, dim)
178    }
179
180    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
181        B::bool_all(tensor)
182    }
183
184    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
185        B::bool_all_dim(tensor, dim)
186    }
187
188    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
189        B::bool_permute(tensor, axes)
190    }
191
192    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
193        B::bool_expand(tensor, shape)
194    }
195
196    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
197        B::bool_flip(tensor, axes)
198    }
199
200    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
201        B::bool_unfold(tensor, dim, size, step)
202    }
203}
204
205impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
206    type InnerKind = Bool;
207
208    fn inner(
209        tensor: <Self as TensorKind<B>>::Primitive,
210    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
211        B::bool_inner(tensor)
212    }
213
214    fn from_inner(
215        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
216    ) -> <Self as TensorKind<B>>::Primitive {
217        B::bool_from_inner(inner)
218    }
219}