burn_tch/ops/
bool_tensor.rs

1use super::TchOps;
2use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_tensor::ops::IntTensor;
4use burn_tensor::{Shape, TensorData, TensorMetadata, backend::Backend, ops::BoolTensorOps};
5
6impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
7    fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
8        match data.dtype {
9            burn_tensor::DType::Bool => TchTensor::from_data::<bool>(data, (*device).into()),
10            _ => unimplemented!("Unsupported dtype for `bool_from_data`"),
11        }
12    }
13
14    fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
15        TchOps::repeat_dim(tensor, dim, times)
16    }
17
18    async fn bool_into_data(tensor: TchTensor) -> TensorData {
19        let shape = tensor.shape();
20        let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
21        let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
22        TensorData::new(values.unwrap(), shape)
23    }
24
25    fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
26        TchOps::to_device(tensor, device)
27    }
28
29    fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
30        TchOps::reshape(tensor, shape)
31    }
32
33    fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
34        tensor.tensor.device().into()
35    }
36
37    fn bool_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
38        let tensor = tch::Tensor::empty(
39            TchShape::from(shape).dims,
40            (tch::Kind::Bool, (*device).into()),
41        );
42
43        TchTensor::new(tensor)
44    }
45
46    fn bool_zeros(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
47        let tensor = tch::Tensor::zeros(
48            TchShape::from(shape).dims,
49            (tch::Kind::Bool, (*device).into()),
50        );
51
52        TchTensor::new(tensor)
53    }
54
55    fn bool_ones(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
56        let tensor = tch::Tensor::ones(
57            TchShape::from(shape).dims,
58            (tch::Kind::Bool, (*device).into()),
59        );
60
61        TchTensor::new(tensor)
62    }
63
64    fn bool_slice(tensor: TchTensor, slices: &[burn_tensor::Slice]) -> TchTensor {
65        TchOps::slice_with_steps(tensor, slices)
66    }
67
68    fn bool_slice_assign(
69        tensor: TchTensor,
70        slices: &[burn_tensor::Slice],
71        value: TchTensor,
72    ) -> TchTensor {
73        TchOps::slice_assign(tensor, slices, value)
74    }
75
76    fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
77        TchOps::cat(tensors, dim)
78    }
79
80    fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
81        TchOps::equal(lhs, rhs)
82    }
83
84    fn bool_not(tensor: TchTensor) -> TchTensor {
85        tensor.unary_ops(
86            |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
87            |tensor| tensor.eq(0),
88        )
89    }
90
91    fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
92        TchTensor::binary_ops_tensor(
93            lhs,
94            rhs,
95            |lhs, rhs| lhs.logical_and_(rhs),
96            |lhs, rhs| rhs.logical_and_(lhs),
97            |lhs, rhs| lhs.logical_and(rhs),
98        )
99    }
100
101    fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
102        TchTensor::binary_ops_tensor(
103            lhs,
104            rhs,
105            |lhs, rhs| lhs.logical_or_(rhs),
106            |lhs, rhs| rhs.logical_or_(lhs),
107            |lhs, rhs| lhs.logical_or(rhs),
108        )
109    }
110
111    fn bool_into_int(tensor: TchTensor) -> TchTensor {
112        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
113        TchTensor::new(tensor)
114    }
115
116    fn bool_into_float(tensor: TchTensor) -> TchTensor {
117        let tensor = tensor.tensor.to_kind(E::KIND);
118        TchTensor::new(tensor)
119    }
120
121    fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
122        TchOps::swap_dims(tensor, dim1, dim2)
123    }
124
125    fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
126        TchOps::permute(tensor, axes)
127    }
128
129    fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
130        TchOps::flip(tensor, axes)
131    }
132
133    async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
134        TchTensor::new(tensor.tensor.argwhere())
135    }
136
137    fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
138        TchOps::index_select_dim(tensor, dim, indices)
139    }
140
141    fn bool_select_assign(
142        tensor: TchTensor,
143        dim: usize,
144        indices: TchTensor,
145        value: TchTensor,
146    ) -> TchTensor {
147        TchOps::select_assign(tensor, dim, indices, value)
148    }
149
150    fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
151        TchOps::expand(tensor, shape)
152    }
153
154    fn bool_unfold(
155        tensor: IntTensor<Self>,
156        dim: usize,
157        size: usize,
158        step: usize,
159    ) -> IntTensor<Self> {
160        TchOps::unfold(tensor, dim, size, step)
161    }
162}