burn_backend/tensor/ops/
bool.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, ExecutionError, TensorData,
6    element::{Element, ElementConversion},
7    ops::TransactionPrimitive,
8    tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11impl<B: Backend> BasicOps<B> for Bool {
12    type Elem = B::BoolElem;
13
14    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
15        if dtype != Self::Elem::dtype() {
16            panic!("Expected bool data type, got {dtype:?}");
17        }
18        B::bool_empty(shape, device)
19    }
20
21    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
22        if dtype != Self::Elem::dtype() {
23            panic!("Expected bool data type, got {dtype:?}");
24        }
25        B::bool_zeros(shape, device)
26    }
27    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
28        if dtype != Self::Elem::dtype() {
29            panic!("Expected bool data type, got {dtype:?}");
30        }
31        B::bool_ones(shape, device)
32    }
33
34    fn full<E: ElementConversion>(
35        shape: Shape,
36        fill_value: E,
37        device: &Device<B>,
38        dtype: DType,
39    ) -> Self::Primitive {
40        if dtype != Self::Elem::dtype() {
41            panic!("Expected bool data type, got {dtype:?}");
42        }
43        if fill_value.elem() {
44            B::bool_ones(shape, device)
45        } else {
46            B::bool_zeros(shape, device)
47        }
48    }
49
50    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
51        tr.register_bool(tensor);
52    }
53
54    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
55        B::bool_reshape(tensor, shape)
56    }
57
58    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
59        B::bool_transpose(tensor)
60    }
61
62    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
63        B::bool_swap_dims(tensor, dim1, dim2)
64    }
65
66    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
67        B::bool_slice(tensor, slices)
68    }
69
70    fn slice_assign(
71        tensor: Self::Primitive,
72        slices: &[Slice],
73        value: Self::Primitive,
74    ) -> Self::Primitive {
75        B::bool_slice_assign(tensor, slices, value)
76    }
77
78    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
79        B::bool_select(tensor, dim, indices)
80    }
81
82    fn select_assign(
83        tensor: Self::Primitive,
84        dim: usize,
85        indices: IntTensor<B>,
86        values: Self::Primitive,
87        update: IndexingUpdateOp,
88    ) -> Self::Primitive {
89        match update {
90            IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),
91        }
92    }
93
94    fn mask_where(
95        tensor: Self::Primitive,
96        mask: B::BoolTensorPrimitive,
97        source: Self::Primitive,
98    ) -> Self::Primitive {
99        B::bool_mask_where(tensor, mask, source)
100    }
101
102    fn mask_fill(
103        tensor: Self::Primitive,
104        mask: B::BoolTensorPrimitive,
105        value: Self::Elem,
106    ) -> Self::Primitive {
107        B::bool_mask_fill(tensor, mask, value)
108    }
109
110    fn gather(
111        dim: usize,
112        tensor: Self::Primitive,
113        indices: B::IntTensorPrimitive,
114    ) -> Self::Primitive {
115        B::bool_gather(dim, tensor, indices)
116    }
117
118    fn scatter(
119        dim: usize,
120        tensor: Self::Primitive,
121        indices: B::IntTensorPrimitive,
122        values: Self::Primitive,
123        update: IndexingUpdateOp,
124    ) -> Self::Primitive {
125        match update {
126            IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),
127        }
128    }
129
130    fn device(tensor: &Self::Primitive) -> Device<B> {
131        B::bool_device(tensor)
132    }
133
134    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
135        B::bool_to_device(tensor, device)
136    }
137
138    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
139        B::bool_into_data(tensor).await
140    }
141
142    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
143        B::bool_from_data(data.convert::<B::BoolElem>(), device)
144    }
145
146    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
147        // Backends only use one bool representation dtype
148        if dtype != B::BoolElem::dtype() {
149            panic!("Expected bool dtype, got {dtype:?}")
150        }
151        B::bool_from_data(data.convert_dtype(dtype), device)
152    }
153
154    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
155        B::bool_repeat_dim(tensor, dim, times)
156    }
157
158    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
159        B::bool_equal(lhs, rhs)
160    }
161
162    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
163        B::bool_not_equal(lhs, rhs)
164    }
165
166    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
167        B::bool_equal_elem(lhs, rhs)
168    }
169
170    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
171        B::bool_not_equal_elem(lhs, rhs)
172    }
173
174    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
175        B::bool_cat(vectors, dim)
176    }
177
178    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
179        B::bool_any(tensor)
180    }
181
182    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
183        B::bool_any_dim(tensor, dim)
184    }
185
186    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
187        B::bool_all(tensor)
188    }
189
190    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
191        B::bool_all_dim(tensor, dim)
192    }
193
194    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
195        B::bool_permute(tensor, axes)
196    }
197
198    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
199        B::bool_expand(tensor, shape)
200    }
201
202    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
203        B::bool_flip(tensor, axes)
204    }
205
206    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
207        B::bool_unfold(tensor, dim, size, step)
208    }
209}
210
211impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
212    type InnerKind = Bool;
213
214    fn inner(
215        tensor: <Self as TensorKind<B>>::Primitive,
216    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
217        B::bool_inner(tensor)
218    }
219
220    fn from_inner(
221        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
222    ) -> <Self as TensorKind<B>>::Primitive {
223        B::bool_from_inner(inner)
224    }
225}