1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_tensor::{
3 TensorData,
4 backend::{Backend, DeviceOps, SyncError},
5};
6use cubecl::server::ComputeServer;
7use std::marker::PhantomData;
8
9#[cfg(not(feature = "fusion"))]
10use burn_ir::{BackendIr, TensorHandle};
11#[cfg(not(feature = "fusion"))]
12use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
13
14#[derive(new)]
16pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
17 _runtime: PhantomData<R>,
18 _float_elem: PhantomData<F>,
19 _int_elem: PhantomData<I>,
20 _bool_elem: PhantomData<BT>,
21}
22
23impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
24where
25 R: CubeRuntime,
26 R::Server: ComputeServer,
27 R::Device: burn_tensor::backend::DeviceOps,
28 F: FloatElement,
29 I: IntElement,
30 BT: BoolElement,
31{
32 type Device = R::Device;
33
34 type FloatElem = F;
35 type IntElem = I;
36 type BoolElem = BT;
37
38 type FloatTensorPrimitive = CubeTensor<R>;
39 type IntTensorPrimitive = CubeTensor<R>;
40 type BoolTensorPrimitive = CubeTensor<R>;
41 type QuantizedTensorPrimitive = CubeTensor<R>;
42
43 fn name(device: &Self::Device) -> String {
44 let client = R::client(device);
45 format!("cubecl<{}>", R::name(&client))
46 }
47
48 fn seed(_device: &Self::Device, seed: u64) {
49 cubecl::random::seed(seed);
50 }
51
52 fn ad_enabled() -> bool {
53 false
54 }
55
56 fn sync(device: &Self::Device) -> Result<(), SyncError> {
57 let client = R::client(device);
58 futures_lite::future::block_on(client.sync()).map_err(|err| SyncError::Generic {
59 context: format!("{err:?}"),
60 })
61 }
62
63 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
64 device: &Self::Device,
65 input: Input,
66 func: Func,
67 ) -> Output {
68 let client = R::client(device);
69 client.memory_persistent_allocation(input, func)
70 }
71
72 fn memory_cleanup(device: &Self::Device) {
73 let client = R::client(device);
74 client.memory_cleanup();
75 }
76
77 fn staging<'a, Iter>(data: Iter, device: &Self::Device)
78 where
79 Iter: Iterator<Item = &'a mut TensorData>,
80 {
81 let client = R::client(device);
82 client.staging(data.map(|td| &mut td.bytes), false);
83 }
84}
85
86impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
87 for CubeBackend<R, F, I, BT>
88{
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.write_str("CubeCLBackend")
91 }
92}
93
94impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
95 for CubeBackend<R, F, I, BT>
96{
97 fn clone(&self) -> Self {
98 Self::new()
99 }
100}
101
102impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
103 for CubeBackend<R, F, I, BT>
104{
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl<R: cubecl::Runtime> CubeRuntime for R
111where
112 R::Device: DeviceOps,
113{
114 type CubeDevice = R::Device;
115 type CubeServer = R::Server;
116}
117
118#[cfg(not(feature = "fusion"))]
119impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
120 for CubeBackend<R, F, I, BT>
121{
122 type Handle = CubeTensor<R>;
123
124 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
125 handle.handle
126 }
127
128 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
129 handle.handle
130 }
131
132 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
133 handle.handle
134 }
135
136 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
137 handle.handle
138 }
139
140 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
141 tensor
142 }
143
144 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
145 tensor
146 }
147
148 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
149 tensor
150 }
151
152 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
153 tensor
154 }
155}