burn_fusion/
backend.rs

1use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor};
2use burn_tensor::{
3    backend::{Backend, DeviceOps},
4    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
5    repr::{OperationDescription, ReprBackend, TensorHandle},
6    Device, Element,
7};
8use serde::{de::DeserializeOwned, Serialize};
9use std::marker::PhantomData;
10
11pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
12
13pub(crate) fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
14    CLIENTS.client::<B::FusionRuntime>(device)
15}
16
17/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
18#[derive(Clone, Debug, Default)]
19pub struct Fusion<B: FusionBackend> {
20    _backend: PhantomData<B>,
21}
22
23impl<B: FusionBackend> Backend for Fusion<B> {
24    type Device = B::Device;
25
26    type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
27
28    type FloatElem = B::FloatElem;
29
30    type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
31
32    type IntElem = B::IntElem;
33
34    type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
35
36    type BoolElem = B::BoolElem;
37
38    type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
39
40    type QuantizedEncoding = B::QuantizedEncoding;
41
42    fn name() -> String {
43        format!("fusion<{}>", B::name())
44    }
45
46    fn seed(seed: u64) {
47        B::seed(seed);
48    }
49
50    fn sync(device: &Self::Device) {
51        let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
52        client.drain();
53        B::sync(device);
54    }
55
56    fn ad_enabled() -> bool {
57        false
58    }
59}
60
61/// The status of a [builder](OptimizationBuilder).
62#[derive(Clone, Debug, Copy)]
63pub enum OptimizationStatus {
64    /// No more operations can be fused.
65    Closed,
66    /// More operations can be fused.
67    Open,
68}
69
70/// The properties of a [builder](OptimizationProperties).
71#[derive(Debug, Clone, Copy, Default)]
72pub struct OptimizationProperties {
73    /// The score of the optimization, higher is better.
74    pub score: u64,
75    /// If the operation is ready to be executed.
76    pub ready: bool,
77}
78
79/// The fusion operation abstraction allows implementations to fuse many
80/// [tensor operations](OperationDescription) into one, improving the performance of the backend.
81///
82///
83/// # Notes
84///
85/// The implementations are free to execute the registered operations the way they want to improve
86/// the speed and efficiency of the computational graph. It doesn't mean that all registered
87/// operations should be fused, but that another way of executing them is more efficient.
88///
89/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
90/// improve the performance.
91pub trait OptimizationBuilder<O>: Send {
92    /// Register a new [tensor operation](OperationDescription).
93    fn register(&mut self, operation: &OperationDescription);
94    /// Finish the optimization and create a fusion operation.
95    fn build(&self) -> O;
96    /// Reset the state.
97    fn reset(&mut self);
98    /// Return the builder [status](OptimizationStatus).
99    fn status(&self) -> OptimizationStatus;
100    /// Return the builder [properties](OptimizationProperties).
101    fn properties(&self) -> OptimizationProperties;
102    /// The number of operation fused.
103    fn len(&self) -> usize;
104    /// If no operations are fused.
105    fn is_empty(&self) -> bool {
106        self.len() == 0
107    }
108}
109
110/// The operation created from the [builder](OptimizationBuilder).
111pub trait Optimization<R: FusionRuntime>: Send {
112    /// Execute the operation.
113    fn execute(&mut self, context: &mut Context<'_, R::FusionHandle>);
114    /// The number of registered operations in this optimization.
115    fn len(&self) -> usize;
116    /// If the current optimization is empty.
117    fn is_empty(&self) -> bool {
118        self.len() == 0
119    }
120    /// Returns the state that can be serialized.
121    fn to_state(&self) -> R::OptimizationState;
122    /// Create the optimization from the state.
123    fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
124}
125
126/// Type alias for `<R as FusionRuntime>::FusionDevice`.
127pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
128/// Type alias for `<R as FusionRuntime>::FusionHandle`.
129pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
130/// Type alias for `<R as FusionRuntime>::FusionClient`.
131pub type Client<R> = <R as FusionRuntime>::FusionClient;
132
133/// Trait that defines a runtime that will benefits from fused operations.
134pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
135    /// The state that can be serialized for an optimization.
136    type OptimizationState: Serialize + DeserializeOwned;
137    /// Optimization type for the backend.
138    type Optimization: Optimization<Self>;
139    /// Handle used to store tensor dynamically.
140    type FusionHandle: Clone + Send;
141    /// Device used by the runtime.
142    type FusionDevice: DeviceOps;
143    /// The client to interact with the runtime.
144    type FusionClient: FusionClient<Self>;
145    /// The type that represents booleans on the backend.
146    type BoolRepr: Element;
147
148    /// The list of optimizations that will be used to optimize the computational graph.
149    fn optimizations(
150        device: Self::FusionDevice,
151    ) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
152}
153
154/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
155/// [operation builder](crate::OptimizationBuilder).
156pub trait FusionBackend:
157    ReprBackend<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
158{
159    /// The runtime used for this backend.
160    type FusionRuntime: FusionRuntime;
161
162    /// Cast a float tensor and returns the resulting handle.
163    fn cast_float(tensor: FloatTensor<Self>, dtype: burn_tensor::DType) -> Self::Handle;
164
165    /// Pointer to the full precision fusion backend.
166    type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
167}
168
169// Fusion implements `ReprBackend` to enable router backend usage.
170impl<B: FusionBackend> ReprBackend for Fusion<B> {
171    type Handle = FusionTensor<B::FusionRuntime>;
172
173    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
174        handle.handle
175    }
176
177    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
178        handle.handle
179    }
180
181    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
182        handle.handle
183    }
184
185    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
186        handle.handle
187    }
188
189    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
190        tensor
191    }
192
193    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
194        tensor
195    }
196
197    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
198        tensor
199    }
200
201    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
202        tensor
203    }
204}