burn_jit/fusion/
base.rs

1use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState};
2use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState};
3use crate::fusion::elemwise::builder::ElementWiseBuilder;
4use crate::fusion::matmul::builder::MatmulBuilder;
5use crate::BoolElement;
6use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime};
7
8use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
9use burn_tensor::repr::TensorHandle;
10use burn_tensor::DType;
11use burn_tensor::{repr::ReprBackend, Shape};
12use core::marker::PhantomData;
13use cubecl::client::ComputeClient;
14use cubecl::prelude::{TensorArg, TensorHandleRef};
15use half::{bf16, f16};
16use serde::{Deserialize, Serialize};
17
18/// Fusion optimization type for JIT.
19///
20/// More optimization variants should be added here.
21pub enum JitOptimization<R: JitRuntime> {
22    /// Element wise optimization.
23    ElementWise(ElemwiseOptimization<R>),
24    /// Matrix multiplication optimization.
25    Matmul(MatmulOptimization<R>),
26}
27
28/// Fusion optimization state type for JIT.
29///
30/// More optimization variants should be added here.
31#[derive(Serialize, Deserialize)]
32pub enum JitOptimizationState {
33    /// Element wise state.
34    ElementWise(ElemwiseOptimizationState),
35    /// Matrix multiplication optimization state.
36    Matmul(MatmulOptimizationState),
37}
38
39impl<R, BT> burn_fusion::Optimization<FusionJitRuntime<R, BT>> for JitOptimization<R>
40where
41    R: JitRuntime,
42    BT: BoolElement,
43{
44    fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle<R>>) {
45        match self {
46            Self::ElementWise(op) => op.execute::<BT>(context),
47            Self::Matmul(op) => op.execute::<BT>(context),
48        }
49    }
50
51    fn len(&self) -> usize {
52        match self {
53            Self::ElementWise(op) => op.num_ops_fused(),
54            Self::Matmul(op) => op.num_ops_fused(),
55        }
56    }
57
58    fn to_state(&self) -> JitOptimizationState {
59        match self {
60            Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()),
61            Self::Matmul(value) => JitOptimizationState::Matmul(value.to_state()),
62        }
63    }
64
65    fn from_state(device: &R::Device, state: JitOptimizationState) -> Self {
66        match state {
67            JitOptimizationState::ElementWise(state) => {
68                Self::ElementWise(ElemwiseOptimization::from_state(device, state))
69            }
70            JitOptimizationState::Matmul(state) => {
71                Self::Matmul(MatmulOptimization::from_state(device, state))
72            }
73        }
74    }
75}
76
77impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> ReprBackend
78    for JitBackend<R, F, I, BT>
79{
80    type Handle = JitFusionHandle<R>;
81
82    fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
83        handle.handle.into_tensor(handle.shape)
84    }
85
86    fn int_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::IntTensor<Self> {
87        handle.handle.into_tensor(handle.shape)
88    }
89
90    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::BoolTensor<Self> {
91        handle.handle.into_tensor(handle.shape)
92    }
93
94    fn quantized_tensor(
95        handle: TensorHandle<Self::Handle>,
96    ) -> burn_tensor::ops::QuantizedTensor<Self> {
97        handle.handle.into_tensor(handle.shape)
98    }
99
100    fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
101        tensor.into()
102    }
103
104    fn int_tensor_handle(tensor: burn_tensor::ops::IntTensor<Self>) -> Self::Handle {
105        tensor.into()
106    }
107
108    fn bool_tensor_handle(tensor: burn_tensor::ops::BoolTensor<Self>) -> Self::Handle {
109        tensor.into()
110    }
111
112    fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
113        tensor.into()
114    }
115}
116
117impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
118    type OptimizationState = JitOptimizationState;
119    type Optimization = JitOptimization<R>;
120    type FusionHandle = JitFusionHandle<R>;
121    type FusionDevice = R::JitDevice;
122    type FusionClient = MutexFusionClient<Self>;
123    type BoolRepr = BT;
124
125    fn optimizations(
126        device: R::Device,
127    ) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
128        let mut optimizations: Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> =
129            vec![Box::new(ElementWiseBuilder::<R>::new(
130                device.clone(),
131                BT::as_elem_native_unchecked().into(),
132            ))];
133
134        if cfg!(feature = "fusion-experimental") {
135            optimizations.push(Box::new(MatmulBuilder::<R>::new(
136                device.clone(),
137                BT::as_elem_native_unchecked().into(),
138            )));
139        }
140
141        optimizations
142    }
143}
144
145/// Fusion runtime for JIT runtimes.
146#[derive(Debug)]
147pub struct FusionJitRuntime<R: JitRuntime, BT: BoolElement> {
148    _b: PhantomData<R>,
149    _bool: PhantomData<BT>,
150}
151
152impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
153    for JitBackend<R, F, I, BT>
154{
155    type FusionRuntime = FusionJitRuntime<R, BT>;
156
157    type FullPrecisionBackend = JitBackend<R, f32, i32, BT>;
158
159    fn cast_float(
160        tensor: burn_tensor::ops::FloatTensor<Self>,
161        dtype: burn_tensor::DType,
162    ) -> Self::Handle {
163        fn cast<R: JitRuntime, F: FloatElement, FTarget: FloatElement>(
164            tensor: JitTensor<R>,
165        ) -> JitFusionHandle<R> {
166            JitFusionHandle::from(kernel::cast::<R, F, FTarget>(tensor))
167        }
168
169        match dtype {
170            burn_tensor::DType::F32 => cast::<R, F, f32>(tensor),
171            burn_tensor::DType::F16 => cast::<R, F, f16>(tensor),
172            burn_tensor::DType::BF16 => cast::<R, F, bf16>(tensor),
173            _ => panic!("Casting error: {dtype:?} unsupported."),
174        }
175    }
176}
177
178pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
179    let mut strides = vec![0; shape.len()];
180
181    let mut current = 1;
182    shape.iter().enumerate().rev().for_each(|(index, val)| {
183        strides[index] = current;
184        current *= val;
185    });
186
187    strides
188}
189
190/// Handle to be used when fusing operations.
191pub struct JitFusionHandle<R: JitRuntime> {
192    /// Compute client for jit.
193    pub client: ComputeClient<R::Server, R::Channel>,
194    /// The buffer where the data are stored.
195    pub handle: cubecl::server::Handle,
196    /// The device of the current tensor.
197    pub device: R::Device,
198    pub(crate) dtype: DType,
199    pub(crate) strides: Vec<usize>,
200}
201
202impl<R: JitRuntime> core::fmt::Debug for JitFusionHandle<R> {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.write_fmt(format_args!(
205            "JitFusionHandle {{ device: {:?}, runtime: {}}}",
206            self.device,
207            R::name(),
208        ))
209    }
210}
211
212impl<R: JitRuntime> Clone for JitFusionHandle<R> {
213    fn clone(&self) -> Self {
214        Self {
215            client: self.client.clone(),
216            handle: self.handle.clone(),
217            device: self.device.clone(),
218            strides: self.strides.clone(),
219            dtype: self.dtype,
220        }
221    }
222}
223
224unsafe impl<R: JitRuntime> Send for JitFusionHandle<R> {}
225unsafe impl<R: JitRuntime> Sync for JitFusionHandle<R> {}
226
227impl<R: JitRuntime> JitFusionHandle<R> {
228    pub(crate) fn into_tensor(self, shape: Shape) -> JitTensor<R> {
229        JitTensor {
230            client: self.client,
231            handle: self.handle,
232            device: self.device,
233            shape,
234            strides: self.strides,
235            dtype: self.dtype,
236        }
237    }
238    /// Return the reference to a tensor handle.
239    pub fn as_handle_ref<'a>(&'a self, shape: &'a [usize]) -> TensorHandleRef<'a, R> {
240        TensorHandleRef {
241            handle: &self.handle,
242            strides: &self.strides,
243            shape,
244            runtime: PhantomData,
245            elem_size: self.dtype.size(),
246        }
247    }
248    /// Return the reference to a tensor argument.
249    pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> {
250        let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape);
251
252        unsafe {
253            TensorArg::from_raw_parts_and_size(
254                handle.handle,
255                handle.strides,
256                handle.shape,
257                vectorisation,
258                self.dtype.size(),
259            )
260        }
261    }
262}
263
264impl<R: JitRuntime> From<JitTensor<R>> for JitFusionHandle<R> {
265    fn from(value: JitTensor<R>) -> Self {
266        Self {
267            client: value.client,
268            handle: value.handle,
269            device: value.device,
270            strides: value.strides,
271            dtype: value.dtype,
272        }
273    }
274}