1use crate::{
2 FusionClientLocator, FusionTensor,
3 client::FusionClient,
4 stream::{Context, OrderedExecution},
5};
6use burn_ir::{BackendIr, OperationIr, TensorHandle};
7use burn_tensor::{
8 Device, Element,
9 backend::{Backend, DeviceOps},
10 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
11};
12use serde::{Serialize, de::DeserializeOwned};
13use std::marker::PhantomData;
14
15pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
16
17pub(crate) fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
18 CLIENTS.client::<B::FusionRuntime>(device)
19}
20
21#[derive(Clone, Debug, Default)]
23pub struct Fusion<B: FusionBackend> {
24 _backend: PhantomData<B>,
25}
26
27impl<B: FusionBackend> Backend for Fusion<B> {
28 type Device = B::Device;
29
30 type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
31
32 type FloatElem = B::FloatElem;
33
34 type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
35
36 type IntElem = B::IntElem;
37
38 type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
39
40 type BoolElem = B::BoolElem;
41
42 type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
43
44 type QuantizedEncoding = B::QuantizedEncoding;
45
46 fn name(device: &Self::Device) -> String {
47 format!("fusion<{}>", B::name(device))
48 }
49
50 fn seed(seed: u64) {
51 B::seed(seed);
52 }
53
54 fn sync(device: &Self::Device) {
55 let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
56 client.drain();
57 B::sync(device);
58 }
59
60 fn ad_enabled() -> bool {
61 false
62 }
63}
64
65#[derive(Clone, Debug, Copy)]
67pub enum OptimizationStatus {
68 Closed,
70 Open,
72}
73
74#[derive(Debug, Clone, Copy, Default)]
76pub struct OptimizationProperties {
77 pub score: u64,
79 pub ready: bool,
81}
82
83pub trait OptimizationBuilder<O>: Send {
96 fn register(&mut self, operation: &OperationIr);
98 fn build(&self) -> O;
100 fn reset(&mut self);
102 fn status(&self) -> OptimizationStatus;
104 fn properties(&self) -> OptimizationProperties;
106 fn len(&self) -> usize;
108 fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112 fn clone_dyn(&self) -> Box<dyn OptimizationBuilder<O>>;
114}
115
116pub trait NumOperations: core::fmt::Debug {
118 fn len(&self) -> usize;
120 fn is_empty(&self) -> bool {
122 self.len() == 0
123 }
124}
125
126pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
128 fn execute(
130 &mut self,
131 context: &mut Context<'_, R::FusionHandle>,
132 execution: &OrderedExecution<R>,
133 );
134
135 fn to_state(&self) -> R::OptimizationState;
137 fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
139}
140
141pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
143pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
145pub type Client<R> = <R as FusionRuntime>::FusionClient;
147
148pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
150 type OptimizationState: Serialize + DeserializeOwned;
152 type Optimization: Optimization<Self>;
154 type FusionHandle: Clone + Send;
156 type FusionDevice: DeviceOps;
158 type FusionClient: FusionClient<Self>;
160 type BoolRepr: Element;
162
163 fn optimizations(
165 device: Self::FusionDevice,
166 ) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
167}
168
169pub trait FusionBackend:
172 BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
173{
174 type FusionRuntime: FusionRuntime;
176
177 fn cast_float(tensor: FloatTensor<Self>, dtype: burn_tensor::DType) -> Self::Handle;
179
180 type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
182}
183
184impl<B: FusionBackend> BackendIr for Fusion<B> {
186 type Handle = FusionTensor<B::FusionRuntime>;
187
188 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
189 handle.handle
190 }
191
192 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
193 handle.handle
194 }
195
196 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
197 handle.handle
198 }
199
200 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
201 handle.handle
202 }
203
204 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
205 tensor
206 }
207
208 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
209 tensor
210 }
211
212 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
213 tensor
214 }
215
216 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
217 tensor
218 }
219}