lumina_node/store/
in_memory_store.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::fmt::Display;
4use std::pin::pin;
5
6use async_trait::async_trait;
7use celestia_types::ExtendedHeader;
8use celestia_types::hash::Hash;
9use cid::Cid;
10use libp2p::identity::Keypair;
11use tokio::sync::{Notify, RwLock};
12use tracing::debug;
13
14use crate::block_ranges::BlockRanges;
15use crate::store::utils::VerifiedExtendedHeaders;
16use crate::store::{Result, SamplingMetadata, Store, StoreError, StoreInsertionError};
17
18/// A non-persistent in memory [`Store`] implementation.
19#[derive(Debug)]
20pub struct InMemoryStore {
21    /// Mutable part
22    inner: RwLock<InMemoryStoreInner>,
23    /// Notify when a new header is added
24    header_added_notifier: Notify,
25    /// Node network identity
26    libp2p_identity: Keypair,
27}
28
29#[derive(Debug, Clone)]
30struct InMemoryStoreInner {
31    /// Maps header Hash to the header itself, responsible for actually storing the header data
32    headers: HashMap<Hash, ExtendedHeader>,
33    /// Maps header height to its hash, in case we need to do lookup by height
34    height_to_hash: HashMap<u64, Hash>,
35    /// Source of truth about headers present in the db, used to synchronise inserts
36    header_ranges: BlockRanges,
37    /// Maps header height to the header sampling metadata
38    sampling_data: HashMap<u64, SamplingMetadata>,
39    /// Source of truth about sampled ranges present in the db.
40    sampled_ranges: BlockRanges,
41    /// Source of truth about ranges that were pruned from db.
42    pruned_ranges: BlockRanges,
43}
44
45impl InMemoryStoreInner {
46    fn new() -> Self {
47        Self {
48            headers: HashMap::new(),
49            height_to_hash: HashMap::new(),
50            header_ranges: BlockRanges::default(),
51            sampling_data: HashMap::new(),
52            sampled_ranges: BlockRanges::default(),
53            pruned_ranges: BlockRanges::default(),
54        }
55    }
56}
57
58impl InMemoryStore {
59    /// Create a new store.
60    pub fn new() -> Self {
61        InMemoryStore {
62            inner: RwLock::new(InMemoryStoreInner::new()),
63            header_added_notifier: Notify::new(),
64            libp2p_identity: Keypair::generate_ed25519(),
65        }
66    }
67
68    #[inline]
69    async fn get_head_height(&self) -> Result<u64> {
70        self.inner.read().await.get_head_height()
71    }
72
73    async fn get_head(&self) -> Result<ExtendedHeader> {
74        let head_height = self.get_head_height().await?;
75        self.get_by_height(head_height).await
76    }
77
78    async fn contains_hash(&self, hash: &Hash) -> bool {
79        self.inner.read().await.contains_hash(hash)
80    }
81
82    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
83        self.inner.read().await.get_by_hash(hash)
84    }
85
86    async fn contains_height(&self, height: u64) -> bool {
87        self.inner.read().await.contains_height(height)
88    }
89
90    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
91        self.inner.read().await.get_by_height(height)
92    }
93
94    pub(crate) async fn insert<R>(&self, headers: R) -> Result<()>
95    where
96        R: TryInto<VerifiedExtendedHeaders> + Send,
97        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
98    {
99        let headers = headers
100            .try_into()
101            .map_err(|e| StoreInsertionError::HeadersVerificationFailed(e.to_string()))?;
102
103        self.inner.write().await.insert(headers).await?;
104        self.header_added_notifier.notify_waiters();
105
106        Ok(())
107    }
108
109    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
110        self.inner
111            .write()
112            .await
113            .update_sampling_metadata(height, cids)
114            .await
115    }
116
117    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
118        self.inner.read().await.get_sampling_metadata(height).await
119    }
120
121    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
122        self.inner.write().await.mark_as_sampled(height).await
123    }
124
125    async fn get_stored_ranges(&self) -> BlockRanges {
126        self.inner.read().await.header_ranges.clone()
127    }
128
129    async fn get_sampled_ranges(&self) -> BlockRanges {
130        self.inner.read().await.sampled_ranges.clone()
131    }
132
133    async fn get_pruned_ranges(&self) -> BlockRanges {
134        self.inner.read().await.pruned_ranges.clone()
135    }
136
137    /// Clone the store and all its contents, except for libp2p identity, which is re-generated.
138    /// Async fn due to internal use of async mutex.
139    pub async fn async_clone(&self) -> Self {
140        InMemoryStore {
141            inner: RwLock::new(self.inner.read().await.clone()),
142            header_added_notifier: Notify::new(),
143            libp2p_identity: Keypair::generate_ed25519(),
144        }
145    }
146
147    async fn remove_height(&self, height: u64) -> Result<()> {
148        let mut inner = self.inner.write().await;
149        inner.remove_height(height)
150    }
151
152    async fn get_identity(&self) -> Result<Keypair> {
153        Ok(self.libp2p_identity.clone())
154    }
155}
156
157impl InMemoryStoreInner {
158    #[inline]
159    fn get_head_height(&self) -> Result<u64> {
160        self.header_ranges.head().ok_or(StoreError::NotFound)
161    }
162
163    fn contains_hash(&self, hash: &Hash) -> bool {
164        self.headers.contains_key(hash)
165    }
166
167    fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
168        self.headers.get(hash).cloned().ok_or(StoreError::NotFound)
169    }
170
171    fn contains_height(&self, height: u64) -> bool {
172        self.header_ranges.contains(height)
173    }
174
175    fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
176        let Some(hash) = self.height_to_hash.get(&height).copied() else {
177            return Err(StoreError::NotFound);
178        };
179
180        Ok(self
181            .headers
182            .get(&hash)
183            .expect("inconsistent between header hash and header heights")
184            .to_owned())
185    }
186
187    async fn insert(&mut self, headers: VerifiedExtendedHeaders) -> Result<()> {
188        let (Some(head), Some(tail)) = (headers.as_ref().first(), headers.as_ref().last()) else {
189            return Ok(());
190        };
191
192        let headers_range = head.height().value()..=tail.height().value();
193        let (prev_exists, next_exists) = self
194            .header_ranges
195            .check_insertion_constraints(&headers_range)
196            .map_err(StoreInsertionError::ContraintsNotMet)?;
197
198        // header range is already internally verified against itself in `P2p::get_unverified_header_ranges`
199        self.verify_against_neighbours(prev_exists.then_some(head), next_exists.then_some(tail))?;
200
201        for header in headers.into_iter() {
202            let hash = header.hash();
203            let height = header.height().value();
204
205            debug_assert!(
206                !self.height_to_hash.contains_key(&height),
207                "inconsistency between headers table and ranges table"
208            );
209
210            let Entry::Vacant(headers_entry) = self.headers.entry(hash) else {
211                // TODO: Remove this when we implement type-safe validation on insertion.
212                return Err(StoreInsertionError::HashExists(hash).into());
213            };
214
215            debug!("Inserting header {hash} with height {height}");
216            headers_entry.insert(header);
217            self.height_to_hash.insert(height, hash);
218        }
219
220        self.header_ranges
221            .insert_relaxed(&headers_range)
222            .expect("invalid range");
223        self.sampled_ranges
224            .remove_relaxed(&headers_range)
225            .expect("invalid range");
226        self.pruned_ranges
227            .remove_relaxed(&headers_range)
228            .expect("invalid range");
229
230        Ok(())
231    }
232
233    fn verify_against_neighbours(
234        &self,
235        lowest_header: Option<&ExtendedHeader>,
236        highest_header: Option<&ExtendedHeader>,
237    ) -> Result<()> {
238        if let Some(lowest_header) = lowest_header {
239            let prev = self
240                .get_by_height(lowest_header.height().value() - 1)
241                .map_err(|e| match e {
242                    StoreError::NotFound => {
243                        panic!("inconsistency between headers and ranges table")
244                    }
245                    e => e,
246                })?;
247
248            prev.verify(lowest_header)
249                .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
250        }
251
252        if let Some(highest_header) = highest_header {
253            let next = self
254                .get_by_height(highest_header.height().value() + 1)
255                .map_err(|e| match e {
256                    StoreError::NotFound => {
257                        panic!("inconsistency between headers and ranges table")
258                    }
259                    e => e,
260                })?;
261
262            highest_header
263                .verify(&next)
264                .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
265        }
266
267        Ok(())
268    }
269
270    async fn update_sampling_metadata(&mut self, height: u64, cids: Vec<Cid>) -> Result<()> {
271        if !self.contains_height(height) {
272            return Err(StoreError::NotFound);
273        }
274
275        match self.sampling_data.entry(height) {
276            Entry::Vacant(entry) => {
277                entry.insert(SamplingMetadata { cids });
278            }
279            Entry::Occupied(mut entry) => {
280                let metadata = entry.get_mut();
281
282                for cid in cids {
283                    if !metadata.cids.contains(&cid) {
284                        metadata.cids.push(cid);
285                    }
286                }
287            }
288        }
289
290        Ok(())
291    }
292
293    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
294        if !self.contains_height(height) {
295            return Err(StoreError::NotFound);
296        }
297
298        let Some(metadata) = self.sampling_data.get(&height) else {
299            return Ok(None);
300        };
301
302        Ok(Some(metadata.clone()))
303    }
304
305    async fn mark_as_sampled(&mut self, height: u64) -> Result<()> {
306        if !self.contains_height(height) {
307            return Err(StoreError::NotFound);
308        }
309
310        self.sampled_ranges
311            .insert_relaxed(height..=height)
312            .expect("invalid height");
313
314        Ok(())
315    }
316
317    fn remove_height(&mut self, height: u64) -> Result<()> {
318        if !self.header_ranges.contains(height) {
319            return Err(StoreError::NotFound);
320        }
321
322        let Entry::Occupied(height_to_hash) = self.height_to_hash.entry(height) else {
323            return Err(StoreError::StoredDataError(format!(
324                "inconsistency between ranges and height_to_hash tables, height {height}"
325            )));
326        };
327
328        let hash = height_to_hash.get();
329        let Entry::Occupied(header) = self.headers.entry(*hash) else {
330            return Err(StoreError::StoredDataError(format!(
331                "inconsistency between header and height_to_hash tables, hash {hash}"
332            )));
333        };
334
335        // sampling data may or may not be there
336        self.sampling_data.remove(&height);
337
338        height_to_hash.remove_entry();
339        header.remove_entry();
340
341        self.header_ranges
342            .remove_relaxed(height..=height)
343            .expect("invalid height");
344        self.sampled_ranges
345            .remove_relaxed(height..=height)
346            .expect("invalid height");
347        self.pruned_ranges
348            .insert_relaxed(height..=height)
349            .expect("invalid height");
350
351        Ok(())
352    }
353}
354
355#[async_trait]
356impl Store for InMemoryStore {
357    async fn get_head(&self) -> Result<ExtendedHeader> {
358        self.get_head().await
359    }
360
361    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
362        self.get_by_hash(hash).await
363    }
364
365    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
366        self.get_by_height(height).await
367    }
368
369    async fn wait_new_head(&self) -> u64 {
370        let head = self.get_head_height().await.unwrap_or(0);
371        let mut notifier = pin!(self.header_added_notifier.notified());
372
373        loop {
374            let new_head = self.get_head_height().await.unwrap_or(0);
375
376            if head != new_head {
377                return new_head;
378            }
379
380            // Await for a notification
381            notifier.as_mut().await;
382
383            // Reset notifier
384            notifier.set(self.header_added_notifier.notified());
385        }
386    }
387
388    async fn wait_height(&self, height: u64) -> Result<()> {
389        let mut notifier = pin!(self.header_added_notifier.notified());
390
391        loop {
392            if self.contains_height(height).await {
393                return Ok(());
394            }
395
396            // Await for a notification
397            notifier.as_mut().await;
398
399            // Reset notifier
400            notifier.set(self.header_added_notifier.notified());
401        }
402    }
403
404    async fn head_height(&self) -> Result<u64> {
405        self.get_head_height().await
406    }
407
408    async fn has(&self, hash: &Hash) -> bool {
409        self.contains_hash(hash).await
410    }
411
412    async fn has_at(&self, height: u64) -> bool {
413        self.contains_height(height).await
414    }
415
416    async fn insert<R>(&self, header: R) -> Result<()>
417    where
418        R: TryInto<VerifiedExtendedHeaders> + Send,
419        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
420    {
421        self.insert(header).await
422    }
423
424    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
425        self.update_sampling_metadata(height, cids).await
426    }
427
428    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
429        self.mark_as_sampled(height).await
430    }
431
432    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
433        self.get_sampling_metadata(height).await
434    }
435
436    async fn get_stored_header_ranges(&self) -> Result<BlockRanges> {
437        Ok(self.get_stored_ranges().await)
438    }
439
440    async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
441        Ok(self.get_sampled_ranges().await)
442    }
443
444    async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
445        Ok(self.get_pruned_ranges().await)
446    }
447
448    async fn remove_height(&self, height: u64) -> Result<()> {
449        self.remove_height(height).await
450    }
451
452    async fn close(self) -> Result<()> {
453        Ok(())
454    }
455
456    async fn get_identity(&self) -> Result<Keypair> {
457        self.get_identity().await
458    }
459}
460
461impl Default for InMemoryStore {
462    fn default() -> Self {
463        Self::new()
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use lumina_utils::test_utils::async_test as test;
471
472    #[test]
473    async fn identity_regen_on_clone() {
474        let store = InMemoryStore::new();
475        let id0 = store.get_identity().await.unwrap().public();
476        let store_clone = store.async_clone().await;
477        let id1 = store.get_identity().await.unwrap().public();
478        let clone_id = store_clone.get_identity().await.unwrap().public();
479
480        assert_eq!(id0, id1);
481        assert_ne!(id0, clone_id)
482    }
483}