Skip to main content

burn_cubecl/
fusion.rs

1use crate::BoolElement;
2use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
3use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
4use burn_backend::{DType, Shape};
5use burn_cubecl_fusion::optim::reduce::ReduceSettings;
6use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser;
7use burn_cubecl_fusion::{
8    CubeFusionHandle, FallbackOperation,
9    optim::{
10        CubeOptimization, CubeOptimizationState,
11        elemwise::{ElementWiseFuser, ElemwiseOptimization},
12        matmul::{MatmulFuser, MatmulOptimization},
13        reduce::{ReduceFuser, ReduceOptimization},
14        reduce_broadcasted::ReduceBroadcastedOptimization,
15    },
16};
17use burn_fusion::UnfusedOp;
18use burn_fusion::{
19    FusionBackend, FusionRuntime,
20    stream::{Operation, OrderedExecution},
21};
22use burn_ir::{BackendIr, TensorHandle};
23use burn_std::Metadata;
24use core::marker::PhantomData;
25use std::sync::Arc;
26
27impl<R> burn_fusion::Optimization<FusionCubeRuntime<R>> for CubeOptimization<R>
28where
29    R: CubeRuntime,
30{
31    fn execute(
32        &mut self,
33        context: &mut burn_fusion::stream::Context<
34            <FusionCubeRuntime<R> as FusionRuntime>::FusionHandle,
35        >,
36        execution: &OrderedExecution<FusionCubeRuntime<R>>,
37    ) {
38        match self {
39            Self::ElementWise(op) => op.execute(context),
40            Self::Matmul(op) => op.execute(context, |index| {
41                let operation = execution.operation_within_optimization(index);
42                Box::new(FallbackOperationWrapper::new(operation))
43            }),
44            Self::Reduce(op) => op.execute(context, |index| {
45                let operation = execution.operation_within_optimization(index);
46                Box::new(FallbackOperationWrapper::new(operation))
47            }),
48            Self::ReduceBroadcasted(op) => op.execute(context, |index| {
49                let operation = execution.operation_within_optimization(index);
50                Box::new(FallbackOperationWrapper::new(operation))
51            }),
52        }
53    }
54
55    fn to_state(&self) -> CubeOptimizationState {
56        self.to_opt_state()
57    }
58
59    fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
60        match state {
61            CubeOptimizationState::ElementWise(state) => {
62                Self::ElementWise(ElemwiseOptimization::from_state(device, state))
63            }
64            CubeOptimizationState::Matmul(state) => {
65                Self::Matmul(MatmulOptimization::from_state(device, state))
66            }
67            CubeOptimizationState::Reduce(state) => {
68                Self::Reduce(ReduceOptimization::from_state(device, state))
69            }
70            CubeOptimizationState::ReduceBroadcasted(state) => {
71                Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state))
72            }
73        }
74    }
75}
76
77struct FallbackOperationWrapper<O: Clone> {
78    operation: O,
79}
80
81impl<O: Clone> FallbackOperationWrapper<O> {
82    fn new(op: O) -> Self {
83        Self { operation: op }
84    }
85}
86
87impl<R: CubeRuntime> FallbackOperation<R>
88    for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R>>>>
89{
90    fn run(&self, context: &mut burn_fusion::stream::Context<CubeFusionHandle<R>>) {
91        self.operation.as_ref().execute(&mut context.handles);
92    }
93}
94
95impl<R: CubeRuntime> FallbackOperation<R>
96    for FallbackOperationWrapper<UnfusedOp<FusionCubeRuntime<R>>>
97{
98    fn run(&self, context: &mut burn_fusion::stream::Context<CubeFusionHandle<R>>) {
99        self.operation.execute(&mut context.handles);
100    }
101}
102
103impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
104    for CubeBackend<R, F, I, BT>
105{
106    type Handle = CubeFusionHandle<R>;
107
108    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
109        into_tensor(handle.handle, handle.shape)
110    }
111
112    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
113        into_tensor(handle.handle, handle.shape)
114    }
115
116    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
117        into_tensor(handle.handle, handle.shape)
118    }
119
120    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
121        into_tensor(handle.handle, handle.shape)
122    }
123
124    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
125        tensor.into()
126    }
127
128    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
129        tensor.into()
130    }
131
132    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
133        tensor.into()
134    }
135
136    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
137        tensor.into()
138    }
139}
140
141impl<R: CubeRuntime> FusionRuntime for FusionCubeRuntime<R> {
142    type OptimizationState = CubeOptimizationState;
143    type Optimization = CubeOptimization<R>;
144    type FusionHandle = CubeFusionHandle<R>;
145    type FusionDevice = R::CubeDevice;
146
147    fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
148        vec![
149            Box::new(ElementWiseFuser::new(device.clone())),
150            Box::new(MatmulFuser::new(device.clone())),
151            Box::new(ReduceFuser::new(device.clone(), ReduceSettings::Always)),
152            Box::new(ReduceBroadcastedFuser::new(device.clone())),
153        ]
154    }
155}
156
157/// Fusion runtime for JIT runtimes.
158#[derive(Debug)]
159pub struct FusionCubeRuntime<R: CubeRuntime> {
160    _b: PhantomData<R>,
161}
162
163impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
164    for CubeBackend<R, F, I, BT>
165{
166    type FusionRuntime = FusionCubeRuntime<R>;
167
168    type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
169
170    fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
171        kernel::cast(tensor, dtype).into()
172    }
173}
174
175fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
176    CubeTensor {
177        client: handle.client.clone(),
178        handle: handle.handle.clone(),
179        device: handle.device.clone(),
180        meta: Box::new(Metadata::new(shape, handle.strides.clone())),
181        dtype: handle.dtype,
182        qparams: handle.qparams.clone(),
183    }
184}
185
186impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
187    fn from(value: CubeTensor<R>) -> Self {
188        Self {
189            client: value.client.clone(),
190            handle: value.handle.clone(),
191            device: value.device.clone(),
192            strides: value.meta.strides.clone(),
193            dtype: value.dtype,
194            qparams: value.qparams.clone(),
195        }
196    }
197}