burn_fusion/
backend.rs

1use crate::{
2    FusionClientLocator, FusionTensor,
3    client::FusionClient,
4    stream::{Context, OrderedExecution},
5};
6use burn_ir::{BackendIr, OperationIr, TensorHandle};
7use burn_tensor::{
8    Device, Element,
9    backend::{Backend, DeviceOps},
10    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
11};
12use serde::{Serialize, de::DeserializeOwned};
13use std::marker::PhantomData;
14
15pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
16
17pub(crate) fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
18    CLIENTS.client::<B::FusionRuntime>(device)
19}
20
21/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
22#[derive(Clone, Debug, Default)]
23pub struct Fusion<B: FusionBackend> {
24    _backend: PhantomData<B>,
25}
26
27impl<B: FusionBackend> Backend for Fusion<B> {
28    type Device = B::Device;
29
30    type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
31
32    type FloatElem = B::FloatElem;
33
34    type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
35
36    type IntElem = B::IntElem;
37
38    type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
39
40    type BoolElem = B::BoolElem;
41
42    type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
43
44    type QuantizedEncoding = B::QuantizedEncoding;
45
46    fn name(device: &Self::Device) -> String {
47        format!("fusion<{}>", B::name(device))
48    }
49
50    fn seed(seed: u64) {
51        B::seed(seed);
52    }
53
54    fn sync(device: &Self::Device) {
55        let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
56        client.drain();
57        B::sync(device);
58    }
59
60    fn ad_enabled() -> bool {
61        false
62    }
63}
64
65/// The status of a [builder](OptimizationBuilder).
66#[derive(Clone, Debug, Copy)]
67pub enum OptimizationStatus {
68    /// No more operations can be fused.
69    Closed,
70    /// More operations can be fused.
71    Open,
72}
73
74/// The properties of a [builder](OptimizationProperties).
75#[derive(Debug, Clone, Copy, Default)]
76pub struct OptimizationProperties {
77    /// The score of the optimization, higher is better.
78    pub score: u64,
79    /// If the operation is ready to be executed.
80    pub ready: bool,
81}
82
83/// The fusion operation abstraction allows implementations to fuse many
84/// [tensor operations](OperationIr) into one, improving the performance of the backend.
85///
86///
87/// # Notes
88///
89/// The implementations are free to execute the registered operations the way they want to improve
90/// the speed and efficiency of the computational graph. It doesn't mean that all registered
91/// operations should be fused, but that another way of executing them is more efficient.
92///
93/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
94/// improve the performance.
95pub trait OptimizationBuilder<O>: Send {
96    /// Register a new [tensor operation](OperationIr).
97    fn register(&mut self, operation: &OperationIr);
98    /// Finish the optimization and create a fusion operation.
99    fn build(&self) -> O;
100    /// Reset the state.
101    fn reset(&mut self);
102    /// Return the builder [status](OptimizationStatus).
103    fn status(&self) -> OptimizationStatus;
104    /// Return the builder [properties](OptimizationProperties).
105    fn properties(&self) -> OptimizationProperties;
106    /// The number of operation fused.
107    fn len(&self) -> usize;
108    /// If no operations are fused.
109    fn is_empty(&self) -> bool {
110        self.len() == 0
111    }
112    /// Clone the optimization builder.
113    fn clone_dyn(&self) -> Box<dyn OptimizationBuilder<O>>;
114}
115
116/// The number of operations contained in the data strusture.
117pub trait NumOperations: core::fmt::Debug {
118    /// The number of registered operations.
119    fn len(&self) -> usize;
120    /// If the current optimization is empty.
121    fn is_empty(&self) -> bool {
122        self.len() == 0
123    }
124}
125
126/// The operation created from the [builder](OptimizationBuilder).
127pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
128    /// Execute the operation.
129    fn execute(
130        &mut self,
131        context: &mut Context<'_, R::FusionHandle>,
132        execution: &OrderedExecution<R>,
133    );
134
135    /// Returns the state that can be serialized.
136    fn to_state(&self) -> R::OptimizationState;
137    /// Create the optimization from the state.
138    fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
139}
140
141/// Type alias for `<R as FusionRuntime>::FusionDevice`.
142pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
143/// Type alias for `<R as FusionRuntime>::FusionHandle`.
144pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
145/// Type alias for `<R as FusionRuntime>::FusionClient`.
146pub type Client<R> = <R as FusionRuntime>::FusionClient;
147
148/// Trait that defines a runtime that will benefits from fused operations.
149pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
150    /// The state that can be serialized for an optimization.
151    type OptimizationState: Serialize + DeserializeOwned;
152    /// Optimization type for the backend.
153    type Optimization: Optimization<Self>;
154    /// Handle used to store tensor dynamically.
155    type FusionHandle: Clone + Send;
156    /// Device used by the runtime.
157    type FusionDevice: DeviceOps;
158    /// The client to interact with the runtime.
159    type FusionClient: FusionClient<Self>;
160    /// The type that represents booleans on the backend.
161    type BoolRepr: Element;
162
163    /// The list of optimizations that will be used to optimize the computational graph.
164    fn optimizations(
165        device: Self::FusionDevice,
166    ) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
167}
168
169/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
170/// [operation builder](crate::OptimizationBuilder).
171pub trait FusionBackend:
172    BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
173{
174    /// The runtime used for this backend.
175    type FusionRuntime: FusionRuntime;
176
177    /// Cast a float tensor and returns the resulting handle.
178    fn cast_float(tensor: FloatTensor<Self>, dtype: burn_tensor::DType) -> Self::Handle;
179
180    /// Pointer to the full precision fusion backend.
181    type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
182}
183
184// Fusion implements `BackendIr` to enable router backend usage.
185impl<B: FusionBackend> BackendIr for Fusion<B> {
186    type Handle = FusionTensor<B::FusionRuntime>;
187
188    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
189        handle.handle
190    }
191
192    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
193        handle.handle
194    }
195
196    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
197        handle.handle
198    }
199
200    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
201        handle.handle
202    }
203
204    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
205        tensor
206    }
207
208    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
209        tensor
210    }
211
212    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
213        tensor
214    }
215
216    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
217        tensor
218    }
219}