1use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
2use burn_tensor::{
3 quantization::{QTensorPrimitive, QuantizationScheme},
4 repr::{TensorDescription, TensorId, TensorStatus},
5 DType, Shape, TensorData, TensorMetadata,
6};
7use std::{future::Future, sync::Arc};
8
9pub struct FusionTensor<R: FusionRuntime> {
11 pub id: Arc<TensorId>,
13 pub shape: Vec<usize>,
15 pub client: Client<R>,
17 pub dtype: DType,
19 pub stream: StreamId,
21 pub(crate) is_orphan: bool,
26}
27
28impl<R: FusionRuntime> Clone for FusionTensor<R> {
29 fn clone(&self) -> Self {
30 Self {
31 id: self.id.clone(),
32 shape: self.shape.clone(),
33 client: self.client.clone(),
34 dtype: self.dtype,
35 is_orphan: self.is_orphan,
36 stream: self.stream,
37 }
38 }
39}
40
41impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.write_str(
44 format!(
45 "{{ id: {:?}, shape: {:?}, should_drop: {:?}, device: {:?} }}",
46 self.id,
47 self.shape,
48 self.is_orphan,
49 self.client.device().clone(),
50 )
51 .as_str(),
52 )
53 }
54}
55
56impl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {
57 fn dtype(&self) -> DType {
58 self.dtype
59 }
60
61 fn shape(&self) -> Shape {
62 Shape::from(self.shape.clone())
63 }
64}
65
66impl<R: FusionRuntime> FusionTensor<R> {
67 pub(crate) fn new(
68 id: Arc<TensorId>,
69 shape: Vec<usize>,
70 dtype: DType,
71 client: Client<R>,
72 stream: StreamId,
73 ) -> Self {
74 Self {
75 id,
76 shape,
77 client,
78 dtype,
79 is_orphan: true,
80 stream,
81 }
82 }
83
84 fn status(&self) -> TensorStatus {
85 if Arc::strong_count(&self.id) <= 1 {
86 TensorStatus::ReadWrite
87 } else {
88 TensorStatus::ReadOnly
89 }
90 }
91
92 pub fn to_description_out(&self) -> TensorDescription {
94 TensorDescription {
95 status: TensorStatus::NotInit,
96 shape: self.shape.clone(),
97 id: *self.id.as_ref(),
98 dtype: self.dtype,
99 }
100 }
101
102 pub fn into_description(mut self) -> TensorDescription {
104 let status = self.status();
105 let mut shape_out = Vec::new();
106 core::mem::swap(&mut self.shape, &mut shape_out);
107
108 if let TensorStatus::ReadWrite = status {
109 self.is_orphan = false;
110 }
111
112 TensorDescription {
113 status,
114 shape: shape_out,
115 id: *self.id.as_ref(),
116 dtype: self.dtype,
117 }
118 }
119
120 pub(crate) fn into_data<B>(self) -> impl Future<Output = TensorData>
121 where
122 B: FusionBackend<FusionRuntime = R>,
123 {
124 let id = self.stream;
125 let client = self.client.clone();
126 let desc = self.into_description();
127 client.read_tensor_float::<B>(desc, id)
128 }
129
130 pub(crate) fn q_into_data<B>(self) -> impl Future<Output = TensorData>
131 where
132 B: FusionBackend<FusionRuntime = R>,
133 {
134 if let DType::QFloat(_scheme) = self.dtype {
135 let id = self.stream;
136 let client = self.client.clone();
137 let desc = self.into_description();
138 client.read_tensor_quantized::<B>(desc, id)
139 } else {
140 panic!("Expected quantized float dtype, got {:?}", self.dtype)
141 }
142 }
143
144 pub(crate) fn int_into_data<B>(self) -> impl Future<Output = TensorData>
145 where
146 B: FusionBackend<FusionRuntime = R>,
147 {
148 let id = self.stream;
149 let client = self.client.clone();
150 let desc = self.into_description();
151 client.read_tensor_int::<B>(desc, id)
152 }
153
154 pub(crate) fn bool_into_data<B>(self) -> impl Future<Output = TensorData>
155 where
156 B: FusionBackend<FusionRuntime = R>,
157 {
158 let id = self.stream;
159 let client = self.client.clone();
160 let desc = self.into_description();
161 client.read_tensor_bool::<B>(desc, id)
162 }
163}
164
165impl<R: FusionRuntime> Drop for FusionTensor<R> {
166 fn drop(&mut self) {
167 if !self.is_orphan {
168 return;
169 }
170
171 match self.status() {
172 TensorStatus::ReadWrite => {
173 self.client.register_orphan(&self.id);
174 }
175 TensorStatus::ReadOnly => {}
176 TensorStatus::NotInit => {}
177 }
178 }
179}
180
181impl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {
182 fn scheme(&self) -> &QuantizationScheme {
183 if let DType::QFloat(scheme) = &self.dtype {
184 scheme
185 } else {
186 panic!(
187 "Quantization scheme is not valid for dtype {:?}",
188 self.dtype,
189 )
190 }
191 }
192}