1use hashbrown::HashMap;
2
3use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};
4
5#[derive(Default)]
8pub struct HandleContainer<H> {
9 handles: HashMap<TensorId, Handle<H>>,
10 counter: u64,
11}
12
13impl<H: Clone> HandleContainer<H> {
14 pub fn fork(&self) -> Self {
16 let mut handles = HashMap::with_capacity(self.handles.len());
17
18 for (id, handle) in self.handles.iter() {
19 handles.insert(*id, handle.clone());
20 }
21
22 Self {
23 handles,
24 counter: self.counter,
25 }
26 }
27}
28
29impl<H> core::fmt::Debug for HandleContainer<H> {
30 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31 f.debug_struct("HandleContainer")
32 .field("handles", &self.handles.keys()) .field("counter", &self.counter)
34 .finish()
35 }
36}
37
38#[derive(Clone)]
40pub enum Handle<H> {
41 NotInit,
43 Existing(H),
45}
46
47impl<H: Clone> HandleContainer<H> {
48 pub fn new() -> Self {
50 Self {
51 handles: HashMap::new(),
52 counter: 0,
53 }
54 }
55
56 pub fn register_handle(&mut self, id: TensorId, handle: H) {
58 self.handles.insert(id, Handle::Existing(handle));
59 }
60
61 pub fn has_handle(&mut self, id: &TensorId) -> bool {
63 self.handles.contains_key(id)
64 }
65
66 pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {
68 self.handles
69 .get(id)
70 .filter(|h| !matches!(h, Handle::NotInit))
71 .map(|h| match h {
72 Handle::Existing(handle) => handle,
73 Handle::NotInit => unreachable!(),
74 })
75 }
76
77 pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
85 let (id, handle) = self
86 .handles
87 .remove_entry(id)
88 .unwrap_or_else(|| panic!("Should have handle for tensor {id:?}"));
89
90 match handle {
91 Handle::Existing(handle) => match status {
92 TensorStatus::ReadOnly => {
93 self.handles.insert(id, Handle::Existing(handle.clone()));
94 handle
95 }
96 TensorStatus::ReadWrite => handle,
97 TensorStatus::NotInit => panic!(
98 "Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status"
99 ),
100 },
101 Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."),
102 }
103 }
104
105 pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {
107 TensorHandle {
108 handle: self.get_handle(&tensor.id, &tensor.status),
109 shape: tensor.shape.clone(),
110 }
111 }
112
113 pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
116 where
117 B: BackendIr<Handle = H>,
118 {
119 B::float_tensor(self.get_tensor_handle(tensor))
120 }
121
122 pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
125 where
126 B: BackendIr<Handle = H>,
127 {
128 B::int_tensor(self.get_tensor_handle(tensor))
129 }
130
131 pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
134 where
135 B: BackendIr<Handle = H>,
136 {
137 B::bool_tensor(self.get_tensor_handle(tensor))
138 }
139
140 pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive
143 where
144 B: BackendIr<Handle = H>,
145 {
146 B::quantized_tensor(self.get_tensor_handle(tensor))
147 }
148
149 pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
151 where
152 B: BackendIr<Handle = H>,
153 {
154 let handle = B::float_tensor_handle(tensor);
155 self.handles.insert(*id, Handle::Existing(handle));
156 }
157
158 pub fn register_quantized_tensor<B>(
160 &mut self,
161 id: &TensorId,
162 tensor: B::QuantizedTensorPrimitive,
163 ) where
164 B: BackendIr<Handle = H>,
165 {
166 let handle = B::quantized_tensor_handle(tensor);
167 self.handles.insert(*id, Handle::Existing(handle));
168 }
169
170 pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
172 where
173 B: BackendIr<Handle = H>,
174 {
175 let handle = B::int_tensor_handle(tensor);
176 self.handles.insert(*id, Handle::Existing(handle));
177 }
178
179 pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
181 where
182 B: BackendIr<Handle = H>,
183 {
184 let handle = B::bool_tensor_handle(tensor);
185 self.handles.insert(*id, Handle::Existing(handle));
186 }
187
188 pub fn create_tensor_uninit(&mut self) -> TensorId {
190 let id = TensorId::new(self.counter);
191 self.counter += 1;
192 self.handles.insert(id, Handle::NotInit);
193 id
194 }
195
196 pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {
198 self.handles.remove(&id)
199 }
200
201 pub fn free(&mut self, tensor: &TensorIr) {
203 match tensor.status {
204 TensorStatus::ReadOnly => (),
205 TensorStatus::NotInit => (),
206 TensorStatus::ReadWrite => {
207 self.handles.remove(&tensor.id);
208 }
209 };
210 }
211
212 pub fn num_handles(&self) -> usize {
214 self.handles.len()
215 }
216}