burn_std/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5//! # Burn Standard Library
6//!
7//! This library contains core types and utilities shared across Burn, including shapes, indexing,
8//! and data types.
9
10extern crate alloc;
11
12/// Id module contains types for unique identifiers.
13pub mod id;
14
15/// Tensor utilities.
16pub mod tensor;
17pub use tensor::*;
18
19// Re-exported types
20pub use cubecl_common::bytes::*;
21pub use cubecl_common::*;
22pub use half::{bf16, f16};
23
24#[cfg(feature = "cubecl")]
25pub use cubecl::flex32;
26
27#[cfg(feature = "cubecl")]
28mod cube {
29    use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
30    use cubecl_quant::scheme::QuantScheme;
31
32    use crate::tensor::DType;
33    use crate::tensor::quantization::{QuantStore, QuantValue};
34
35    impl From<DType> for cubecl::ir::ElemType {
36        fn from(dtype: DType) -> Self {
37            match dtype {
38                DType::F64 => ElemType::Float(FloatKind::F64),
39                DType::F32 => ElemType::Float(FloatKind::F32),
40                DType::Flex32 => ElemType::Float(FloatKind::Flex32),
41                DType::F16 => ElemType::Float(FloatKind::F16),
42                DType::BF16 => ElemType::Float(FloatKind::BF16),
43                DType::I64 => ElemType::Int(IntKind::I64),
44                DType::I32 => ElemType::Int(IntKind::I32),
45                DType::I16 => ElemType::Int(IntKind::I16),
46                DType::I8 => ElemType::Int(IntKind::I8),
47                DType::U64 => ElemType::UInt(UIntKind::U64),
48                DType::U32 => ElemType::UInt(UIntKind::U32),
49                DType::U16 => ElemType::UInt(UIntKind::U16),
50                DType::U8 => ElemType::UInt(UIntKind::U8),
51                DType::Bool => ElemType::Bool,
52                DType::QFloat(scheme) => match scheme.store {
53                    QuantStore::Native => match scheme.value {
54                        QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
55                        QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
56                        QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
57                        QuantValue::Q4F
58                        | QuantValue::Q4S
59                        | QuantValue::Q2F
60                        | QuantValue::Q2S
61                        | QuantValue::E2M1 => {
62                            panic!("Can't store native sub-byte values")
63                        }
64                    },
65                    QuantStore::U32 => Self::UInt(UIntKind::U32),
66                },
67            }
68        }
69    }
70
71    impl From<DType> for cubecl::ir::StorageType {
72        fn from(dtype: DType) -> cubecl::ir::StorageType {
73            match dtype {
74                DType::QFloat(QuantScheme {
75                    store: QuantStore::Native,
76                    value: QuantValue::E2M1,
77                    ..
78                }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
79                _ => {
80                    let elem: ElemType = dtype.into();
81                    elem.into()
82                }
83            }
84        }
85    }
86}