1use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState};
2use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState};
3use crate::fusion::elemwise::builder::ElementWiseBuilder;
4use crate::fusion::matmul::builder::MatmulBuilder;
5use crate::BoolElement;
6use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime};
7
8use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
9use burn_tensor::repr::TensorHandle;
10use burn_tensor::DType;
11use burn_tensor::{repr::ReprBackend, Shape};
12use core::marker::PhantomData;
13use cubecl::client::ComputeClient;
14use cubecl::prelude::{TensorArg, TensorHandleRef};
15use half::{bf16, f16};
16use serde::{Deserialize, Serialize};
17
18pub enum JitOptimization<R: JitRuntime> {
22 ElementWise(ElemwiseOptimization<R>),
24 Matmul(MatmulOptimization<R>),
26}
27
28#[derive(Serialize, Deserialize)]
32pub enum JitOptimizationState {
33 ElementWise(ElemwiseOptimizationState),
35 Matmul(MatmulOptimizationState),
37}
38
39impl<R, BT> burn_fusion::Optimization<FusionJitRuntime<R, BT>> for JitOptimization<R>
40where
41 R: JitRuntime,
42 BT: BoolElement,
43{
44 fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle<R>>) {
45 match self {
46 Self::ElementWise(op) => op.execute::<BT>(context),
47 Self::Matmul(op) => op.execute::<BT>(context),
48 }
49 }
50
51 fn len(&self) -> usize {
52 match self {
53 Self::ElementWise(op) => op.num_ops_fused(),
54 Self::Matmul(op) => op.num_ops_fused(),
55 }
56 }
57
58 fn to_state(&self) -> JitOptimizationState {
59 match self {
60 Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()),
61 Self::Matmul(value) => JitOptimizationState::Matmul(value.to_state()),
62 }
63 }
64
65 fn from_state(device: &R::Device, state: JitOptimizationState) -> Self {
66 match state {
67 JitOptimizationState::ElementWise(state) => {
68 Self::ElementWise(ElemwiseOptimization::from_state(device, state))
69 }
70 JitOptimizationState::Matmul(state) => {
71 Self::Matmul(MatmulOptimization::from_state(device, state))
72 }
73 }
74 }
75}
76
77impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> ReprBackend
78 for JitBackend<R, F, I, BT>
79{
80 type Handle = JitFusionHandle<R>;
81
82 fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
83 handle.handle.into_tensor(handle.shape)
84 }
85
86 fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
87 handle.handle.into_tensor(handle.shape)
88 }
89
90 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
91 handle.handle.into_tensor(handle.shape)
92 }
93
94 fn quantized_tensor(
95 handle: TensorHandle<Self::Handle>,
96 ) -> burn_tensor::ops::QuantizedTensor<Self> {
97 handle.handle.into_tensor(handle.shape)
98 }
99
100 fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
101 tensor.into()
102 }
103
104 fn int_tensor_handle(tensor: burn_tensor::ops::IntTensor<Self>) -> Self::Handle {
105 tensor.into()
106 }
107
108 fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
109 tensor.into()
110 }
111
112 fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
113 tensor.into()
114 }
115}
116
117impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
118 type OptimizationState = JitOptimizationState;
119 type Optimization = JitOptimization<R>;
120 type FusionHandle = JitFusionHandle<R>;
121 type FusionDevice = R::JitDevice;
122 type FusionClient = MutexFusionClient<Self>;
123 type BoolRepr = BT;
124
125 fn optimizations(
126 device: R::Device,
127 ) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
128 let mut optimizations: Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> =
129 vec![Box::new(ElementWiseBuilder::<R>::new(
130 device.clone(),
131 BT::as_elem_native_unchecked().into(),
132 ))];
133
134 if cfg!(feature = "fusion-experimental") {
135 optimizations.push(Box::new(MatmulBuilder::<R>::new(
136 device.clone(),
137 BT::as_elem_native_unchecked().into(),
138 )));
139 }
140
141 optimizations
142 }
143}
144
145#[derive(Debug)]
147pub struct FusionJitRuntime<R: JitRuntime, BT: BoolElement> {
148 _b: PhantomData<R>,
149 _bool: PhantomData<BT>,
150}
151
152impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
153 for JitBackend<R, F, I, BT>
154{
155 type FusionRuntime = FusionJitRuntime<R, BT>;
156
157 type FullPrecisionBackend = JitBackend<R, f32, i32, BT>;
158
159 fn cast_float(
160 tensor: burn_tensor::ops::FloatTensor<Self>,
161 dtype: burn_tensor::DType,
162 ) -> Self::Handle {
163 fn cast<R: JitRuntime, F: FloatElement, FTarget: FloatElement>(
164 tensor: JitTensor<R>,
165 ) -> JitFusionHandle<R> {
166 JitFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
167 }
168
169 match dtype {
170 burn_tensor::DType::F32 => cast::<R, F, f32>(tensor),
171 burn_tensor::DType::F16 => cast::<R, F, f16>(tensor),
172 burn_tensor::DType::BF16 => cast::<R, F, bf16>(tensor),
173 _ => panic!("Casting error: {dtype:?} unsupported."),
174 }
175 }
176}
177
178pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
179 let mut strides = vec![0; shape.len()];
180
181 let mut current = 1;
182 shape.iter().enumerate().rev().for_each(|(index, val)| {
183 strides[index] = current;
184 current *= val;
185 });
186
187 strides
188}
189
190pub struct JitFusionHandle<R: JitRuntime> {
192 pub client: ComputeClient<R::Server, R::Channel>,
194 pub handle: cubecl::server::Handle,
196 pub device: R::Device,
198 pub(crate) dtype: DType,
199 pub(crate) strides: Vec<usize>,
200}
201
202impl<R: JitRuntime> core::fmt::Debug for JitFusionHandle<R> {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.write_fmt(format_args!(
205 "JitFusionHandle {{ device: {:?}, runtime: {}}}",
206 self.device,
207 R::name(),
208 ))
209 }
210}
211
212impl<R: JitRuntime> Clone for JitFusionHandle<R> {
213 fn clone(&self) -> Self {
214 Self {
215 client: self.client.clone(),
216 handle: self.handle.clone(),
217 device: self.device.clone(),
218 strides: self.strides.clone(),
219 dtype: self.dtype,
220 }
221 }
222}
223
224unsafe impl<R: JitRuntime> Send for JitFusionHandle<R> {}
225unsafe impl<R: JitRuntime> Sync for JitFusionHandle<R> {}
226
227impl<R: JitRuntime> JitFusionHandle<R> {
228 pub(crate) fn into_tensor(self, shape: Shape) -> JitTensor<R> {
229 JitTensor {
230 client: self.client,
231 handle: self.handle,
232 device: self.device,
233 shape,
234 strides: self.strides,
235 dtype: self.dtype,
236 }
237 }
238 pub fn as_handle_ref<'a>(&'a self, shape: &'a [usize]) -> TensorHandleRef<'a, R> {
240 TensorHandleRef {
241 handle: &self.handle,
242 strides: &self.strides,
243 shape,
244 runtime: PhantomData,
245 elem_size: self.dtype.size(),
246 }
247 }
248 pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> {
250 let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape);
251
252 unsafe {
253 TensorArg::from_raw_parts_and_size(
254 handle.handle,
255 handle.strides,
256 handle.shape,
257 vectorisation,
258 self.dtype.size(),
259 )
260 }
261 }
262}
263
264impl<R: JitRuntime> From<JitTensor<R>> for JitFusionHandle<R> {
265 fn from(value: JitTensor<R>) -> Self {
266 Self {
267 client: value.client,
268 handle: value.handle,
269 device: value.device,
270 strides: value.strides,
271 dtype: value.dtype,
272 }
273 }
274}