1use crate::{element::BoolElement, tensor::JitTensor, FloatElement, IntElement, JitRuntime};
2use burn_tensor::backend::{Backend, DeviceOps};
3use cubecl::server::ComputeServer;
4use rand::{rngs::StdRng, SeedableRng};
5use std::{marker::PhantomData, sync::Mutex};
6
7#[cfg(not(feature = "fusion"))]
8use burn_tensor::{
9 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
10 repr::{ReprBackend, TensorHandle},
11};
12
13pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
14
15#[derive(new)]
17pub struct JitBackend<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
18 _runtime: PhantomData<R>,
19 _float_elem: PhantomData<F>,
20 _int_elem: PhantomData<I>,
21 _bool_elem: PhantomData<BT>,
22}
23
24impl<R, F, I, BT> Backend for JitBackend<R, F, I, BT>
25where
26 R: JitRuntime,
27 R::Server: ComputeServer,
28 R::Device: burn_tensor::backend::DeviceOps,
29 F: FloatElement,
30 I: IntElement,
31 BT: BoolElement,
32{
33 type Device = R::Device;
34
35 type FloatElem = F;
36 type IntElem = I;
37 type BoolElem = BT;
38
39 type FloatTensorPrimitive = JitTensor<R>;
40 type IntTensorPrimitive = JitTensor<R>;
41 type BoolTensorPrimitive = JitTensor<R>;
42 type QuantizedTensorPrimitive = JitTensor<R>;
43 type QuantizedEncoding = u32;
44
45 fn name() -> String {
46 format!("jit<{}>", R::name())
47 }
48
49 fn seed(seed: u64) {
50 let rng = StdRng::seed_from_u64(seed);
51 let mut seed = SEED.lock().unwrap();
52 *seed = Some(rng);
53 }
54
55 fn ad_enabled() -> bool {
56 false
57 }
58
59 fn sync(device: &Self::Device) {
60 let client = R::client(device);
61 futures_lite::future::block_on(client.sync());
62 }
63}
64
65impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
66 for JitBackend<R, F, I, BT>
67{
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name()))
70 }
71}
72
73impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
74 for JitBackend<R, F, I, BT>
75{
76 fn clone(&self) -> Self {
77 Self::new()
78 }
79}
80
81impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
82 for JitBackend<R, F, I, BT>
83{
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl<R: cubecl::Runtime> JitRuntime for R
90where
91 R::Device: DeviceOps,
92{
93 type JitDevice = R::Device;
94 type JitServer = R::Server;
95}
96
97#[cfg(not(feature = "fusion"))]
98impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> ReprBackend
99 for JitBackend<R, F, I, BT>
100{
101 type Handle = JitTensor<R>;
102
103 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
104 handle.handle
105 }
106
107 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
108 handle.handle
109 }
110
111 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
112 handle.handle
113 }
114
115 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
116 handle.handle
117 }
118
119 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
120 tensor
121 }
122
123 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
124 tensor
125 }
126
127 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
128 tensor
129 }
130
131 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
132 tensor
133 }
134}