hpt_allocator/storage/
mod.rs1pub mod cpu;
2#[cfg(feature = "cuda")]
3pub mod cuda;
4
5use std::collections::HashMap;
6
7use crate::ptr::SafePtr;
8
9pub trait Storage {
10 fn increment_ref(&mut self, ptr: SafePtr);
11 fn decrement_ref(&mut self, ptr: SafePtr) -> bool;
12}
13
14#[derive(Debug)]
15pub struct CommonStorage {
16 pub(crate) storage: HashMap<SafePtr, usize>,
17}
18
19impl CommonStorage {
20 pub fn new() -> Self {
21 CommonStorage {
22 storage: HashMap::new(),
23 }
24 }
25}
26
27impl Storage for CommonStorage {
28 fn increment_ref(&mut self, ptr: SafePtr) {
29 if let Some(cnt) = self.storage.get_mut(&ptr) {
30 *cnt = match cnt.checked_add(1) {
31 Some(cnt) => cnt,
32 None => {
33 panic!(
34 "Reference count overflow for ptr {:p} in cpu storage",
35 ptr.ptr
36 );
37 }
38 };
39 } else {
40 self.storage.insert(ptr, 1);
41 }
42 }
43
44 fn decrement_ref(&mut self, ptr: SafePtr) -> bool {
45 if let Some(cnt) = self.storage.get_mut(&ptr) {
46 *cnt = cnt.checked_sub(1).expect("Reference count underflow");
47 if *cnt == 0 {
48 self.storage.remove(&ptr);
49 true
50 } else {
51 false
52 }
53 } else {
54 false
55 }
56 }
57}
58
59pub fn clone_storage(ptr: *mut u8, device_id: usize, map: &mut HashMap<usize, CommonStorage>) {
63 if let Some(storage) = map.get_mut(&device_id) {
64 storage.increment_ref(SafePtr { ptr });
65 } else {
66 map.insert(device_id, CommonStorage::new());
67 map.get_mut(&device_id)
68 .unwrap()
69 .increment_ref(SafePtr { ptr });
70 }
71}