burn_tensor/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5//! This library provides the core abstractions required to run tensor operations with Burn.
6//! `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations.
7//! Burn's tensors also support auto-differentiation thanks to the `AutodiffBackend` trait.
8
9#[macro_use]
10extern crate derive_new;
11
12extern crate alloc;
13
14mod tensor;
15
16#[cfg(feature = "export_tests")]
17#[allow(missing_docs)]
18pub mod tests;
19
20#[cfg(feature = "export_tests")]
21// Re-export the might_panic proc macro for easy access
22pub use burn_tensor_testgen::might_panic;
23
24pub use half::{bf16, f16};
25pub(crate) use tensor::check::macros::check;
26pub use tensor::*;
27
28pub use burn_common::stream_id::StreamId;
29
30pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency.
31
32#[cfg(feature = "cubecl")]
33pub use cubecl::flex32;
34
35#[cfg(feature = "cubecl")]
36mod cube {
37    use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
38    use cubecl_quant::scheme::QuantScheme;
39
40    use crate::quantization::{QuantStore, QuantValue};
41
42    impl From<crate::DType> for cubecl::ir::ElemType {
43        fn from(dtype: crate::DType) -> Self {
44            match dtype {
45                crate::DType::F64 => ElemType::Float(FloatKind::F64),
46                crate::DType::F32 => ElemType::Float(FloatKind::F32),
47                crate::DType::Flex32 => ElemType::Float(FloatKind::Flex32),
48                crate::DType::F16 => ElemType::Float(FloatKind::F16),
49                crate::DType::BF16 => ElemType::Float(FloatKind::BF16),
50                crate::DType::I64 => ElemType::Int(IntKind::I64),
51                crate::DType::I32 => ElemType::Int(IntKind::I32),
52                crate::DType::I16 => ElemType::Int(IntKind::I16),
53                crate::DType::I8 => ElemType::Int(IntKind::I8),
54                crate::DType::U64 => ElemType::UInt(UIntKind::U64),
55                crate::DType::U32 => ElemType::UInt(UIntKind::U32),
56                crate::DType::U16 => ElemType::UInt(UIntKind::U16),
57                crate::DType::U8 => ElemType::UInt(UIntKind::U8),
58                crate::DType::Bool => ElemType::Bool,
59                crate::DType::QFloat(scheme) => match scheme.store {
60                    QuantStore::Native => match scheme.value {
61                        QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
62                        QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
63                        QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
64                        QuantValue::Q4F
65                        | QuantValue::Q4S
66                        | QuantValue::Q2F
67                        | QuantValue::Q2S
68                        | QuantValue::E2M1 => {
69                            panic!("Can't store native sub-byte values")
70                        }
71                    },
72                    QuantStore::U32 => Self::UInt(UIntKind::U32),
73                },
74            }
75        }
76    }
77
78    impl From<crate::DType> for cubecl::ir::StorageType {
79        fn from(dtype: crate::DType) -> cubecl::ir::StorageType {
80            match dtype {
81                crate::DType::QFloat(QuantScheme {
82                    store: QuantStore::Native,
83                    value: QuantValue::E2M1,
84                    ..
85                }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
86                _ => {
87                    let elem: ElemType = dtype.into();
88                    elem.into()
89                }
90            }
91        }
92    }
93}
94
95#[cfg(feature = "cubecl-wgpu")]
96mod cube_wgpu {
97    use crate::backend::DeviceOps;
98    use cubecl::wgpu::WgpuDevice;
99
100    impl DeviceOps for WgpuDevice {}
101}
102
103#[cfg(feature = "cubecl-cuda")]
104mod cube_cuda {
105    use crate::backend::DeviceOps;
106    use cubecl::cuda::CudaDevice;
107
108    impl DeviceOps for CudaDevice {}
109}
110
111#[cfg(all(feature = "cubecl-cpu", target_os = "linux"))]
112mod cube_cpu {
113    use crate::backend::DeviceOps;
114    use cubecl::cpu::CpuDevice;
115
116    impl DeviceOps for CpuDevice {}
117}
118
119#[cfg(feature = "cubecl-hip")]
120mod cube_hip {
121    use crate::backend::DeviceOps;
122    use cubecl::hip::AmdDevice;
123
124    impl DeviceOps for AmdDevice {}
125}