Skip to main content

forest/state_manager/
cache.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3use crate::blocks::TipsetKey;
4use crate::state_manager::DEFAULT_TIPSET_CACHE_SIZE;
5use crate::utils::ShallowClone;
6use crate::utils::cache::{LruValueConstraints, SizeTrackingLruCache};
7use ahash::{HashMap, HashMapExt as _};
8use parking_lot::RwLock as SyncRwLock;
9use std::future::Future;
10use std::num::NonZeroUsize;
11use std::sync::Arc;
12use tokio::sync::Mutex as TokioMutex;
13
14struct TipsetStateCacheInner<V: LruValueConstraints> {
15    values: SizeTrackingLruCache<TipsetKey, V>,
16    pending: HashMap<TipsetKey, Arc<TokioMutex<()>>>,
17}
18
19impl<V: LruValueConstraints> TipsetStateCacheInner<V> {
20    pub fn with_size(cache_identifier: &str, cache_size: NonZeroUsize) -> Self {
21        Self {
22            values: SizeTrackingLruCache::new_with_metrics(
23                Self::cache_name(cache_identifier).into(),
24                cache_size,
25            ),
26            pending: HashMap::with_capacity(8),
27        }
28    }
29
30    fn cache_name(cache_identifier: &str) -> String {
31        format!("tipset_state_{cache_identifier}")
32    }
33}
34
35/// A generic cache that handles concurrent access and computation for tipset-related data.
36pub(crate) struct TipsetStateCache<V: LruValueConstraints> {
37    cache: Arc<SyncRwLock<TipsetStateCacheInner<V>>>,
38}
39
40enum CacheLookupStatus<V> {
41    Exist(V),
42    Empty(Arc<TokioMutex<()>>),
43}
44
45impl<V: LruValueConstraints> TipsetStateCache<V> {
46    pub fn new(cache_identifier: &str) -> Self {
47        Self::with_size(cache_identifier, DEFAULT_TIPSET_CACHE_SIZE)
48    }
49
50    pub fn with_size(cache_identifier: &str, cache_size: NonZeroUsize) -> Self {
51        Self {
52            cache: Arc::new(SyncRwLock::new(TipsetStateCacheInner::with_size(
53                cache_identifier,
54                cache_size,
55            ))),
56        }
57    }
58
59    fn get_or_insert<F1, F2, T>(&self, get_func: F1, or_insert_func: F2) -> T
60    where
61        F1: FnOnce(&TipsetStateCacheInner<V>) -> Option<T>,
62        F2: FnOnce(&mut TipsetStateCacheInner<V>) -> T,
63    {
64        if let Some(v) = get_func(&self.cache.read()) {
65            v
66        } else {
67            or_insert_func(&mut self.cache.write())
68        }
69    }
70
71    pub async fn get_or_else<F, Fut>(&self, key: &TipsetKey, compute: F) -> anyhow::Result<V>
72    where
73        F: FnOnce() -> Fut,
74        Fut: Future<Output = anyhow::Result<V>> + Send,
75        V: Send + Sync + 'static,
76    {
77        let status = self.get_or_insert(
78            |inner| inner.values.get_cloned(key).map(CacheLookupStatus::Exist),
79            |inner| {
80                let mutex = inner
81                    .pending
82                    .entry(key.clone())
83                    .or_insert_with(|| Arc::new(TokioMutex::new(())))
84                    .shallow_clone();
85                CacheLookupStatus::Empty(mutex)
86            },
87        );
88        match status {
89            CacheLookupStatus::Exist(x) => {
90                crate::metrics::LRU_CACHE_HIT
91                    .get_or_create(&crate::metrics::values::STATE_MANAGER_TIPSET)
92                    .inc();
93                Ok(x)
94            }
95            CacheLookupStatus::Empty(mtx) => {
96                let _guard = mtx.lock().await;
97                match self.get(key) {
98                    Some(v) => {
99                        // While locking someone else computed the pending task
100                        crate::metrics::LRU_CACHE_HIT
101                            .get_or_create(&crate::metrics::values::STATE_MANAGER_TIPSET)
102                            .inc();
103
104                        Ok(v)
105                    }
106                    None => {
107                        // Entry does not have state computed yet, compute value and fill the cache
108                        crate::metrics::LRU_CACHE_MISS
109                            .get_or_create(&crate::metrics::values::STATE_MANAGER_TIPSET)
110                            .inc();
111                        let value = compute().await?;
112                        // Write back to cache, release lock and return value
113                        self.insert(key.clone(), value.clone());
114                        Ok(value)
115                    }
116                }
117            }
118        }
119    }
120
121    pub fn get_map<T>(&self, key: &TipsetKey, mapper: impl Fn(&V) -> T) -> Option<T> {
122        self.cache.read().values.get_map(key, mapper)
123    }
124
125    pub fn get(&self, key: &TipsetKey) -> Option<V> {
126        self.get_map(key, Clone::clone)
127    }
128
129    pub fn insert(&self, key: TipsetKey, value: V) {
130        let mut cache = self.cache.write();
131        cache.pending.retain(|k, _| k != &key);
132        cache.values.push(key, value);
133    }
134
135    pub fn remove(&self, key: &TipsetKey) {
136        let mut cache = self.cache.write();
137        cache.pending.retain(|k, _| k != key);
138        cache.values.remove(key);
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::blocks::TipsetKey;
146    use cid::Cid;
147    use fvm_ipld_encoding::DAG_CBOR;
148    use multihash_derive::MultihashDigest;
149    use std::sync::Arc;
150    use std::sync::atomic::{AtomicU8, Ordering};
151    use std::time::Duration;
152
153    fn create_test_tipset_key(i: u64) -> TipsetKey {
154        let bytes = i.to_le_bytes().to_vec();
155        let cid = Cid::new_v1(
156            DAG_CBOR,
157            crate::utils::multihash::MultihashCode::Blake2b256.digest(&bytes),
158        );
159        TipsetKey::from(nunny::vec![cid])
160    }
161
162    #[tokio::test]
163    async fn test_tipset_cache_basic_functionality() {
164        let cache: TipsetStateCache<String> = TipsetStateCache::new("test");
165        let key = create_test_tipset_key(1);
166
167        // Test cache miss and computation
168        let result = cache
169            .get_or_else(&key, || async { Ok("computed_value".to_string()) })
170            .await
171            .unwrap();
172        assert_eq!(result, "computed_value");
173
174        // Test cache hit
175        let result = cache
176            .get_or_else(&key, || async { Ok("should_not_compute".to_string()) })
177            .await
178            .unwrap();
179        assert_eq!(result, "computed_value");
180    }
181
182    #[tokio::test]
183    async fn test_concurrent_same_key_computation() {
184        let cache: Arc<TipsetStateCache<String>> = Arc::new(TipsetStateCache::new("test"));
185        let key = create_test_tipset_key(1);
186        let computation_count = Arc::new(AtomicU8::new(0));
187
188        // Start multiple tasks that try to compute the same key concurrently
189        let mut handles = vec![];
190        for i in 0..10 {
191            let cache_clone = Arc::clone(&cache);
192            let key_clone = key.clone();
193            let count_clone = Arc::clone(&computation_count);
194
195            let handle = tokio::spawn(async move {
196                cache_clone
197                    .get_or_else(&key_clone, || {
198                        let count = Arc::clone(&count_clone);
199                        async move {
200                            // Increment computation count
201                            count.fetch_add(1, Ordering::SeqCst);
202                            // Simulate some computation time
203                            tokio::time::sleep(Duration::from_millis(10)).await;
204                            Ok(format!("computed_value_{i}"))
205                        }
206                    })
207                    .await
208            });
209            handles.push(handle);
210        }
211
212        let results: Vec<_> = futures::future::join_all(handles)
213            .await
214            .into_iter()
215            .collect::<Result<Vec<_>, _>>()
216            .unwrap();
217
218        // Computation should have been performed once
219        assert_eq!(computation_count.load(Ordering::SeqCst), 1);
220
221        // Only one result should be returned as computation was performed once,
222        // and all tasks will get the same result from the cache
223        let first_result = results[0].as_ref().unwrap();
224        for result in &results {
225            assert_eq!(result.as_ref().unwrap(), first_result);
226        }
227    }
228
229    #[tokio::test]
230    async fn test_concurrent_different_keys() {
231        let cache: Arc<TipsetStateCache<String>> = Arc::new(TipsetStateCache::new("test"));
232        let computation_count = Arc::new(AtomicU8::new(0));
233
234        // Start tasks that try to compute the different keys
235        let mut handles = vec![];
236        for i in 0..10 {
237            let cache_clone = Arc::clone(&cache);
238            let key = create_test_tipset_key(i);
239            let count_clone = Arc::clone(&computation_count);
240
241            let handle = tokio::spawn(async move {
242                cache_clone
243                    .get_or_else(&key, || {
244                        let count = Arc::clone(&count_clone);
245                        async move {
246                            count.fetch_add(1, Ordering::SeqCst);
247                            tokio::time::sleep(Duration::from_millis(5)).await;
248                            Ok(format!("value_{i}"))
249                        }
250                    })
251                    .await
252            });
253            handles.push(handle);
254        }
255
256        let results: Vec<_> = futures::future::join_all(handles)
257            .await
258            .into_iter()
259            .collect::<Result<Vec<_>, _>>()
260            .unwrap();
261
262        // Computation should have been performed for each key
263        assert_eq!(computation_count.load(Ordering::SeqCst), 10);
264
265        // All results should be returned as computation was performed once for each key
266        for (i, result) in results.iter().enumerate() {
267            assert_eq!(result.as_ref().unwrap(), &format!("value_{i}"));
268        }
269    }
270}