1use crate::BoolElement;
2use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
3use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
4use burn_backend::{DType, Shape};
5use burn_cubecl_fusion::{
6 CubeFusionHandle, FallbackOperation,
7 optim::{
8 CubeOptimization, CubeOptimizationState,
9 elemwise::{ElementWiseFuser, ElemwiseOptimization},
10 matmul::{MatmulFuser, MatmulOptimization},
11 reduce::{ReduceFuser, ReduceOptimization},
12 },
13};
14use burn_fusion::{
15 FusionBackend, FusionRuntime,
16 stream::{Operation, OrderedExecution},
17};
18use burn_ir::{BackendIr, TensorHandle};
19use core::marker::PhantomData;
20use std::sync::Arc;
21
22impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
23where
24 R: CubeRuntime,
25 BT: BoolElement,
26{
27 fn execute(
28 &mut self,
29 context: &mut burn_fusion::stream::Context<
30 '_,
31 <FusionCubeRuntime<R, BT> as FusionRuntime>::FusionHandle,
32 >,
33 execution: &OrderedExecution<FusionCubeRuntime<R, BT>>,
34 ) {
35 match self {
36 Self::ElementWise(op) => op.execute::<BT>(context),
37 Self::Matmul(op) => op.execute::<BT>(context, |index| {
38 let operation = execution.operation_within_optimization(index);
39 Box::new(FallbackOperationWrapper::new(operation))
40 }),
41 Self::Reduce(op) => op.execute::<BT>(context, |index| {
42 let operation = execution.operation_within_optimization(index);
43 Box::new(FallbackOperationWrapper::new(operation))
44 }),
45 }
46 }
47
48 fn to_state(&self) -> CubeOptimizationState {
49 self.to_opt_state()
50 }
51
52 fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
53 match state {
54 CubeOptimizationState::ElementWise(state) => {
55 Self::ElementWise(ElemwiseOptimization::from_state(device, state))
56 }
57 CubeOptimizationState::Matmul(state) => {
58 Self::Matmul(MatmulOptimization::from_state(device, state))
59 }
60 CubeOptimizationState::Reduce(state) => {
61 Self::Reduce(ReduceOptimization::from_state(device, state))
62 }
63 }
64 }
65}
66
67struct FallbackOperationWrapper<O: Clone> {
68 operation: O,
69}
70
71impl<O: Clone> FallbackOperationWrapper<O> {
72 fn new(op: O) -> Self {
73 Self { operation: op }
74 }
75}
76
77impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
78 for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
79{
80 fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
81 self.operation.as_ref().execute(context.handles);
82 }
83}
84
85impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
86 for CubeBackend<R, F, I, BT>
87{
88 type Handle = CubeFusionHandle<R>;
89
90 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
91 into_tensor(handle.handle, handle.shape)
92 }
93
94 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
95 into_tensor(handle.handle, handle.shape)
96 }
97
98 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
99 into_tensor(handle.handle, handle.shape)
100 }
101
102 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
103 into_tensor(handle.handle, handle.shape)
104 }
105
106 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
107 tensor.into()
108 }
109
110 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
111 tensor.into()
112 }
113
114 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
115 tensor.into()
116 }
117
118 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
119 tensor.into()
120 }
121}
122
123impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
124 type OptimizationState = CubeOptimizationState;
125 type Optimization = CubeOptimization<R>;
126 type FusionHandle = CubeFusionHandle<R>;
127 type FusionDevice = R::CubeDevice;
128 type BoolRepr = BT;
129
130 fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
131 vec![
132 Box::new(ElementWiseFuser::new(
133 device.clone(),
134 BT::as_type_native_unchecked().into(),
135 )),
136 Box::new(MatmulFuser::new(
137 device.clone(),
138 BT::as_type_native_unchecked().into(),
139 )),
140 Box::new(ReduceFuser::new(
141 device.clone(),
142 BT::as_type_native_unchecked().into(),
143 )),
144 ]
145 }
146}
147
148#[derive(Debug)]
150pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
151 _b: PhantomData<R>,
152 _bool: PhantomData<BT>,
153}
154
155impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
156 for CubeBackend<R, F, I, BT>
157{
158 type FusionRuntime = FusionCubeRuntime<R, BT>;
159
160 type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
161
162 fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
163 kernel::cast(tensor, dtype).into()
164 }
165}
166
167fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
168 CubeTensor {
169 client: handle.client,
170 handle: handle.handle,
171 device: handle.device,
172 shape,
173 strides: handle.strides,
174 dtype: handle.dtype,
175 qparams: handle.qparams,
176 }
177}
178
179impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
180 fn from(value: CubeTensor<R>) -> Self {
181 Self {
182 client: value.client,
183 handle: value.handle,
184 device: value.device,
185 strides: value.strides,
186 dtype: value.dtype,
187 qparams: value.qparams,
188 }
189 }
190}