burn_cubecl/
backend.rs

1use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
2use burn_tensor::{
3    TensorData,
4    backend::{Backend, DeviceOps},
5};
6use cubecl::server::ComputeServer;
7use std::marker::PhantomData;
8
9#[cfg(not(feature = "fusion"))]
10use burn_ir::{BackendIr, TensorHandle};
11#[cfg(not(feature = "fusion"))]
12use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
13
14/// Generic tensor backend that can be compiled just-in-time to any shader runtime
15#[derive(new)]
16pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
17    _runtime: PhantomData<R>,
18    _float_elem: PhantomData<F>,
19    _int_elem: PhantomData<I>,
20    _bool_elem: PhantomData<BT>,
21}
22
23impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
24where
25    R: CubeRuntime,
26    R::Server: ComputeServer,
27    R::Device: burn_tensor::backend::DeviceOps,
28    F: FloatElement,
29    I: IntElement,
30    BT: BoolElement,
31{
32    type Device = R::Device;
33
34    type FloatElem = F;
35    type IntElem = I;
36    type BoolElem = BT;
37
38    type FloatTensorPrimitive = CubeTensor<R>;
39    type IntTensorPrimitive = CubeTensor<R>;
40    type BoolTensorPrimitive = CubeTensor<R>;
41    type QuantizedTensorPrimitive = CubeTensor<R>;
42
43    fn name(device: &Self::Device) -> String {
44        let client = R::client(device);
45        format!("cubecl<{}>", R::name(&client))
46    }
47
48    fn seed(_device: &Self::Device, seed: u64) {
49        cubecl::random::seed(seed);
50    }
51
52    fn ad_enabled() -> bool {
53        false
54    }
55
56    fn sync(device: &Self::Device) {
57        let client = R::client(device);
58        futures_lite::future::block_on(client.sync());
59    }
60
61    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
62        device: &Self::Device,
63        input: Input,
64        func: Func,
65    ) -> Output {
66        let client = R::client(device);
67        client.memory_persistent_allocation(input, func)
68    }
69
70    fn memory_cleanup(device: &Self::Device) {
71        let client = R::client(device);
72        client.memory_cleanup();
73    }
74
75    fn staging<'a, Iter>(data: Iter, device: &Self::Device)
76    where
77        Iter: Iterator<Item = &'a mut TensorData>,
78    {
79        let client = R::client(device);
80        client.staging(data.map(|td| &mut td.bytes), false);
81    }
82}
83
84impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
85    for CubeBackend<R, F, I, BT>
86{
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.write_str("CubeCLBackend")
89    }
90}
91
92impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
93    for CubeBackend<R, F, I, BT>
94{
95    fn clone(&self) -> Self {
96        Self::new()
97    }
98}
99
100impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
101    for CubeBackend<R, F, I, BT>
102{
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl<R: cubecl::Runtime> CubeRuntime for R
109where
110    R::Device: DeviceOps,
111{
112    type CubeDevice = R::Device;
113    type CubeServer = R::Server;
114}
115
116#[cfg(not(feature = "fusion"))]
117impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
118    for CubeBackend<R, F, I, BT>
119{
120    type Handle = CubeTensor<R>;
121
122    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
123        handle.handle
124    }
125
126    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
127        handle.handle
128    }
129
130    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
131        handle.handle
132    }
133
134    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
135        handle.handle
136    }
137
138    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
139        tensor
140    }
141
142    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
143        tensor
144    }
145
146    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
147        tensor
148    }
149
150    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
151        tensor
152    }
153}