kaspa_database/
cache.rs

1use indexmap::IndexMap;
2use kaspa_utils::mem_size::{MemMode, MemSizeEstimator};
3use parking_lot::RwLock;
4use rand::Rng;
5use std::{collections::hash_map::RandomState, hash::BuildHasher, sync::Arc};
6
7#[derive(Debug, Clone, Copy)]
8pub enum CachePolicy {
9    /// An empty cache (avoids acquiring locks etc so considered perf-free)
10    Empty,
11    /// The cache bounds the number of items it holds w/o tracking their inner size
12    Count(usize),
13    /// Items are tracked by size with a `max_size` limit overall. The cache will pass this limit
14    /// if there are no more than `min_items` items in the cache. `mem_mode` determines whether
15    /// items are tracked by bytes or by units
16    Tracked { max_size: usize, min_items: usize, mem_mode: MemMode },
17}
18
19#[derive(Clone)]
20struct CachePolicyInner {
21    /// Indicates if this cache was set to be tracked.
22    tracked: bool,
23    /// The max size of this cache. Size units are bytes or a logical unit depending on `mem_mode`.
24    /// The implementation of `MemSizeEstimator` is expected to support the provided mode.
25    max_size: usize,
26    /// Minimum number of items to keep in the cache even if passing tracked size limit.
27    min_items: usize,
28    /// Indicates whether tracking is in bytes mode, units mode or undefined
29    mem_mode: MemMode,
30}
31
32impl From<CachePolicy> for CachePolicyInner {
33    fn from(policy: CachePolicy) -> Self {
34        match policy {
35            CachePolicy::Empty => CachePolicyInner { tracked: false, max_size: 0, min_items: 0, mem_mode: MemMode::Undefined },
36            CachePolicy::Count(max_size) => CachePolicyInner { tracked: false, max_size, min_items: 0, mem_mode: MemMode::Undefined },
37            CachePolicy::Tracked { max_size, min_items, mem_mode } => {
38                CachePolicyInner { tracked: true, max_size, min_items, mem_mode }
39            }
40        }
41    }
42}
43
44struct Inner<TKey, TData, S = RandomState>
45where
46    TKey: Clone + std::hash::Hash + Eq + Send + Sync,
47    TData: Clone + Send + Sync + MemSizeEstimator,
48{
49    // We use IndexMap and not HashMap because it makes it cheaper to remove a random element when the cache is full.
50    map: IndexMap<TKey, TData, S>,
51    tracked_size: usize,
52}
53
54impl<TKey, TData, S> Inner<TKey, TData, S>
55where
56    TKey: Clone + std::hash::Hash + Eq + Send + Sync,
57    TData: Clone + Send + Sync + MemSizeEstimator,
58    S: BuildHasher + Default,
59{
60    /// Evicts items until meeting cache policy requirements (in tracked mode)
61    fn tracked_evict(&mut self, policy: &CachePolicyInner) {
62        // We allow passing tracked size limit as long as there are no more than min_items items
63        while self.tracked_size > policy.max_size && self.map.len() > policy.min_items {
64            if let Some((_, v)) = self.map.swap_remove_index(rand::thread_rng().gen_range(0..self.map.len())) {
65                self.tracked_size -= v.estimate_size(policy.mem_mode)
66            }
67        }
68    }
69
70    fn insert(&mut self, policy: &CachePolicyInner, key: TKey, data: TData) {
71        if policy.tracked {
72            let new_data_size = data.estimate_size(policy.mem_mode);
73            self.tracked_size += new_data_size;
74            if let Some(removed) = self.map.insert(key, data) {
75                self.tracked_size -= removed.estimate_size(policy.mem_mode);
76            }
77            self.tracked_evict(policy);
78        } else {
79            if self.map.len() == policy.max_size {
80                self.map.swap_remove_index(rand::thread_rng().gen_range(0..policy.max_size));
81            }
82            self.map.insert(key, data);
83        }
84    }
85
86    fn update_if_entry_exists<F>(&mut self, policy: &CachePolicyInner, key: TKey, op: F)
87    where
88        F: Fn(&mut TData),
89    {
90        if let Some(data) = self.map.get_mut(&key) {
91            if policy.tracked {
92                self.tracked_size -= data.estimate_size(policy.mem_mode);
93                op(data);
94                self.tracked_size += data.estimate_size(policy.mem_mode);
95                self.tracked_evict(policy);
96            } else {
97                op(data);
98            }
99        }
100    }
101
102    fn remove(&mut self, policy: &CachePolicyInner, key: &TKey) -> Option<TData> {
103        match self.map.swap_remove(key) {
104            Some(data) => {
105                if policy.tracked {
106                    self.tracked_size -= data.estimate_size(policy.mem_mode);
107                }
108                Some(data)
109            }
110            None => None,
111        }
112    }
113}
114
115impl<TKey, TData, S> Inner<TKey, TData, S>
116where
117    TKey: Clone + std::hash::Hash + Eq + Send + Sync,
118    TData: Clone + Send + Sync + MemSizeEstimator,
119    S: BuildHasher + Default,
120{
121    pub fn new(prealloc_size: usize) -> Self {
122        Self { map: IndexMap::with_capacity_and_hasher(prealloc_size, S::default()), tracked_size: 0 }
123    }
124}
125
126#[derive(Clone)]
127pub struct Cache<TKey, TData, S = RandomState>
128where
129    TKey: Clone + std::hash::Hash + Eq + Send + Sync,
130    TData: Clone + Send + Sync + MemSizeEstimator,
131{
132    inner: Arc<RwLock<Inner<TKey, TData, S>>>,
133    policy: CachePolicyInner,
134}
135
136impl<TKey, TData, S> Cache<TKey, TData, S>
137where
138    TKey: Clone + std::hash::Hash + Eq + Send + Sync,
139    TData: Clone + Send + Sync + MemSizeEstimator,
140    S: BuildHasher + Default,
141{
142    pub fn new(policy: CachePolicy) -> Self {
143        let policy: CachePolicyInner = policy.into();
144        let prealloc_size = if policy.tracked { 0 } else { policy.max_size }; // TODO: estimate prealloc also in tracked mode
145        Self { inner: Arc::new(RwLock::new(Inner::new(prealloc_size))), policy }
146    }
147
148    pub fn get(&self, key: &TKey) -> Option<TData> {
149        self.inner.read().map.get(key).cloned()
150    }
151
152    pub fn contains_key(&self, key: &TKey) -> bool {
153        self.inner.read().map.contains_key(key)
154    }
155
156    pub fn insert(&self, key: TKey, data: TData) {
157        if self.policy.max_size == 0 {
158            return;
159        }
160
161        self.inner.write().insert(&self.policy, key, data);
162    }
163
164    pub fn insert_many(&self, iter: &mut impl Iterator<Item = (TKey, TData)>) {
165        if self.policy.max_size == 0 {
166            return;
167        }
168        let mut inner = self.inner.write();
169        for (key, data) in iter {
170            inner.insert(&self.policy, key, data);
171        }
172    }
173
174    pub fn update_if_entry_exists<F>(&self, key: TKey, op: F)
175    where
176        F: Fn(&mut TData),
177    {
178        if self.policy.max_size == 0 {
179            return;
180        }
181        self.inner.write().update_if_entry_exists(&self.policy, key, op);
182    }
183
184    pub fn remove(&self, key: &TKey) -> Option<TData> {
185        if self.policy.max_size == 0 {
186            return None;
187        }
188        self.inner.write().remove(&self.policy, key)
189    }
190
191    pub fn remove_many(&self, key_iter: &mut impl Iterator<Item = TKey>) {
192        if self.policy.max_size == 0 {
193            return;
194        }
195        let mut inner = self.inner.write();
196        for key in key_iter {
197            inner.remove(&self.policy, &key);
198        }
199    }
200
201    pub fn remove_all(&self) {
202        if self.policy.max_size == 0 {
203            return;
204        }
205        let mut inner = self.inner.write();
206        inner.map.clear();
207        if self.policy.tracked {
208            inner.tracked_size = 0;
209        }
210    }
211}