1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4
5#[macro_use]
10extern crate derive_new;
11
12extern crate alloc;
13
14mod tensor;
15
16#[cfg(feature = "export_tests")]
17#[allow(missing_docs)]
18pub mod tests;
19
20#[cfg(feature = "export_tests")]
21pub use burn_tensor_testgen::might_panic;
23
24pub use half::{bf16, f16};
25pub(crate) use tensor::check::macros::check;
26pub use tensor::*;
27
28pub use burn_common::reader::*; #[cfg(feature = "cubecl")]
31pub use cubecl::flex32;
32
33#[cfg(feature = "cubecl")]
34mod cube {
35 use cubecl::ir::{Elem, FloatKind, IntKind, UIntKind};
36
37 impl From<crate::DType> for cubecl::ir::Elem {
38 fn from(dtype: crate::DType) -> Self {
39 match dtype {
40 crate::DType::F64 => Elem::Float(FloatKind::F64),
41 crate::DType::F32 => Elem::Float(FloatKind::F32),
42 crate::DType::Flex32 => Elem::Float(FloatKind::Flex32),
43 crate::DType::F16 => Elem::Float(FloatKind::F16),
44 crate::DType::BF16 => Elem::Float(FloatKind::BF16),
45 crate::DType::I64 => Elem::Int(IntKind::I64),
46 crate::DType::I32 => Elem::Int(IntKind::I32),
47 crate::DType::I16 => Elem::Int(IntKind::I16),
48 crate::DType::I8 => Elem::Int(IntKind::I8),
49 crate::DType::U64 => Elem::UInt(UIntKind::U64),
50 crate::DType::U32 => Elem::UInt(UIntKind::U32),
51 crate::DType::U16 => Elem::UInt(UIntKind::U16),
52 crate::DType::U8 => Elem::UInt(UIntKind::U8),
53 crate::DType::Bool => Elem::Bool,
54 crate::DType::QFloat(_) => panic!("quantized type is not supported yet."),
55 }
56 }
57 }
58}
59
60#[cfg(feature = "cubecl-wgpu")]
61mod cube_wgpu {
62 use crate::backend::{DeviceId, DeviceOps};
63 use cubecl::wgpu::WgpuDevice;
64
65 #[allow(deprecated)]
67 impl DeviceOps for WgpuDevice {
68 fn id(&self) -> DeviceId {
69 match self {
70 WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
71 WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
72 WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
73 WgpuDevice::Cpu => DeviceId::new(3, 0),
74 WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
75 WgpuDevice::Existing(id) => DeviceId::new(5, *id),
76 }
77 }
78 }
79}
80
81#[cfg(feature = "cubecl-cuda")]
82mod cube_cuda {
83 use crate::backend::{DeviceId, DeviceOps};
84 use cubecl::cuda::CudaDevice;
85
86 impl DeviceOps for CudaDevice {
87 fn id(&self) -> DeviceId {
88 DeviceId::new(0, self.index as u32)
89 }
90 }
91}
92
93#[cfg(target_os = "linux")]
94#[cfg(feature = "cubecl-hip")]
95mod cube_hip {
96 use crate::backend::{DeviceId, DeviceOps};
97 use cubecl::hip::HipDevice;
98
99 impl DeviceOps for HipDevice {
100 fn id(&self) -> DeviceId {
101 DeviceId::new(0, self.index as u32)
102 }
103 }
104}