1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5extern crate alloc;
11
12pub mod id;
14
15pub mod tensor;
17pub use tensor::*;
18
19pub mod errors;
21pub use errors::*;
22
23#[cfg(feature = "network")]
25pub mod network;
26
27pub use cubecl_common::bytes::*;
29pub use cubecl_common::*;
30pub use half::{bf16, f16};
31
32#[cfg(feature = "cubecl")]
33pub use cubecl::flex32;
34
35#[cfg(feature = "cubecl")]
36mod cube {
37 use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
38 use cubecl_common::quant::scheme::QuantScheme;
39
40 use crate::tensor::DType;
41 use crate::tensor::quantization::{QuantStore, QuantValue};
42
43 impl From<DType> for cubecl::ir::ElemType {
44 fn from(dtype: DType) -> Self {
45 match dtype {
46 DType::F64 => ElemType::Float(FloatKind::F64),
47 DType::F32 => ElemType::Float(FloatKind::F32),
48 DType::Flex32 => ElemType::Float(FloatKind::Flex32),
49 DType::F16 => ElemType::Float(FloatKind::F16),
50 DType::BF16 => ElemType::Float(FloatKind::BF16),
51 DType::I64 => ElemType::Int(IntKind::I64),
52 DType::I32 => ElemType::Int(IntKind::I32),
53 DType::I16 => ElemType::Int(IntKind::I16),
54 DType::I8 => ElemType::Int(IntKind::I8),
55 DType::U64 => ElemType::UInt(UIntKind::U64),
56 DType::U32 => ElemType::UInt(UIntKind::U32),
57 DType::U16 => ElemType::UInt(UIntKind::U16),
58 DType::U8 => ElemType::UInt(UIntKind::U8),
59 DType::Bool => ElemType::Bool,
60 DType::QFloat(scheme) => match scheme.store {
61 QuantStore::Native => match scheme.value {
62 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
63 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
64 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
65 QuantValue::Q4F
66 | QuantValue::Q4S
67 | QuantValue::Q2F
68 | QuantValue::Q2S
69 | QuantValue::E2M1 => {
70 panic!("Can't store native sub-byte values")
71 }
72 },
73 QuantStore::PackedU32(_) => Self::UInt(UIntKind::U32),
74 QuantStore::PackedNative(_) => match scheme.value {
75 QuantValue::E2M1 => panic!("Can't store native sub-byte values"),
76 other => panic!("{other:?} doesn't support native packing"),
77 },
78 },
79 }
80 }
81 }
82
83 impl From<DType> for cubecl::ir::StorageType {
84 fn from(dtype: DType) -> cubecl::ir::StorageType {
85 match dtype {
86 DType::QFloat(QuantScheme {
87 store: QuantStore::PackedNative(_),
88 value: QuantValue::E2M1,
89 ..
90 }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
91 _ => {
92 let elem: ElemType = dtype.into();
93 elem.into()
94 }
95 }
96 }
97 }
98}