Skip to main content

burn_cubecl/
backend.rs

1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_backend::{
3    Backend, BackendTypes, DTypeUsage, DTypeUsageSet, DeviceOps, ExecutionError, TensorData,
4};
5use burn_std::DType;
6use cubecl::{
7    features::{MmaConfig, TypeUsage},
8    server::ComputeServer,
9};
10use std::marker::PhantomData;
11
12#[cfg(not(feature = "fusion"))]
13use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
14#[cfg(not(feature = "fusion"))]
15use burn_ir::{BackendIr, TensorHandle};
16
17/// Generic tensor backend that can be compiled just-in-time to any shader runtime
18#[derive(new)]
19pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
20    _runtime: PhantomData<R>,
21    _float_elem: PhantomData<F>,
22    _int_elem: PhantomData<I>,
23    _bool_elem: PhantomData<BT>,
24}
25
26impl<R, F, I, BT> BackendTypes for CubeBackend<R, F, I, BT>
27where
28    R: CubeRuntime,
29    R::Server: ComputeServer,
30    R::Device: DeviceOps,
31    F: FloatElement,
32    I: IntElement,
33    BT: BoolElement,
34{
35    type Device = R::Device;
36
37    type FloatElem = F;
38    type IntElem = I;
39    type BoolElem = BT;
40
41    type FloatTensorPrimitive = CubeTensor<R>;
42    type IntTensorPrimitive = CubeTensor<R>;
43    type BoolTensorPrimitive = CubeTensor<R>;
44    type QuantizedTensorPrimitive = CubeTensor<R>;
45}
46
47impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
48where
49    R: CubeRuntime,
50    R::Server: ComputeServer,
51    R::Device: DeviceOps,
52    F: FloatElement,
53    I: IntElement,
54    BT: BoolElement,
55{
56    fn name(device: &Self::Device) -> String {
57        let client = R::client(device);
58        format!("cubecl<{}>", R::name(&client))
59    }
60
61    fn seed(_device: &Self::Device, seed: u64) {
62        cubek::random::seed(seed);
63    }
64
65    fn ad_enabled(_device: &Self::Device) -> bool {
66        false
67    }
68
69    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
70        let client = R::client(device);
71        futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {
72            reason: format!("{err}"),
73        })
74    }
75
76    fn memory_persistent_allocations<
77        Output: Send,
78        Input: Send,
79        Func: Fn(Input) -> Output + Send,
80    >(
81        device: &Self::Device,
82        input: Input,
83        func: Func,
84    ) -> Output {
85        let client = R::client(device);
86        client.memory_persistent_allocation(input, func).unwrap()
87    }
88
89    fn memory_cleanup(device: &Self::Device) {
90        let client = R::client(device);
91        client.memory_cleanup();
92    }
93
94    fn staging<'a, Iter>(data: Iter, device: &Self::Device)
95    where
96        Iter: Iterator<Item = &'a mut TensorData>,
97    {
98        let client = R::client(device);
99        client.staging(data.map(|td| &mut td.bytes), false);
100    }
101
102    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
103        let client = R::client(device);
104
105        let type_usage = client.properties().type_usage(dtype.into());
106        // Same as `TypeUsage::all_scalar()`, but we make the usage explicit here
107        type_usage.is_superset(
108            TypeUsage::Buffer
109                | TypeUsage::Conversion
110                | TypeUsage::Arithmetic
111                | TypeUsage::DotProduct,
112        )
113    }
114
115    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet {
116        let client = R::client(device);
117
118        let props = client.properties();
119        let storage = dtype.into();
120        let usage = props.type_usage(storage);
121
122        let mut out = DTypeUsageSet::new();
123
124        if usage.is_superset(TypeUsage::Buffer | TypeUsage::Conversion) {
125            out |= DTypeUsage::Storage;
126        }
127
128        if usage.contains(TypeUsage::Arithmetic) {
129            out |= DTypeUsage::Arithmetic;
130        }
131
132        let has_mma = |cfg: &MmaConfig| {
133            cfg.a_type == storage || cfg.b_type == storage || cfg.cd_type == storage
134        };
135        if props.features.matmul.cmma.iter().any(has_mma)
136            || props.features.matmul.mma.iter().any(has_mma)
137        {
138            out |= DTypeUsage::Accelerated;
139        }
140
141        out
142    }
143
144    fn device_count(type_id: u16) -> usize {
145        let client = R::client(&Default::default());
146        client.device_count(type_id)
147    }
148}
149
150impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
151    for CubeBackend<R, F, I, BT>
152{
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.write_str("CubeCLBackend")
155    }
156}
157
158impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
159    for CubeBackend<R, F, I, BT>
160{
161    fn clone(&self) -> Self {
162        Self::new()
163    }
164}
165
166impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
167    for CubeBackend<R, F, I, BT>
168{
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl<R: cubecl::Runtime> CubeRuntime for R
175where
176    R::Device: DeviceOps,
177{
178    type CubeDevice = R::Device;
179    type CubeServer = R::Server;
180}
181
182#[cfg(not(feature = "fusion"))]
183impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
184    for CubeBackend<R, F, I, BT>
185{
186    type Handle = CubeTensor<R>;
187
188    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
189        handle.handle
190    }
191
192    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
193        handle.handle
194    }
195
196    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
197        handle.handle
198    }
199
200    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
201        handle.handle
202    }
203
204    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
205        tensor
206    }
207
208    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
209        tensor
210    }
211
212    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
213        tensor
214    }
215
216    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
217        tensor
218    }
219}