1use crate::{
2 FusionTensor,
3 client::GlobalFusionClient,
4 stream::{Context, OrderedExecution},
5};
6use burn_ir::{BackendIr, OperationIr, TensorHandle};
7use burn_tensor::{
8 Device, Element,
9 backend::{Backend, DeviceOps},
10 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
11};
12use serde::{Serialize, de::DeserializeOwned};
13use std::marker::PhantomData;
14
15pub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
17 GlobalFusionClient::load(device)
18}
19
20#[derive(Clone, Debug, Default)]
22pub struct Fusion<B: FusionBackend> {
23 _backend: PhantomData<B>,
24}
25
26impl<B: FusionBackend> Backend for Fusion<B> {
27 type Device = B::Device;
28
29 type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
30
31 type FloatElem = B::FloatElem;
32
33 type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
34
35 type IntElem = B::IntElem;
36
37 type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
38
39 type BoolElem = B::BoolElem;
40
41 type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
42
43 fn name(device: &Self::Device) -> String {
44 format!("fusion<{}>", B::name(device))
45 }
46
47 fn seed(device: &B::Device, seed: u64) {
48 let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
49 client.drain();
50 B::seed(device, seed);
51 }
52
53 fn sync(device: &Self::Device) {
54 let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
55 client.drain();
56 B::sync(device);
57 }
58
59 fn ad_enabled() -> bool {
60 false
61 }
62
63 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
64 device: &Self::Device,
65 input: Input,
66 func: Func,
67 ) -> Output {
68 B::memory_persistent_allocations(device, input, func)
69 }
70
71 fn memory_cleanup(device: &Self::Device) {
72 B::memory_cleanup(device)
73 }
74}
75
76#[derive(Clone, Debug, Copy)]
78pub enum OptimizationStatus {
79 Closed,
81 Open,
83}
84
85#[derive(Debug, Clone, Copy, Default)]
87pub struct OptimizationProperties {
88 pub score: u64,
90 pub ready: bool,
92}
93
94pub trait OptimizationBuilder<O>: Send {
107 fn register(&mut self, operation: &OperationIr);
109 fn build(&self) -> O;
111 fn reset(&mut self);
113 fn status(&self) -> OptimizationStatus;
115 fn properties(&self) -> OptimizationProperties;
117 fn len(&self) -> usize;
119 fn is_empty(&self) -> bool {
121 self.len() == 0
122 }
123 fn clone_dyn(&self) -> Box<dyn OptimizationBuilder<O>>;
125}
126
127pub trait NumOperations: core::fmt::Debug {
129 fn len(&self) -> usize;
131 fn is_empty(&self) -> bool {
133 self.len() == 0
134 }
135}
136
137pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
139 fn execute(
141 &mut self,
142 context: &mut Context<'_, R::FusionHandle>,
143 execution: &OrderedExecution<R>,
144 );
145
146 fn to_state(&self) -> R::OptimizationState;
148 fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
150}
151
152pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
154pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
156pub type Client<R> = GlobalFusionClient<R>;
158
159pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {
161 type OptimizationState: Serialize + DeserializeOwned;
163 type Optimization: Optimization<Self>;
165 type FusionHandle: Clone + Send;
167 type FusionDevice: DeviceOps;
169 type BoolRepr: Element;
171
172 fn optimizations(
174 device: Self::FusionDevice,
175 ) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
176}
177
178pub trait FusionBackend:
181 BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
182{
183 type FusionRuntime: FusionRuntime;
185
186 fn cast_float(tensor: FloatTensor<Self>, dtype: burn_tensor::DType) -> Self::Handle;
188
189 type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
191}
192
193impl<B: FusionBackend> BackendIr for Fusion<B> {
195 type Handle = FusionTensor<B::FusionRuntime>;
196
197 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
198 handle.handle
199 }
200
201 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
202 handle.handle
203 }
204
205 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
206 handle.handle
207 }
208
209 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
210 handle.handle
211 }
212
213 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
214 tensor
215 }
216
217 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
218 tensor
219 }
220
221 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
222 tensor
223 }
224
225 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
226 tensor
227 }
228}