burn_router/
tensor.rs

1use core::sync::atomic::{AtomicU32, Ordering};
2
3use alloc::{sync::Arc, vec::Vec};
4
5use super::RunnerClient;
6use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};
7use burn_ir::{TensorId, TensorIr, TensorStatus};
8
9/// Tensor primitive for the [router backend](crate::BackendRouter).
10pub struct RouterTensor<C: RunnerClient> {
11    pub(crate) id: TensorId,
12    pub(crate) shape: Shape,
13    pub(crate) dtype: DType,
14    /// The client that has this tensor
15    pub client: C,
16    pub(crate) count: Arc<AtomicU32>,
17}
18
19impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
20    fn dtype(&self) -> DType {
21        self.dtype
22    }
23
24    fn shape(&self) -> Shape {
25        self.shape.clone()
26    }
27
28    fn rank(&self) -> usize {
29        self.shape.num_dims()
30    }
31}
32
33impl<C: RunnerClient> RouterTensor<C> {
34    /// Create a new router tensor.
35    pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {
36        Self {
37            id,
38            shape,
39            dtype,
40            client,
41            count: Arc::new(AtomicU32::new(1)),
42        }
43    }
44
45    pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {
46        self.client.clone().read_tensor_async(self.into_ir()).await
47    }
48
49    /// Get the ir for this tensor
50    pub fn into_ir(mut self) -> TensorIr {
51        let count = self.count.load(Ordering::Relaxed);
52        let status = self.status(count);
53        let mut shape_out = Shape::from(Vec::<usize>::new());
54        core::mem::swap(&mut self.shape, &mut shape_out);
55
56        if let TensorStatus::ReadWrite = status {
57            // Avoids an unwanted drop on the same thread.
58            //
59            // Since `drop` is called after `into_ir`, we must not register a drop if the tensor
60            // was consumed with a `ReadWrite` status.
61            self.count.fetch_add(1, Ordering::Relaxed);
62        }
63
64        TensorIr {
65            status,
66            shape: shape_out,
67            id: self.id,
68            dtype: self.dtype,
69        }
70    }
71
72    pub(crate) fn to_ir_out(&self) -> TensorIr {
73        TensorIr {
74            status: TensorStatus::NotInit,
75            shape: self.shape.clone(),
76            id: self.id,
77            dtype: self.dtype,
78        }
79    }
80
81    pub(crate) fn status(&self, count: u32) -> TensorStatus {
82        if count <= 1 {
83            TensorStatus::ReadWrite
84        } else {
85            TensorStatus::ReadOnly
86        }
87    }
88}
89
90impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
91    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
92        f.write_str(
93            format!(
94                "{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}",
95                self.id,
96                self.shape,
97                self.dtype,
98                self.client.device().clone(),
99            )
100            .as_str(),
101        )
102    }
103}
104
105impl<C: RunnerClient> Clone for RouterTensor<C> {
106    fn clone(&self) -> Self {
107        self.count.fetch_add(1, Ordering::Relaxed);
108
109        Self {
110            id: self.id,
111            shape: self.shape.clone(),
112            client: self.client.clone(),
113            dtype: self.dtype,
114            count: self.count.clone(),
115        }
116    }
117}
118
119impl<C: RunnerClient> Drop for RouterTensor<C> {
120    fn drop(&mut self) {
121        let count = self.count.fetch_sub(1, Ordering::Relaxed);
122
123        match self.status(count) {
124            TensorStatus::ReadWrite => {
125                let id = self.id;
126                let mut shape = Shape::from(Vec::<usize>::new());
127                core::mem::swap(&mut shape, &mut self.shape);
128
129                let ir = TensorIr {
130                    id,
131                    shape,
132                    status: TensorStatus::ReadWrite,
133                    dtype: self.dtype,
134                };
135                self.client.register_op(burn_ir::OperationIr::Drop(ir));
136            }
137            TensorStatus::ReadOnly => {}
138            TensorStatus::NotInit => {}
139        }
140    }
141}