1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4#![allow(deprecated)]
6
7#[macro_use]
11extern crate derive_new;
12
13extern crate alloc;
14
15mod tensor;
16
17#[cfg(feature = "repr")]
19pub mod repr;
20
21#[cfg(feature = "export_tests")]
22#[allow(missing_docs)]
23pub mod tests;
24
25pub use half::{bf16, f16};
26pub(crate) use tensor::check::macros::check;
27pub use tensor::*;
28
29pub use burn_common::reader::*; #[cfg(feature = "cubecl")]
32mod cube {
33 use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind};
34
35 impl From<crate::DType> for cubecl::ir::Elem {
36 fn from(dtype: crate::DType) -> Self {
37 match dtype {
38 crate::DType::F64 => Elem::Float(FloatKind::F64),
39 crate::DType::F32 => Elem::Float(FloatKind::F32),
40 crate::DType::F16 => Elem::Float(FloatKind::F16),
41 crate::DType::BF16 => Elem::Float(FloatKind::BF16),
42 crate::DType::I64 => Elem::Int(IntKind::I64),
43 crate::DType::I32 => Elem::Int(IntKind::I32),
44 crate::DType::I16 => Elem::Int(IntKind::I16),
45 crate::DType::I8 => Elem::Int(IntKind::I8),
46 crate::DType::U64 => Elem::UInt(UIntKind::U64),
47 crate::DType::U32 => Elem::UInt(UIntKind::U32),
48 crate::DType::U16 => Elem::UInt(UIntKind::U16),
49 crate::DType::U8 => Elem::UInt(UIntKind::U8),
50 crate::DType::Bool => Elem::Bool,
51 crate::DType::QFloat(_) => panic!("quantized type is not supported yet."),
52 }
53 }
54 }
55}
56
57#[cfg(feature = "cubecl-wgpu")]
58mod cube_wgpu {
59 use crate::backend::{DeviceId, DeviceOps};
60 use cubecl::wgpu::WgpuDevice;
61
62 impl DeviceOps for WgpuDevice {
63 fn id(&self) -> DeviceId {
64 match self {
65 WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
66 WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
67 WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
68 WgpuDevice::Cpu => DeviceId::new(3, 0),
69 WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
70 WgpuDevice::Existing(id) => DeviceId::new(5, *id),
71 }
72 }
73 }
74}
75
76#[cfg(feature = "cubecl-cuda")]
77mod cube_cuda {
78 use crate::backend::{DeviceId, DeviceOps};
79 use cubecl::cuda::CudaDevice;
80
81 impl DeviceOps for CudaDevice {
82 fn id(&self) -> DeviceId {
83 DeviceId::new(0, self.index as u32)
84 }
85 }
86}
87
88#[cfg(target_os = "linux")]
89#[cfg(feature = "cubecl-hip")]
90mod cube_hip {
91 use crate::backend::{DeviceId, DeviceOps};
92 use cubecl::hip::HipDevice;
93
94 impl DeviceOps for HipDevice {
95 fn id(&self) -> DeviceId {
96 DeviceId::new(0, self.index as u32)
97 }
98 }
99}