burn_fusion/
tensor.rs

1use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
2use burn_tensor::{
3    quantization::{QTensorPrimitive, QuantizationScheme},
4    repr::{TensorDescription, TensorId, TensorStatus},
5    DType, Shape, TensorData, TensorMetadata,
6};
7use std::{future::Future, sync::Arc};
8
9/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.
10pub struct FusionTensor<R: FusionRuntime> {
11    /// Tensor id.
12    pub id: Arc<TensorId>,
13    /// The shape of the tensor.
14    pub shape: Vec<usize>,
15    /// The [fusion client](FusionClient).
16    pub client: Client<R>,
17    /// The datatype of the tensor.
18    pub dtype: DType,
19    /// The current stream id this tensor is on.
20    pub stream: StreamId,
21    // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`.
22    //
23    // When a tensor is dropped and is still an orphan, we need to register it as such to avoid
24    // memory leak. Otherwise, the cleanup is going to happen during a graph execution.
25    pub(crate) is_orphan: bool,
26}
27
28impl<R: FusionRuntime> Clone for FusionTensor<R> {
29    fn clone(&self) -> Self {
30        Self {
31            id: self.id.clone(),
32            shape: self.shape.clone(),
33            client: self.client.clone(),
34            dtype: self.dtype,
35            is_orphan: self.is_orphan,
36            stream: self.stream,
37        }
38    }
39}
40
41impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.write_str(
44            format!(
45                "{{ id: {:?}, shape: {:?}, should_drop: {:?}, device: {:?} }}",
46                self.id,
47                self.shape,
48                self.is_orphan,
49                self.client.device().clone(),
50            )
51            .as_str(),
52        )
53    }
54}
55
56impl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {
57    fn dtype(&self) -> DType {
58        self.dtype
59    }
60
61    fn shape(&self) -> Shape {
62        Shape::from(self.shape.clone())
63    }
64}
65
66impl<R: FusionRuntime> FusionTensor<R> {
67    pub(crate) fn new(
68        id: Arc<TensorId>,
69        shape: Vec<usize>,
70        dtype: DType,
71        client: Client<R>,
72        stream: StreamId,
73    ) -> Self {
74        Self {
75            id,
76            shape,
77            client,
78            dtype,
79            is_orphan: true,
80            stream,
81        }
82    }
83
84    fn status(&self) -> TensorStatus {
85        if Arc::strong_count(&self.id) <= 1 {
86            TensorStatus::ReadWrite
87        } else {
88            TensorStatus::ReadOnly
89        }
90    }
91
92    /// Description to be used when using an uninitialized tensor as output.
93    pub fn to_description_out(&self) -> TensorDescription {
94        TensorDescription {
95            status: TensorStatus::NotInit,
96            shape: self.shape.clone(),
97            id: *self.id.as_ref(),
98            dtype: self.dtype,
99        }
100    }
101
102    /// Description to be used when using an initialized tensor used as input.
103    pub fn into_description(mut self) -> TensorDescription {
104        let status = self.status();
105        let mut shape_out = Vec::new();
106        core::mem::swap(&mut self.shape, &mut shape_out);
107
108        if let TensorStatus::ReadWrite = status {
109            self.is_orphan = false;
110        }
111
112        TensorDescription {
113            status,
114            shape: shape_out,
115            id: *self.id.as_ref(),
116            dtype: self.dtype,
117        }
118    }
119
120    pub(crate) fn into_data<B>(self) -> impl Future<Output = TensorData>
121    where
122        B: FusionBackend<FusionRuntime = R>,
123    {
124        let id = self.stream;
125        let client = self.client.clone();
126        let desc = self.into_description();
127        client.read_tensor_float::<B>(desc, id)
128    }
129
130    pub(crate) fn q_into_data<B>(self) -> impl Future<Output = TensorData>
131    where
132        B: FusionBackend<FusionRuntime = R>,
133    {
134        if let DType::QFloat(_scheme) = self.dtype {
135            let id = self.stream;
136            let client = self.client.clone();
137            let desc = self.into_description();
138            client.read_tensor_quantized::<B>(desc, id)
139        } else {
140            panic!("Expected quantized float dtype, got {:?}", self.dtype)
141        }
142    }
143
144    pub(crate) fn int_into_data<B>(self) -> impl Future<Output = TensorData>
145    where
146        B: FusionBackend<FusionRuntime = R>,
147    {
148        let id = self.stream;
149        let client = self.client.clone();
150        let desc = self.into_description();
151        client.read_tensor_int::<B>(desc, id)
152    }
153
154    pub(crate) fn bool_into_data<B>(self) -> impl Future<Output = TensorData>
155    where
156        B: FusionBackend<FusionRuntime = R>,
157    {
158        let id = self.stream;
159        let client = self.client.clone();
160        let desc = self.into_description();
161        client.read_tensor_bool::<B>(desc, id)
162    }
163}
164
165impl<R: FusionRuntime> Drop for FusionTensor<R> {
166    fn drop(&mut self) {
167        if !self.is_orphan {
168            return;
169        }
170
171        match self.status() {
172            TensorStatus::ReadWrite => {
173                self.client.register_orphan(&self.id);
174            }
175            TensorStatus::ReadOnly => {}
176            TensorStatus::NotInit => {}
177        }
178    }
179}
180
181impl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {
182    fn scheme(&self) -> &QuantizationScheme {
183        if let DType::QFloat(scheme) = &self.dtype {
184            scheme
185        } else {
186            panic!(
187                "Quantization scheme is not valid for dtype {:?}",
188                self.dtype,
189            )
190        }
191    }
192}