1use crate::BoolElement;
2use crate::element::CubeElement;
3use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
4
5use burn_cubecl_fusion::CubeFusionHandle;
6use burn_cubecl_fusion::elemwise::optimization::ElemwiseOptimization;
7use burn_cubecl_fusion::matmul::MatmulFallbackFn;
8use burn_cubecl_fusion::matmul::builder::MatmulBuilder;
9use burn_cubecl_fusion::matmul::optimization::MatmulOptimization;
10use burn_cubecl_fusion::reduce::builder::ReduceBuilder;
11use burn_cubecl_fusion::reduce::optimization::{
12 ReduceFallbackFn, ReduceInstruction, ReduceOptimization,
13};
14use burn_cubecl_fusion::{
15 CubeOptimization, CubeOptimizationState, elemwise::builder::ElementWiseBuilder,
16};
17use burn_fusion::{FusionBackend, FusionRuntime, client::MutexFusionClient};
18use burn_ir::{BackendIr, TensorHandle};
19use burn_tensor::{DType, Shape};
20use core::marker::PhantomData;
21use cubecl::flex32;
22use cubecl::reduce::instructions::ReduceFnConfig;
23use half::{bf16, f16};
24use std::sync::Arc;
25
26impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
27where
28 R: CubeRuntime,
29 BT: BoolElement,
30{
31 fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
32 match self {
33 Self::ElementWise(op) => op.execute::<BT>(context),
34 Self::Matmul(op) => op.execute::<BT>(context),
35 Self::Reduce(op) => op.execute::<BT>(context),
36 }
37 }
38
39 fn len(&self) -> usize {
40 match self {
41 Self::ElementWise(op) => op.num_ops_fused(),
42 Self::Matmul(op) => op.num_ops_fused(),
43 Self::Reduce(op) => op.num_ops_fused(),
44 }
45 }
46
47 fn to_state(&self) -> CubeOptimizationState {
48 match self {
49 Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()),
50 Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()),
51 Self::Reduce(value) => CubeOptimizationState::Reduce(value.to_state()),
52 }
53 }
54
55 fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
56 match state {
57 CubeOptimizationState::ElementWise(state) => {
58 Self::ElementWise(ElemwiseOptimization::from_state(device, state))
59 }
60 CubeOptimizationState::Matmul(state) => Self::Matmul(MatmulOptimization::from_state(
61 device,
62 state,
63 Arc::new(FallbackMatmul),
64 )),
65 CubeOptimizationState::Reduce(state) => Self::Reduce(ReduceOptimization::from_state(
66 device,
67 state,
68 Arc::new(FallbackReduce),
69 )),
70 }
71 }
72}
73
74struct FallbackMatmul;
75struct FallbackReduce;
76
77impl<R: CubeRuntime> MatmulFallbackFn<R> for FallbackMatmul {
78 fn run(
79 &self,
80 lhs: (CubeFusionHandle<R>, &[usize]),
81 rhs: (CubeFusionHandle<R>, &[usize]),
82 ) -> CubeFusionHandle<R> {
83 match lhs.0.dtype {
84 DType::F64 => run_fallback_matmul::<R, f64>(lhs, rhs),
85 DType::F32 => run_fallback_matmul::<R, f32>(lhs, rhs),
86 DType::Flex32 => run_fallback_matmul::<R, flex32>(lhs, rhs),
87 DType::F16 => run_fallback_matmul::<R, f16>(lhs, rhs),
88 DType::BF16 => run_fallback_matmul::<R, bf16>(lhs, rhs),
89 _ => todo!("Not yet supported"),
90 }
91 }
92}
93
94impl<R: CubeRuntime> ReduceFallbackFn<R> for FallbackReduce {
95 fn run(
96 &self,
97 input: CubeFusionHandle<R>,
98 shape: &[usize],
99 axis: usize,
100 inst: &ReduceInstruction,
101 d_o: &DType,
102 ) -> CubeFusionHandle<R> {
103 let d_i = input.dtype;
104 let config = match inst {
105 ReduceInstruction::ArgMax => ReduceFnConfig::ArgMax,
106 ReduceInstruction::ArgMin => ReduceFnConfig::ArgMin,
107 ReduceInstruction::Mean => ReduceFnConfig::Mean,
108 ReduceInstruction::Prod => ReduceFnConfig::Prod,
109 ReduceInstruction::Sum => ReduceFnConfig::Sum,
110 ReduceInstruction::Min => ReduceFnConfig::Min,
111 ReduceInstruction::Max => ReduceFnConfig::Max,
112 ReduceInstruction::MaxAbs => ReduceFnConfig::MaxAbs,
113 };
114
115 reduce_dtype::<R>(input, shape, axis, &d_i, d_o, config)
116 }
117}
118
119fn run_fallback_matmul<R: CubeRuntime, EG: FloatElement>(
120 lhs: (CubeFusionHandle<R>, &[usize]),
121 rhs: (CubeFusionHandle<R>, &[usize]),
122) -> CubeFusionHandle<R> {
123 let lhs_tensor = into_tensor(
124 lhs.0,
125 Shape {
126 dims: lhs.1.to_vec(),
127 },
128 );
129 let rhs_tensor = into_tensor(
130 rhs.0,
131 Shape {
132 dims: rhs.1.to_vec(),
133 },
134 );
135 let out_tensor = crate::kernel::matmul::matmul::<R, EG>(
136 lhs_tensor,
137 rhs_tensor,
138 None,
139 crate::kernel::matmul::MatmulStrategy::default(),
140 )
141 .unwrap();
142
143 CubeFusionHandle {
144 client: out_tensor.client,
145 handle: out_tensor.handle,
146 device: out_tensor.device,
147 dtype: out_tensor.dtype,
148 strides: out_tensor.strides,
149 }
150}
151
152fn reduce_dtype<R: CubeRuntime>(
153 input_handle: CubeFusionHandle<R>,
154 shape: &[usize],
155 axis: usize,
156 dtype_input: &DType,
157 dtype_output: &DType,
158 config: ReduceFnConfig,
159) -> CubeFusionHandle<R> {
160 match dtype_input {
161 DType::F64 => {
162 reduce_dtype_output::<R, f64>(input_handle, shape, axis, dtype_output, config)
163 }
164 DType::F32 | DType::Flex32 => {
165 reduce_dtype_output::<R, f32>(input_handle, shape, axis, dtype_output, config)
166 }
167 DType::F16 => {
168 reduce_dtype_output::<R, f16>(input_handle, shape, axis, dtype_output, config)
169 }
170 DType::BF16 => {
171 reduce_dtype_output::<R, bf16>(input_handle, shape, axis, dtype_output, config)
172 }
173 DType::I64 => {
174 reduce_dtype_output::<R, i64>(input_handle, shape, axis, dtype_output, config)
175 }
176 DType::I32 => {
177 reduce_dtype_output::<R, i32>(input_handle, shape, axis, dtype_output, config)
178 }
179 DType::I16 => {
180 reduce_dtype_output::<R, i16>(input_handle, shape, axis, dtype_output, config)
181 }
182 DType::U64 => {
183 reduce_dtype_output::<R, u64>(input_handle, shape, axis, dtype_output, config)
184 }
185 DType::U32 => {
186 reduce_dtype_output::<R, u32>(input_handle, shape, axis, dtype_output, config)
187 }
188 DType::U16 => {
189 reduce_dtype_output::<R, u16>(input_handle, shape, axis, dtype_output, config)
190 }
191 _ => todo!("Not yet supported"),
192 }
193}
194
195fn reduce_dtype_output<R: CubeRuntime, In: CubeElement>(
196 input_handle: CubeFusionHandle<R>,
197 shape: &[usize],
198 axis: usize,
199 dtype_output: &DType,
200 config: ReduceFnConfig,
201) -> CubeFusionHandle<R> {
202 match dtype_output {
203 DType::F64 => reduce::<R, In, f64>(input_handle, shape, axis, config),
204 DType::F32 | DType::Flex32 => reduce::<R, In, f32>(input_handle, shape, axis, config),
205 DType::F16 => reduce::<R, In, f16>(input_handle, shape, axis, config),
206 DType::BF16 => reduce::<R, In, bf16>(input_handle, shape, axis, config),
207 DType::I64 => reduce::<R, In, i64>(input_handle, shape, axis, config),
208 DType::I32 => reduce::<R, In, i32>(input_handle, shape, axis, config),
209 DType::I16 => reduce::<R, In, i16>(input_handle, shape, axis, config),
210 DType::U64 => reduce::<R, In, u64>(input_handle, shape, axis, config),
211 DType::U32 => reduce::<R, In, u32>(input_handle, shape, axis, config),
212 DType::U16 => reduce::<R, In, u16>(input_handle, shape, axis, config),
213 _ => todo!("Not yet supported"),
214 }
215}
216
217fn reduce<R: CubeRuntime, In: CubeElement, Out: CubeElement>(
218 input_handle: CubeFusionHandle<R>,
219 shape: &[usize],
220 axis: usize,
221 config: ReduceFnConfig,
222) -> CubeFusionHandle<R> {
223 let input_tensor = into_tensor(
224 input_handle,
225 Shape {
226 dims: shape.to_vec(),
227 },
228 );
229 let out_tensor = crate::kernel::reduce::reduce_dim::<R, In, Out>(
230 input_tensor,
231 axis,
232 crate::kernel::reduce::ReduceStrategy::default(),
233 config,
234 )
235 .unwrap();
236
237 CubeFusionHandle {
238 client: out_tensor.client,
239 handle: out_tensor.handle,
240 device: out_tensor.device,
241 dtype: out_tensor.dtype,
242 strides: out_tensor.strides,
243 }
244}
245
246impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
247 for CubeBackend<R, F, I, BT>
248{
249 type Handle = CubeFusionHandle<R>;
250
251 fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
252 into_tensor(handle.handle, handle.shape)
253 }
254
255 fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
256 into_tensor(handle.handle, handle.shape)
257 }
258
259 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
260 into_tensor(handle.handle, handle.shape)
261 }
262
263 fn quantized_tensor(
264 handle: TensorHandle<Self::Handle>,
265 ) -> burn_tensor::ops::QuantizedTensor<Self> {
266 into_tensor(handle.handle, handle.shape)
267 }
268
269 fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
270 tensor.into()
271 }
272
273 fn int_tensor_handle(tensor: burn_tensor::ops::IntTensor<Self>) -> Self::Handle {
274 tensor.into()
275 }
276
277 fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
278 tensor.into()
279 }
280
281 fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
282 tensor.into()
283 }
284}
285
286impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
287 type OptimizationState = CubeOptimizationState;
288 type Optimization = CubeOptimization<R>;
289 type FusionHandle = CubeFusionHandle<R>;
290 type FusionDevice = R::CubeDevice;
291 type FusionClient = MutexFusionClient<Self>;
292 type BoolRepr = BT;
293
294 fn optimizations(
295 device: R::Device,
296 ) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
297 vec![
298 Box::new(ElementWiseBuilder::<R>::new(
299 device.clone(),
300 BT::as_elem_native_unchecked().into(),
301 )),
302 Box::new(MatmulBuilder::<R>::new(
303 device.clone(),
304 BT::as_elem_native_unchecked().into(),
305 Arc::new(FallbackMatmul),
306 )),
307 Box::new(ReduceBuilder::<R>::new(
308 device.clone(),
309 BT::as_elem_native_unchecked().into(),
310 Arc::new(FallbackReduce),
311 )),
312 ]
313 }
314}
315
316#[derive(Debug)]
318pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
319 _b: PhantomData<R>,
320 _bool: PhantomData<BT>,
321}
322
323impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
324 for CubeBackend<R, F, I, BT>
325{
326 type FusionRuntime = FusionCubeRuntime<R, BT>;
327
328 type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
329
330 fn cast_float(tensor: burn_tensor::ops::FloatTensor<Self>, dtype: DType) -> Self::Handle {
331 fn cast<R: CubeRuntime, F: FloatElement, FTarget: FloatElement>(
332 tensor: CubeTensor<R>,
333 ) -> CubeFusionHandle<R> {
334 CubeFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
335 }
336
337 match dtype {
338 DType::F32 | DType::Flex32 => cast::<R, F, f32>(tensor),
339 DType::F16 => cast::<R, F, f16>(tensor),
340 DType::BF16 => cast::<R, F, bf16>(tensor),
341 _ => panic!("Casting error: {dtype:?} unsupported."),
342 }
343 }
344}
345
346fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
347 CubeTensor {
348 client: handle.client,
349 handle: handle.handle,
350 device: handle.device,
351 shape,
352 strides: handle.strides,
353 dtype: handle.dtype,
354 }
355}
356
357impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
358 fn from(value: CubeTensor<R>) -> Self {
359 Self {
360 client: value.client,
361 handle: value.handle,
362 device: value.device,
363 strides: value.strides,
364 dtype: value.dtype,
365 }
366 }
367}