1use crate::BoolElement;
2use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
3use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
4use burn_backend::{DType, Shape};
5use burn_cubecl_fusion::optim::reduce::ReduceSettings;
6use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser;
7use burn_cubecl_fusion::{
8 CubeFusionHandle, FallbackOperation,
9 optim::{
10 CubeOptimization, CubeOptimizationState,
11 elemwise::{ElementWiseFuser, ElemwiseOptimization},
12 matmul::{MatmulFuser, MatmulOptimization},
13 reduce::{ReduceFuser, ReduceOptimization},
14 reduce_broadcasted::ReduceBroadcastedOptimization,
15 },
16};
17use burn_fusion::UnfusedOp;
18use burn_fusion::{
19 FusionBackend, FusionRuntime,
20 stream::{Operation, OrderedExecution},
21};
22use burn_ir::{BackendIr, TensorHandle};
23use burn_std::Metadata;
24use core::marker::PhantomData;
25use std::sync::Arc;
26
27impl<R> burn_fusion::Optimization<FusionCubeRuntime<R>> for CubeOptimization<R>
28where
29 R: CubeRuntime,
30{
31 fn execute(
32 &mut self,
33 context: &mut burn_fusion::stream::Context<
34 <FusionCubeRuntime<R> as FusionRuntime>::FusionHandle,
35 >,
36 execution: &OrderedExecution<FusionCubeRuntime<R>>,
37 ) {
38 match self {
39 Self::ElementWise(op) => op.execute(context),
40 Self::Matmul(op) => op.execute(context, |index| {
41 let operation = execution.operation_within_optimization(index);
42 Box::new(FallbackOperationWrapper::new(operation))
43 }),
44 Self::Reduce(op) => op.execute(context, |index| {
45 let operation = execution.operation_within_optimization(index);
46 Box::new(FallbackOperationWrapper::new(operation))
47 }),
48 Self::ReduceBroadcasted(op) => op.execute(context, |index| {
49 let operation = execution.operation_within_optimization(index);
50 Box::new(FallbackOperationWrapper::new(operation))
51 }),
52 }
53 }
54
55 fn to_state(&self) -> CubeOptimizationState {
56 self.to_opt_state()
57 }
58
59 fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
60 match state {
61 CubeOptimizationState::ElementWise(state) => {
62 Self::ElementWise(ElemwiseOptimization::from_state(device, state))
63 }
64 CubeOptimizationState::Matmul(state) => {
65 Self::Matmul(MatmulOptimization::from_state(device, state))
66 }
67 CubeOptimizationState::Reduce(state) => {
68 Self::Reduce(ReduceOptimization::from_state(device, state))
69 }
70 CubeOptimizationState::ReduceBroadcasted(state) => {
71 Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state))
72 }
73 }
74 }
75}
76
77struct FallbackOperationWrapper<O: Clone> {
78 operation: O,
79}
80
81impl<O: Clone> FallbackOperationWrapper<O> {
82 fn new(op: O) -> Self {
83 Self { operation: op }
84 }
85}
86
87impl<R: CubeRuntime> FallbackOperation<R>
88 for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R>>>>
89{
90 fn run(&self, context: &mut burn_fusion::stream::Context<CubeFusionHandle<R>>) {
91 self.operation.as_ref().execute(&mut context.handles);
92 }
93}
94
95impl<R: CubeRuntime> FallbackOperation<R>
96 for FallbackOperationWrapper<UnfusedOp<FusionCubeRuntime<R>>>
97{
98 fn run(&self, context: &mut burn_fusion::stream::Context<CubeFusionHandle<R>>) {
99 self.operation.execute(&mut context.handles);
100 }
101}
102
103impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
104 for CubeBackend<R, F, I, BT>
105{
106 type Handle = CubeFusionHandle<R>;
107
108 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
109 into_tensor(handle.handle, handle.shape)
110 }
111
112 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
113 into_tensor(handle.handle, handle.shape)
114 }
115
116 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
117 into_tensor(handle.handle, handle.shape)
118 }
119
120 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
121 into_tensor(handle.handle, handle.shape)
122 }
123
124 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
125 tensor.into()
126 }
127
128 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
129 tensor.into()
130 }
131
132 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
133 tensor.into()
134 }
135
136 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
137 tensor.into()
138 }
139}
140
141impl<R: CubeRuntime> FusionRuntime for FusionCubeRuntime<R> {
142 type OptimizationState = CubeOptimizationState;
143 type Optimization = CubeOptimization<R>;
144 type FusionHandle = CubeFusionHandle<R>;
145 type FusionDevice = R::CubeDevice;
146
147 fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
148 vec![
149 Box::new(ElementWiseFuser::new(device.clone())),
150 Box::new(MatmulFuser::new(device.clone())),
151 Box::new(ReduceFuser::new(device.clone(), ReduceSettings::Always)),
152 Box::new(ReduceBroadcastedFuser::new(device.clone())),
153 ]
154 }
155}
156
157#[derive(Debug)]
159pub struct FusionCubeRuntime<R: CubeRuntime> {
160 _b: PhantomData<R>,
161}
162
163impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
164 for CubeBackend<R, F, I, BT>
165{
166 type FusionRuntime = FusionCubeRuntime<R>;
167
168 type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
169
170 fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
171 kernel::cast(tensor, dtype).into()
172 }
173}
174
175fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
176 CubeTensor {
177 client: handle.client.clone(),
178 handle: handle.handle.clone(),
179 device: handle.device.clone(),
180 meta: Box::new(Metadata::new(shape, handle.strides.clone())),
181 dtype: handle.dtype,
182 qparams: handle.qparams.clone(),
183 }
184}
185
186impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
187 fn from(value: CubeTensor<R>) -> Self {
188 Self {
189 client: value.client.clone(),
190 handle: value.handle.clone(),
191 device: value.device.clone(),
192 strides: value.meta.strides.clone(),
193 dtype: value.dtype,
194 qparams: value.qparams.clone(),
195 }
196 }
197}