Skip to main content

forest/utils/cache/
lru.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4use std::{
5    borrow::{Borrow, Cow},
6    fmt::Debug,
7    hash::Hash,
8    num::NonZeroUsize,
9    sync::{
10        Arc,
11        atomic::{AtomicUsize, Ordering},
12    },
13};
14
15use get_size2::GetSize;
16use hashlink::LruCache;
17use parking_lot::RwLock;
18use prometheus_client::{
19    collector::Collector,
20    encoding::{DescriptorEncoder, EncodeMetric},
21    metrics::gauge::Gauge,
22    registry::Unit,
23};
24
25use crate::utils::ShallowClone;
26
27pub trait KeyConstraints:
28    GetSize + Debug + Send + Sync + Hash + PartialEq + Eq + Clone + 'static
29{
30}
31
32impl<T> KeyConstraints for T where
33    T: GetSize + Debug + Send + Sync + Hash + PartialEq + Eq + Clone + 'static
34{
35}
36
37pub trait LruValueConstraints: GetSize + Debug + Send + Sync + Clone + 'static {}
38
39impl<T> LruValueConstraints for T where T: GetSize + Debug + Send + Sync + Clone + 'static {}
40
41#[derive(Debug)]
42pub struct SizeTrackingLruCache<K, V>
43where
44    K: KeyConstraints,
45    V: LruValueConstraints,
46{
47    cache_id: usize,
48    cache_name: Cow<'static, str>,
49    cache: Arc<RwLock<LruCache<K, V>>>,
50}
51
52impl<K, V> ShallowClone for SizeTrackingLruCache<K, V>
53where
54    K: KeyConstraints,
55    V: LruValueConstraints,
56{
57    fn shallow_clone(&self) -> Self {
58        Self {
59            cache_id: self.cache_id,
60            cache_name: self.cache_name.clone(),
61            cache: self.cache.shallow_clone(),
62        }
63    }
64}
65
66impl<K, V> SizeTrackingLruCache<K, V>
67where
68    K: KeyConstraints,
69    V: LruValueConstraints,
70{
71    fn register_metrics(&self) {
72        crate::metrics::register_collector(Box::new(self.shallow_clone()));
73    }
74
75    fn new_inner(cache_name: Cow<'static, str>, capacity: Option<NonZeroUsize>) -> Self {
76        static ID_GENERATOR: AtomicUsize = AtomicUsize::new(0);
77
78        Self {
79            cache_id: ID_GENERATOR.fetch_add(1, Ordering::Relaxed),
80            cache_name,
81            #[allow(clippy::disallowed_methods)]
82            cache: Arc::new(RwLock::new(
83                capacity
84                    .map(From::from)
85                    .map(LruCache::new)
86                    // For constructing lru cache that is bounded by memory usage instead of length
87                    .unwrap_or_else(LruCache::new_unbounded),
88            )),
89        }
90    }
91
92    pub fn new_without_metrics_registry(
93        cache_name: Cow<'static, str>,
94        capacity: NonZeroUsize,
95    ) -> Self {
96        Self::new_inner(cache_name, Some(capacity))
97    }
98
99    pub fn new_with_metrics(cache_name: Cow<'static, str>, capacity: NonZeroUsize) -> Self {
100        let c = Self::new_without_metrics_registry(cache_name, capacity);
101        c.register_metrics();
102        c
103    }
104
105    pub fn unbounded_without_metrics_registry(cache_name: Cow<'static, str>) -> Self {
106        Self::new_inner(cache_name, None)
107    }
108
109    pub fn unbounded_with_metrics(cache_name: Cow<'static, str>) -> Self {
110        let c = Self::unbounded_without_metrics_registry(cache_name);
111        c.register_metrics();
112        c
113    }
114
115    pub fn cache(&self) -> &Arc<RwLock<LruCache<K, V>>> {
116        &self.cache
117    }
118
119    pub fn remove<Q>(&self, k: &Q) -> Option<V>
120    where
121        K: Borrow<Q>,
122        Q: Hash + Eq + ?Sized,
123    {
124        self.cache.write().remove(k)
125    }
126
127    pub fn push(&self, k: K, v: V) -> Option<V> {
128        self.cache.write().insert(k, v)
129    }
130
131    pub fn get_map<Q, T>(&self, k: &Q, mapper: impl Fn(&V) -> T) -> Option<T>
132    where
133        K: Borrow<Q>,
134        Q: Hash + Eq + ?Sized,
135    {
136        self.cache.write().get(k).map(mapper)
137    }
138
139    pub fn get_cloned<Q>(&self, k: &Q) -> Option<V>
140    where
141        K: Borrow<Q>,
142        Q: Hash + Eq + ?Sized,
143    {
144        self.get_map(k, Clone::clone)
145    }
146
147    pub fn peek_cloned<Q>(&self, k: &Q) -> Option<V>
148    where
149        K: Borrow<Q>,
150        Q: Hash + Eq + ?Sized,
151    {
152        self.cache.read().peek(k).cloned()
153    }
154
155    pub fn pop_lru(&self) -> Option<(K, V)> {
156        self.cache.write().remove_lru()
157    }
158
159    pub fn len(&self) -> usize {
160        self.cache.read().len()
161    }
162
163    pub fn cap(&self) -> usize {
164        self.cache.read().capacity()
165    }
166
167    pub fn clear(&self) {
168        self.cache.write().clear()
169    }
170
171    pub(crate) fn size_in_bytes(&self) -> usize {
172        let mut size = 0_usize;
173        for (k, v) in self.cache.read().iter() {
174            size = size
175                .saturating_add(k.get_size())
176                .saturating_add(v.get_size());
177        }
178        size
179    }
180
181    #[cfg(test)]
182    pub(crate) fn new_mocked() -> Self {
183        Self::new_inner(Cow::Borrowed("mocked_cache"), NonZeroUsize::new(1))
184    }
185}
186
187impl<K, V> Collector for SizeTrackingLruCache<K, V>
188where
189    K: KeyConstraints,
190    V: LruValueConstraints,
191{
192    fn encode(&self, mut encoder: DescriptorEncoder) -> Result<(), std::fmt::Error> {
193        {
194            let size_in_bytes = {
195                let g: Gauge = Default::default();
196                g.set(self.size_in_bytes() as _);
197                g
198            };
199            let size_metric_name = format!("cache_{}_{}_size", self.cache_name, self.cache_id);
200            let size_metric_help = format!(
201                "Size of LruCache {}_{} in bytes",
202                self.cache_name, self.cache_id
203            );
204            let size_metric_encoder = encoder.encode_descriptor(
205                &size_metric_name,
206                &size_metric_help,
207                Some(&Unit::Bytes),
208                size_in_bytes.metric_type(),
209            )?;
210            size_in_bytes.encode(size_metric_encoder)?;
211        }
212        {
213            let len_metric_name = format!("{}_{}_len", self.cache_name, self.cache_id);
214            let len_metric_help =
215                format!("Length of LruCache {}_{}", self.cache_name, self.cache_id);
216            let len: Gauge = Default::default();
217            len.set(self.len() as _);
218            let len_metric_encoder = encoder.encode_descriptor(
219                &len_metric_name,
220                &len_metric_help,
221                None,
222                len.metric_type(),
223            )?;
224            len.encode(len_metric_encoder)?;
225        }
226        {
227            let cap_metric_name = format!("{}_{}_cap", self.cache_name, self.cache_id);
228            let cap_metric_help =
229                format!("Capacity of LruCache {}_{}", self.cache_name, self.cache_id);
230            let cap: Gauge = Default::default();
231            cap.set(self.cap() as _);
232            let cap_metric_encoder = encoder.encode_descriptor(
233                &cap_metric_name,
234                &cap_metric_help,
235                None,
236                cap.metric_type(),
237            )?;
238            cap.encode(cap_metric_encoder)?;
239        }
240
241        Ok(())
242    }
243}