burn_cubecl/
backend.rs

1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_tensor::backend::{Backend, DeviceOps};
3use cubecl::server::ComputeServer;
4use std::marker::PhantomData;
5
6#[cfg(not(feature = "fusion"))]
7use burn_ir::{BackendIr, TensorHandle};
8#[cfg(not(feature = "fusion"))]
9use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
10
11/// Generic tensor backend that can be compiled just-in-time to any shader runtime
12#[derive(new)]
13pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
14    _runtime: PhantomData<R>,
15    _float_elem: PhantomData<F>,
16    _int_elem: PhantomData<I>,
17    _bool_elem: PhantomData<BT>,
18}
19
20impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
21where
22    R: CubeRuntime,
23    R::Server: ComputeServer,
24    R::Device: burn_tensor::backend::DeviceOps,
25    F: FloatElement,
26    I: IntElement,
27    BT: BoolElement,
28{
29    type Device = R::Device;
30
31    type FloatElem = F;
32    type IntElem = I;
33    type BoolElem = BT;
34
35    type FloatTensorPrimitive = CubeTensor<R>;
36    type IntTensorPrimitive = CubeTensor<R>;
37    type BoolTensorPrimitive = CubeTensor<R>;
38    type QuantizedTensorPrimitive = CubeTensor<R>;
39
40    fn name(device: &Self::Device) -> String {
41        let client = R::client(device);
42        format!("cubecl<{}>", R::name(&client))
43    }
44
45    fn seed(_device: &Self::Device, seed: u64) {
46        cubecl::random::seed(seed);
47    }
48
49    fn ad_enabled() -> bool {
50        false
51    }
52
53    fn sync(device: &Self::Device) {
54        let client = R::client(device);
55        futures_lite::future::block_on(client.sync());
56    }
57
58    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
59        device: &Self::Device,
60        input: Input,
61        func: Func,
62    ) -> Output {
63        let client = R::client(device);
64        client.memory_persistent_allocation(input, func)
65    }
66
67    fn memory_cleanup(device: &Self::Device) {
68        let client = R::client(device);
69        client.memory_cleanup();
70    }
71}
72
73impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
74    for CubeBackend<R, F, I, BT>
75{
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.write_str("CubeCLBackend")
78    }
79}
80
81impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
82    for CubeBackend<R, F, I, BT>
83{
84    fn clone(&self) -> Self {
85        Self::new()
86    }
87}
88
89impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
90    for CubeBackend<R, F, I, BT>
91{
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl<R: cubecl::Runtime> CubeRuntime for R
98where
99    R::Device: DeviceOps,
100{
101    type CubeDevice = R::Device;
102    type CubeServer = R::Server;
103}
104
105#[cfg(not(feature = "fusion"))]
106impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
107    for CubeBackend<R, F, I, BT>
108{
109    type Handle = CubeTensor<R>;
110
111    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
112        handle.handle
113    }
114
115    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
116        handle.handle
117    }
118
119    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
120        handle.handle
121    }
122
123    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
124        handle.handle
125    }
126
127    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
128        tensor
129    }
130
131    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
132        tensor
133    }
134
135    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
136        tensor
137    }
138
139    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
140        tensor
141    }
142}