Skip to main content

burn_backend/tensor/ops/
bool.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,
6    ops::TransactionPrimitive,
7    tensor::{
8        BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind,
9        TransactionOp,
10    },
11};
12
13impl<B: Backend> TransactionOp<B> for Bool {
14    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
15        tr.register_bool(tensor);
16    }
17}
18
19impl<B: Backend> BasicOps<B> for Bool {
20    type Elem = B::BoolElem;
21
22    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23        if !dtype.is_bool() {
24            panic!("Expected bool data type, got {dtype:?}");
25        }
26        B::bool_empty(shape, device, dtype.into())
27    }
28
29    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30        if !dtype.is_bool() {
31            panic!("Expected bool data type, got {dtype:?}");
32        }
33        B::bool_zeros(shape, device, dtype.into())
34    }
35    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
36        if !dtype.is_bool() {
37            panic!("Expected bool data type, got {dtype:?}");
38        }
39        B::bool_ones(shape, device, dtype.into())
40    }
41
42    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
43        if !dtype.is_bool() {
44            panic!("Expected bool data type, got {dtype:?}");
45        }
46        if fill_value.elem() {
47            B::bool_ones(shape, device, dtype.into())
48        } else {
49            B::bool_zeros(shape, device, dtype.into())
50        }
51    }
52
53    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
54        B::bool_reshape(tensor, shape)
55    }
56
57    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
58        B::bool_transpose(tensor)
59    }
60
61    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
62        B::bool_swap_dims(tensor, dim1, dim2)
63    }
64
65    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
66        B::bool_slice(tensor, slices)
67    }
68
69    fn slice_assign(
70        tensor: Self::Primitive,
71        slices: &[Slice],
72        value: Self::Primitive,
73    ) -> Self::Primitive {
74        B::bool_slice_assign(tensor, slices, value)
75    }
76
77    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
78        B::bool_select(tensor, dim, indices)
79    }
80
81    fn select_assign(
82        tensor: Self::Primitive,
83        dim: usize,
84        indices: IntTensor<B>,
85        values: Self::Primitive,
86        update: IndexingUpdateOp,
87    ) -> Self::Primitive {
88        match update {
89            IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
90            _ => unimplemented!(),
91        }
92    }
93
94    fn mask_where(
95        tensor: Self::Primitive,
96        mask: B::BoolTensorPrimitive,
97        source: Self::Primitive,
98    ) -> Self::Primitive {
99        B::bool_mask_where(tensor, mask, source)
100    }
101
102    fn mask_fill(
103        tensor: Self::Primitive,
104        mask: B::BoolTensorPrimitive,
105        value: Scalar,
106    ) -> Self::Primitive {
107        B::bool_mask_fill(tensor, mask, value)
108    }
109
110    fn gather(
111        dim: usize,
112        tensor: Self::Primitive,
113        indices: B::IntTensorPrimitive,
114    ) -> Self::Primitive {
115        B::bool_gather(dim, tensor, indices)
116    }
117
118    fn scatter(
119        dim: usize,
120        tensor: Self::Primitive,
121        indices: B::IntTensorPrimitive,
122        values: Self::Primitive,
123        update: IndexingUpdateOp,
124    ) -> Self::Primitive {
125        match update {
126            IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
127            _ => unimplemented!(),
128        }
129    }
130
131    fn scatter_nd(
132        _data: Self::Primitive,
133        _indices: IntTensor<B>,
134        _values: Self::Primitive,
135        _reduction: IndexingUpdateOp,
136    ) -> Self::Primitive {
137        panic!("scatter_nd is not supported for bool tensors")
138    }
139
140    fn gather_nd(_data: Self::Primitive, _indices: IntTensor<B>) -> Self::Primitive {
141        panic!("gather_nd is not supported for bool tensors")
142    }
143
144    fn device(tensor: &Self::Primitive) -> Device<B> {
145        B::bool_device(tensor)
146    }
147
148    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
149        B::bool_to_device(tensor, device)
150    }
151
152    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
153        B::bool_into_data(tensor).await
154    }
155
156    fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
157        // Bool tensors have exactly one representation per backend, so the
158        // requested dtype should have been resolved to the default bool dtype with the
159        // tensor creation options.
160        B::bool_from_data(data.convert_dtype(dtype), device)
161    }
162
163    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
164        B::bool_repeat_dim(tensor, dim, times)
165    }
166
167    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
168        B::bool_equal(lhs, rhs)
169    }
170
171    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
172        B::bool_not_equal(lhs, rhs)
173    }
174
175    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
176        B::bool_equal_elem(lhs, rhs)
177    }
178
179    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
180        B::bool_not_equal_elem(lhs, rhs)
181    }
182
183    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
184        B::bool_cat(vectors, dim)
185    }
186
187    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
188        B::bool_any(tensor)
189    }
190
191    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
192        B::bool_any_dim(tensor, dim)
193    }
194
195    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
196        B::bool_all(tensor)
197    }
198
199    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
200        B::bool_all_dim(tensor, dim)
201    }
202
203    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
204        B::bool_permute(tensor, axes)
205    }
206
207    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
208        B::bool_expand(tensor, shape)
209    }
210
211    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
212        B::bool_flip(tensor, axes)
213    }
214
215    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
216        B::bool_unfold(tensor, dim, size, step)
217    }
218}
219
220impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
221    type InnerKind = Bool;
222
223    fn inner(
224        tensor: <Self as TensorKind<B>>::Primitive,
225    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
226        B::bool_inner(tensor)
227    }
228
229    fn from_inner(
230        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
231    ) -> <Self as TensorKind<B>>::Primitive {
232        B::bool_from_inner(inner)
233    }
234}