1use crate::{LibTorchDevice, TchElement};
2use burn_tensor::{DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};
3use libc::c_void;
4use std::sync::Arc;
5
6#[allow(clippy::arc_with_non_send_sync)]
10pub type StorageRef = Arc<*mut c_void>;
11
12#[derive(PartialEq, Debug, Clone)]
14pub enum Storage {
15 View {
17 buffer_ref: StorageRef,
19 view_ref: StorageRef,
21 },
22 Owned {
24 buffer_ref: StorageRef,
26 },
27}
28
29impl Storage {
30 pub fn can_mut(&self) -> bool {
32 match self {
33 Storage::View {
34 buffer_ref: start_ref,
35 view_ref,
36 } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
37 Storage::Owned {
38 buffer_ref: start_ref,
39 } => Arc::strong_count(start_ref) == 1,
40 }
41 }
42
43 pub fn buffer_ref(&self) -> &StorageRef {
45 match self {
46 Storage::View {
47 buffer_ref: start_ref,
48 view_ref: _,
49 } => start_ref,
50 Storage::Owned {
51 buffer_ref: start_ref,
52 } => start_ref,
53 }
54 }
55}
56
57#[derive(Debug, PartialEq)]
59pub struct TchTensor {
60 pub tensor: tch::Tensor,
62
63 pub storage: Storage,
65}
66
67impl TensorMetadata for TchTensor {
68 fn dtype(&self) -> DType {
69 match self.tensor.kind() {
70 tch::Kind::Uint8 => DType::U8,
71 tch::Kind::Int8 => DType::I8,
72 tch::Kind::Int16 => DType::I16,
73 tch::Kind::Int => DType::I32,
74 tch::Kind::Int64 => DType::I64,
75 tch::Kind::Half => DType::F16,
76 tch::Kind::Float => DType::F32,
77 tch::Kind::Double => DType::F64,
78 tch::Kind::Bool => DType::Bool,
79 tch::Kind::BFloat16 => DType::BF16,
80 _ => unimplemented!(),
82 }
83 }
84
85 fn shape(&self) -> Shape {
86 Shape::from(self.tensor.size())
87 }
88
89 fn rank(&self) -> usize {
90 self.tensor.dim()
91 }
92}
93
94impl burn_tensor::quantization::QTensorPrimitive for TchTensor {
95 fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
96 unimplemented!("Quantization is not supported")
97 }
98}
99
100pub(crate) trait IntoKind {
101 fn into_kind(self) -> tch::Kind;
102}
103
104impl IntoKind for FloatDType {
105 fn into_kind(self) -> tch::Kind {
106 match self {
107 FloatDType::F64 => tch::Kind::Double,
108 FloatDType::F32 => tch::Kind::Float,
109 FloatDType::Flex32 => tch::Kind::Float,
110 FloatDType::F16 => tch::Kind::Half,
111 FloatDType::BF16 => tch::Kind::BFloat16,
112 }
113 }
114}
115
116impl IntoKind for IntDType {
117 fn into_kind(self) -> tch::Kind {
118 match self {
119 IntDType::I64 => tch::Kind::Int64,
120 IntDType::I32 => tch::Kind::Int,
121 IntDType::I16 => tch::Kind::Int16,
122 IntDType::I8 => tch::Kind::Int8,
123 IntDType::U64 => tch::Kind::Uint8,
124 other => panic!("Unsupported dtype {other:?}"),
125 }
126 }
127}
128
129impl TchTensor {
130 pub fn new(tensor: tch::Tensor) -> Self {
136 #[allow(clippy::arc_with_non_send_sync)]
137 let storage = Storage::Owned {
138 buffer_ref: Arc::new(tensor.data_ptr()),
139 };
140
141 Self { tensor, storage }
142 }
143
144 pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
149 let storage_child = tensor.data_ptr();
150 let mut is_a_new_tensor = true;
151
152 match &storage_parent {
153 Storage::View {
154 buffer_ref: start_ref,
155 view_ref,
156 } => {
157 if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
158 is_a_new_tensor = false;
159 }
160 }
161 Storage::Owned {
162 buffer_ref: start_ref,
163 } => {
164 if storage_child == *start_ref.as_ref() {
165 is_a_new_tensor = false;
166 }
167 }
168 };
169
170 let storage = match is_a_new_tensor {
171 true => Storage::Owned {
172 #[allow(clippy::arc_with_non_send_sync)]
173 buffer_ref: Arc::new(storage_child),
174 },
175 false => storage_parent.clone(),
176 };
177
178 Self { tensor, storage }
179 }
180
181 pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
183 let storage = Storage::View {
184 buffer_ref: storage_parent.buffer_ref().clone(),
185 #[allow(clippy::arc_with_non_send_sync)]
186 view_ref: Arc::new(tensor.data_ptr()),
187 };
188 Self { tensor, storage }
189 }
190}
191
192unsafe impl Send for TchTensor {}
196unsafe impl Sync for TchTensor {}
197
198impl TchTensor {
199 pub fn can_mut(&self) -> bool {
204 let stride_contains_zero = self.tensor.stride().contains(&0);
205
206 !stride_contains_zero && self.storage.can_mut()
207 }
208
209 pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
211 &mut self,
212 func: F,
213 ) -> Option<TchTensor> {
214 if !self.can_mut() {
215 return None;
216 }
217
218 let data = self.storage.clone();
219 Some(TchTensor::from_existing(func(&mut self.tensor), data))
220 }
221
222 pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
224 where
225 FOwn: Fn(tch::Tensor) -> tch::Tensor,
226 FRef: Fn(&tch::Tensor) -> tch::Tensor,
227 {
228 if !self.can_mut() {
229 return TchTensor::from_existing(fref(&self.tensor), self.storage);
230 }
231
232 TchTensor::from_existing(fown(self.tensor), self.storage)
233 }
234
235 pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
237 mut lhs: Self,
238 mut rhs: Self,
239 flmut: FLMut,
240 frmut: FRMut,
241 fref: FRef,
242 ) -> TchTensor
243 where
244 FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
245 FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
246 FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
247 {
248 let lhs_shape = lhs.shape();
249 let rhs_shape = rhs.shape();
250
251 let d_out = lhs_shape.num_dims();
253 let mut out_shape = Shape::from(vec![1usize; d_out]);
254
255 for i in 0..d_out {
256 out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);
257 }
258
259 let num_elements_out = out_shape.num_elements();
260
261 if lhs_shape.num_elements() == num_elements_out
263 && let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))
264 {
265 return output;
266 }
267
268 if rhs_shape.num_elements() == num_elements_out
270 && let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))
271 {
272 return output;
273 }
274
275 let storage = lhs.storage;
276 let tensor = fref(&lhs.tensor, &rhs.tensor);
277
278 TchTensor::from_existing(tensor, storage)
279 }
280}
281
282impl Clone for TchTensor {
283 fn clone(&self) -> Self {
284 Self {
285 tensor: self.tensor.shallow_clone(),
286 storage: self.storage.clone(),
287 }
288 }
289}
290
291#[derive(Debug)]
293pub struct TchShape {
294 pub dims: Vec<i64>,
296}
297
298impl From<Shape> for TchShape {
299 fn from(shape: Shape) -> Self {
300 TchShape {
301 dims: shape.into_iter().map(|d| d as i64).collect(),
302 }
303 }
304}
305
306impl From<&[usize]> for TchShape {
307 fn from(shape: &[usize]) -> Self {
308 TchShape {
309 dims: shape.iter().map(|d| *d as i64).collect(),
310 }
311 }
312}
313
314impl TchTensor {
315 pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
326 let shape_tch = TchShape::from(data.shape.as_slice());
327 let tensor = tch::Tensor::from_slice(data.as_slice::<E>().unwrap()).to(device);
328 let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
329
330 Self::new(tensor)
331 }
332}
333
334impl TchTensor {
335 pub fn empty<E: tch::kind::Element>(shape: Shape, device: LibTorchDevice) -> Self {
346 let shape_tch = TchShape::from(shape);
347 let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
348
349 Self::new(tensor)
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use crate::LibTorch;
356
357 use super::*;
358 use burn_tensor::{Distribution, Tensor, TensorPrimitive};
359 use rand::SeedableRng;
360 use rand::prelude::StdRng;
361
362 #[test]
363 fn should_support_into_and_from_data_1d() {
364 let data_expected = TensorData::random::<f32, _, _>(
365 Shape::new([3]),
366 Distribution::Default,
367 &mut StdRng::from_os_rng(),
368 );
369 let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
370
371 let data_actual =
372 Tensor::<LibTorch<f32>, 1>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
373
374 assert_eq!(data_expected, data_actual);
375 }
376
377 #[test]
378 fn should_support_into_and_from_data_2d() {
379 let data_expected = TensorData::random::<f32, _, _>(
380 Shape::new([2, 3]),
381 Distribution::Default,
382 &mut StdRng::from_os_rng(),
383 );
384 let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
385
386 let data_actual =
387 Tensor::<LibTorch<f32>, 2>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
388
389 assert_eq!(data_expected, data_actual);
390 }
391
392 #[test]
393 fn should_not_update_inplace_after_reshape() {
394 let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
395 let tensor_2 = tensor_1.clone();
396
397 let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0);
398
399 assert_ne!(
400 tensor_3.to_data().as_slice::<f32>().unwrap(),
401 tensor_1.to_data().as_slice::<f32>().unwrap()
402 );
403 }
404
405 #[test]
406 fn should_not_update_inplace_after_slice() {
407 let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
408 let tensor_2 = tensor_1.clone();
409
410 let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
411
412 assert_ne!(
413 tensor_3.to_data().as_slice::<f32>().unwrap(),
414 tensor_1.to_data().as_slice::<f32>().unwrap()
415 );
416 }
417}
418
419unsafe extern "C" {
420 pub fn dummy_cuda_dependency();
422}
423
424#[used]
425static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];