1use alloc::{sync::Arc, vec::Vec};
2
3use super::RunnerClient;
4use burn_tensor::{
5 repr::{TensorDescription, TensorId, TensorStatus},
6 DType, Shape, TensorData, TensorMetadata,
7};
8
9pub struct RouterTensor<C: RunnerClient> {
11 pub(crate) id: Arc<TensorId>,
12 pub(crate) shape: Vec<usize>,
13 pub(crate) dtype: DType,
14 pub(crate) client: C,
15
16 pub(crate) is_orphan: bool,
21}
22
23impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
24 fn dtype(&self) -> DType {
25 self.dtype
26 }
27
28 fn shape(&self) -> Shape {
29 Shape::from(self.shape.clone())
30 }
31}
32
33impl<C: RunnerClient> RouterTensor<C> {
34 pub fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
36 Self {
37 id,
38 shape,
39 dtype,
40 client,
41 is_orphan: true,
42 }
43 }
44
45 pub(crate) async fn into_data(self) -> TensorData {
46 self.client
47 .clone()
48 .read_tensor(self.into_description())
49 .await
50 }
51
52 pub(crate) fn into_description(mut self) -> TensorDescription {
53 let status = self.status();
54 let mut shape_out = Vec::new();
55 core::mem::swap(&mut self.shape, &mut shape_out);
56
57 if let TensorStatus::ReadWrite = status {
58 self.is_orphan = false;
59 }
60
61 TensorDescription {
62 status,
63 shape: shape_out,
64 id: *self.id.as_ref(),
65 dtype: self.dtype,
66 }
67 }
68
69 pub(crate) fn to_description_out(&self) -> TensorDescription {
70 TensorDescription {
71 status: TensorStatus::NotInit,
72 shape: self.shape.clone(),
73 id: *self.id.as_ref(),
74 dtype: self.dtype,
75 }
76 }
77
78 pub(crate) fn status(&self) -> TensorStatus {
79 if Arc::strong_count(&self.id) <= 1 {
80 TensorStatus::ReadWrite
81 } else {
82 TensorStatus::ReadOnly
83 }
84 }
85}
86
87impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
88 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
89 f.write_str(
90 format!(
91 "{{ id: {:?}, shape: {:?}, dtype: {:?}, should_drop: {:?}, device: {:?} }}",
92 self.id,
93 self.shape,
94 self.dtype,
95 self.is_orphan,
96 self.client.device().clone(),
97 )
98 .as_str(),
99 )
100 }
101}
102
103impl<C: RunnerClient> Clone for RouterTensor<C> {
104 fn clone(&self) -> Self {
105 Self {
106 id: self.id.clone(),
107 shape: self.shape.clone(),
108 client: self.client.clone(),
109 dtype: self.dtype,
110 is_orphan: self.is_orphan,
111 }
112 }
113}
114
115impl<C: RunnerClient> Drop for RouterTensor<C> {
116 fn drop(&mut self) {
117 if !self.is_orphan {
118 return;
119 }
120
121 match self.status() {
122 TensorStatus::ReadWrite => {
123 self.client.register_orphan(&self.id);
124 }
125 TensorStatus::ReadOnly => {}
126 TensorStatus::NotInit => {}
127 }
128 }
129}