burn_fusion/
tensor.rs

1use crate::{
2    Client, FusionBackend, FusionRuntime,
3    stream::{Operation, OperationStreams, StreamId},
4};
5use burn_ir::{OperationIr, TensorId, TensorIr, TensorStatus};
6use burn_tensor::{
7    DType, Shape, TensorData, TensorMetadata,
8    quantization::{QTensorPrimitive, QuantScheme},
9};
10use std::sync::{
11    Arc,
12    atomic::{AtomicU32, Ordering},
13};
14
15/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.
16pub struct FusionTensor<R: FusionRuntime> {
17    /// Tensor id.
18    pub id: TensorId,
19    /// The shape of the tensor.
20    pub shape: Shape,
21    /// The fusion client.
22    pub client: Client<R>,
23    /// The datatype of the tensor.
24    pub dtype: DType,
25    /// The current stream id this tensor is on.
26    pub stream: StreamId,
27    pub(crate) count: Arc<AtomicU32>,
28}
29
30impl<R: FusionRuntime> Clone for FusionTensor<R> {
31    fn clone(&self) -> Self {
32        self.count.fetch_add(1, Ordering::Relaxed);
33
34        Self {
35            id: self.id,
36            shape: self.shape.clone(),
37            client: self.client.clone(),
38            dtype: self.dtype,
39            stream: self.stream,
40            count: self.count.clone(),
41        }
42    }
43}
44
45impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.write_str(
48            format!(
49                "{{ id: {:?}, shape: {:?}, device: {:?} }}",
50                self.id,
51                self.shape,
52                self.client.device().clone(),
53            )
54            .as_str(),
55        )
56    }
57}
58
59impl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {
60    fn dtype(&self) -> DType {
61        self.dtype
62    }
63
64    fn shape(&self) -> Shape {
65        self.shape.clone()
66    }
67
68    fn rank(&self) -> usize {
69        self.shape.num_dims()
70    }
71}
72
73impl<R: FusionRuntime> FusionTensor<R> {
74    pub(crate) fn new(
75        id: TensorId,
76        shape: Shape,
77        dtype: DType,
78        client: Client<R>,
79        stream: StreamId,
80    ) -> Self {
81        Self {
82            id,
83            shape,
84            client,
85            dtype,
86            stream,
87            count: Arc::new(AtomicU32::new(1)),
88        }
89    }
90
91    fn status(&self, count: u32) -> TensorStatus {
92        if count <= 1 {
93            TensorStatus::ReadWrite
94        } else {
95            TensorStatus::ReadOnly
96        }
97    }
98
99    /// Intermediate representation to be used when using an uninitialized tensor as output.
100    pub fn to_ir_out(&self) -> TensorIr {
101        TensorIr {
102            status: TensorStatus::NotInit,
103            shape: self.shape.clone(),
104            id: self.id,
105            dtype: self.dtype,
106        }
107    }
108
109    /// Intermediate representation to be used when using an initialized tensor used as input.
110    pub fn into_ir(mut self) -> TensorIr {
111        let count = self.count.load(Ordering::Relaxed);
112        let status = self.status(count);
113
114        let mut shape_out = Shape::from(Vec::<usize>::new());
115        core::mem::swap(&mut self.shape, &mut shape_out);
116
117        if let TensorStatus::ReadWrite = status {
118            // Avoids an unwanted drop on the same thread.
119            //
120            // Since `drop` is called after `into_ir`, we must not register a drop if the tensor
121            // was consumed with a `ReadWrite` status.
122            self.count.fetch_add(1, Ordering::Relaxed);
123        }
124
125        TensorIr {
126            status,
127            shape: shape_out,
128            id: self.id,
129            dtype: self.dtype,
130        }
131    }
132
133    pub(crate) async fn into_data<B>(self) -> TensorData
134    where
135        B: FusionBackend<FusionRuntime = R>,
136    {
137        let id = self.stream;
138        let client = self.client.clone();
139        let desc = self.into_ir();
140        client.read_tensor_float::<B>(desc, id).await
141    }
142
143    pub(crate) async fn q_into_data<B>(self) -> TensorData
144    where
145        B: FusionBackend<FusionRuntime = R>,
146    {
147        if let DType::QFloat(_scheme) = self.dtype {
148            let id = self.stream;
149            let client = self.client.clone();
150            let desc = self.into_ir();
151            client.read_tensor_quantized::<B>(desc, id).await
152        } else {
153            panic!("Expected quantized float dtype, got {:?}", self.dtype)
154        }
155    }
156
157    pub(crate) async fn int_into_data<B>(self) -> TensorData
158    where
159        B: FusionBackend<FusionRuntime = R>,
160    {
161        let id = self.stream;
162        let client = self.client.clone();
163        let desc = self.into_ir();
164        client.read_tensor_int::<B>(desc, id).await
165    }
166
167    pub(crate) async fn bool_into_data<B>(self) -> TensorData
168    where
169        B: FusionBackend<FusionRuntime = R>,
170    {
171        let id = self.stream;
172        let client = self.client.clone();
173        let desc = self.into_ir();
174        client.read_tensor_bool::<B>(desc, id).await
175    }
176}
177
178#[derive(new, Debug)]
179pub(crate) struct DropOp {
180    pub(crate) id: TensorId,
181}
182
183impl<RO: FusionRuntime> Operation<RO> for DropOp {
184    fn execute(&self, handles: &mut burn_ir::HandleContainer<RO::FusionHandle>) {
185        handles.remove_handle(self.id);
186    }
187}
188
189impl<R: FusionRuntime> Drop for FusionTensor<R> {
190    fn drop(&mut self) {
191        let count = self.count.fetch_sub(1, Ordering::Relaxed);
192
193        // Workaround to prevent segfaults when an operation panics
194        if std::thread::panicking() {
195            return;
196        }
197
198        match self.status(count) {
199            TensorStatus::ReadWrite => {
200                let mut shape = Shape::from(Vec::<usize>::new());
201                core::mem::swap(&mut shape, &mut self.shape);
202
203                let ir = TensorIr {
204                    id: self.id,
205                    shape,
206                    status: TensorStatus::ReadWrite,
207                    dtype: self.dtype,
208                };
209                let mut streams = OperationStreams::default();
210                streams.tensor(self);
211
212                self.client
213                    .register(streams, OperationIr::Drop(ir), DropOp { id: self.id });
214            }
215            TensorStatus::ReadOnly => {}
216            TensorStatus::NotInit => {}
217        }
218    }
219}
220
221impl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {
222    fn scheme(&self) -> &QuantScheme {
223        if let DType::QFloat(scheme) = &self.dtype {
224            scheme
225        } else {
226            panic!(
227                "Quantization scheme is not valid for dtype {:?}",
228                self.dtype,
229            )
230        }
231    }
232}