burn_fusion/
backend.rs

1use crate::{
2    FusionTensor,
3    client::GlobalFusionClient,
4    stream::{Context, OrderedExecution},
5};
6use burn_ir::{BackendIr, OperationIr, TensorHandle};
7use burn_tensor::{
8    Device, Element,
9    backend::{Backend, DeviceOps},
10    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
11};
12use serde::{Serialize, de::DeserializeOwned};
13use std::marker::PhantomData;
14
15/// Get the client for the given device.
16pub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
17    GlobalFusionClient::load(device)
18}
19
20/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
21#[derive(Clone, Debug, Default)]
22pub struct Fusion<B: FusionBackend> {
23    _backend: PhantomData<B>,
24}
25
26impl<B: FusionBackend> Backend for Fusion<B> {
27    type Device = B::Device;
28
29    type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
30
31    type FloatElem = B::FloatElem;
32
33    type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
34
35    type IntElem = B::IntElem;
36
37    type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
38
39    type BoolElem = B::BoolElem;
40
41    type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
42
43    fn name(device: &Self::Device) -> String {
44        format!("fusion<{}>", B::name(device))
45    }
46
47    fn seed(device: &B::Device, seed: u64) {
48        let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
49        client.drain();
50        B::seed(device, seed);
51    }
52
53    fn sync(device: &Self::Device) {
54        let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
55        client.drain();
56        B::sync(device);
57    }
58
59    fn ad_enabled() -> bool {
60        false
61    }
62
63    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
64        device: &Self::Device,
65        input: Input,
66        func: Func,
67    ) -> Output {
68        B::memory_persistent_allocations(device, input, func)
69    }
70
71    fn memory_cleanup(device: &Self::Device) {
72        B::memory_cleanup(device)
73    }
74}
75
76/// The status of a [builder](OptimizationBuilder).
77#[derive(Clone, Debug, Copy)]
78pub enum OptimizationStatus {
79    /// No more operations can be fused.
80    Closed,
81    /// More operations can be fused.
82    Open,
83}
84
85/// The properties of a [builder](OptimizationProperties).
86#[derive(Debug, Clone, Copy, Default)]
87pub struct OptimizationProperties {
88    /// The score of the optimization, higher is better.
89    pub score: u64,
90    /// If the operation is ready to be executed.
91    pub ready: bool,
92}
93
94/// The fusion operation abstraction allows implementations to fuse many
95/// [tensor operations](OperationIr) into one, improving the performance of the backend.
96///
97///
98/// # Notes
99///
100/// The implementations are free to execute the registered operations the way they want to improve
101/// the speed and efficiency of the computational graph. It doesn't mean that all registered
102/// operations should be fused, but that another way of executing them is more efficient.
103///
104/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
105/// improve the performance.
106pub trait OptimizationBuilder<O>: Send {
107    /// Register a new [tensor operation](OperationIr).
108    fn register(&mut self, operation: &OperationIr);
109    /// Finish the optimization and create a fusion operation.
110    fn build(&self) -> O;
111    /// Reset the state.
112    fn reset(&mut self);
113    /// Return the builder [status](OptimizationStatus).
114    fn status(&self) -> OptimizationStatus;
115    /// Return the builder [properties](OptimizationProperties).
116    fn properties(&self) -> OptimizationProperties;
117    /// The number of operation fused.
118    fn len(&self) -> usize;
119    /// If no operations are fused.
120    fn is_empty(&self) -> bool {
121        self.len() == 0
122    }
123    /// Clone the optimization builder.
124    fn clone_dyn(&self) -> Box<dyn OptimizationBuilder<O>>;
125}
126
127/// The number of operations contained in the data structure.
128pub trait NumOperations: core::fmt::Debug {
129    /// The number of registered operations.
130    fn len(&self) -> usize;
131    /// If the current optimization is empty.
132    fn is_empty(&self) -> bool {
133        self.len() == 0
134    }
135}
136
137/// The operation created from the [builder](OptimizationBuilder).
138pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
139    /// Execute the operation.
140    fn execute(
141        &mut self,
142        context: &mut Context<'_, R::FusionHandle>,
143        execution: &OrderedExecution<R>,
144    );
145
146    /// Returns the state that can be serialized.
147    fn to_state(&self) -> R::OptimizationState;
148    /// Create the optimization from the state.
149    fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
150}
151
152/// Type alias for `<R as FusionRuntime>::FusionDevice`.
153pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
154/// Type alias for `<R as FusionRuntime>::FusionHandle`.
155pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
156/// Client alias.
157pub type Client<R> = GlobalFusionClient<R>;
158
159/// Trait that defines a runtime that will benefits from fused operations.
160pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {
161    /// The state that can be serialized for an optimization.
162    type OptimizationState: Serialize + DeserializeOwned;
163    /// Optimization type for the backend.
164    type Optimization: Optimization<Self>;
165    /// Handle used to store tensor dynamically.
166    type FusionHandle: Clone + Send;
167    /// Device used by the runtime.
168    type FusionDevice: DeviceOps;
169    /// The type that represents booleans on the backend.
170    type BoolRepr: Element;
171
172    /// The list of optimizations that will be used to optimize the computational graph.
173    fn optimizations(
174        device: Self::FusionDevice,
175    ) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
176}
177
178/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
179/// [operation builder](crate::OptimizationBuilder).
180pub trait FusionBackend:
181    BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
182{
183    /// The runtime used for this backend.
184    type FusionRuntime: FusionRuntime;
185
186    /// Cast a float tensor and returns the resulting handle.
187    fn cast_float(tensor: FloatTensor<Self>, dtype: burn_tensor::DType) -> Self::Handle;
188
189    /// Pointer to the full precision fusion backend.
190    type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
191}
192
193// Fusion implements `BackendIr` to enable router backend usage.
194impl<B: FusionBackend> BackendIr for Fusion<B> {
195    type Handle = FusionTensor<B::FusionRuntime>;
196
197    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
198        handle.handle
199    }
200
201    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
202        handle.handle
203    }
204
205    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
206        handle.handle
207    }
208
209    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
210        handle.handle
211    }
212
213    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
214        tensor
215    }
216
217    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
218        tensor
219    }
220
221    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
222        tensor
223    }
224
225    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
226        tensor
227    }
228}