1use super::TchOps;
2use crate::IntoKind;
3use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
4use burn_backend::BoolStore;
5use burn_backend::ExecutionError;
6use burn_backend::IntDType;
7use burn_backend::Scalar;
8use burn_backend::tensor::BoolTensor;
9use burn_backend::tensor::IntTensor;
10use burn_backend::{BoolDType, FloatDType};
11use burn_backend::{Shape, TensorData, TensorMetadata, ops::BoolTensorOps};
12
13impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
14 fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
15 match data.dtype {
16 burn_backend::DType::Bool(BoolStore::Native) => {
17 TchTensor::from_data::<bool>(data, (*device).into())
18 }
19 _ => unimplemented!("Unsupported dtype for `bool_from_data`"),
20 }
21 }
22
23 fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
24 TchOps::repeat_dim(tensor, dim, times)
25 }
26
27 async fn bool_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
28 let shape = tensor.shape();
29 let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
30 let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
31 Ok(TensorData::new(values.unwrap(), shape))
32 }
33
34 fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
35 TchOps::to_device(tensor, device)
36 }
37
38 fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
39 TchOps::reshape(tensor, shape)
40 }
41
42 fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
43 tensor.tensor.device().into()
44 }
45
46 fn bool_empty(shape: Shape, device: &LibTorchDevice, _dtype: BoolDType) -> TchTensor {
47 let tensor = tch::Tensor::empty(
48 TchShape::from(shape).dims,
49 (tch::Kind::Bool, (*device).into()),
50 );
51
52 TchTensor::new(tensor)
53 }
54
55 fn bool_zeros(shape: Shape, device: &LibTorchDevice, _dtype: BoolDType) -> TchTensor {
56 let tensor = tch::Tensor::zeros(
57 TchShape::from(shape).dims,
58 (tch::Kind::Bool, (*device).into()),
59 );
60
61 TchTensor::new(tensor)
62 }
63
64 fn bool_ones(shape: Shape, device: &LibTorchDevice, _dtype: BoolDType) -> TchTensor {
65 let tensor = tch::Tensor::ones(
66 TchShape::from(shape).dims,
67 (tch::Kind::Bool, (*device).into()),
68 );
69
70 TchTensor::new(tensor)
71 }
72
73 fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
74 TchOps::slice_with_steps(tensor, slices)
75 }
76
77 fn bool_slice_assign(
78 tensor: TchTensor,
79 slices: &[burn_backend::Slice],
80 value: TchTensor,
81 ) -> TchTensor {
82 TchOps::slice_assign(tensor, slices, value)
83 }
84
85 fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
86 TchOps::cat(tensors, dim)
87 }
88
89 fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
90 TchOps::equal(lhs, rhs)
91 }
92
93 fn bool_not(tensor: TchTensor) -> TchTensor {
94 tensor.unary_ops(
95 |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
96 |tensor| tensor.eq(0),
97 )
98 }
99
100 fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
101 TchTensor::binary_ops_tensor(
102 lhs,
103 rhs,
104 |lhs, rhs| lhs.logical_and_(rhs),
105 |lhs, rhs| rhs.logical_and_(lhs),
106 |lhs, rhs| lhs.logical_and(rhs),
107 )
108 }
109
110 fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
111 TchTensor::binary_ops_tensor(
112 lhs,
113 rhs,
114 |lhs, rhs| lhs.logical_or_(rhs),
115 |lhs, rhs| rhs.logical_or_(lhs),
116 |lhs, rhs| lhs.logical_or(rhs),
117 )
118 }
119
120 fn bool_into_int(tensor: TchTensor, out_dtype: IntDType) -> TchTensor {
121 let tensor = tensor.tensor.to_kind(out_dtype.into_kind());
122 TchTensor::new(tensor)
123 }
124
125 fn bool_into_float(tensor: TchTensor, out_dtype: FloatDType) -> TchTensor {
126 let tensor = tensor.tensor.to_kind(out_dtype.into_kind());
127 TchTensor::new(tensor)
128 }
129
130 fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
131 TchOps::swap_dims(tensor, dim1, dim2)
132 }
133
134 fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
135 TchOps::permute(tensor, axes)
136 }
137
138 fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
139 TchOps::flip(tensor, axes)
140 }
141
142 async fn bool_argwhere(tensor: TchTensor, out_dtype: IntDType) -> TchTensor {
143 TchTensor::new(tensor.tensor.argwhere().to_kind(out_dtype.into_kind()))
144 }
145
146 fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
147 TchOps::index_select_dim(tensor, dim, indices)
148 }
149
150 fn bool_select_or(
151 tensor: TchTensor,
152 dim: usize,
153 indices: TchTensor,
154 value: TchTensor,
155 ) -> TchTensor {
156 TchOps::select_assign(tensor, dim, indices, value)
157 }
158
159 fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
160 TchOps::expand(tensor, shape)
161 }
162
163 fn bool_unfold(
164 tensor: IntTensor<Self>,
165 dim: usize,
166 size: usize,
167 step: usize,
168 ) -> IntTensor<Self> {
169 TchOps::unfold(tensor, dim, size, step)
170 }
171
172 fn bool_mask_where(
173 tensor: BoolTensor<Self>,
174 mask: BoolTensor<Self>,
175 value: BoolTensor<Self>,
176 ) -> BoolTensor<Self> {
177 TchTensor::binary_ops_tensor(
178 tensor,
179 value,
180 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
181 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
182 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
183 )
184 }
185
186 fn bool_mask_fill(
187 tensor: BoolTensor<Self>,
188 mask: BoolTensor<Self>,
189 value: Scalar,
190 ) -> BoolTensor<Self> {
191 tensor.unary_ops(
192 |mut tensor| {
193 tensor
194 .f_masked_fill_(&mask.tensor, value.elem::<i64>())
195 .unwrap()
196 },
197 |tensor| {
198 tensor
199 .f_masked_fill(&mask.tensor, value.elem::<i64>())
200 .unwrap()
201 },
202 )
203 }
204
205 fn bool_gather(
206 dim: usize,
207 tensor: BoolTensor<Self>,
208 indices: IntTensor<Self>,
209 ) -> BoolTensor<Self> {
210 TchOps::gather(dim, tensor, indices)
211 }
212
213 fn bool_scatter_or(
214 dim: usize,
215 tensor: BoolTensor<Self>,
216 indices: IntTensor<Self>,
217 value: BoolTensor<Self>,
218 ) -> BoolTensor<Self> {
219 TchOps::scatter(dim, tensor, indices, value)
220 }
221
222 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
223 TchOps::equal_elem(lhs, rhs.elem::<i64>())
224 }
225}