1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5extern crate alloc;
11
12pub mod id;
14
15pub mod tensor;
17pub use tensor::*;
18
19#[cfg(feature = "network")]
21pub mod network;
22
23pub use cubecl_common::bytes::*;
25pub use cubecl_common::*;
26pub use half::{bf16, f16};
27
28#[cfg(feature = "cubecl")]
29pub use cubecl::flex32;
30
31#[cfg(feature = "cubecl")]
32mod cube {
33 use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
34 use cubecl_quant::scheme::QuantScheme;
35
36 use crate::tensor::DType;
37 use crate::tensor::quantization::{QuantStore, QuantValue};
38
39 impl From<DType> for cubecl::ir::ElemType {
40 fn from(dtype: DType) -> Self {
41 match dtype {
42 DType::F64 => ElemType::Float(FloatKind::F64),
43 DType::F32 => ElemType::Float(FloatKind::F32),
44 DType::Flex32 => ElemType::Float(FloatKind::Flex32),
45 DType::F16 => ElemType::Float(FloatKind::F16),
46 DType::BF16 => ElemType::Float(FloatKind::BF16),
47 DType::I64 => ElemType::Int(IntKind::I64),
48 DType::I32 => ElemType::Int(IntKind::I32),
49 DType::I16 => ElemType::Int(IntKind::I16),
50 DType::I8 => ElemType::Int(IntKind::I8),
51 DType::U64 => ElemType::UInt(UIntKind::U64),
52 DType::U32 => ElemType::UInt(UIntKind::U32),
53 DType::U16 => ElemType::UInt(UIntKind::U16),
54 DType::U8 => ElemType::UInt(UIntKind::U8),
55 DType::Bool => ElemType::Bool,
56 DType::QFloat(scheme) => match scheme.store {
57 QuantStore::Native => match scheme.value {
58 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
59 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
60 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
61 QuantValue::Q4F
62 | QuantValue::Q4S
63 | QuantValue::Q2F
64 | QuantValue::Q2S
65 | QuantValue::E2M1 => {
66 panic!("Can't store native sub-byte values")
67 }
68 },
69 QuantStore::U32 => Self::UInt(UIntKind::U32),
70 },
71 }
72 }
73 }
74
75 impl From<DType> for cubecl::ir::StorageType {
76 fn from(dtype: DType) -> cubecl::ir::StorageType {
77 match dtype {
78 DType::QFloat(QuantScheme {
79 store: QuantStore::Native,
80 value: QuantValue::E2M1,
81 ..
82 }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
83 _ => {
84 let elem: ElemType = dtype.into();
85 elem.into()
86 }
87 }
88 }
89 }
90}