1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_backend::{
3 Backend, BackendTypes, DTypeUsage, DTypeUsageSet, DeviceOps, ExecutionError, TensorData,
4};
5use burn_std::DType;
6use cubecl::{
7 features::{MmaConfig, TypeUsage},
8 server::ComputeServer,
9};
10use std::marker::PhantomData;
11
12#[cfg(not(feature = "fusion"))]
13use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
14#[cfg(not(feature = "fusion"))]
15use burn_ir::{BackendIr, TensorHandle};
16
17#[derive(new)]
19pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
20 _runtime: PhantomData<R>,
21 _float_elem: PhantomData<F>,
22 _int_elem: PhantomData<I>,
23 _bool_elem: PhantomData<BT>,
24}
25
26impl<R, F, I, BT> BackendTypes for CubeBackend<R, F, I, BT>
27where
28 R: CubeRuntime,
29 R::Server: ComputeServer,
30 R::Device: DeviceOps,
31 F: FloatElement,
32 I: IntElement,
33 BT: BoolElement,
34{
35 type Device = R::Device;
36
37 type FloatElem = F;
38 type IntElem = I;
39 type BoolElem = BT;
40
41 type FloatTensorPrimitive = CubeTensor<R>;
42 type IntTensorPrimitive = CubeTensor<R>;
43 type BoolTensorPrimitive = CubeTensor<R>;
44 type QuantizedTensorPrimitive = CubeTensor<R>;
45}
46
47impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
48where
49 R: CubeRuntime,
50 R::Server: ComputeServer,
51 R::Device: DeviceOps,
52 F: FloatElement,
53 I: IntElement,
54 BT: BoolElement,
55{
56 fn name(device: &Self::Device) -> String {
57 let client = R::client(device);
58 format!("cubecl<{}>", R::name(&client))
59 }
60
61 fn seed(_device: &Self::Device, seed: u64) {
62 cubek::random::seed(seed);
63 }
64
65 fn ad_enabled(_device: &Self::Device) -> bool {
66 false
67 }
68
69 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
70 let client = R::client(device);
71 futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {
72 reason: format!("{err}"),
73 })
74 }
75
76 fn memory_persistent_allocations<
77 Output: Send,
78 Input: Send,
79 Func: Fn(Input) -> Output + Send,
80 >(
81 device: &Self::Device,
82 input: Input,
83 func: Func,
84 ) -> Output {
85 let client = R::client(device);
86 client.memory_persistent_allocation(input, func).unwrap()
87 }
88
89 fn memory_cleanup(device: &Self::Device) {
90 let client = R::client(device);
91 client.memory_cleanup();
92 }
93
94 fn staging<'a, Iter>(data: Iter, device: &Self::Device)
95 where
96 Iter: Iterator<Item = &'a mut TensorData>,
97 {
98 let client = R::client(device);
99 client.staging(data.map(|td| &mut td.bytes), false);
100 }
101
102 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
103 let client = R::client(device);
104
105 let type_usage = client.properties().type_usage(dtype.into());
106 type_usage.is_superset(
108 TypeUsage::Buffer
109 | TypeUsage::Conversion
110 | TypeUsage::Arithmetic
111 | TypeUsage::DotProduct,
112 )
113 }
114
115 fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet {
116 let client = R::client(device);
117
118 let props = client.properties();
119 let storage = dtype.into();
120 let usage = props.type_usage(storage);
121
122 let mut out = DTypeUsageSet::new();
123
124 if usage.is_superset(TypeUsage::Buffer | TypeUsage::Conversion) {
125 out |= DTypeUsage::Storage;
126 }
127
128 if usage.contains(TypeUsage::Arithmetic) {
129 out |= DTypeUsage::Arithmetic;
130 }
131
132 let has_mma = |cfg: &MmaConfig| {
133 cfg.a_type == storage || cfg.b_type == storage || cfg.cd_type == storage
134 };
135 if props.features.matmul.cmma.iter().any(has_mma)
136 || props.features.matmul.mma.iter().any(has_mma)
137 {
138 out |= DTypeUsage::Accelerated;
139 }
140
141 out
142 }
143
144 fn device_count(type_id: u16) -> usize {
145 let client = R::client(&Default::default());
146 client.device_count(type_id)
147 }
148}
149
150impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
151 for CubeBackend<R, F, I, BT>
152{
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.write_str("CubeCLBackend")
155 }
156}
157
158impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
159 for CubeBackend<R, F, I, BT>
160{
161 fn clone(&self) -> Self {
162 Self::new()
163 }
164}
165
166impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
167 for CubeBackend<R, F, I, BT>
168{
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl<R: cubecl::Runtime> CubeRuntime for R
175where
176 R::Device: DeviceOps,
177{
178 type CubeDevice = R::Device;
179 type CubeServer = R::Server;
180}
181
182#[cfg(not(feature = "fusion"))]
183impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
184 for CubeBackend<R, F, I, BT>
185{
186 type Handle = CubeTensor<R>;
187
188 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
189 handle.handle
190 }
191
192 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
193 handle.handle
194 }
195
196 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
197 handle.handle
198 }
199
200 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
201 handle.handle
202 }
203
204 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
205 tensor
206 }
207
208 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
209 tensor
210 }
211
212 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
213 tensor
214 }
215
216 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
217 tensor
218 }
219}