1use crate::{
2 Client, FusionBackend, FusionRuntime,
3 stream::{Operation, OperationStreams, StreamId},
4};
5use burn_ir::{OperationIr, TensorId, TensorIr, TensorStatus};
6use burn_tensor::{
7 DType, Shape, TensorData, TensorMetadata,
8 quantization::{QTensorPrimitive, QuantScheme},
9};
10use std::sync::{
11 Arc,
12 atomic::{AtomicU32, Ordering},
13};
14
15pub struct FusionTensor<R: FusionRuntime> {
17 pub id: TensorId,
19 pub shape: Shape,
21 pub client: Client<R>,
23 pub dtype: DType,
25 pub stream: StreamId,
27 pub(crate) count: Arc<AtomicU32>,
28}
29
30impl<R: FusionRuntime> Clone for FusionTensor<R> {
31 fn clone(&self) -> Self {
32 self.count.fetch_add(1, Ordering::Relaxed);
33
34 Self {
35 id: self.id,
36 shape: self.shape.clone(),
37 client: self.client.clone(),
38 dtype: self.dtype,
39 stream: self.stream,
40 count: self.count.clone(),
41 }
42 }
43}
44
45impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.write_str(
48 format!(
49 "{{ id: {:?}, shape: {:?}, device: {:?} }}",
50 self.id,
51 self.shape,
52 self.client.device().clone(),
53 )
54 .as_str(),
55 )
56 }
57}
58
59impl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {
60 fn dtype(&self) -> DType {
61 self.dtype
62 }
63
64 fn shape(&self) -> Shape {
65 self.shape.clone()
66 }
67
68 fn rank(&self) -> usize {
69 self.shape.num_dims()
70 }
71}
72
73impl<R: FusionRuntime> FusionTensor<R> {
74 pub(crate) fn new(
75 id: TensorId,
76 shape: Shape,
77 dtype: DType,
78 client: Client<R>,
79 stream: StreamId,
80 ) -> Self {
81 Self {
82 id,
83 shape,
84 client,
85 dtype,
86 stream,
87 count: Arc::new(AtomicU32::new(1)),
88 }
89 }
90
91 fn status(&self, count: u32) -> TensorStatus {
92 if count <= 1 {
93 TensorStatus::ReadWrite
94 } else {
95 TensorStatus::ReadOnly
96 }
97 }
98
99 pub fn to_ir_out(&self) -> TensorIr {
101 TensorIr {
102 status: TensorStatus::NotInit,
103 shape: self.shape.clone(),
104 id: self.id,
105 dtype: self.dtype,
106 }
107 }
108
109 pub fn into_ir(mut self) -> TensorIr {
111 let count = self.count.load(Ordering::Relaxed);
112 let status = self.status(count);
113
114 let mut shape_out = Shape::from(Vec::<usize>::new());
115 core::mem::swap(&mut self.shape, &mut shape_out);
116
117 if let TensorStatus::ReadWrite = status {
118 self.count.fetch_add(1, Ordering::Relaxed);
123 }
124
125 TensorIr {
126 status,
127 shape: shape_out,
128 id: self.id,
129 dtype: self.dtype,
130 }
131 }
132
133 pub(crate) async fn into_data<B>(self) -> TensorData
134 where
135 B: FusionBackend<FusionRuntime = R>,
136 {
137 let id = self.stream;
138 let client = self.client.clone();
139 let desc = self.into_ir();
140 client.read_tensor_float::<B>(desc, id).await
141 }
142
143 pub(crate) async fn q_into_data<B>(self) -> TensorData
144 where
145 B: FusionBackend<FusionRuntime = R>,
146 {
147 if let DType::QFloat(_scheme) = self.dtype {
148 let id = self.stream;
149 let client = self.client.clone();
150 let desc = self.into_ir();
151 client.read_tensor_quantized::<B>(desc, id).await
152 } else {
153 panic!("Expected quantized float dtype, got {:?}", self.dtype)
154 }
155 }
156
157 pub(crate) async fn int_into_data<B>(self) -> TensorData
158 where
159 B: FusionBackend<FusionRuntime = R>,
160 {
161 let id = self.stream;
162 let client = self.client.clone();
163 let desc = self.into_ir();
164 client.read_tensor_int::<B>(desc, id).await
165 }
166
167 pub(crate) async fn bool_into_data<B>(self) -> TensorData
168 where
169 B: FusionBackend<FusionRuntime = R>,
170 {
171 let id = self.stream;
172 let client = self.client.clone();
173 let desc = self.into_ir();
174 client.read_tensor_bool::<B>(desc, id).await
175 }
176}
177
178#[derive(new, Debug)]
179pub(crate) struct DropOp {
180 pub(crate) id: TensorId,
181}
182
183impl<RO: FusionRuntime> Operation<RO> for DropOp {
184 fn execute(&self, handles: &mut burn_ir::HandleContainer<RO::FusionHandle>) {
185 handles.remove_handle(self.id);
186 }
187}
188
189impl<R: FusionRuntime> Drop for FusionTensor<R> {
190 fn drop(&mut self) {
191 let count = self.count.fetch_sub(1, Ordering::Relaxed);
192
193 if std::thread::panicking() {
195 return;
196 }
197
198 match self.status(count) {
199 TensorStatus::ReadWrite => {
200 let mut shape = Shape::from(Vec::<usize>::new());
201 core::mem::swap(&mut shape, &mut self.shape);
202
203 let ir = TensorIr {
204 id: self.id,
205 shape,
206 status: TensorStatus::ReadWrite,
207 dtype: self.dtype,
208 };
209 let mut streams = OperationStreams::default();
210 streams.tensor(self);
211
212 self.client
213 .register(streams, OperationIr::Drop(ir), DropOp { id: self.id });
214 }
215 TensorStatus::ReadOnly => {}
216 TensorStatus::NotInit => {}
217 }
218 }
219}
220
221impl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {
222 fn scheme(&self) -> &QuantScheme {
223 if let DType::QFloat(scheme) = &self.dtype {
224 scheme
225 } else {
226 panic!(
227 "Quantization scheme is not valid for dtype {:?}",
228 self.dtype,
229 )
230 }
231 }
232}