1use super::TchOps;
2use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_tensor::ops::IntTensor;
4use burn_tensor::{Shape, TensorData, TensorMetadata, backend::Backend, ops::BoolTensorOps};
5
6impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
7 fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
8 match data.dtype {
9 burn_tensor::DType::Bool => TchTensor::from_data::<bool>(data, (*device).into()),
10 _ => unimplemented!("Unsupported dtype for `bool_from_data`"),
11 }
12 }
13
14 fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
15 TchOps::repeat_dim(tensor, dim, times)
16 }
17
18 async fn bool_into_data(tensor: TchTensor) -> TensorData {
19 let shape = tensor.shape();
20 let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
21 let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
22 TensorData::new(values.unwrap(), shape)
23 }
24
25 fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
26 TchOps::to_device(tensor, device)
27 }
28
29 fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
30 TchOps::reshape(tensor, shape)
31 }
32
33 fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
34 tensor.tensor.device().into()
35 }
36
37 fn bool_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
38 let tensor = tch::Tensor::empty(
39 TchShape::from(shape).dims,
40 (tch::Kind::Bool, (*device).into()),
41 );
42
43 TchTensor::new(tensor)
44 }
45
46 fn bool_zeros(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
47 let tensor = tch::Tensor::zeros(
48 TchShape::from(shape).dims,
49 (tch::Kind::Bool, (*device).into()),
50 );
51
52 TchTensor::new(tensor)
53 }
54
55 fn bool_ones(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
56 let tensor = tch::Tensor::ones(
57 TchShape::from(shape).dims,
58 (tch::Kind::Bool, (*device).into()),
59 );
60
61 TchTensor::new(tensor)
62 }
63
64 fn bool_slice(tensor: TchTensor, slices: &[burn_tensor::Slice]) -> TchTensor {
65 TchOps::slice_with_steps(tensor, slices)
66 }
67
68 fn bool_slice_assign(
69 tensor: TchTensor,
70 slices: &[burn_tensor::Slice],
71 value: TchTensor,
72 ) -> TchTensor {
73 TchOps::slice_assign(tensor, slices, value)
74 }
75
76 fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
77 TchOps::cat(tensors, dim)
78 }
79
80 fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
81 TchOps::equal(lhs, rhs)
82 }
83
84 fn bool_not(tensor: TchTensor) -> TchTensor {
85 tensor.unary_ops(
86 |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
87 |tensor| tensor.eq(0),
88 )
89 }
90
91 fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
92 TchTensor::binary_ops_tensor(
93 lhs,
94 rhs,
95 |lhs, rhs| lhs.logical_and_(rhs),
96 |lhs, rhs| rhs.logical_and_(lhs),
97 |lhs, rhs| lhs.logical_and(rhs),
98 )
99 }
100
101 fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
102 TchTensor::binary_ops_tensor(
103 lhs,
104 rhs,
105 |lhs, rhs| lhs.logical_or_(rhs),
106 |lhs, rhs| rhs.logical_or_(lhs),
107 |lhs, rhs| lhs.logical_or(rhs),
108 )
109 }
110
111 fn bool_into_int(tensor: TchTensor) -> TchTensor {
112 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
113 TchTensor::new(tensor)
114 }
115
116 fn bool_into_float(tensor: TchTensor) -> TchTensor {
117 let tensor = tensor.tensor.to_kind(E::KIND);
118 TchTensor::new(tensor)
119 }
120
121 fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
122 TchOps::swap_dims(tensor, dim1, dim2)
123 }
124
125 fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
126 TchOps::permute(tensor, axes)
127 }
128
129 fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
130 TchOps::flip(tensor, axes)
131 }
132
133 async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
134 TchTensor::new(tensor.tensor.argwhere())
135 }
136
137 fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
138 TchOps::index_select_dim(tensor, dim, indices)
139 }
140
141 fn bool_select_assign(
142 tensor: TchTensor,
143 dim: usize,
144 indices: TchTensor,
145 value: TchTensor,
146 ) -> TchTensor {
147 TchOps::select_assign(tensor, dim, indices, value)
148 }
149
150 fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
151 TchOps::expand(tensor, shape)
152 }
153
154 fn bool_unfold(
155 tensor: IntTensor<Self>,
156 dim: usize,
157 size: usize,
158 step: usize,
159 ) -> IntTensor<Self> {
160 TchOps::unfold(tensor, dim, size, step)
161 }
162}