burn_tensor/repr/
handle.rs1use crate::{
2 repr::{
3 backend::ReprBackend,
4 tensor::{TensorDescription, TensorId, TensorStatus},
5 },
6 Shape,
7};
8use alloc::vec::Vec;
9use hashbrown::HashMap;
10
11#[cfg(target_has_atomic = "ptr")]
12use alloc::sync::Arc;
13
14#[cfg(not(target_has_atomic = "ptr"))]
15use portable_atomic_util::Arc;
16
17use super::TensorHandle;
18
19#[derive(Default)]
22pub struct HandleContainer<H> {
23 handles: HashMap<TensorId, Handle<H>>,
24 counter: u64,
25 pub handles_orphan: Vec<TensorId>,
27}
28
29impl<H> core::fmt::Debug for HandleContainer<H> {
30 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31 f.debug_struct("HandleContainer")
32 .field("handles", &self.handles.keys()) .field("counter", &self.counter)
34 .field("handles_orphan", &self.handles_orphan)
35 .finish()
36 }
37}
38
39pub enum Handle<H> {
41 NotInit,
43 Existing(H),
45}
46
47impl<H: Clone> HandleContainer<H> {
48 pub fn new() -> Self {
50 Self {
51 handles: HashMap::new(),
52 handles_orphan: Vec::new(),
53 counter: 0,
54 }
55 }
56
57 pub fn register_handle(&mut self, id: TensorId, handle: H) {
59 self.handles.insert(id, Handle::Existing(handle));
60 }
61
62 pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
70 let (id, handle) = self
71 .handles
72 .remove_entry(id)
73 .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", id));
74
75 match handle {
76 Handle::Existing(handle) => match status {
77 TensorStatus::ReadOnly => {
78 self.handles.insert(id, Handle::Existing(handle.clone()));
79 handle
80 }
81 TensorStatus::ReadWrite => handle,
82 TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."),
83 },
84 Handle::NotInit => panic!("Cannot get uninitialized handle."),
85 }
86 }
87
88 pub fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
90 TensorHandle {
91 handle: self.get_handle(&tensor.id, &tensor.status),
92 shape: Shape::from(&tensor.shape),
93 }
94 }
95
96 pub fn get_float_tensor<B>(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive
99 where
100 B: ReprBackend<Handle = H>,
101 {
102 B::float_tensor(self.get_tensor_handle(tensor))
103 }
104
105 pub fn get_int_tensor<B>(&mut self, tensor: &TensorDescription) -> B::IntTensorPrimitive
108 where
109 B: ReprBackend<Handle = H>,
110 {
111 B::int_tensor(self.get_tensor_handle(tensor))
112 }
113
114 pub fn get_bool_tensor<B>(&mut self, tensor: &TensorDescription) -> B::BoolTensorPrimitive
117 where
118 B: ReprBackend<Handle = H>,
119 {
120 B::bool_tensor(self.get_tensor_handle(tensor))
121 }
122
123 pub fn get_quantized_tensor<B>(
126 &mut self,
127 tensor: &TensorDescription,
128 ) -> B::QuantizedTensorPrimitive
129 where
130 B: ReprBackend<Handle = H>,
131 {
132 B::quantized_tensor(self.get_tensor_handle(tensor))
133 }
134
135 pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
137 where
138 B: ReprBackend<Handle = H>,
139 {
140 let handle = B::float_tensor_handle(tensor);
141 self.handles.insert(*id, Handle::Existing(handle));
142 }
143
144 pub fn register_quantized_tensor<B>(
146 &mut self,
147 id: &TensorId,
148 tensor: B::QuantizedTensorPrimitive,
149 ) where
150 B: ReprBackend<Handle = H>,
151 {
152 let handle = B::quantized_tensor_handle(tensor);
153 self.handles.insert(*id, Handle::Existing(handle));
154 }
155
156 pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
158 where
159 B: ReprBackend<Handle = H>,
160 {
161 let handle = B::int_tensor_handle(tensor);
162 self.handles.insert(*id, Handle::Existing(handle));
163 }
164
165 pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
167 where
168 B: ReprBackend<Handle = H>,
169 {
170 let handle = B::bool_tensor_handle(tensor);
171 self.handles.insert(*id, Handle::Existing(handle));
172 }
173
174 pub fn create_tensor_uninit(&mut self) -> Arc<TensorId> {
176 let id = TensorId::new(self.counter);
177 self.counter += 1;
178 self.handles.insert(id, Handle::NotInit);
179
180 Arc::new(id)
181 }
182
183 pub fn free(&mut self, tensor: &TensorDescription) {
185 match tensor.status {
186 TensorStatus::ReadOnly => (),
187 TensorStatus::NotInit => (),
188 TensorStatus::ReadWrite => {
189 self.handles.remove(&tensor.id);
190 }
191 }
192 }
193
194 pub fn free_orphans(&mut self, remaining: &[&TensorId]) {
196 let mut handles_orphan = Vec::new();
197
198 for id in self.handles_orphan.drain(..) {
200 if remaining.contains(&&id) {
201 handles_orphan.push(id);
202 } else {
203 self.handles.remove(&id);
204 }
205 }
206
207 self.handles_orphan = handles_orphan;
208 }
209}