burn_cubecl/
backend.rs

1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_backend::{Backend, DeviceOps, ExecutionError, TensorData};
3use burn_std::DType;
4use cubecl::{ir::StorageType, server::ComputeServer};
5use std::marker::PhantomData;
6
7#[cfg(not(feature = "fusion"))]
8use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
9#[cfg(not(feature = "fusion"))]
10use burn_ir::{BackendIr, TensorHandle};
11
12/// Generic tensor backend that can be compiled just-in-time to any shader runtime
13#[derive(new)]
14pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
15    _runtime: PhantomData<R>,
16    _float_elem: PhantomData<F>,
17    _int_elem: PhantomData<I>,
18    _bool_elem: PhantomData<BT>,
19}
20
21impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
22where
23    R: CubeRuntime,
24    R::Server: ComputeServer,
25    R::Device: DeviceOps,
26    F: FloatElement,
27    I: IntElement,
28    BT: BoolElement,
29{
30    type Device = R::Device;
31
32    type FloatElem = F;
33    type IntElem = I;
34    type BoolElem = BT;
35
36    type FloatTensorPrimitive = CubeTensor<R>;
37    type IntTensorPrimitive = CubeTensor<R>;
38    type BoolTensorPrimitive = CubeTensor<R>;
39    type QuantizedTensorPrimitive = CubeTensor<R>;
40
41    fn name(device: &Self::Device) -> String {
42        let client = R::client(device);
43        format!("cubecl<{}>", R::name(&client))
44    }
45
46    fn seed(_device: &Self::Device, seed: u64) {
47        cubek::random::seed(seed);
48    }
49
50    fn ad_enabled() -> bool {
51        false
52    }
53
54    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
55        let client = R::client(device);
56        futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {
57            reason: format!("{err}"),
58        })
59    }
60
61    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
62        device: &Self::Device,
63        input: Input,
64        func: Func,
65    ) -> Output {
66        let client = R::client(device);
67        client.memory_persistent_allocation(input, func)
68    }
69
70    fn memory_cleanup(device: &Self::Device) {
71        let client = R::client(device);
72        client.memory_cleanup();
73    }
74
75    fn staging<'a, Iter>(data: Iter, device: &Self::Device)
76    where
77        Iter: Iterator<Item = &'a mut TensorData>,
78    {
79        let client = R::client(device);
80        client.staging(data.map(|td| &mut td.bytes), false);
81    }
82
83    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
84        let client = R::client(device);
85
86        let ty: StorageType = dtype.into();
87        client.properties().supports_type(ty.elem_type())
88    }
89}
90
91impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
92    for CubeBackend<R, F, I, BT>
93{
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.write_str("CubeCLBackend")
96    }
97}
98
99impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
100    for CubeBackend<R, F, I, BT>
101{
102    fn clone(&self) -> Self {
103        Self::new()
104    }
105}
106
107impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
108    for CubeBackend<R, F, I, BT>
109{
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl<R: cubecl::Runtime> CubeRuntime for R
116where
117    R::Device: DeviceOps,
118{
119    type CubeDevice = R::Device;
120    type CubeServer = R::Server;
121}
122
123#[cfg(not(feature = "fusion"))]
124impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
125    for CubeBackend<R, F, I, BT>
126{
127    type Handle = CubeTensor<R>;
128
129    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
130        handle.handle
131    }
132
133    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
134        handle.handle
135    }
136
137    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
138        handle.handle
139    }
140
141    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
142        handle.handle
143    }
144
145    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
146        tensor
147    }
148
149    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
150        tensor
151    }
152
153    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
154        tensor
155    }
156
157    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
158        tensor
159    }
160}