Skip to main content

intern_mint/
pool.rs

1use std::{ops::Deref, sync::LazyLock};
2
3use hashbrown::HashTable;
4use parking_lot::{Mutex, MutexGuard};
5use triomphe::Arc;
6
7type LockedShard = HashTable<Arc<[u8]>>;
8type Shard = Mutex<LockedShard>;
9
10#[derive(Debug, Default, Clone, Copy)]
11pub struct MemoryUsage {
12    pub len: usize,
13    pub capacity: usize,
14}
15
16pub(crate) struct ShardedSet {
17    pub(crate) shift: usize,
18    pub(crate) hash_builder: ahash::RandomState,
19    pub(crate) shards: Box<[Shard]>,
20}
21
22impl ShardedSet {
23    fn get_hash_and_shard(&self, value: &[u8]) -> (u64, MutexGuard<'_, LockedShard>) {
24        // hash before locking
25        let hash = self.hash_builder.hash_one(value);
26        // copied from https://github.com/xacrimon/dashmap/blob/366ce7e7872866a06de66eb95002fa6cf2c117a7/src/lib.rs#L419
27        let idx = ((hash << 7) >> self.shift) as usize;
28        let shard = self.shards[idx].lock();
29        (hash, shard)
30    }
31
32    fn hasher(&self, value: &Arc<[u8]>) -> u64 {
33        self.hash_builder.hash_one(value.deref())
34    }
35
36    pub(crate) fn get_from_existing_ref(&self, value: &[u8]) -> Option<Arc<[u8]>> {
37        let (hash, shard) = self.get_hash_and_shard(value);
38        shard
39            .find(hash, |o| std::ptr::addr_eq(o.as_ptr(), value.as_ptr()))
40            .cloned()
41    }
42
43    pub(crate) fn get_or_insert(&self, value: &[u8]) -> Arc<[u8]> {
44        let (hash, mut shard) = self.get_hash_and_shard(value);
45
46        shard
47            .entry(hash, |o| o.deref() == value, |o| self.hasher(o))
48            .or_insert_with(|| Arc::from(value))
49            .get()
50            .clone()
51    }
52
53    /// Only try to remove values from the pool when the reference count is two
54    /// one for the given [value] and another for the reference in the pool
55    pub(crate) fn remove_if_needed(&self, value: &Arc<[u8]>) {
56        // one count for `value` and one for the entry in our pool
57        const MINIMUM_STRONG_COUNT: usize = 2;
58
59        if Arc::strong_count(value) > MINIMUM_STRONG_COUNT {
60            return;
61        }
62
63        let (hash, mut shard) = self.get_hash_and_shard(value);
64
65        let Ok(entry) = shard.find_entry(hash, |o| std::ptr::addr_eq(o.as_ptr(), value.as_ptr()))
66        else {
67            return;
68        };
69
70        // check again in case the value has been cloned
71        if Arc::strong_count(entry.get()) > MINIMUM_STRONG_COUNT {
72            return;
73        }
74
75        entry.remove();
76    }
77
78    pub(crate) fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81
82    pub(crate) fn len(&self) -> usize {
83        self.shards.iter().map(|o| o.lock().len()).sum()
84    }
85
86    pub(crate) fn capacity(&self) -> usize {
87        self.shards.iter().map(|o| o.lock().capacity()).sum()
88    }
89
90    pub(crate) fn get_memory_usage(&self) -> MemoryUsage {
91        self.shards
92            .iter()
93            .map(|o| {
94                let o = o.lock();
95                MemoryUsage {
96                    len: o.len(),
97                    capacity: o.capacity(),
98                }
99            })
100            .reduce(|acc, o| MemoryUsage {
101                len: acc.len + o.len,
102                capacity: acc.capacity + o.capacity,
103            })
104            .unwrap_or_default()
105    }
106
107    pub(crate) fn shrink_to_fit(&self) {
108        for shard in self.shards.iter() {
109            shard.lock().shrink_to_fit(|o| self.hasher(o));
110        }
111    }
112}
113
114impl Default for ShardedSet {
115    fn default() -> Self {
116        // copied from https://github.com/xacrimon/dashmap/blob/366ce7e7872866a06de66eb95002fa6cf2c117a7/src/lib.rs#L63
117        static DEFAULT_SHARDS_COUNT: LazyLock<usize> = LazyLock::new(|| {
118            (std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two()
119        });
120
121        // copied from https://github.com/xacrimon/dashmap/blob/366ce7e7872866a06de66eb95002fa6cf2c117a7/src/lib.rs#L269
122        let shift =
123            (std::mem::size_of::<usize>() * 8) - DEFAULT_SHARDS_COUNT.trailing_zeros() as usize;
124
125        Self {
126            shift,
127            hash_builder: Default::default(),
128            shards: (0..*DEFAULT_SHARDS_COUNT)
129                .map(|_| Default::default())
130                .collect(),
131        }
132    }
133}
134
135pub(crate) static POOL: LazyLock<ShardedSet> = LazyLock::new(Default::default);
136
137pub fn is_empty() -> bool {
138    POOL.is_empty()
139}
140
141pub fn len() -> usize {
142    POOL.len()
143}
144
145pub fn capacity() -> usize {
146    POOL.capacity()
147}
148
149pub fn get_memory_usage() -> MemoryUsage {
150    POOL.get_memory_usage()
151}
152
153pub fn shrink_to_fit() {
154    POOL.shrink_to_fit();
155}