1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5#[macro_use]
10extern crate derive_new;
11
12extern crate alloc;
13
14mod tensor;
15
16#[cfg(feature = "export_tests")]
17#[allow(missing_docs)]
18pub mod tests;
19
20#[cfg(feature = "export_tests")]
21pub use burn_tensor_testgen::might_panic;
23
24pub use half::{bf16, f16};
25pub(crate) use tensor::check::macros::check;
26pub use tensor::*;
27
28pub use burn_common::stream_id::StreamId;
29
30pub use burn_common::reader::*; #[cfg(feature = "cubecl")]
33pub use cubecl::flex32;
34
35#[cfg(feature = "cubecl")]
36mod cube {
37 use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
38 use cubecl_quant::scheme::QuantScheme;
39
40 use crate::quantization::{QuantStore, QuantValue};
41
42 impl From<crate::DType> for cubecl::ir::ElemType {
43 fn from(dtype: crate::DType) -> Self {
44 match dtype {
45 crate::DType::F64 => ElemType::Float(FloatKind::F64),
46 crate::DType::F32 => ElemType::Float(FloatKind::F32),
47 crate::DType::Flex32 => ElemType::Float(FloatKind::Flex32),
48 crate::DType::F16 => ElemType::Float(FloatKind::F16),
49 crate::DType::BF16 => ElemType::Float(FloatKind::BF16),
50 crate::DType::I64 => ElemType::Int(IntKind::I64),
51 crate::DType::I32 => ElemType::Int(IntKind::I32),
52 crate::DType::I16 => ElemType::Int(IntKind::I16),
53 crate::DType::I8 => ElemType::Int(IntKind::I8),
54 crate::DType::U64 => ElemType::UInt(UIntKind::U64),
55 crate::DType::U32 => ElemType::UInt(UIntKind::U32),
56 crate::DType::U16 => ElemType::UInt(UIntKind::U16),
57 crate::DType::U8 => ElemType::UInt(UIntKind::U8),
58 crate::DType::Bool => ElemType::Bool,
59 crate::DType::QFloat(scheme) => match scheme.store {
60 QuantStore::Native => match scheme.value {
61 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
62 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
63 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
64 QuantValue::Q4F
65 | QuantValue::Q4S
66 | QuantValue::Q2F
67 | QuantValue::Q2S
68 | QuantValue::E2M1 => {
69 panic!("Can't store native sub-byte values")
70 }
71 },
72 QuantStore::U32 => Self::UInt(UIntKind::U32),
73 },
74 }
75 }
76 }
77
78 impl From<crate::DType> for cubecl::ir::StorageType {
79 fn from(dtype: crate::DType) -> cubecl::ir::StorageType {
80 match dtype {
81 crate::DType::QFloat(QuantScheme {
82 store: QuantStore::Native,
83 value: QuantValue::E2M1,
84 ..
85 }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
86 _ => {
87 let elem: ElemType = dtype.into();
88 elem.into()
89 }
90 }
91 }
92 }
93}
94
95#[cfg(feature = "cubecl-wgpu")]
96mod cube_wgpu {
97 use crate::backend::DeviceOps;
98 use cubecl::wgpu::WgpuDevice;
99
100 impl DeviceOps for WgpuDevice {}
101}
102
103#[cfg(feature = "cubecl-cuda")]
104mod cube_cuda {
105 use crate::backend::DeviceOps;
106 use cubecl::cuda::CudaDevice;
107
108 impl DeviceOps for CudaDevice {}
109}
110
111#[cfg(all(feature = "cubecl-cpu", target_os = "linux"))]
112mod cube_cpu {
113 use crate::backend::DeviceOps;
114 use cubecl::cpu::CpuDevice;
115
116 impl DeviceOps for CpuDevice {}
117}
118
119#[cfg(feature = "cubecl-hip")]
120mod cube_hip {
121 use crate::backend::DeviceOps;
122 use cubecl::hip::AmdDevice;
123
124 impl DeviceOps for AmdDevice {}
125}