1use crate::BoolElement;
2use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
3
4use burn_cubecl_fusion::elemwise::optimization::ElemwiseOptimization;
5use burn_cubecl_fusion::matmul::builder::MatmulBuilder;
6use burn_cubecl_fusion::matmul::optimization::MatmulOptimization;
7use burn_cubecl_fusion::reduce::builder::ReduceBuilder;
8use burn_cubecl_fusion::reduce::optimization::ReduceOptimization;
9use burn_cubecl_fusion::{CubeFusionHandle, FallbackOperation};
10use burn_cubecl_fusion::{
11 CubeOptimization, CubeOptimizationState, elemwise::builder::ElementWiseBuilder,
12};
13use burn_fusion::stream::{Operation, OrderedExecution};
14use burn_fusion::{FusionBackend, FusionRuntime};
15use burn_ir::{BackendIr, TensorHandle};
16use burn_tensor::{DType, Shape};
17use core::marker::PhantomData;
18use half::{bf16, f16};
19use std::sync::Arc;
20
21impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
22where
23 R: CubeRuntime,
24 BT: BoolElement,
25{
26 fn execute(
27 &mut self,
28 context: &mut burn_fusion::stream::Context<
29 '_,
30 <FusionCubeRuntime<R, BT> as FusionRuntime>::FusionHandle,
31 >,
32 execution: &OrderedExecution<FusionCubeRuntime<R, BT>>,
33 ) {
34 match self {
35 Self::ElementWise(op) => op.execute::<BT>(context),
36 Self::Matmul(op) => op.execute::<BT>(context, |index| {
37 let operation = execution.operation_within_optimization(index);
38 Box::new(FallbackOperationWrapper::new(operation))
39 }),
40 Self::Reduce(op) => op.execute::<BT>(context, |index| {
41 let operation = execution.operation_within_optimization(index);
42 Box::new(FallbackOperationWrapper::new(operation))
43 }),
44 }
45 }
46
47 fn to_state(&self) -> CubeOptimizationState {
48 self.to_opt_state()
49 }
50
51 fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
52 match state {
53 CubeOptimizationState::ElementWise(state) => {
54 Self::ElementWise(ElemwiseOptimization::from_state(device, state))
55 }
56 CubeOptimizationState::Matmul(state) => {
57 Self::Matmul(MatmulOptimization::from_state(device, state))
58 }
59 CubeOptimizationState::Reduce(state) => {
60 Self::Reduce(ReduceOptimization::from_state(device, state))
61 }
62 }
63 }
64}
65
66struct FallbackOperationWrapper<O: Clone> {
67 operation: O,
68}
69
70impl<O: Clone> FallbackOperationWrapper<O> {
71 fn new(op: O) -> Self {
72 Self { operation: op }
73 }
74}
75
76impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
77 for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
78{
79 fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
80 self.operation.as_ref().execute(context.handles);
81 }
82}
83
84impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
85 for CubeBackend<R, F, I, BT>
86{
87 type Handle = CubeFusionHandle<R>;
88
89 fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
90 into_tensor(handle.handle, handle.shape)
91 }
92
93 fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
94 into_tensor(handle.handle, handle.shape)
95 }
96
97 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
98 into_tensor(handle.handle, handle.shape)
99 }
100
101 fn quantized_tensor(
102 handle: TensorHandle<Self::Handle>,
103 ) -> burn_tensor::ops::QuantizedTensor<Self> {
104 into_tensor(handle.handle, handle.shape)
105 }
106
107 fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
108 tensor.into()
109 }
110
111 fn int_tensor_handle(tensor: burn_tensor::ops::IntTensor<Self>) -> Self::Handle {
112 tensor.into()
113 }
114
115 fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
116 tensor.into()
117 }
118
119 fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
120 tensor.into()
121 }
122}
123
124impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
125 type OptimizationState = CubeOptimizationState;
126 type Optimization = CubeOptimization<R>;
127 type FusionHandle = CubeFusionHandle<R>;
128 type FusionDevice = R::CubeDevice;
129 type BoolRepr = BT;
130
131 fn optimizations(
132 device: R::Device,
133 ) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
134 vec![
135 Box::new(ElementWiseBuilder::<R>::new(
136 device.clone(),
137 BT::as_type_native_unchecked().into(),
138 )),
139 Box::new(MatmulBuilder::<R>::new(
140 device.clone(),
141 BT::as_type_native_unchecked().into(),
142 )),
143 Box::new(ReduceBuilder::<R>::new(
144 device.clone(),
145 BT::as_type_native_unchecked().into(),
146 )),
147 ]
148 }
149}
150
151#[derive(Debug)]
153pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
154 _b: PhantomData<R>,
155 _bool: PhantomData<BT>,
156}
157
158impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
159 for CubeBackend<R, F, I, BT>
160{
161 type FusionRuntime = FusionCubeRuntime<R, BT>;
162
163 type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
164
165 fn cast_float(tensor: burn_tensor::ops::FloatTensor<Self>, dtype: DType) -> Self::Handle {
166 fn cast<R: CubeRuntime, F: FloatElement, FTarget: FloatElement>(
167 tensor: CubeTensor<R>,
168 ) -> CubeFusionHandle<R> {
169 CubeFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
170 }
171
172 match dtype {
173 DType::F32 | DType::Flex32 => cast::<R, F, f32>(tensor),
174 DType::F16 => cast::<R, F, f16>(tensor),
175 DType::BF16 => cast::<R, F, bf16>(tensor),
176 _ => panic!("Casting error: {dtype:?} unsupported."),
177 }
178 }
179}
180
181fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
182 CubeTensor {
183 client: handle.client,
184 handle: handle.handle,
185 device: handle.device,
186 shape,
187 strides: handle.strides,
188 dtype: handle.dtype,
189 qparams: handle.qparams,
190 }
191}
192
193impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
194 fn from(value: CubeTensor<R>) -> Self {
195 Self {
196 client: value.client,
197 handle: value.handle,
198 device: value.device,
199 strides: value.strides,
200 dtype: value.dtype,
201 qparams: value.qparams,
202 }
203 }
204}