1use indexmap::IndexMap;
2use kaspa_utils::mem_size::{MemMode, MemSizeEstimator};
3use parking_lot::RwLock;
4use rand::Rng;
5use std::{collections::hash_map::RandomState, hash::BuildHasher, sync::Arc};
6
7#[derive(Debug, Clone, Copy)]
8pub enum CachePolicy {
9 Empty,
11 Count(usize),
13 Tracked { max_size: usize, min_items: usize, mem_mode: MemMode },
17}
18
19#[derive(Clone)]
20struct CachePolicyInner {
21 tracked: bool,
23 max_size: usize,
26 min_items: usize,
28 mem_mode: MemMode,
30}
31
32impl From<CachePolicy> for CachePolicyInner {
33 fn from(policy: CachePolicy) -> Self {
34 match policy {
35 CachePolicy::Empty => CachePolicyInner { tracked: false, max_size: 0, min_items: 0, mem_mode: MemMode::Undefined },
36 CachePolicy::Count(max_size) => CachePolicyInner { tracked: false, max_size, min_items: 0, mem_mode: MemMode::Undefined },
37 CachePolicy::Tracked { max_size, min_items, mem_mode } => {
38 CachePolicyInner { tracked: true, max_size, min_items, mem_mode }
39 }
40 }
41 }
42}
43
44struct Inner<TKey, TData, S = RandomState>
45where
46 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
47 TData: Clone + Send + Sync + MemSizeEstimator,
48{
49 map: IndexMap<TKey, TData, S>,
51 tracked_size: usize,
52}
53
54impl<TKey, TData, S> Inner<TKey, TData, S>
55where
56 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
57 TData: Clone + Send + Sync + MemSizeEstimator,
58 S: BuildHasher + Default,
59{
60 fn tracked_evict(&mut self, policy: &CachePolicyInner) {
62 while self.tracked_size > policy.max_size && self.map.len() > policy.min_items {
64 if let Some((_, v)) = self.map.swap_remove_index(rand::thread_rng().gen_range(0..self.map.len())) {
65 self.tracked_size -= v.estimate_size(policy.mem_mode)
66 }
67 }
68 }
69
70 fn insert(&mut self, policy: &CachePolicyInner, key: TKey, data: TData) {
71 if policy.tracked {
72 let new_data_size = data.estimate_size(policy.mem_mode);
73 self.tracked_size += new_data_size;
74 if let Some(removed) = self.map.insert(key, data) {
75 self.tracked_size -= removed.estimate_size(policy.mem_mode);
76 }
77 self.tracked_evict(policy);
78 } else {
79 if self.map.len() == policy.max_size {
80 self.map.swap_remove_index(rand::thread_rng().gen_range(0..policy.max_size));
81 }
82 self.map.insert(key, data);
83 }
84 }
85
86 fn update_if_entry_exists<F>(&mut self, policy: &CachePolicyInner, key: TKey, op: F)
87 where
88 F: Fn(&mut TData),
89 {
90 if let Some(data) = self.map.get_mut(&key) {
91 if policy.tracked {
92 self.tracked_size -= data.estimate_size(policy.mem_mode);
93 op(data);
94 self.tracked_size += data.estimate_size(policy.mem_mode);
95 self.tracked_evict(policy);
96 } else {
97 op(data);
98 }
99 }
100 }
101
102 fn remove(&mut self, policy: &CachePolicyInner, key: &TKey) -> Option<TData> {
103 match self.map.swap_remove(key) {
104 Some(data) => {
105 if policy.tracked {
106 self.tracked_size -= data.estimate_size(policy.mem_mode);
107 }
108 Some(data)
109 }
110 None => None,
111 }
112 }
113}
114
115impl<TKey, TData, S> Inner<TKey, TData, S>
116where
117 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
118 TData: Clone + Send + Sync + MemSizeEstimator,
119 S: BuildHasher + Default,
120{
121 pub fn new(prealloc_size: usize) -> Self {
122 Self { map: IndexMap::with_capacity_and_hasher(prealloc_size, S::default()), tracked_size: 0 }
123 }
124}
125
126#[derive(Clone)]
127pub struct Cache<TKey, TData, S = RandomState>
128where
129 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
130 TData: Clone + Send + Sync + MemSizeEstimator,
131{
132 inner: Arc<RwLock<Inner<TKey, TData, S>>>,
133 policy: CachePolicyInner,
134}
135
136impl<TKey, TData, S> Cache<TKey, TData, S>
137where
138 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
139 TData: Clone + Send + Sync + MemSizeEstimator,
140 S: BuildHasher + Default,
141{
142 pub fn new(policy: CachePolicy) -> Self {
143 let policy: CachePolicyInner = policy.into();
144 let prealloc_size = if policy.tracked { 0 } else { policy.max_size }; Self { inner: Arc::new(RwLock::new(Inner::new(prealloc_size))), policy }
146 }
147
148 pub fn get(&self, key: &TKey) -> Option<TData> {
149 self.inner.read().map.get(key).cloned()
150 }
151
152 pub fn contains_key(&self, key: &TKey) -> bool {
153 self.inner.read().map.contains_key(key)
154 }
155
156 pub fn insert(&self, key: TKey, data: TData) {
157 if self.policy.max_size == 0 {
158 return;
159 }
160
161 self.inner.write().insert(&self.policy, key, data);
162 }
163
164 pub fn insert_many(&self, iter: &mut impl Iterator<Item = (TKey, TData)>) {
165 if self.policy.max_size == 0 {
166 return;
167 }
168 let mut inner = self.inner.write();
169 for (key, data) in iter {
170 inner.insert(&self.policy, key, data);
171 }
172 }
173
174 pub fn update_if_entry_exists<F>(&self, key: TKey, op: F)
175 where
176 F: Fn(&mut TData),
177 {
178 if self.policy.max_size == 0 {
179 return;
180 }
181 self.inner.write().update_if_entry_exists(&self.policy, key, op);
182 }
183
184 pub fn remove(&self, key: &TKey) -> Option<TData> {
185 if self.policy.max_size == 0 {
186 return None;
187 }
188 self.inner.write().remove(&self.policy, key)
189 }
190
191 pub fn remove_many(&self, key_iter: &mut impl Iterator<Item = TKey>) {
192 if self.policy.max_size == 0 {
193 return;
194 }
195 let mut inner = self.inner.write();
196 for key in key_iter {
197 inner.remove(&self.policy, &key);
198 }
199 }
200
201 pub fn remove_all(&self) {
202 if self.policy.max_size == 0 {
203 return;
204 }
205 let mut inner = self.inner.write();
206 inner.map.clear();
207 if self.policy.tracked {
208 inner.tracked_size = 0;
209 }
210 }
211}