1use crate::{
2 FusionTensor,
3 client::GlobalFusionClient,
4 stream::{Context, OrderedExecution},
5};
6use burn_backend::{
7 Backend, BackendTypes, DType, DeviceOps, ExecutionError,
8 tensor::{BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor},
9};
10use burn_ir::{BackendIr, OperationIr, TensorHandle};
11use serde::{Serialize, de::DeserializeOwned};
12use std::marker::PhantomData;
13
14pub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
16 GlobalFusionClient::load(device)
17}
18
19#[derive(Clone, Debug, Default)]
21pub struct Fusion<B: FusionBackend> {
22 _backend: PhantomData<B>,
23}
24
25impl<B: FusionBackend> BackendTypes for Fusion<B> {
26 type Device = B::Device;
27
28 type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
29
30 type FloatElem = B::FloatElem;
31
32 type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
33
34 type IntElem = B::IntElem;
35
36 type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
37
38 type BoolElem = B::BoolElem;
39
40 type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
41}
42
43impl<B: FusionBackend> Backend for Fusion<B> {
44 fn name(device: &Self::Device) -> String {
45 format!("fusion<{}>", B::name(device))
46 }
47
48 fn seed(device: &B::Device, seed: u64) {
49 let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
50 let device = device.clone();
51 client.sync(move || B::seed(&device, seed));
52 }
53
54 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
55 let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
56 let device = device.clone();
57 client.sync(move || B::sync(&device))
58 }
59
60 fn ad_enabled(_device: &Self::Device) -> bool {
61 false
62 }
63
64 fn memory_persistent_allocations<
65 Output: Send,
66 Input: Send,
67 Func: Fn(Input) -> Output + Send,
68 >(
69 device: &Self::Device,
70 input: Input,
71 func: Func,
72 ) -> Output {
73 B::memory_persistent_allocations(device, input, func)
74 }
75
76 fn memory_cleanup(device: &Self::Device) {
77 B::memory_cleanup(device)
78 }
79
80 fn staging<'a, Iter>(data: Iter, device: &Self::Device)
81 where
82 Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
83 {
84 B::staging(data, device);
85 }
86
87 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
88 B::supports_dtype(device, dtype)
89 }
90
91 fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
92 B::dtype_usage(device, dtype)
93 }
94
95 fn device_count(type_id: u16) -> usize {
96 B::device_count(type_id)
97 }
98}
99
100#[derive(Clone, Debug, Copy, PartialEq, Eq)]
102pub enum FuserStatus {
103 Closed,
105 Open,
107}
108
109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
111pub struct FuserProperties {
112 pub score: u64,
114 pub ready: bool,
116}
117
118pub trait OperationFuser<O>: Send {
131 fn fuse(&mut self, operation: &OperationIr);
133 fn finish(&mut self) -> O;
135 fn reset(&mut self);
137 fn status(&self) -> FuserStatus;
139 fn properties(&self) -> FuserProperties;
141 fn len(&self) -> usize;
143 fn is_empty(&self) -> bool {
145 self.len() == 0
146 }
147 fn clone_dyn(&self) -> Box<dyn OperationFuser<O>>;
149}
150
151pub trait NumOperations: core::fmt::Debug {
153 fn len(&self) -> usize;
155 fn is_empty(&self) -> bool {
157 self.len() == 0
158 }
159 fn name(&self) -> &'static str;
161}
162
163pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
165 fn execute(&mut self, context: &mut Context<R::FusionHandle>, execution: &OrderedExecution<R>);
167
168 fn to_state(&self) -> R::OptimizationState;
170 fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
172}
173
174pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
176pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
178pub type Client<R> = GlobalFusionClient<R>;
180
181pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {
183 type OptimizationState: Serialize + DeserializeOwned;
185 type Optimization: Optimization<Self>;
187 type FusionHandle: Clone + Send;
189 type FusionDevice: DeviceOps;
191
192 fn fusers(device: Self::FusionDevice) -> Vec<Box<dyn OperationFuser<Self::Optimization>>>;
194}
195
196pub trait FusionBackend:
199 BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
200{
201 type FusionRuntime: FusionRuntime;
203
204 fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle;
206
207 type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
209}
210
211impl<B: FusionBackend> BackendIr for Fusion<B> {
213 type Handle = FusionTensor<B::FusionRuntime>;
214
215 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
216 handle.handle
217 }
218
219 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
220 handle.handle
221 }
222
223 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
224 handle.handle
225 }
226
227 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
228 handle.handle
229 }
230
231 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
232 tensor
233 }
234
235 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
236 tensor
237 }
238
239 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
240 tensor
241 }
242
243 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
244 tensor
245 }
246}