burn_cubecl/
fusion.rs

1use crate::BoolElement;
2use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
3use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
4use burn_backend::{DType, Shape};
5use burn_cubecl_fusion::{
6    CubeFusionHandle, FallbackOperation,
7    optim::{
8        CubeOptimization, CubeOptimizationState,
9        elemwise::{ElementWiseFuser, ElemwiseOptimization},
10        matmul::{MatmulFuser, MatmulOptimization},
11        reduce::{ReduceFuser, ReduceOptimization},
12    },
13};
14use burn_fusion::{
15    FusionBackend, FusionRuntime,
16    stream::{Operation, OrderedExecution},
17};
18use burn_ir::{BackendIr, TensorHandle};
19use core::marker::PhantomData;
20use std::sync::Arc;
21
22impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
23where
24    R: CubeRuntime,
25    BT: BoolElement,
26{
27    fn execute(
28        &mut self,
29        context: &mut burn_fusion::stream::Context<
30            '_,
31            <FusionCubeRuntime<R, BT> as FusionRuntime>::FusionHandle,
32        >,
33        execution: &OrderedExecution<FusionCubeRuntime<R, BT>>,
34    ) {
35        match self {
36            Self::ElementWise(op) => op.execute::<BT>(context),
37            Self::Matmul(op) => op.execute::<BT>(context, |index| {
38                let operation = execution.operation_within_optimization(index);
39                Box::new(FallbackOperationWrapper::new(operation))
40            }),
41            Self::Reduce(op) => op.execute::<BT>(context, |index| {
42                let operation = execution.operation_within_optimization(index);
43                Box::new(FallbackOperationWrapper::new(operation))
44            }),
45        }
46    }
47
48    fn to_state(&self) -> CubeOptimizationState {
49        self.to_opt_state()
50    }
51
52    fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
53        match state {
54            CubeOptimizationState::ElementWise(state) => {
55                Self::ElementWise(ElemwiseOptimization::from_state(device, state))
56            }
57            CubeOptimizationState::Matmul(state) => {
58                Self::Matmul(MatmulOptimization::from_state(device, state))
59            }
60            CubeOptimizationState::Reduce(state) => {
61                Self::Reduce(ReduceOptimization::from_state(device, state))
62            }
63        }
64    }
65}
66
67struct FallbackOperationWrapper<O: Clone> {
68    operation: O,
69}
70
71impl<O: Clone> FallbackOperationWrapper<O> {
72    fn new(op: O) -> Self {
73        Self { operation: op }
74    }
75}
76
77impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
78    for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
79{
80    fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
81        self.operation.as_ref().execute(context.handles);
82    }
83}
84
85impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
86    for CubeBackend<R, F, I, BT>
87{
88    type Handle = CubeFusionHandle<R>;
89
90    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
91        into_tensor(handle.handle, handle.shape)
92    }
93
94    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
95        into_tensor(handle.handle, handle.shape)
96    }
97
98    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
99        into_tensor(handle.handle, handle.shape)
100    }
101
102    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
103        into_tensor(handle.handle, handle.shape)
104    }
105
106    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
107        tensor.into()
108    }
109
110    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
111        tensor.into()
112    }
113
114    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
115        tensor.into()
116    }
117
118    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
119        tensor.into()
120    }
121}
122
123impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
124    type OptimizationState = CubeOptimizationState;
125    type Optimization = CubeOptimization<R>;
126    type FusionHandle = CubeFusionHandle<R>;
127    type FusionDevice = R::CubeDevice;
128    type BoolRepr = BT;
129
130    fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
131        vec![
132            Box::new(ElementWiseFuser::new(
133                device.clone(),
134                BT::as_type_native_unchecked().into(),
135            )),
136            Box::new(MatmulFuser::new(
137                device.clone(),
138                BT::as_type_native_unchecked().into(),
139            )),
140            Box::new(ReduceFuser::new(
141                device.clone(),
142                BT::as_type_native_unchecked().into(),
143            )),
144        ]
145    }
146}
147
148/// Fusion runtime for JIT runtimes.
149#[derive(Debug)]
150pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
151    _b: PhantomData<R>,
152    _bool: PhantomData<BT>,
153}
154
155impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
156    for CubeBackend<R, F, I, BT>
157{
158    type FusionRuntime = FusionCubeRuntime<R, BT>;
159
160    type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
161
162    fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
163        kernel::cast(tensor, dtype).into()
164    }
165}
166
167fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
168    CubeTensor {
169        client: handle.client,
170        handle: handle.handle,
171        device: handle.device,
172        shape,
173        strides: handle.strides,
174        dtype: handle.dtype,
175        qparams: handle.qparams,
176    }
177}
178
179impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
180    fn from(value: CubeTensor<R>) -> Self {
181        Self {
182            client: value.client,
183            handle: value.handle,
184            device: value.device,
185            strides: value.strides,
186            dtype: value.dtype,
187            qparams: value.qparams,
188        }
189    }
190}