1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_backend::{Backend, DeviceOps, ExecutionError, TensorData};
3use burn_std::DType;
4use cubecl::{ir::StorageType, server::ComputeServer};
5use std::marker::PhantomData;
6
7#[cfg(not(feature = "fusion"))]
8use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
9#[cfg(not(feature = "fusion"))]
10use burn_ir::{BackendIr, TensorHandle};
11
12#[derive(new)]
14pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
15 _runtime: PhantomData<R>,
16 _float_elem: PhantomData<F>,
17 _int_elem: PhantomData<I>,
18 _bool_elem: PhantomData<BT>,
19}
20
21impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
22where
23 R: CubeRuntime,
24 R::Server: ComputeServer,
25 R::Device: DeviceOps,
26 F: FloatElement,
27 I: IntElement,
28 BT: BoolElement,
29{
30 type Device = R::Device;
31
32 type FloatElem = F;
33 type IntElem = I;
34 type BoolElem = BT;
35
36 type FloatTensorPrimitive = CubeTensor<R>;
37 type IntTensorPrimitive = CubeTensor<R>;
38 type BoolTensorPrimitive = CubeTensor<R>;
39 type QuantizedTensorPrimitive = CubeTensor<R>;
40
41 fn name(device: &Self::Device) -> String {
42 let client = R::client(device);
43 format!("cubecl<{}>", R::name(&client))
44 }
45
46 fn seed(_device: &Self::Device, seed: u64) {
47 cubek::random::seed(seed);
48 }
49
50 fn ad_enabled() -> bool {
51 false
52 }
53
54 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
55 let client = R::client(device);
56 futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {
57 reason: format!("{err}"),
58 })
59 }
60
61 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
62 device: &Self::Device,
63 input: Input,
64 func: Func,
65 ) -> Output {
66 let client = R::client(device);
67 client.memory_persistent_allocation(input, func)
68 }
69
70 fn memory_cleanup(device: &Self::Device) {
71 let client = R::client(device);
72 client.memory_cleanup();
73 }
74
75 fn staging<'a, Iter>(data: Iter, device: &Self::Device)
76 where
77 Iter: Iterator<Item = &'a mut TensorData>,
78 {
79 let client = R::client(device);
80 client.staging(data.map(|td| &mut td.bytes), false);
81 }
82
83 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
84 let client = R::client(device);
85
86 let ty: StorageType = dtype.into();
87 client.properties().supports_type(ty.elem_type())
88 }
89}
90
91impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
92 for CubeBackend<R, F, I, BT>
93{
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.write_str("CubeCLBackend")
96 }
97}
98
99impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
100 for CubeBackend<R, F, I, BT>
101{
102 fn clone(&self) -> Self {
103 Self::new()
104 }
105}
106
107impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
108 for CubeBackend<R, F, I, BT>
109{
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl<R: cubecl::Runtime> CubeRuntime for R
116where
117 R::Device: DeviceOps,
118{
119 type CubeDevice = R::Device;
120 type CubeServer = R::Server;
121}
122
123#[cfg(not(feature = "fusion"))]
124impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
125 for CubeBackend<R, F, I, BT>
126{
127 type Handle = CubeTensor<R>;
128
129 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
130 handle.handle
131 }
132
133 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
134 handle.handle
135 }
136
137 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
138 handle.handle
139 }
140
141 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
142 handle.handle
143 }
144
145 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
146 tensor
147 }
148
149 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
150 tensor
151 }
152
153 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
154 tensor
155 }
156
157 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
158 tensor
159 }
160}