Skip to main content

burn_std/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5//! # Burn Standard Library
6//!
7//! This library contains core types and utilities shared across Burn, including shapes, indexing,
8//! and data types.
9
10extern crate alloc;
11
12/// Id module contains types for unique identifiers.
13pub mod id;
14
15/// Tensor utilities.
16pub mod tensor;
17pub use tensor::*;
18
19/// Common Errors.
20pub use cubecl_zspace::errors::{self, *};
21
22/// Network utilities.
23#[cfg(feature = "network")]
24pub mod network;
25
26// Re-exported types
27pub use cubecl_common::bytes::*;
28pub use cubecl_common::device_handle::DeviceHandle;
29pub use cubecl_common::*;
30pub use half::{bf16, f16};
31
32#[cfg(feature = "cubecl")]
33pub use cubecl::flex32;
34
35#[cfg(feature = "cubecl")]
36mod cube {
37    use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
38    use cubecl_common::quant::scheme::QuantScheme;
39
40    use crate::tensor::DType;
41    use crate::tensor::quantization::{QuantStore, QuantValue};
42
43    impl From<DType> for cubecl::ir::ElemType {
44        fn from(dtype: DType) -> Self {
45            match dtype {
46                DType::F64 => ElemType::Float(FloatKind::F64),
47                DType::F32 => ElemType::Float(FloatKind::F32),
48                DType::Flex32 => ElemType::Float(FloatKind::Flex32),
49                DType::F16 => ElemType::Float(FloatKind::F16),
50                DType::BF16 => ElemType::Float(FloatKind::BF16),
51                DType::I64 => ElemType::Int(IntKind::I64),
52                DType::I32 => ElemType::Int(IntKind::I32),
53                DType::I16 => ElemType::Int(IntKind::I16),
54                DType::I8 => ElemType::Int(IntKind::I8),
55                DType::U64 => ElemType::UInt(UIntKind::U64),
56                DType::U32 => ElemType::UInt(UIntKind::U32),
57                DType::U16 => ElemType::UInt(UIntKind::U16),
58                DType::U8 => ElemType::UInt(UIntKind::U8),
59                DType::Bool(store) => match store {
60                    crate::BoolStore::Native => ElemType::Bool,
61                    crate::BoolStore::U8 => ElemType::UInt(UIntKind::U8),
62                    crate::BoolStore::U32 => ElemType::UInt(UIntKind::U32),
63                },
64                DType::QFloat(scheme) => match scheme.store {
65                    QuantStore::Native => match scheme.value {
66                        QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
67                        QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
68                        QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
69                        QuantValue::Q4F
70                        | QuantValue::Q4S
71                        | QuantValue::Q2F
72                        | QuantValue::Q2S
73                        | QuantValue::E2M1 => {
74                            panic!("Can't store native sub-byte values")
75                        }
76                    },
77                    QuantStore::PackedU32(_) => Self::UInt(UIntKind::U32),
78                    QuantStore::PackedNative(_) => match scheme.value {
79                        QuantValue::E2M1 => panic!("Can't store native sub-byte values"),
80                        other => panic!("{other:?} doesn't support native packing"),
81                    },
82                },
83            }
84        }
85    }
86
87    impl From<DType> for cubecl::ir::StorageType {
88        fn from(dtype: DType) -> cubecl::ir::StorageType {
89            match dtype {
90                DType::QFloat(QuantScheme {
91                    store: QuantStore::PackedNative(_),
92                    value: QuantValue::E2M1,
93                    ..
94                }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
95                _ => {
96                    let elem: ElemType = dtype.into();
97                    elem.into()
98                }
99            }
100        }
101    }
102}