burn_tch/ops/
bool_tensor.rs

1use super::TchOps;
2use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_backend::ElementConversion;
4use burn_backend::ExecutionError;
5use burn_backend::tensor::BoolElem;
6use burn_backend::tensor::BoolTensor;
7use burn_backend::tensor::IntTensor;
8use burn_backend::{Backend, Shape, TensorData, TensorMetadata, ops::BoolTensorOps};
9
10impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
11    fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
12        match data.dtype {
13            burn_backend::DType::Bool => TchTensor::from_data::<bool>(data, (*device).into()),
14            _ => unimplemented!("Unsupported dtype for `bool_from_data`"),
15        }
16    }
17
18    fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
19        TchOps::repeat_dim(tensor, dim, times)
20    }
21
22    async fn bool_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
23        let shape = tensor.shape();
24        let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
25        let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
26        Ok(TensorData::new(values.unwrap(), shape))
27    }
28
29    fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
30        TchOps::to_device(tensor, device)
31    }
32
33    fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
34        TchOps::reshape(tensor, shape)
35    }
36
37    fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
38        tensor.tensor.device().into()
39    }
40
41    fn bool_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
42        let tensor = tch::Tensor::empty(
43            TchShape::from(shape).dims,
44            (tch::Kind::Bool, (*device).into()),
45        );
46
47        TchTensor::new(tensor)
48    }
49
50    fn bool_zeros(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
51        let tensor = tch::Tensor::zeros(
52            TchShape::from(shape).dims,
53            (tch::Kind::Bool, (*device).into()),
54        );
55
56        TchTensor::new(tensor)
57    }
58
59    fn bool_ones(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
60        let tensor = tch::Tensor::ones(
61            TchShape::from(shape).dims,
62            (tch::Kind::Bool, (*device).into()),
63        );
64
65        TchTensor::new(tensor)
66    }
67
68    fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
69        TchOps::slice_with_steps(tensor, slices)
70    }
71
72    fn bool_slice_assign(
73        tensor: TchTensor,
74        slices: &[burn_backend::Slice],
75        value: TchTensor,
76    ) -> TchTensor {
77        TchOps::slice_assign(tensor, slices, value)
78    }
79
80    fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
81        TchOps::cat(tensors, dim)
82    }
83
84    fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
85        TchOps::equal(lhs, rhs)
86    }
87
88    fn bool_not(tensor: TchTensor) -> TchTensor {
89        tensor.unary_ops(
90            |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
91            |tensor| tensor.eq(0),
92        )
93    }
94
95    fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
96        TchTensor::binary_ops_tensor(
97            lhs,
98            rhs,
99            |lhs, rhs| lhs.logical_and_(rhs),
100            |lhs, rhs| rhs.logical_and_(lhs),
101            |lhs, rhs| lhs.logical_and(rhs),
102        )
103    }
104
105    fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
106        TchTensor::binary_ops_tensor(
107            lhs,
108            rhs,
109            |lhs, rhs| lhs.logical_or_(rhs),
110            |lhs, rhs| rhs.logical_or_(lhs),
111            |lhs, rhs| lhs.logical_or(rhs),
112        )
113    }
114
115    fn bool_into_int(tensor: TchTensor) -> TchTensor {
116        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
117        TchTensor::new(tensor)
118    }
119
120    fn bool_into_float(tensor: TchTensor) -> TchTensor {
121        let tensor = tensor.tensor.to_kind(E::kind());
122        TchTensor::new(tensor)
123    }
124
125    fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
126        TchOps::swap_dims(tensor, dim1, dim2)
127    }
128
129    fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
130        TchOps::permute(tensor, axes)
131    }
132
133    fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
134        TchOps::flip(tensor, axes)
135    }
136
137    async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
138        TchTensor::new(tensor.tensor.argwhere())
139    }
140
141    fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
142        TchOps::index_select_dim(tensor, dim, indices)
143    }
144
145    fn bool_select_or(
146        tensor: TchTensor,
147        dim: usize,
148        indices: TchTensor,
149        value: TchTensor,
150    ) -> TchTensor {
151        TchOps::select_assign(tensor, dim, indices, value)
152    }
153
154    fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
155        TchOps::expand(tensor, shape)
156    }
157
158    fn bool_unfold(
159        tensor: IntTensor<Self>,
160        dim: usize,
161        size: usize,
162        step: usize,
163    ) -> IntTensor<Self> {
164        TchOps::unfold(tensor, dim, size, step)
165    }
166
167    fn bool_mask_where(
168        tensor: BoolTensor<Self>,
169        mask: BoolTensor<Self>,
170        value: BoolTensor<Self>,
171    ) -> BoolTensor<Self> {
172        TchTensor::binary_ops_tensor(
173            tensor,
174            value,
175            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
176            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
177            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
178        )
179    }
180
181    fn bool_mask_fill(
182        tensor: BoolTensor<Self>,
183        mask: BoolTensor<Self>,
184        value: BoolElem<Self>,
185    ) -> BoolTensor<Self> {
186        tensor.unary_ops(
187            |mut tensor| {
188                tensor
189                    .f_masked_fill_(&mask.tensor, value.elem::<i64>())
190                    .unwrap()
191            },
192            |tensor| {
193                tensor
194                    .f_masked_fill(&mask.tensor, value.elem::<i64>())
195                    .unwrap()
196            },
197        )
198    }
199
200    fn bool_gather(
201        dim: usize,
202        tensor: BoolTensor<Self>,
203        indices: IntTensor<Self>,
204    ) -> BoolTensor<Self> {
205        TchOps::gather(dim, tensor, indices)
206    }
207
208    fn bool_scatter_or(
209        dim: usize,
210        tensor: BoolTensor<Self>,
211        indices: IntTensor<Self>,
212        value: BoolTensor<Self>,
213    ) -> BoolTensor<Self> {
214        TchOps::scatter(dim, tensor, indices, value)
215    }
216
217    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: BoolElem<Self>) -> BoolTensor<Self> {
218        TchOps::equal_elem(lhs, rhs.elem::<i64>())
219    }
220}