burn_tch/ops/
bool_tensor.rs1use super::TchOps;
2use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
3use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData, TensorMetadata};
4use std::ops::Range;
5
6impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
7 fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
8 TchTensor::from_data::<bool>(data, (*device).into())
9 }
10
11 fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
12 TchOps::repeat_dim(tensor, dim, times)
13 }
14
15 async fn bool_into_data(tensor: TchTensor) -> TensorData {
16 let shape = tensor.shape();
17 let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
18 let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
19 TensorData::new(values.unwrap(), shape)
20 }
21
22 fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
23 TchOps::to_device(tensor, device)
24 }
25
26 fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
27 TchOps::reshape(tensor, shape)
28 }
29
30 fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
31 tensor.tensor.device().into()
32 }
33
34 fn bool_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
35 let tensor = tch::Tensor::empty(
36 TchShape::from(shape).dims,
37 (tch::Kind::Bool, (*device).into()),
38 );
39
40 TchTensor::new(tensor)
41 }
42
43 fn bool_slice(tensor: TchTensor, ranges: &[Range<usize>]) -> TchTensor {
44 TchOps::slice(tensor, ranges)
45 }
46
47 fn bool_slice_assign(
48 tensor: TchTensor,
49 ranges: &[Range<usize>],
50 value: TchTensor,
51 ) -> TchTensor {
52 TchOps::slice_assign(tensor, ranges, value)
53 }
54
55 fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
56 TchOps::cat(tensors, dim)
57 }
58
59 fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
60 TchOps::equal(lhs, rhs)
61 }
62
63 fn bool_not(tensor: TchTensor) -> TchTensor {
64 tensor.unary_ops(
65 |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
66 |tensor| tensor.eq(0),
67 )
68 }
69
70 fn bool_into_int(tensor: TchTensor) -> TchTensor {
71 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
72 TchTensor::new(tensor)
73 }
74
75 fn bool_into_float(tensor: TchTensor) -> TchTensor {
76 let tensor = tensor.tensor.to_kind(E::KIND);
77 TchTensor::new(tensor)
78 }
79
80 fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
81 TchOps::swap_dims(tensor, dim1, dim2)
82 }
83
84 fn bool_narrow(tensor: TchTensor, dim: usize, start: usize, length: usize) -> TchTensor {
85 TchOps::narrow(tensor, dim, start, length)
86 }
87
88 fn bool_chunk(tensor: TchTensor, chunks: usize, dim: usize) -> Vec<TchTensor> {
89 TchOps::chunk(tensor, chunks, dim)
90 }
91
92 fn bool_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
93 TchOps::split(tensor, split_size, dim)
94 }
95
96 fn bool_split_with_sizes(
97 tensor: TchTensor,
98 split_sizes: Vec<usize>,
99 dim: usize,
100 ) -> Vec<TchTensor> {
101 TchOps::split_with_sizes(tensor, split_sizes, dim)
102 }
103
104 fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
105 TchOps::permute(tensor, axes)
106 }
107
108 fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
109 TchOps::flip(tensor, axes)
110 }
111
112 async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
113 TchTensor::new(tensor.tensor.argwhere())
114 }
115
116 async fn bool_nonzero(tensor: TchTensor) -> Vec<TchTensor> {
117 tensor
118 .tensor
119 .nonzero_numpy()
120 .into_iter()
121 .filter_map(|t| if t.numel() > 0 { Some(t) } else { None })
123 .map(TchTensor::new)
124 .collect()
125 }
126
127 fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
128 TchOps::expand(tensor, shape)
129 }
130}