hpt_allocator/storage/
mod.rs

1pub mod cpu;
2#[cfg(feature = "cuda")]
3pub mod cuda;
4
5use std::collections::HashMap;
6
7use crate::ptr::SafePtr;
8
9pub trait Storage {
10    fn increment_ref(&mut self, ptr: SafePtr);
11    fn decrement_ref(&mut self, ptr: SafePtr) -> bool;
12}
13
14#[derive(Debug)]
15pub struct CommonStorage {
16    pub(crate) storage: HashMap<SafePtr, usize>,
17}
18
19impl CommonStorage {
20    pub fn new() -> Self {
21        CommonStorage {
22            storage: HashMap::new(),
23        }
24    }
25}
26
27impl Storage for CommonStorage {
28    fn increment_ref(&mut self, ptr: SafePtr) {
29        if let Some(cnt) = self.storage.get_mut(&ptr) {
30            *cnt = match cnt.checked_add(1) {
31                Some(cnt) => cnt,
32                None => {
33                    panic!(
34                        "Reference count overflow for ptr {:p} in cpu storage",
35                        ptr.ptr
36                    );
37                }
38            };
39        } else {
40            self.storage.insert(ptr, 1);
41        }
42    }
43
44    fn decrement_ref(&mut self, ptr: SafePtr) -> bool {
45        if let Some(cnt) = self.storage.get_mut(&ptr) {
46            *cnt = cnt.checked_sub(1).expect("Reference count underflow");
47            if *cnt == 0 {
48                self.storage.remove(&ptr);
49                true
50            } else {
51                false
52            }
53        } else {
54            false
55        }
56    }
57}
58
59/// # Clone Storage
60///
61/// increment the reference count of the ptr in the storage
62pub fn clone_storage(ptr: *mut u8, device_id: usize, map: &mut HashMap<usize, CommonStorage>) {
63    if let Some(storage) = map.get_mut(&device_id) {
64        storage.increment_ref(SafePtr { ptr });
65    } else {
66        map.insert(device_id, CommonStorage::new());
67        map.get_mut(&device_id)
68            .unwrap()
69            .increment_ref(SafePtr { ptr });
70    }
71}