burn_router/
tensor.rs

1use alloc::{sync::Arc, vec::Vec};
2
3use super::RunnerClient;
4use burn_tensor::{
5    repr::{TensorDescription, TensorId, TensorStatus},
6    DType, Shape, TensorData, TensorMetadata,
7};
8
9/// Tensor primitive for the [router backend](crate::BackendRouter).
10pub struct RouterTensor<C: RunnerClient> {
11    pub(crate) id: Arc<TensorId>,
12    pub(crate) shape: Vec<usize>,
13    pub(crate) dtype: DType,
14    pub(crate) client: C,
15
16    // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`.
17    //
18    // When a tensor is dropped and is still an orphan, we need to register it as such to avoid
19    // memory leak.
20    pub(crate) is_orphan: bool,
21}
22
23impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
24    fn dtype(&self) -> DType {
25        self.dtype
26    }
27
28    fn shape(&self) -> Shape {
29        Shape::from(self.shape.clone())
30    }
31}
32
33impl<C: RunnerClient> RouterTensor<C> {
34    /// Create a new router tensor.
35    pub fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
36        Self {
37            id,
38            shape,
39            dtype,
40            client,
41            is_orphan: true,
42        }
43    }
44
45    pub(crate) async fn into_data(self) -> TensorData {
46        self.client
47            .clone()
48            .read_tensor(self.into_description())
49            .await
50    }
51
52    pub(crate) fn into_description(mut self) -> TensorDescription {
53        let status = self.status();
54        let mut shape_out = Vec::new();
55        core::mem::swap(&mut self.shape, &mut shape_out);
56
57        if let TensorStatus::ReadWrite = status {
58            self.is_orphan = false;
59        }
60
61        TensorDescription {
62            status,
63            shape: shape_out,
64            id: *self.id.as_ref(),
65            dtype: self.dtype,
66        }
67    }
68
69    pub(crate) fn to_description_out(&self) -> TensorDescription {
70        TensorDescription {
71            status: TensorStatus::NotInit,
72            shape: self.shape.clone(),
73            id: *self.id.as_ref(),
74            dtype: self.dtype,
75        }
76    }
77
78    pub(crate) fn status(&self) -> TensorStatus {
79        if Arc::strong_count(&self.id) <= 1 {
80            TensorStatus::ReadWrite
81        } else {
82            TensorStatus::ReadOnly
83        }
84    }
85}
86
87impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
88    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
89        f.write_str(
90            format!(
91                "{{ id: {:?}, shape: {:?}, dtype: {:?}, should_drop: {:?}, device: {:?} }}",
92                self.id,
93                self.shape,
94                self.dtype,
95                self.is_orphan,
96                self.client.device().clone(),
97            )
98            .as_str(),
99        )
100    }
101}
102
103impl<C: RunnerClient> Clone for RouterTensor<C> {
104    fn clone(&self) -> Self {
105        Self {
106            id: self.id.clone(),
107            shape: self.shape.clone(),
108            client: self.client.clone(),
109            dtype: self.dtype,
110            is_orphan: self.is_orphan,
111        }
112    }
113}
114
115impl<C: RunnerClient> Drop for RouterTensor<C> {
116    fn drop(&mut self) {
117        if !self.is_orphan {
118            return;
119        }
120
121        match self.status() {
122            TensorStatus::ReadWrite => {
123                self.client.register_orphan(&self.id);
124            }
125            TensorStatus::ReadOnly => {}
126            TensorStatus::NotInit => {}
127        }
128    }
129}