1use hashbrown::HashMap;
2
3use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};
4
5pub struct HandleContainer<H> {
8 handles: HashMap<TensorId, Handle<H>>,
9 counter: u64,
10}
11
12impl<H> Default for HandleContainer<H> {
14 fn default() -> Self {
15 Self {
16 handles: HashMap::new(),
17 counter: 0,
18 }
19 }
20}
21
22impl<H: Clone> HandleContainer<H> {
23 pub fn fork(&self) -> Self {
25 let mut handles = HashMap::with_capacity(self.handles.len());
26
27 for (id, handle) in self.handles.iter() {
28 handles.insert(*id, handle.clone());
29 }
30
31 Self {
32 handles,
33 counter: self.counter,
34 }
35 }
36}
37
38impl<H> core::fmt::Debug for HandleContainer<H> {
39 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40 f.debug_struct("HandleContainer")
41 .field("handles", &self.handles.keys()) .field("counter", &self.counter)
43 .finish()
44 }
45}
46
47#[derive(Clone)]
49pub enum Handle<H> {
50 NotInit,
52 Existing(H),
54}
55
56impl<H: Clone> HandleContainer<H> {
57 pub fn new() -> Self {
59 Self {
60 handles: HashMap::new(),
61 counter: 0,
62 }
63 }
64
65 pub fn register_handle(&mut self, id: TensorId, handle: H) {
67 self.handles.insert(id, Handle::Existing(handle));
68 }
69
70 pub fn has_handle(&self, id: &TensorId) -> bool {
72 self.handles.contains_key(id)
73 }
74
75 pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {
77 self.handles
78 .get(id)
79 .filter(|h| !matches!(h, Handle::NotInit))
80 .map(|h| match h {
81 Handle::Existing(handle) => handle,
82 Handle::NotInit => unreachable!(),
83 })
84 }
85
86 pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
94 let (id, handle) = self
95 .handles
96 .remove_entry(id)
97 .unwrap_or_else(|| panic!("Should have handle for tensor {id:?}"));
98
99 match handle {
100 Handle::Existing(handle) => match status {
101 TensorStatus::ReadOnly => {
102 self.handles.insert(id, Handle::Existing(handle.clone()));
103 handle
104 }
105 TensorStatus::ReadWrite => handle,
106 TensorStatus::NotInit => panic!(
107 "Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status"
108 ),
109 },
110 Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."),
111 }
112 }
113
114 pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {
116 TensorHandle {
117 handle: self.get_handle(&tensor.id, &tensor.status),
118 shape: tensor.shape.clone(),
119 }
120 }
121
122 pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
125 where
126 B: BackendIr<Handle = H>,
127 {
128 B::float_tensor(self.get_tensor_handle(tensor))
129 }
130
131 pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
134 where
135 B: BackendIr<Handle = H>,
136 {
137 B::int_tensor(self.get_tensor_handle(tensor))
138 }
139
140 pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
143 where
144 B: BackendIr<Handle = H>,
145 {
146 B::bool_tensor(self.get_tensor_handle(tensor))
147 }
148
149 pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive
152 where
153 B: BackendIr<Handle = H>,
154 {
155 B::quantized_tensor(self.get_tensor_handle(tensor))
156 }
157
158 pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
160 where
161 B: BackendIr<Handle = H>,
162 {
163 let handle = B::float_tensor_handle(tensor);
164 self.handles.insert(*id, Handle::Existing(handle));
165 }
166
167 pub fn register_quantized_tensor<B>(
169 &mut self,
170 id: &TensorId,
171 tensor: B::QuantizedTensorPrimitive,
172 ) where
173 B: BackendIr<Handle = H>,
174 {
175 let handle = B::quantized_tensor_handle(tensor);
176 self.handles.insert(*id, Handle::Existing(handle));
177 }
178
179 pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
181 where
182 B: BackendIr<Handle = H>,
183 {
184 let handle = B::int_tensor_handle(tensor);
185 self.handles.insert(*id, Handle::Existing(handle));
186 }
187
188 pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
190 where
191 B: BackendIr<Handle = H>,
192 {
193 let handle = B::bool_tensor_handle(tensor);
194 self.handles.insert(*id, Handle::Existing(handle));
195 }
196
197 pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {
199 self.handles.remove(&id)
200 }
201
202 pub fn free(&mut self, tensor: &TensorIr) {
204 match tensor.status {
205 TensorStatus::ReadOnly => (),
206 TensorStatus::NotInit => (),
207 TensorStatus::ReadWrite => {
208 self.handles.remove(&tensor.id);
209 }
210 };
211 }
212
213 pub fn num_handles(&self) -> usize {
215 self.handles.len()
216 }
217
218 pub fn handle_ids(&self) -> impl Iterator<Item = &'_ TensorId> {
224 self.handles.keys()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::TensorId;
232
233 fn tid(value: u64) -> TensorId {
235 TensorId::new(value)
236 }
237
238 #[test]
239 fn fork_clones_existing_handles() {
240 let mut container = HandleContainer::<String>::new();
241 container.register_handle(tid(1), "input_a".to_string());
242 container.register_handle(tid(2), "input_b".to_string());
243
244 let fork = container.fork();
245
246 assert_eq!(fork.num_handles(), 2);
247 assert!(fork.get_handle_ref(&tid(1)).is_some());
248 assert!(fork.get_handle_ref(&tid(2)).is_some());
249 }
250
251 #[test]
252 fn fork_is_isolated_from_original() {
253 let mut container = HandleContainer::<String>::new();
256 container.register_handle(tid(1), "input_a".to_string());
257
258 let mut fork = container.fork();
259
260 fork.register_handle(tid(100), "output_x".to_string());
262 fork.register_handle(tid(101), "output_y".to_string());
263
264 assert_eq!(fork.num_handles(), 3);
266 assert!(fork.get_handle_ref(&tid(100)).is_some());
267 assert!(fork.get_handle_ref(&tid(101)).is_some());
268
269 assert_eq!(container.num_handles(), 1);
271 assert!(container.get_handle_ref(&tid(100)).is_none());
272 assert!(container.get_handle_ref(&tid(101)).is_none());
273 }
274
275 #[test]
276 fn fork_mutations_do_not_affect_original() {
277 let mut container = HandleContainer::<String>::new();
278 container.register_handle(tid(1), "original_value".to_string());
279
280 let mut fork = container.fork();
281
282 fork.register_handle(tid(1), "modified_in_fork".to_string());
284
285 assert_eq!(
287 container.get_handle_ref(&tid(1)),
288 Some(&"original_value".to_string())
289 );
290 assert_eq!(
291 fork.get_handle_ref(&tid(1)),
292 Some(&"modified_in_fork".to_string())
293 );
294 }
295
296 #[test]
297 fn double_fork_is_fully_isolated() {
298 let mut container = HandleContainer::<String>::new();
301 container.register_handle(tid(1), "input".to_string());
302
303 let fork1 = container.fork();
304 let mut fork2 = fork1.fork();
305
306 fork2.register_handle(tid(200), "deep_output".to_string());
307
308 assert!(fork1.get_handle_ref(&tid(200)).is_none());
309 assert!(container.get_handle_ref(&tid(200)).is_none());
310 assert!(fork2.get_handle_ref(&tid(200)).is_some());
311 }
312}