burn_tensor/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4
5//! This library provides the core abstractions required to run tensor operations with Burn.
6//! `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations.
7//! Burn's tensors also support support auto-differentiation thanks to the `AutodiffBackend` trait.
8
9#[macro_use]
10extern crate derive_new;
11
12extern crate alloc;
13
14mod tensor;
15
16#[cfg(feature = "export_tests")]
17#[allow(missing_docs)]
18pub mod tests;
19
20#[cfg(feature = "export_tests")]
21// Re-export the might_panic proc macro for easy access
22pub use burn_tensor_testgen::might_panic;
23
24pub use half::{bf16, f16};
25pub(crate) use tensor::check::macros::check;
26pub use tensor::*;
27
28pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency.
29
30#[cfg(feature = "cubecl")]
31pub use cubecl::flex32;
32
33#[cfg(feature = "cubecl")]
34mod cube {
35    use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind};
36
37    impl From<crate::DType> for cubecl::ir::Elem {
38        fn from(dtype: crate::DType) -> Self {
39            match dtype {
40                crate::DType::F64 => Elem::Float(FloatKind::F64),
41                crate::DType::F32 => Elem::Float(FloatKind::F32),
42                crate::DType::Flex32 => Elem::Float(FloatKind::Flex32),
43                crate::DType::F16 => Elem::Float(FloatKind::F16),
44                crate::DType::BF16 => Elem::Float(FloatKind::BF16),
45                crate::DType::I64 => Elem::Int(IntKind::I64),
46                crate::DType::I32 => Elem::Int(IntKind::I32),
47                crate::DType::I16 => Elem::Int(IntKind::I16),
48                crate::DType::I8 => Elem::Int(IntKind::I8),
49                crate::DType::U64 => Elem::UInt(UIntKind::U64),
50                crate::DType::U32 => Elem::UInt(UIntKind::U32),
51                crate::DType::U16 => Elem::UInt(UIntKind::U16),
52                crate::DType::U8 => Elem::UInt(UIntKind::U8),
53                crate::DType::Bool => Elem::Bool,
54                crate::DType::QFloat(_) => panic!("quantized type is not supported yet."),
55            }
56        }
57    }
58}
59
60#[cfg(feature = "cubecl-wgpu")]
61mod cube_wgpu {
62    use crate::backend::{DeviceId, DeviceOps};
63    use cubecl::wgpu::WgpuDevice;
64
65    // Allow deprecated `WgpuDevice::BestAvailable`
66    #[allow(deprecated)]
67    impl DeviceOps for WgpuDevice {
68        fn id(&self) -> DeviceId {
69            match self {
70                WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
71                WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
72                WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
73                WgpuDevice::Cpu => DeviceId::new(3, 0),
74                WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
75                WgpuDevice::Existing(id) => DeviceId::new(5, *id),
76            }
77        }
78    }
79}
80
81#[cfg(feature = "cubecl-cuda")]
82mod cube_cuda {
83    use crate::backend::{DeviceId, DeviceOps};
84    use cubecl::cuda::CudaDevice;
85
86    impl DeviceOps for CudaDevice {
87        fn id(&self) -> DeviceId {
88            DeviceId::new(0, self.index as u32)
89        }
90    }
91}
92
93#[cfg(target_os = "linux")]
94#[cfg(feature = "cubecl-hip")]
95mod cube_hip {
96    use crate::backend::{DeviceId, DeviceOps};
97    use cubecl::hip::HipDevice;
98
99    impl DeviceOps for HipDevice {
100        fn id(&self) -> DeviceId {
101            DeviceId::new(0, self.index as u32)
102        }
103    }
104}