1use core::sync::atomic::{AtomicU32, Ordering};
2
3use alloc::{sync::Arc, vec::Vec};
4
5use super::RunnerClient;
6use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};
7use burn_ir::{TensorId, TensorIr, TensorStatus};
8
9pub struct RouterTensor<C: RunnerClient> {
11 pub(crate) id: TensorId,
12 pub(crate) shape: Shape,
13 pub(crate) dtype: DType,
14 pub client: C,
16 pub(crate) count: Arc<AtomicU32>,
17}
18
19impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
20 fn dtype(&self) -> DType {
21 self.dtype
22 }
23
24 fn shape(&self) -> Shape {
25 self.shape.clone()
26 }
27
28 fn rank(&self) -> usize {
29 self.shape.num_dims()
30 }
31}
32
33impl<C: RunnerClient> RouterTensor<C> {
34 pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {
36 Self {
37 id,
38 shape,
39 dtype,
40 client,
41 count: Arc::new(AtomicU32::new(1)),
42 }
43 }
44
45 pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {
46 self.client.clone().read_tensor_async(self.into_ir()).await
47 }
48
49 pub fn into_ir(mut self) -> TensorIr {
51 let count = self.count.load(Ordering::Relaxed);
52 let status = self.status(count);
53 let mut shape_out = Shape::from(Vec::<usize>::new());
54 core::mem::swap(&mut self.shape, &mut shape_out);
55
56 if let TensorStatus::ReadWrite = status {
57 self.count.fetch_add(1, Ordering::Relaxed);
62 }
63
64 TensorIr {
65 status,
66 shape: shape_out,
67 id: self.id,
68 dtype: self.dtype,
69 }
70 }
71
72 pub(crate) fn to_ir_out(&self) -> TensorIr {
73 TensorIr {
74 status: TensorStatus::NotInit,
75 shape: self.shape.clone(),
76 id: self.id,
77 dtype: self.dtype,
78 }
79 }
80
81 pub(crate) fn status(&self, count: u32) -> TensorStatus {
82 if count <= 1 {
83 TensorStatus::ReadWrite
84 } else {
85 TensorStatus::ReadOnly
86 }
87 }
88}
89
90impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
91 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
92 f.write_str(
93 format!(
94 "{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}",
95 self.id,
96 self.shape,
97 self.dtype,
98 self.client.device().clone(),
99 )
100 .as_str(),
101 )
102 }
103}
104
105impl<C: RunnerClient> Clone for RouterTensor<C> {
106 fn clone(&self) -> Self {
107 self.count.fetch_add(1, Ordering::Relaxed);
108
109 Self {
110 id: self.id,
111 shape: self.shape.clone(),
112 client: self.client.clone(),
113 dtype: self.dtype,
114 count: self.count.clone(),
115 }
116 }
117}
118
119impl<C: RunnerClient> Drop for RouterTensor<C> {
120 fn drop(&mut self) {
121 let count = self.count.fetch_sub(1, Ordering::Relaxed);
122
123 match self.status(count) {
124 TensorStatus::ReadWrite => {
125 let id = self.id;
126 let mut shape = Shape::from(Vec::<usize>::new());
127 core::mem::swap(&mut shape, &mut self.shape);
128
129 let ir = TensorIr {
130 id,
131 shape,
132 status: TensorStatus::ReadWrite,
133 dtype: self.dtype,
134 };
135 self.client.register_op(burn_ir::OperationIr::Drop(ir));
136 }
137 TensorStatus::ReadOnly => {}
138 TensorStatus::NotInit => {}
139 }
140 }
141}