1use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor};
2use burn_tensor::{
3 backend::{Backend, DeviceOps},
4 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
5 repr::{OperationDescription, ReprBackend, TensorHandle},
6 Device, Element,
7};
8use serde::{de::DeserializeOwned, Serialize};
9use std::marker::PhantomData;
10
11pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
12
13pub(crate) fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
14 CLIENTS.client::<B::FusionRuntime>(device)
15}
16
17#[derive(Clone, Debug, Default)]
19pub struct Fusion<B: FusionBackend> {
20 _backend: PhantomData<B>,
21}
22
23impl<B: FusionBackend> Backend for Fusion<B> {
24 type Device = B::Device;
25
26 type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
27
28 type FloatElem = B::FloatElem;
29
30 type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
31
32 type IntElem = B::IntElem;
33
34 type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
35
36 type BoolElem = B::BoolElem;
37
38 type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
39
40 type QuantizedEncoding = B::QuantizedEncoding;
41
42 fn name() -> String {
43 format!("fusion<{}>", B::name())
44 }
45
46 fn seed(seed: u64) {
47 B::seed(seed);
48 }
49
50 fn sync(device: &Self::Device) {
51 let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
52 client.drain();
53 B::sync(device);
54 }
55
56 fn ad_enabled() -> bool {
57 false
58 }
59}
60
61#[derive(Clone, Debug, Copy)]
63pub enum OptimizationStatus {
64 Closed,
66 Open,
68}
69
70#[derive(Debug, Clone, Copy, Default)]
72pub struct OptimizationProperties {
73 pub score: u64,
75 pub ready: bool,
77}
78
79pub trait OptimizationBuilder<O>: Send {
92 fn register(&mut self, operation: &OperationDescription);
94 fn build(&self) -> O;
96 fn reset(&mut self);
98 fn status(&self) -> OptimizationStatus;
100 fn properties(&self) -> OptimizationProperties;
102 fn len(&self) -> usize;
104 fn is_empty(&self) -> bool {
106 self.len() == 0
107 }
108}
109
110pub trait Optimization<R: FusionRuntime>: Send {
112 fn execute(&mut self, context: &mut Context<'_, R::FusionHandle>);
114 fn len(&self) -> usize;
116 fn is_empty(&self) -> bool {
118 self.len() == 0
119 }
120 fn to_state(&self) -> R::OptimizationState;
122 fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
124}
125
126pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
128pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
130pub type Client<R> = <R as FusionRuntime>::FusionClient;
132
133pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
135 type OptimizationState: Serialize + DeserializeOwned;
137 type Optimization: Optimization<Self>;
139 type FusionHandle: Clone + Send;
141 type FusionDevice: DeviceOps;
143 type FusionClient: FusionClient<Self>;
145 type BoolRepr: Element;
147
148 fn optimizations(
150 device: Self::FusionDevice,
151 ) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
152}
153
154pub trait FusionBackend:
157 ReprBackend<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
158{
159 type FusionRuntime: FusionRuntime;
161
162 fn cast_float(tensor: FloatTensor<Self>, dtype: burn_tensor::DType) -> Self::Handle;
164
165 type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
167}
168
169impl<B: FusionBackend> ReprBackend for Fusion<B> {
171 type Handle = FusionTensor<B::FusionRuntime>;
172
173 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
174 handle.handle
175 }
176
177 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
178 handle.handle
179 }
180
181 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
182 handle.handle
183 }
184
185 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
186 handle.handle
187 }
188
189 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
190 tensor
191 }
192
193 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
194 tensor
195 }
196
197 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
198 tensor
199 }
200
201 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
202 tensor
203 }
204}