burn_tensor/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4// Allow deprecated `Data` and `DataSerialize`
5#![allow(deprecated)]
6
7//! This library provides multiple tensor implementations hidden behind an easy to use API
8//! that supports reverse mode automatic differentiation.
9
10#[macro_use]
11extern crate derive_new;
12
13extern crate alloc;
14
15mod tensor;
16
17/// Burn Tensor representaton
18#[cfg(feature = "repr")]
19pub mod repr;
20
21#[cfg(feature = "export_tests")]
22#[allow(missing_docs)]
23pub mod tests;
24
25pub use half::{bf16, f16};
26pub(crate) use tensor::check::macros::check;
27pub use tensor::*;
28
29pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency.
30
31#[cfg(feature = "cubecl")]
32mod cube {
33    use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind};
34
35    impl From<crate::DType> for cubecl::ir::Elem {
36        fn from(dtype: crate::DType) -> Self {
37            match dtype {
38                crate::DType::F64 => Elem::Float(FloatKind::F64),
39                crate::DType::F32 => Elem::Float(FloatKind::F32),
40                crate::DType::F16 => Elem::Float(FloatKind::F16),
41                crate::DType::BF16 => Elem::Float(FloatKind::BF16),
42                crate::DType::I64 => Elem::Int(IntKind::I64),
43                crate::DType::I32 => Elem::Int(IntKind::I32),
44                crate::DType::I16 => Elem::Int(IntKind::I16),
45                crate::DType::I8 => Elem::Int(IntKind::I8),
46                crate::DType::U64 => Elem::UInt(UIntKind::U64),
47                crate::DType::U32 => Elem::UInt(UIntKind::U32),
48                crate::DType::U16 => Elem::UInt(UIntKind::U16),
49                crate::DType::U8 => Elem::UInt(UIntKind::U8),
50                crate::DType::Bool => Elem::Bool,
51                crate::DType::QFloat(_) => panic!("quantized type is not supported yet."),
52            }
53        }
54    }
55}
56
57#[cfg(feature = "cubecl-wgpu")]
58mod cube_wgpu {
59    use crate::backend::{DeviceId, DeviceOps};
60    use cubecl::wgpu::WgpuDevice;
61
62    impl DeviceOps for WgpuDevice {
63        fn id(&self) -> DeviceId {
64            match self {
65                WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
66                WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
67                WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
68                WgpuDevice::Cpu => DeviceId::new(3, 0),
69                WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
70                WgpuDevice::Existing(id) => DeviceId::new(5, *id),
71            }
72        }
73    }
74}
75
76#[cfg(feature = "cubecl-cuda")]
77mod cube_cuda {
78    use crate::backend::{DeviceId, DeviceOps};
79    use cubecl::cuda::CudaDevice;
80
81    impl DeviceOps for CudaDevice {
82        fn id(&self) -> DeviceId {
83            DeviceId::new(0, self.index as u32)
84        }
85    }
86}
87
88#[cfg(target_os = "linux")]
89#[cfg(feature = "cubecl-hip")]
90mod cube_hip {
91    use crate::backend::{DeviceId, DeviceOps};
92    use cubecl::hip::HipDevice;
93
94    impl DeviceOps for HipDevice {
95        fn id(&self) -> DeviceId {
96            DeviceId::new(0, self.index as u32)
97        }
98    }
99}