burn_cubecl/
fusion.rs

1use crate::BoolElement;
2use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
3
4use burn_cubecl_fusion::elemwise::optimization::ElemwiseOptimization;
5use burn_cubecl_fusion::matmul::builder::MatmulBuilder;
6use burn_cubecl_fusion::matmul::optimization::MatmulOptimization;
7use burn_cubecl_fusion::reduce::builder::ReduceBuilder;
8use burn_cubecl_fusion::reduce::optimization::ReduceOptimization;
9use burn_cubecl_fusion::{CubeFusionHandle, FallbackOperation};
10use burn_cubecl_fusion::{
11    CubeOptimization, CubeOptimizationState, elemwise::builder::ElementWiseBuilder,
12};
13use burn_fusion::stream::{Operation, OrderedExecution};
14use burn_fusion::{FusionBackend, FusionRuntime};
15use burn_ir::{BackendIr, TensorHandle};
16use burn_tensor::{DType, Shape};
17use core::marker::PhantomData;
18use half::{bf16, f16};
19use std::sync::Arc;
20
21impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
22where
23    R: CubeRuntime,
24    BT: BoolElement,
25{
26    fn execute(
27        &mut self,
28        context: &mut burn_fusion::stream::Context<
29            '_,
30            <FusionCubeRuntime<R, BT> as FusionRuntime>::FusionHandle,
31        >,
32        execution: &OrderedExecution<FusionCubeRuntime<R, BT>>,
33    ) {
34        match self {
35            Self::ElementWise(op) => op.execute::<BT>(context),
36            Self::Matmul(op) => op.execute::<BT>(context, |index| {
37                let operation = execution.operation_within_optimization(index);
38                Box::new(FallbackOperationWrapper::new(operation))
39            }),
40            Self::Reduce(op) => op.execute::<BT>(context, |index| {
41                let operation = execution.operation_within_optimization(index);
42                Box::new(FallbackOperationWrapper::new(operation))
43            }),
44        }
45    }
46
47    fn to_state(&self) -> CubeOptimizationState {
48        self.to_opt_state()
49    }
50
51    fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
52        match state {
53            CubeOptimizationState::ElementWise(state) => {
54                Self::ElementWise(ElemwiseOptimization::from_state(device, state))
55            }
56            CubeOptimizationState::Matmul(state) => {
57                Self::Matmul(MatmulOptimization::from_state(device, state))
58            }
59            CubeOptimizationState::Reduce(state) => {
60                Self::Reduce(ReduceOptimization::from_state(device, state))
61            }
62        }
63    }
64}
65
66struct FallbackOperationWrapper<O: Clone> {
67    operation: O,
68}
69
70impl<O: Clone> FallbackOperationWrapper<O> {
71    fn new(op: O) -> Self {
72        Self { operation: op }
73    }
74}
75
76impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
77    for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
78{
79    fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
80        self.operation.as_ref().execute(context.handles);
81    }
82}
83
84impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
85    for CubeBackend<R, F, I, BT>
86{
87    type Handle = CubeFusionHandle<R>;
88
89    fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
90        into_tensor(handle.handle, handle.shape)
91    }
92
93    fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
94        into_tensor(handle.handle, handle.shape)
95    }
96
97    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
98        into_tensor(handle.handle, handle.shape)
99    }
100
101    fn quantized_tensor(
102        handle: TensorHandle<Self::Handle>,
103    ) -> burn_tensor::ops::QuantizedTensor<Self> {
104        into_tensor(handle.handle, handle.shape)
105    }
106
107    fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
108        tensor.into()
109    }
110
111    fn int_tensor_handle(tensor: burn_tensor::ops::IntTensor<Self>) -> Self::Handle {
112        tensor.into()
113    }
114
115    fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
116        tensor.into()
117    }
118
119    fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
120        tensor.into()
121    }
122}
123
124impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
125    type OptimizationState = CubeOptimizationState;
126    type Optimization = CubeOptimization<R>;
127    type FusionHandle = CubeFusionHandle<R>;
128    type FusionDevice = R::CubeDevice;
129    type BoolRepr = BT;
130
131    fn optimizations(
132        device: R::Device,
133    ) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
134        vec![
135            Box::new(ElementWiseBuilder::<R>::new(
136                device.clone(),
137                BT::as_type_native_unchecked().into(),
138            )),
139            Box::new(MatmulBuilder::<R>::new(
140                device.clone(),
141                BT::as_type_native_unchecked().into(),
142            )),
143            Box::new(ReduceBuilder::<R>::new(
144                device.clone(),
145                BT::as_type_native_unchecked().into(),
146            )),
147        ]
148    }
149}
150
151/// Fusion runtime for JIT runtimes.
152#[derive(Debug)]
153pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
154    _b: PhantomData<R>,
155    _bool: PhantomData<BT>,
156}
157
158impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
159    for CubeBackend<R, F, I, BT>
160{
161    type FusionRuntime = FusionCubeRuntime<R, BT>;
162
163    type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
164
165    fn cast_float(tensor: burn_tensor::ops::FloatTensor<Self>, dtype: DType) -> Self::Handle {
166        fn cast<R: CubeRuntime, F: FloatElement, FTarget: FloatElement>(
167            tensor: CubeTensor<R>,
168        ) -> CubeFusionHandle<R> {
169            CubeFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
170        }
171
172        match dtype {
173            DType::F32 | DType::Flex32 => cast::<R, F, f32>(tensor),
174            DType::F16 => cast::<R, F, f16>(tensor),
175            DType::BF16 => cast::<R, F, bf16>(tensor),
176            _ => panic!("Casting error: {dtype:?} unsupported."),
177        }
178    }
179}
180
181fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
182    CubeTensor {
183        client: handle.client,
184        handle: handle.handle,
185        device: handle.device,
186        shape,
187        strides: handle.strides,
188        dtype: handle.dtype,
189        qparams: handle.qparams,
190    }
191}
192
193impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
194    fn from(value: CubeTensor<R>) -> Self {
195        Self {
196            client: value.client,
197            handle: value.handle,
198            device: value.device,
199            strides: value.strides,
200            dtype: value.dtype,
201            qparams: value.qparams,
202        }
203    }
204}