1use super::TchOps;
2use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_backend::ElementConversion;
4use burn_backend::ExecutionError;
5use burn_backend::tensor::BoolElem;
6use burn_backend::tensor::BoolTensor;
7use burn_backend::tensor::IntTensor;
8use burn_backend::{Backend, Shape, TensorData, TensorMetadata, ops::BoolTensorOps};
9
10impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
11 fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
12 match data.dtype {
13 burn_backend::DType::Bool => TchTensor::from_data::<bool>(data, (*device).into()),
14 _ => unimplemented!("Unsupported dtype for `bool_from_data`"),
15 }
16 }
17
18 fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
19 TchOps::repeat_dim(tensor, dim, times)
20 }
21
22 async fn bool_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
23 let shape = tensor.shape();
24 let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
25 let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
26 Ok(TensorData::new(values.unwrap(), shape))
27 }
28
29 fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
30 TchOps::to_device(tensor, device)
31 }
32
33 fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
34 TchOps::reshape(tensor, shape)
35 }
36
37 fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
38 tensor.tensor.device().into()
39 }
40
41 fn bool_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
42 let tensor = tch::Tensor::empty(
43 TchShape::from(shape).dims,
44 (tch::Kind::Bool, (*device).into()),
45 );
46
47 TchTensor::new(tensor)
48 }
49
50 fn bool_zeros(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
51 let tensor = tch::Tensor::zeros(
52 TchShape::from(shape).dims,
53 (tch::Kind::Bool, (*device).into()),
54 );
55
56 TchTensor::new(tensor)
57 }
58
59 fn bool_ones(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
60 let tensor = tch::Tensor::ones(
61 TchShape::from(shape).dims,
62 (tch::Kind::Bool, (*device).into()),
63 );
64
65 TchTensor::new(tensor)
66 }
67
68 fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
69 TchOps::slice_with_steps(tensor, slices)
70 }
71
72 fn bool_slice_assign(
73 tensor: TchTensor,
74 slices: &[burn_backend::Slice],
75 value: TchTensor,
76 ) -> TchTensor {
77 TchOps::slice_assign(tensor, slices, value)
78 }
79
80 fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
81 TchOps::cat(tensors, dim)
82 }
83
84 fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
85 TchOps::equal(lhs, rhs)
86 }
87
88 fn bool_not(tensor: TchTensor) -> TchTensor {
89 tensor.unary_ops(
90 |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
91 |tensor| tensor.eq(0),
92 )
93 }
94
95 fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
96 TchTensor::binary_ops_tensor(
97 lhs,
98 rhs,
99 |lhs, rhs| lhs.logical_and_(rhs),
100 |lhs, rhs| rhs.logical_and_(lhs),
101 |lhs, rhs| lhs.logical_and(rhs),
102 )
103 }
104
105 fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
106 TchTensor::binary_ops_tensor(
107 lhs,
108 rhs,
109 |lhs, rhs| lhs.logical_or_(rhs),
110 |lhs, rhs| rhs.logical_or_(lhs),
111 |lhs, rhs| lhs.logical_or(rhs),
112 )
113 }
114
115 fn bool_into_int(tensor: TchTensor) -> TchTensor {
116 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
117 TchTensor::new(tensor)
118 }
119
120 fn bool_into_float(tensor: TchTensor) -> TchTensor {
121 let tensor = tensor.tensor.to_kind(E::kind());
122 TchTensor::new(tensor)
123 }
124
125 fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
126 TchOps::swap_dims(tensor, dim1, dim2)
127 }
128
129 fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
130 TchOps::permute(tensor, axes)
131 }
132
133 fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
134 TchOps::flip(tensor, axes)
135 }
136
137 async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
138 TchTensor::new(tensor.tensor.argwhere())
139 }
140
141 fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
142 TchOps::index_select_dim(tensor, dim, indices)
143 }
144
145 fn bool_select_or(
146 tensor: TchTensor,
147 dim: usize,
148 indices: TchTensor,
149 value: TchTensor,
150 ) -> TchTensor {
151 TchOps::select_assign(tensor, dim, indices, value)
152 }
153
154 fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
155 TchOps::expand(tensor, shape)
156 }
157
158 fn bool_unfold(
159 tensor: IntTensor<Self>,
160 dim: usize,
161 size: usize,
162 step: usize,
163 ) -> IntTensor<Self> {
164 TchOps::unfold(tensor, dim, size, step)
165 }
166
167 fn bool_mask_where(
168 tensor: BoolTensor<Self>,
169 mask: BoolTensor<Self>,
170 value: BoolTensor<Self>,
171 ) -> BoolTensor<Self> {
172 TchTensor::binary_ops_tensor(
173 tensor,
174 value,
175 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
176 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
177 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
178 )
179 }
180
181 fn bool_mask_fill(
182 tensor: BoolTensor<Self>,
183 mask: BoolTensor<Self>,
184 value: BoolElem<Self>,
185 ) -> BoolTensor<Self> {
186 tensor.unary_ops(
187 |mut tensor| {
188 tensor
189 .f_masked_fill_(&mask.tensor, value.elem::<i64>())
190 .unwrap()
191 },
192 |tensor| {
193 tensor
194 .f_masked_fill(&mask.tensor, value.elem::<i64>())
195 .unwrap()
196 },
197 )
198 }
199
200 fn bool_gather(
201 dim: usize,
202 tensor: BoolTensor<Self>,
203 indices: IntTensor<Self>,
204 ) -> BoolTensor<Self> {
205 TchOps::gather(dim, tensor, indices)
206 }
207
208 fn bool_scatter_or(
209 dim: usize,
210 tensor: BoolTensor<Self>,
211 indices: IntTensor<Self>,
212 value: BoolTensor<Self>,
213 ) -> BoolTensor<Self> {
214 TchOps::scatter(dim, tensor, indices, value)
215 }
216
217 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
218 TchOps::equal_elem(lhs, rhs.elem::<i64>())
219 }
220}