1use crate::{LibTorchDevice, TchElement};
2use burn_backend::{DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};
3use libc::c_void;
4use std::sync::Arc;
5
6#[allow(clippy::arc_with_non_send_sync)]
10pub type StorageRef = Arc<*mut c_void>;
11
12#[derive(PartialEq, Debug, Clone)]
14pub enum Storage {
15 View {
17 buffer_ref: StorageRef,
19 view_ref: StorageRef,
21 },
22 Owned {
24 buffer_ref: StorageRef,
26 },
27}
28
29impl Storage {
30 pub fn can_mut(&self) -> bool {
32 match self {
33 Storage::View {
34 buffer_ref: start_ref,
35 view_ref,
36 } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
37 Storage::Owned {
38 buffer_ref: start_ref,
39 } => Arc::strong_count(start_ref) == 1,
40 }
41 }
42
43 pub fn buffer_ref(&self) -> &StorageRef {
45 match self {
46 Storage::View {
47 buffer_ref: start_ref,
48 view_ref: _,
49 } => start_ref,
50 Storage::Owned {
51 buffer_ref: start_ref,
52 } => start_ref,
53 }
54 }
55}
56
57#[derive(Debug, PartialEq)]
59pub struct TchTensor {
60 pub tensor: tch::Tensor,
62
63 pub storage: Storage,
65}
66
67impl TensorMetadata for TchTensor {
68 fn dtype(&self) -> DType {
69 match self.tensor.kind() {
70 tch::Kind::Uint8 => DType::U8,
71 tch::Kind::Int8 => DType::I8,
72 tch::Kind::Int16 => DType::I16,
73 tch::Kind::Int => DType::I32,
74 tch::Kind::Int64 => DType::I64,
75 tch::Kind::Half => DType::F16,
76 tch::Kind::Float => DType::F32,
77 tch::Kind::Double => DType::F64,
78 tch::Kind::Bool => DType::Bool,
79 tch::Kind::BFloat16 => DType::BF16,
80 _ => unimplemented!(),
82 }
83 }
84
85 fn shape(&self) -> Shape {
86 Shape::from(self.tensor.size())
87 }
88
89 fn rank(&self) -> usize {
90 self.tensor.dim()
91 }
92}
93
94impl burn_backend::QTensorPrimitive for TchTensor {
95 fn scheme(&self) -> &burn_backend::quantization::QuantScheme {
96 unimplemented!("Quantization is not supported")
97 }
98}
99
100impl core::fmt::Display for TchTensor {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 write!(f, "{}", self.tensor)
103 }
104}
105
106pub(crate) trait IntoKind {
107 fn try_into_kind(self) -> Result<tch::Kind, tch::TchError>;
108 fn into_kind(self) -> tch::Kind
109 where
110 Self: Sized,
111 {
112 self.try_into_kind().unwrap()
113 }
114}
115
116impl IntoKind for IntDType {
117 fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
118 let dtype: DType = self.into();
119 dtype.try_into_kind()
120 }
121}
122
123impl IntoKind for FloatDType {
124 fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
125 let dtype: DType = self.into();
126 dtype.try_into_kind()
127 }
128}
129
130impl IntoKind for DType {
131 fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
132 match self {
133 DType::F64 => Ok(tch::Kind::Double),
134 DType::F32 => Ok(tch::Kind::Float),
135 DType::Flex32 => Ok(tch::Kind::Float),
136 DType::F16 => Ok(tch::Kind::Half),
137 DType::BF16 => Ok(tch::Kind::BFloat16),
138 DType::I64 => Ok(tch::Kind::Int64),
139 DType::I32 => Ok(tch::Kind::Int),
140 DType::I16 => Ok(tch::Kind::Int16),
141 DType::I8 => Ok(tch::Kind::Int8),
142 DType::U8 => Ok(tch::Kind::Uint8),
143 DType::Bool => Ok(tch::Kind::Bool),
144 other => Err(tch::TchError::Kind(format!("Unsupported dtype {other:?}"))),
145 }
146 }
147}
148
149impl TchTensor {
150 pub fn new(tensor: tch::Tensor) -> Self {
156 #[allow(clippy::arc_with_non_send_sync)]
157 let storage = Storage::Owned {
158 buffer_ref: Arc::new(tensor.data_ptr()),
159 };
160
161 Self { tensor, storage }
162 }
163
164 pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
169 let storage_child = tensor.data_ptr();
170 let mut is_a_new_tensor = true;
171
172 match &storage_parent {
173 Storage::View {
174 buffer_ref: start_ref,
175 view_ref,
176 } => {
177 if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
178 is_a_new_tensor = false;
179 }
180 }
181 Storage::Owned {
182 buffer_ref: start_ref,
183 } => {
184 if storage_child == *start_ref.as_ref() {
185 is_a_new_tensor = false;
186 }
187 }
188 };
189
190 let storage = match is_a_new_tensor {
191 true => Storage::Owned {
192 #[allow(clippy::arc_with_non_send_sync)]
193 buffer_ref: Arc::new(storage_child),
194 },
195 false => storage_parent.clone(),
196 };
197
198 Self { tensor, storage }
199 }
200
201 pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
203 let storage = Storage::View {
204 buffer_ref: storage_parent.buffer_ref().clone(),
205 #[allow(clippy::arc_with_non_send_sync)]
206 view_ref: Arc::new(tensor.data_ptr()),
207 };
208 Self { tensor, storage }
209 }
210}
211
212unsafe impl Send for TchTensor {}
216unsafe impl Sync for TchTensor {}
217
218impl TchTensor {
219 pub fn can_mut(&self) -> bool {
224 let stride_contains_zero = self.tensor.stride().contains(&0);
225
226 !stride_contains_zero && self.storage.can_mut()
227 }
228
229 pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
231 &mut self,
232 func: F,
233 ) -> Option<TchTensor> {
234 if !self.can_mut() {
235 return None;
236 }
237
238 let data = self.storage.clone();
239 Some(TchTensor::from_existing(func(&mut self.tensor), data))
240 }
241
242 pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
244 where
245 FOwn: Fn(tch::Tensor) -> tch::Tensor,
246 FRef: Fn(&tch::Tensor) -> tch::Tensor,
247 {
248 if !self.can_mut() {
249 return TchTensor::from_existing(fref(&self.tensor), self.storage);
250 }
251
252 TchTensor::from_existing(fown(self.tensor), self.storage)
253 }
254
255 pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
257 mut lhs: Self,
258 mut rhs: Self,
259 flmut: FLMut,
260 frmut: FRMut,
261 fref: FRef,
262 ) -> TchTensor
263 where
264 FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
265 FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
266 FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
267 {
268 let lhs_shape = lhs.shape();
269 let rhs_shape = rhs.shape();
270
271 let d_out = lhs_shape.num_dims();
273 let mut out_shape = Shape::from(vec![1usize; d_out]);
274
275 for i in 0..d_out {
276 out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);
277 }
278
279 let num_elements_out = out_shape.num_elements();
280
281 if lhs_shape.num_elements() == num_elements_out
283 && let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))
284 {
285 return output;
286 }
287
288 if rhs_shape.num_elements() == num_elements_out
290 && let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))
291 {
292 return output;
293 }
294
295 let storage = lhs.storage;
296 let tensor = fref(&lhs.tensor, &rhs.tensor);
297
298 TchTensor::from_existing(tensor, storage)
299 }
300}
301
302impl Clone for TchTensor {
303 fn clone(&self) -> Self {
304 Self {
305 tensor: self.tensor.shallow_clone(),
306 storage: self.storage.clone(),
307 }
308 }
309}
310
311#[derive(Debug)]
313pub struct TchShape {
314 pub dims: Vec<i64>,
316}
317
318impl From<Shape> for TchShape {
319 fn from(shape: Shape) -> Self {
320 TchShape {
321 dims: shape.dims.into_iter().map(|d| d as i64).collect(),
322 }
323 }
324}
325
326impl From<&[usize]> for TchShape {
327 fn from(shape: &[usize]) -> Self {
328 TchShape {
329 dims: shape.iter().map(|d| *d as i64).collect(),
330 }
331 }
332}
333
334impl TchTensor {
335 pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
346 let shape_tch = TchShape::from(data.shape.as_slice());
347 let tensor =
348 tch::Tensor::from_data_size(&data.bytes, &shape_tch.dims, E::kind()).to(device);
349
350 Self::new(tensor)
351 }
352}
353
354impl TchTensor {
355 pub fn empty<E: TchElement>(shape: Shape, device: LibTorchDevice) -> Self {
366 let shape_tch = TchShape::from(shape);
367 let tensor = tch::Tensor::empty(shape_tch.dims, (E::kind(), device.into()));
368
369 Self::new(tensor)
370 }
371}
372
373impl<T: TchElement + Copy> TryFrom<&TchTensor> for Vec<T> {
376 type Error = tch::TchError;
377 fn try_from(tensor: &TchTensor) -> Result<Self, Self::Error> {
378 let tensor = &tensor.tensor;
379 let size = tensor.size();
380 if size.len() != 1 {
381 Err(tch::TchError::Convert(format!(
382 "Attempting to convert a Tensor with {} dimensions to flat vector",
383 size.len()
384 )))?;
385 }
386 let numel = size[0] as usize;
387 let mut vec = vec![T::ZERO; numel];
388 f_copy_data(&mut tensor.f_to_kind(T::kind())?, &mut vec, numel)?;
391 Ok(vec)
392 }
393}
394
395unsafe fn ptr_to_string(ptr: *mut libc::c_char) -> Option<String> {
396 if !ptr.is_null() {
397 unsafe {
398 let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
399 libc::free(ptr as *mut libc::c_void);
400 Some(str)
401 }
402 } else {
403 None
404 }
405}
406
407fn f_copy_data<T: TchElement>(
409 tensor: &mut tch::Tensor,
410 dst: &mut [T],
411 numel: usize,
412) -> Result<(), tch::TchError> {
413 if T::kind() != tensor.f_kind()? {
414 return Err(tch::TchError::Kind(format!(
415 "incoherent elt kind, {:?} != {:?}",
416 tensor.f_kind(),
417 T::kind()
418 )));
419 }
420 if dst.len() < numel {
421 return Err(tch::TchError::Shape(format!("slice len < {numel}")));
422 }
423
424 unsafe {
425 torch_sys::at_copy_data(
426 tensor.as_mut_ptr(),
427 dst.as_mut_ptr() as *const c_void,
428 numel,
429 T::kind().elt_size_in_bytes(),
430 );
431 match ptr_to_string(torch_sys::get_and_reset_last_err()) {
432 None => Ok(()),
433 Some(c_error) => Err(tch::TchError::Torch(c_error)),
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use burn_backend::ops::FloatTensorOps;
442 use burn_backend::{Backend, quantization::QuantScheme, read_sync};
443
444 type B = crate::LibTorch<f32>;
445
446 #[test]
447 fn should_have_bf16_kind() {
448 let data = TensorData::from([4.0, 4.0]);
449 let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
450 let tensor_2 = B::float_cast(tensor_1, DType::BF16.into());
451
452 assert_eq!(tensor_2.tensor.kind(), tch::Kind::BFloat16);
453
454 let out = read_sync(B::float_into_data(tensor_2)).unwrap();
455
456 out.assert_eq(&TensorData::from([4.0, 4.0]), false);
457 }
458
459 #[test]
460 fn should_support_dtypes() {
461 let device = Default::default();
462
463 assert!(B::supports_dtype(&device, DType::F64));
464 assert!(B::supports_dtype(&device, DType::F32));
465 assert!(B::supports_dtype(&device, DType::Flex32));
466 assert!(B::supports_dtype(&device, DType::F16));
467 assert!(B::supports_dtype(&device, DType::BF16));
468 assert!(B::supports_dtype(&device, DType::I64));
469 assert!(B::supports_dtype(&device, DType::I32));
470 assert!(B::supports_dtype(&device, DType::I16));
471 assert!(B::supports_dtype(&device, DType::I8));
472 assert!(B::supports_dtype(&device, DType::U8));
473 assert!(B::supports_dtype(&device, DType::Bool));
474
475 assert!(!B::supports_dtype(&device, DType::U64));
476 assert!(!B::supports_dtype(&device, DType::U32));
477 assert!(!B::supports_dtype(&device, DType::U16));
478 assert!(!B::supports_dtype(
479 &device,
480 DType::QFloat(QuantScheme::default())
481 ));
482 }
483
484 #[test]
485 fn should_support_from_bf16() {
486 let data = TensorData::from([[1.0], [1.]]).convert_dtype(DType::BF16);
487 let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
488 let data = TensorData::from([[2.0], [2.]]).convert_dtype(DType::BF16);
489 let tensor_2 = B::float_from_data(data, &Default::default());
490
491 let tensor_3 = B::float_add(tensor_1, tensor_2);
492
493 assert_eq!(tensor_3.tensor.kind(), tch::Kind::BFloat16);
494
495 let out = read_sync(B::float_into_data(tensor_3)).unwrap();
496
497 out.assert_eq(&TensorData::from([[3.0], [3.0]]), false);
498 }
499}
500
501unsafe extern "C" {
502 pub fn dummy_cuda_dependency();
504}
505
506#[used]
507static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];