Skip to main content

burn_fusion/
backend.rs

1use crate::{
2    FusionTensor,
3    client::GlobalFusionClient,
4    stream::{Context, OrderedExecution},
5};
6use burn_backend::{
7    Backend, BackendTypes, DType, DeviceOps, ExecutionError,
8    tensor::{BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor},
9};
10use burn_ir::{BackendIr, OperationIr, TensorHandle};
11use serde::{Serialize, de::DeserializeOwned};
12use std::marker::PhantomData;
13
14/// Get the client for the given device.
15pub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
16    GlobalFusionClient::load(device)
17}
18
19/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
20#[derive(Clone, Debug, Default)]
21pub struct Fusion<B: FusionBackend> {
22    _backend: PhantomData<B>,
23}
24
25impl<B: FusionBackend> BackendTypes for Fusion<B> {
26    type Device = B::Device;
27
28    type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
29
30    type FloatElem = B::FloatElem;
31
32    type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
33
34    type IntElem = B::IntElem;
35
36    type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
37
38    type BoolElem = B::BoolElem;
39
40    type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
41}
42
43impl<B: FusionBackend> Backend for Fusion<B> {
44    fn name(device: &Self::Device) -> String {
45        format!("fusion<{}>", B::name(device))
46    }
47
48    fn seed(device: &B::Device, seed: u64) {
49        let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
50        let device = device.clone();
51        client.sync(move || B::seed(&device, seed));
52    }
53
54    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
55        let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
56        let device = device.clone();
57        client.sync(move || B::sync(&device))
58    }
59
60    fn ad_enabled(_device: &Self::Device) -> bool {
61        false
62    }
63
64    fn memory_persistent_allocations<
65        Output: Send,
66        Input: Send,
67        Func: Fn(Input) -> Output + Send,
68    >(
69        device: &Self::Device,
70        input: Input,
71        func: Func,
72    ) -> Output {
73        B::memory_persistent_allocations(device, input, func)
74    }
75
76    fn memory_cleanup(device: &Self::Device) {
77        B::memory_cleanup(device)
78    }
79
80    fn staging<'a, Iter>(data: Iter, device: &Self::Device)
81    where
82        Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
83    {
84        B::staging(data, device);
85    }
86
87    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
88        B::supports_dtype(device, dtype)
89    }
90
91    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
92        B::dtype_usage(device, dtype)
93    }
94
95    fn device_count(type_id: u16) -> usize {
96        B::device_count(type_id)
97    }
98}
99
100/// The status of a [fuser](OperationFuser).
101#[derive(Clone, Debug, Copy, PartialEq, Eq)]
102pub enum FuserStatus {
103    /// No more operations can be fused.
104    Closed,
105    /// More operations can be fused.
106    Open,
107}
108
109/// The properties of a [fuser](OperationFuser).
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
111pub struct FuserProperties {
112    /// The score of the optimization, higher is better.
113    pub score: u64,
114    /// If the operation is ready to be executed.
115    pub ready: bool,
116}
117
118/// The fusion operation abstraction allows implementations to fuse many
119/// [tensor operations](OperationIr) into one, improving the performance of the backend.
120///
121///
122/// # Notes
123///
124/// The implementations are free to execute the registered operations the way they want to improve
125/// the speed and efficiency of the computational graph. It doesn't mean that all registered
126/// operations should be fused, but that another way of executing them is more efficient.
127///
128/// Also, it is important to return (FuserStatus::Closed) when no more registered operation can
129/// improve the performance.
130pub trait OperationFuser<O>: Send {
131    /// Register a new [tensor operation](OperationIr).
132    fn fuse(&mut self, operation: &OperationIr);
133    /// Finish the optimization and create a fusion operation.
134    fn finish(&mut self) -> O;
135    /// Reset the state.
136    fn reset(&mut self);
137    /// Return the builder [status](FuserStatus).
138    fn status(&self) -> FuserStatus;
139    /// Return the builder [properties](FuserProperties).
140    fn properties(&self) -> FuserProperties;
141    /// The number of operation fused.
142    fn len(&self) -> usize;
143    /// If no operations are fused.
144    fn is_empty(&self) -> bool {
145        self.len() == 0
146    }
147    /// Clone the optimization builder.
148    fn clone_dyn(&self) -> Box<dyn OperationFuser<O>>;
149}
150
151/// The number of operations contained in the data structure.
152pub trait NumOperations: core::fmt::Debug {
153    /// The number of registered operations.
154    fn len(&self) -> usize;
155    /// If the current optimization is empty.
156    fn is_empty(&self) -> bool {
157        self.len() == 0
158    }
159    /// The name of the optimization.
160    fn name(&self) -> &'static str;
161}
162
163/// The optimization created from a [fuser](OperationFuser).
164pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
165    /// Execute the optimization.
166    fn execute(&mut self, context: &mut Context<R::FusionHandle>, execution: &OrderedExecution<R>);
167
168    /// Returns the state that can be serialized.
169    fn to_state(&self) -> R::OptimizationState;
170    /// Create the optimization from the state.
171    fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
172}
173
174/// Type alias for `<R as FusionRuntime>::FusionDevice`.
175pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
176/// Type alias for `<R as FusionRuntime>::FusionHandle`.
177pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
178/// Client alias.
179pub type Client<R> = GlobalFusionClient<R>;
180
181/// Trait that defines a runtime that will benefits from fused operations.
182pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {
183    /// The state that can be serialized for an optimization.
184    type OptimizationState: Serialize + DeserializeOwned;
185    /// Optimization type for the backend.
186    type Optimization: Optimization<Self>;
187    /// Handle used to store tensor dynamically.
188    type FusionHandle: Clone + Send;
189    /// Device used by the runtime.
190    type FusionDevice: DeviceOps;
191
192    /// The list of fusers that will be used to optimize the computational graph.
193    fn fusers(device: Self::FusionDevice) -> Vec<Box<dyn OperationFuser<Self::Optimization>>>;
194}
195
196/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
197/// [operation fuser](crate::OperationFuser).
198pub trait FusionBackend:
199    BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
200{
201    /// The runtime used for this backend.
202    type FusionRuntime: FusionRuntime;
203
204    /// Cast a float tensor and returns the resulting handle.
205    fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle;
206
207    /// Pointer to the full precision fusion backend.
208    type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
209}
210
211// Fusion implements `BackendIr` to enable router backend usage.
212impl<B: FusionBackend> BackendIr for Fusion<B> {
213    type Handle = FusionTensor<B::FusionRuntime>;
214
215    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
216        handle.handle
217    }
218
219    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
220        handle.handle
221    }
222
223    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
224        handle.handle
225    }
226
227    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
228        handle.handle
229    }
230
231    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
232        tensor
233    }
234
235    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
236        tensor
237    }
238
239    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
240        tensor
241    }
242
243    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
244        tensor
245    }
246}