burn_cubecl/
backend.rs

1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_tensor::backend::{Backend, DeviceOps};
3use cubecl::server::ComputeServer;
4use rand::{SeedableRng, rngs::StdRng};
5use std::{marker::PhantomData, sync::Mutex};
6
7#[cfg(not(feature = "fusion"))]
8use burn_ir::{BackendIr, TensorHandle};
9#[cfg(not(feature = "fusion"))]
10use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
11
12pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
13
14/// Generic tensor backend that can be compiled just-in-time to any shader runtime
15#[derive(new)]
16pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
17    _runtime: PhantomData<R>,
18    _float_elem: PhantomData<F>,
19    _int_elem: PhantomData<I>,
20    _bool_elem: PhantomData<BT>,
21}
22
23impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
24where
25    R: CubeRuntime,
26    R::Server: ComputeServer,
27    R::Device: burn_tensor::backend::DeviceOps,
28    F: FloatElement,
29    I: IntElement,
30    BT: BoolElement,
31{
32    type Device = R::Device;
33
34    type FloatElem = F;
35    type IntElem = I;
36    type BoolElem = BT;
37
38    type FloatTensorPrimitive = CubeTensor<R>;
39    type IntTensorPrimitive = CubeTensor<R>;
40    type BoolTensorPrimitive = CubeTensor<R>;
41    type QuantizedTensorPrimitive = CubeTensor<R>;
42    type QuantizedEncoding = u32;
43
44    fn name(device: &Self::Device) -> String {
45        let client = R::client(device);
46        format!("cubecl<{}>", R::name(&client))
47    }
48
49    fn seed(seed: u64) {
50        let rng = StdRng::seed_from_u64(seed);
51        let mut seed = SEED.lock().unwrap();
52        *seed = Some(rng);
53    }
54
55    fn ad_enabled() -> bool {
56        false
57    }
58
59    fn sync(device: &Self::Device) {
60        let client = R::client(device);
61        futures_lite::future::block_on(client.sync());
62    }
63}
64
65impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
66    for CubeBackend<R, F, I, BT>
67{
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.write_str("CubeCLBackend")
70    }
71}
72
73impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
74    for CubeBackend<R, F, I, BT>
75{
76    fn clone(&self) -> Self {
77        Self::new()
78    }
79}
80
81impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
82    for CubeBackend<R, F, I, BT>
83{
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl<R: cubecl::Runtime> CubeRuntime for R
90where
91    R::Device: DeviceOps,
92{
93    type CubeDevice = R::Device;
94    type CubeServer = R::Server;
95}
96
97#[cfg(not(feature = "fusion"))]
98impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
99    for CubeBackend<R, F, I, BT>
100{
101    type Handle = CubeTensor<R>;
102
103    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
104        handle.handle
105    }
106
107    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
108        handle.handle
109    }
110
111    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
112        handle.handle
113    }
114
115    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
116        handle.handle
117    }
118
119    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
120        tensor
121    }
122
123    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
124        tensor
125    }
126
127    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
128        tensor
129    }
130
131    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
132        tensor
133    }
134}