burn_cubecl/
fusion.rs

1use crate::BoolElement;
2use crate::element::CubeElement;
3use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
4
5use burn_cubecl_fusion::CubeFusionHandle;
6use burn_cubecl_fusion::elemwise::optimization::ElemwiseOptimization;
7use burn_cubecl_fusion::matmul::MatmulFallbackFn;
8use burn_cubecl_fusion::matmul::builder::MatmulBuilder;
9use burn_cubecl_fusion::matmul::optimization::MatmulOptimization;
10use burn_cubecl_fusion::reduce::builder::ReduceBuilder;
11use burn_cubecl_fusion::reduce::optimization::{
12    ReduceFallbackFn, ReduceInstruction, ReduceOptimization,
13};
14use burn_cubecl_fusion::{
15    CubeOptimization, CubeOptimizationState, elemwise::builder::ElementWiseBuilder,
16};
17use burn_fusion::{FusionBackend, FusionRuntime, client::MutexFusionClient};
18use burn_ir::{BackendIr, TensorHandle};
19use burn_tensor::{DType, Shape};
20use core::marker::PhantomData;
21use cubecl::flex32;
22use cubecl::reduce::instructions::ReduceFnConfig;
23use half::{bf16, f16};
24use std::sync::Arc;
25
26impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
27where
28    R: CubeRuntime,
29    BT: BoolElement,
30{
31    fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
32        match self {
33            Self::ElementWise(op) => op.execute::<BT>(context),
34            Self::Matmul(op) => op.execute::<BT>(context),
35            Self::Reduce(op) => op.execute::<BT>(context),
36        }
37    }
38
39    fn len(&self) -> usize {
40        match self {
41            Self::ElementWise(op) => op.num_ops_fused(),
42            Self::Matmul(op) => op.num_ops_fused(),
43            Self::Reduce(op) => op.num_ops_fused(),
44        }
45    }
46
47    fn to_state(&self) -> CubeOptimizationState {
48        match self {
49            Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()),
50            Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()),
51            Self::Reduce(value) => CubeOptimizationState::Reduce(value.to_state()),
52        }
53    }
54
55    fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
56        match state {
57            CubeOptimizationState::ElementWise(state) => {
58                Self::ElementWise(ElemwiseOptimization::from_state(device, state))
59            }
60            CubeOptimizationState::Matmul(state) => Self::Matmul(MatmulOptimization::from_state(
61                device,
62                state,
63                Arc::new(FallbackMatmul),
64            )),
65            CubeOptimizationState::Reduce(state) => Self::Reduce(ReduceOptimization::from_state(
66                device,
67                state,
68                Arc::new(FallbackReduce),
69            )),
70        }
71    }
72}
73
74struct FallbackMatmul;
75struct FallbackReduce;
76
77impl<R: CubeRuntime> MatmulFallbackFn<R> for FallbackMatmul {
78    fn run(
79        &self,
80        lhs: (CubeFusionHandle<R>, &[usize]),
81        rhs: (CubeFusionHandle<R>, &[usize]),
82    ) -> CubeFusionHandle<R> {
83        match lhs.0.dtype {
84            DType::F64 => run_fallback_matmul::<R, f64>(lhs, rhs),
85            DType::F32 => run_fallback_matmul::<R, f32>(lhs, rhs),
86            DType::Flex32 => run_fallback_matmul::<R, flex32>(lhs, rhs),
87            DType::F16 => run_fallback_matmul::<R, f16>(lhs, rhs),
88            DType::BF16 => run_fallback_matmul::<R, bf16>(lhs, rhs),
89            _ => todo!("Not yet supported"),
90        }
91    }
92}
93
94impl<R: CubeRuntime> ReduceFallbackFn<R> for FallbackReduce {
95    fn run(
96        &self,
97        input: CubeFusionHandle<R>,
98        shape: &[usize],
99        axis: usize,
100        inst: &ReduceInstruction,
101        d_o: &DType,
102    ) -> CubeFusionHandle<R> {
103        let d_i = input.dtype;
104        let config = match inst {
105            ReduceInstruction::ArgMax => ReduceFnConfig::ArgMax,
106            ReduceInstruction::ArgMin => ReduceFnConfig::ArgMin,
107            ReduceInstruction::Mean => ReduceFnConfig::Mean,
108            ReduceInstruction::Prod => ReduceFnConfig::Prod,
109            ReduceInstruction::Sum => ReduceFnConfig::Sum,
110            ReduceInstruction::Min => ReduceFnConfig::Min,
111            ReduceInstruction::Max => ReduceFnConfig::Max,
112            ReduceInstruction::MaxAbs => ReduceFnConfig::MaxAbs,
113        };
114
115        reduce_dtype::<R>(input, shape, axis, &d_i, d_o, config)
116    }
117}
118
119fn run_fallback_matmul<R: CubeRuntime, EG: FloatElement>(
120    lhs: (CubeFusionHandle<R>, &[usize]),
121    rhs: (CubeFusionHandle<R>, &[usize]),
122) -> CubeFusionHandle<R> {
123    let lhs_tensor = into_tensor(
124        lhs.0,
125        Shape {
126            dims: lhs.1.to_vec(),
127        },
128    );
129    let rhs_tensor = into_tensor(
130        rhs.0,
131        Shape {
132            dims: rhs.1.to_vec(),
133        },
134    );
135    let out_tensor = crate::kernel::matmul::matmul::<R, EG>(
136        lhs_tensor,
137        rhs_tensor,
138        None,
139        crate::kernel::matmul::MatmulStrategy::default(),
140    )
141    .unwrap();
142
143    CubeFusionHandle {
144        client: out_tensor.client,
145        handle: out_tensor.handle,
146        device: out_tensor.device,
147        dtype: out_tensor.dtype,
148        strides: out_tensor.strides,
149    }
150}
151
152fn reduce_dtype<R: CubeRuntime>(
153    input_handle: CubeFusionHandle<R>,
154    shape: &[usize],
155    axis: usize,
156    dtype_input: &DType,
157    dtype_output: &DType,
158    config: ReduceFnConfig,
159) -> CubeFusionHandle<R> {
160    match dtype_input {
161        DType::F64 => {
162            reduce_dtype_output::<R, f64>(input_handle, shape, axis, dtype_output, config)
163        }
164        DType::F32 | DType::Flex32 => {
165            reduce_dtype_output::<R, f32>(input_handle, shape, axis, dtype_output, config)
166        }
167        DType::F16 => {
168            reduce_dtype_output::<R, f16>(input_handle, shape, axis, dtype_output, config)
169        }
170        DType::BF16 => {
171            reduce_dtype_output::<R, bf16>(input_handle, shape, axis, dtype_output, config)
172        }
173        DType::I64 => {
174            reduce_dtype_output::<R, i64>(input_handle, shape, axis, dtype_output, config)
175        }
176        DType::I32 => {
177            reduce_dtype_output::<R, i32>(input_handle, shape, axis, dtype_output, config)
178        }
179        DType::I16 => {
180            reduce_dtype_output::<R, i16>(input_handle, shape, axis, dtype_output, config)
181        }
182        DType::U64 => {
183            reduce_dtype_output::<R, u64>(input_handle, shape, axis, dtype_output, config)
184        }
185        DType::U32 => {
186            reduce_dtype_output::<R, u32>(input_handle, shape, axis, dtype_output, config)
187        }
188        DType::U16 => {
189            reduce_dtype_output::<R, u16>(input_handle, shape, axis, dtype_output, config)
190        }
191        _ => todo!("Not yet supported"),
192    }
193}
194
195fn reduce_dtype_output<R: CubeRuntime, In: CubeElement>(
196    input_handle: CubeFusionHandle<R>,
197    shape: &[usize],
198    axis: usize,
199    dtype_output: &DType,
200    config: ReduceFnConfig,
201) -> CubeFusionHandle<R> {
202    match dtype_output {
203        DType::F64 => reduce::<R, In, f64>(input_handle, shape, axis, config),
204        DType::F32 | DType::Flex32 => reduce::<R, In, f32>(input_handle, shape, axis, config),
205        DType::F16 => reduce::<R, In, f16>(input_handle, shape, axis, config),
206        DType::BF16 => reduce::<R, In, bf16>(input_handle, shape, axis, config),
207        DType::I64 => reduce::<R, In, i64>(input_handle, shape, axis, config),
208        DType::I32 => reduce::<R, In, i32>(input_handle, shape, axis, config),
209        DType::I16 => reduce::<R, In, i16>(input_handle, shape, axis, config),
210        DType::U64 => reduce::<R, In, u64>(input_handle, shape, axis, config),
211        DType::U32 => reduce::<R, In, u32>(input_handle, shape, axis, config),
212        DType::U16 => reduce::<R, In, u16>(input_handle, shape, axis, config),
213        _ => todo!("Not yet supported"),
214    }
215}
216
217fn reduce<R: CubeRuntime, In: CubeElement, Out: CubeElement>(
218    input_handle: CubeFusionHandle<R>,
219    shape: &[usize],
220    axis: usize,
221    config: ReduceFnConfig,
222) -> CubeFusionHandle<R> {
223    let input_tensor = into_tensor(
224        input_handle,
225        Shape {
226            dims: shape.to_vec(),
227        },
228    );
229    let out_tensor = crate::kernel::reduce::reduce_dim::<R, In, Out>(
230        input_tensor,
231        axis,
232        crate::kernel::reduce::ReduceStrategy::default(),
233        config,
234    )
235    .unwrap();
236
237    CubeFusionHandle {
238        client: out_tensor.client,
239        handle: out_tensor.handle,
240        device: out_tensor.device,
241        dtype: out_tensor.dtype,
242        strides: out_tensor.strides,
243    }
244}
245
246impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
247    for CubeBackend<R, F, I, BT>
248{
249    type Handle = CubeFusionHandle<R>;
250
251    fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
252        into_tensor(handle.handle, handle.shape)
253    }
254
255    fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
256        into_tensor(handle.handle, handle.shape)
257    }
258
259    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
260        into_tensor(handle.handle, handle.shape)
261    }
262
263    fn quantized_tensor(
264        handle: TensorHandle<Self::Handle>,
265    ) -> burn_tensor::ops::QuantizedTensor<Self> {
266        into_tensor(handle.handle, handle.shape)
267    }
268
269    fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
270        tensor.into()
271    }
272
273    fn int_tensor_handle(tensor: burn_tensor::ops::IntTensor<Self>) -> Self::Handle {
274        tensor.into()
275    }
276
277    fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
278        tensor.into()
279    }
280
281    fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
282        tensor.into()
283    }
284}
285
286impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
287    type OptimizationState = CubeOptimizationState;
288    type Optimization = CubeOptimization<R>;
289    type FusionHandle = CubeFusionHandle<R>;
290    type FusionDevice = R::CubeDevice;
291    type FusionClient = MutexFusionClient<Self>;
292    type BoolRepr = BT;
293
294    fn optimizations(
295        device: R::Device,
296    ) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
297        vec![
298            Box::new(ElementWiseBuilder::<R>::new(
299                device.clone(),
300                BT::as_elem_native_unchecked().into(),
301            )),
302            Box::new(MatmulBuilder::<R>::new(
303                device.clone(),
304                BT::as_elem_native_unchecked().into(),
305                Arc::new(FallbackMatmul),
306            )),
307            Box::new(ReduceBuilder::<R>::new(
308                device.clone(),
309                BT::as_elem_native_unchecked().into(),
310                Arc::new(FallbackReduce),
311            )),
312        ]
313    }
314}
315
316/// Fusion runtime for JIT runtimes.
317#[derive(Debug)]
318pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
319    _b: PhantomData<R>,
320    _bool: PhantomData<BT>,
321}
322
323impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
324    for CubeBackend<R, F, I, BT>
325{
326    type FusionRuntime = FusionCubeRuntime<R, BT>;
327
328    type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
329
330    fn cast_float(tensor: burn_tensor::ops::FloatTensor<Self>, dtype: DType) -> Self::Handle {
331        fn cast<R: CubeRuntime, F: FloatElement, FTarget: FloatElement>(
332            tensor: CubeTensor<R>,
333        ) -> CubeFusionHandle<R> {
334            CubeFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
335        }
336
337        match dtype {
338            DType::F32 | DType::Flex32 => cast::<R, F, f32>(tensor),
339            DType::F16 => cast::<R, F, f16>(tensor),
340            DType::BF16 => cast::<R, F, bf16>(tensor),
341            _ => panic!("Casting error: {dtype:?} unsupported."),
342        }
343    }
344}
345
346fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
347    CubeTensor {
348        client: handle.client,
349        handle: handle.handle,
350        device: handle.device,
351        shape,
352        strides: handle.strides,
353        dtype: handle.dtype,
354    }
355}
356
357impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
358    fn from(value: CubeTensor<R>) -> Self {
359        Self {
360            client: value.client,
361            handle: value.handle,
362            device: value.device,
363            strides: value.strides,
364            dtype: value.dtype,
365        }
366    }
367}