lumina_node/store/
in_memory_store.rs

1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::pin::pin;
5
6use async_trait::async_trait;
7use celestia_types::hash::Hash;
8use celestia_types::ExtendedHeader;
9use cid::Cid;
10use tokio::sync::{Notify, RwLock};
11use tracing::debug;
12
13use crate::block_ranges::BlockRanges;
14use crate::store::utils::VerifiedExtendedHeaders;
15use crate::store::{Result, SamplingMetadata, Store, StoreError, StoreInsertionError};
16
17/// A non-persistent in memory [`Store`] implementation.
18#[derive(Debug)]
19pub struct InMemoryStore {
20    /// Mutable part
21    inner: RwLock<InMemoryStoreInner>,
22    /// Notify when a new header is added
23    header_added_notifier: Notify,
24}
25
26#[derive(Debug, Clone)]
27struct InMemoryStoreInner {
28    /// Maps header Hash to the header itself, responsible for actually storing the header data
29    headers: HashMap<Hash, ExtendedHeader>,
30    /// Maps header height to its hash, in case we need to do lookup by height
31    height_to_hash: HashMap<u64, Hash>,
32    /// Source of truth about headers present in the db, used to synchronise inserts
33    header_ranges: BlockRanges,
34    /// Maps header height to the header sampling metadata
35    sampling_data: HashMap<u64, SamplingMetadata>,
36    /// Source of truth about sampled ranges present in the db.
37    sampled_ranges: BlockRanges,
38    /// Source of truth about ranges that were pruned from db.
39    pruned_ranges: BlockRanges,
40}
41
42impl InMemoryStoreInner {
43    fn new() -> Self {
44        Self {
45            headers: HashMap::new(),
46            height_to_hash: HashMap::new(),
47            header_ranges: BlockRanges::default(),
48            sampling_data: HashMap::new(),
49            sampled_ranges: BlockRanges::default(),
50            pruned_ranges: BlockRanges::default(),
51        }
52    }
53}
54
55impl InMemoryStore {
56    /// Create a new store.
57    pub fn new() -> Self {
58        InMemoryStore {
59            inner: RwLock::new(InMemoryStoreInner::new()),
60            header_added_notifier: Notify::new(),
61        }
62    }
63
64    #[inline]
65    async fn get_head_height(&self) -> Result<u64> {
66        self.inner.read().await.get_head_height()
67    }
68
69    async fn get_head(&self) -> Result<ExtendedHeader> {
70        let head_height = self.get_head_height().await?;
71        self.get_by_height(head_height).await
72    }
73
74    async fn contains_hash(&self, hash: &Hash) -> bool {
75        self.inner.read().await.contains_hash(hash)
76    }
77
78    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
79        self.inner.read().await.get_by_hash(hash)
80    }
81
82    async fn contains_height(&self, height: u64) -> bool {
83        self.inner.read().await.contains_height(height)
84    }
85
86    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
87        self.inner.read().await.get_by_height(height)
88    }
89
90    pub(crate) async fn insert<R>(&self, headers: R) -> Result<()>
91    where
92        R: TryInto<VerifiedExtendedHeaders> + Send,
93        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
94    {
95        let headers = headers
96            .try_into()
97            .map_err(|e| StoreInsertionError::HeadersVerificationFailed(e.to_string()))?;
98
99        self.inner.write().await.insert(headers).await?;
100        self.header_added_notifier.notify_waiters();
101
102        Ok(())
103    }
104
105    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
106        self.inner
107            .write()
108            .await
109            .update_sampling_metadata(height, cids)
110            .await
111    }
112
113    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
114        self.inner.read().await.get_sampling_metadata(height).await
115    }
116
117    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
118        self.inner.write().await.mark_as_sampled(height).await
119    }
120
121    async fn get_stored_ranges(&self) -> BlockRanges {
122        self.inner.read().await.header_ranges.clone()
123    }
124
125    async fn get_sampled_ranges(&self) -> BlockRanges {
126        self.inner.read().await.sampled_ranges.clone()
127    }
128
129    async fn get_pruned_ranges(&self) -> BlockRanges {
130        self.inner.read().await.pruned_ranges.clone()
131    }
132
133    /// Clone the store and all its contents. Async fn due to internal use of async mutex.
134    pub async fn async_clone(&self) -> Self {
135        InMemoryStore {
136            inner: RwLock::new(self.inner.read().await.clone()),
137            header_added_notifier: Notify::new(),
138        }
139    }
140
141    async fn remove_height(&self, height: u64) -> Result<()> {
142        let mut inner = self.inner.write().await;
143        inner.remove_height(height)
144    }
145}
146
147impl InMemoryStoreInner {
148    #[inline]
149    fn get_head_height(&self) -> Result<u64> {
150        self.header_ranges.head().ok_or(StoreError::NotFound)
151    }
152
153    fn contains_hash(&self, hash: &Hash) -> bool {
154        self.headers.contains_key(hash)
155    }
156
157    fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
158        self.headers.get(hash).cloned().ok_or(StoreError::NotFound)
159    }
160
161    fn contains_height(&self, height: u64) -> bool {
162        self.header_ranges.contains(height)
163    }
164
165    fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
166        let Some(hash) = self.height_to_hash.get(&height).copied() else {
167            return Err(StoreError::NotFound);
168        };
169
170        Ok(self
171            .headers
172            .get(&hash)
173            .expect("inconsistent between header hash and header heights")
174            .to_owned())
175    }
176
177    async fn insert(&mut self, headers: VerifiedExtendedHeaders) -> Result<()> {
178        let (Some(head), Some(tail)) = (headers.as_ref().first(), headers.as_ref().last()) else {
179            return Ok(());
180        };
181
182        let headers_range = head.height().value()..=tail.height().value();
183        let (prev_exists, next_exists) = self
184            .header_ranges
185            .check_insertion_constraints(&headers_range)
186            .map_err(StoreInsertionError::ContraintsNotMet)?;
187
188        // header range is already internally verified against itself in `P2p::get_unverified_header_ranges`
189        self.verify_against_neighbours(prev_exists.then_some(head), next_exists.then_some(tail))?;
190
191        for header in headers.into_iter() {
192            let hash = header.hash();
193            let height = header.height().value();
194
195            debug_assert!(
196                !self.height_to_hash.contains_key(&height),
197                "inconsistency between headers table and ranges table"
198            );
199
200            let Entry::Vacant(headers_entry) = self.headers.entry(hash) else {
201                // TODO: Remove this when we implement type-safe validation on insertion.
202                return Err(StoreInsertionError::HashExists(hash).into());
203            };
204
205            debug!("Inserting header {hash} with height {height}");
206            headers_entry.insert(header);
207            self.height_to_hash.insert(height, hash);
208        }
209
210        self.header_ranges
211            .insert_relaxed(&headers_range)
212            .expect("invalid range");
213        self.sampled_ranges
214            .remove_relaxed(&headers_range)
215            .expect("invalid range");
216        self.pruned_ranges
217            .remove_relaxed(&headers_range)
218            .expect("invalid range");
219
220        Ok(())
221    }
222
223    fn verify_against_neighbours(
224        &self,
225        lowest_header: Option<&ExtendedHeader>,
226        highest_header: Option<&ExtendedHeader>,
227    ) -> Result<()> {
228        if let Some(lowest_header) = lowest_header {
229            let prev = self
230                .get_by_height(lowest_header.height().value() - 1)
231                .map_err(|e| match e {
232                    StoreError::NotFound => {
233                        panic!("inconsistency between headers and ranges table")
234                    }
235                    e => e,
236                })?;
237
238            prev.verify(lowest_header)
239                .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
240        }
241
242        if let Some(highest_header) = highest_header {
243            let next = self
244                .get_by_height(highest_header.height().value() + 1)
245                .map_err(|e| match e {
246                    StoreError::NotFound => {
247                        panic!("inconsistency between headers and ranges table")
248                    }
249                    e => e,
250                })?;
251
252            highest_header
253                .verify(&next)
254                .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
255        }
256
257        Ok(())
258    }
259
260    async fn update_sampling_metadata(&mut self, height: u64, cids: Vec<Cid>) -> Result<()> {
261        if !self.contains_height(height) {
262            return Err(StoreError::NotFound);
263        }
264
265        match self.sampling_data.entry(height) {
266            Entry::Vacant(entry) => {
267                entry.insert(SamplingMetadata { cids });
268            }
269            Entry::Occupied(mut entry) => {
270                let metadata = entry.get_mut();
271
272                for cid in cids {
273                    if !metadata.cids.contains(&cid) {
274                        metadata.cids.push(cid);
275                    }
276                }
277            }
278        }
279
280        Ok(())
281    }
282
283    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
284        if !self.contains_height(height) {
285            return Err(StoreError::NotFound);
286        }
287
288        let Some(metadata) = self.sampling_data.get(&height) else {
289            return Ok(None);
290        };
291
292        Ok(Some(metadata.clone()))
293    }
294
295    async fn mark_as_sampled(&mut self, height: u64) -> Result<()> {
296        if !self.contains_height(height) {
297            return Err(StoreError::NotFound);
298        }
299
300        self.sampled_ranges
301            .insert_relaxed(height..=height)
302            .expect("invalid height");
303
304        Ok(())
305    }
306
307    fn remove_height(&mut self, height: u64) -> Result<()> {
308        if !self.header_ranges.contains(height) {
309            return Err(StoreError::NotFound);
310        }
311
312        let Entry::Occupied(height_to_hash) = self.height_to_hash.entry(height) else {
313            return Err(StoreError::StoredDataError(format!(
314                "inconsistency between ranges and height_to_hash tables, height {height}"
315            )));
316        };
317
318        let hash = height_to_hash.get();
319        let Entry::Occupied(header) = self.headers.entry(*hash) else {
320            return Err(StoreError::StoredDataError(format!(
321                "inconsistency between header and height_to_hash tables, hash {hash}"
322            )));
323        };
324
325        // sampling data may or may not be there
326        self.sampling_data.remove(&height);
327
328        height_to_hash.remove_entry();
329        header.remove_entry();
330
331        self.header_ranges
332            .remove_relaxed(height..=height)
333            .expect("invalid height");
334        self.sampled_ranges
335            .remove_relaxed(height..=height)
336            .expect("invalid height");
337        self.pruned_ranges
338            .insert_relaxed(height..=height)
339            .expect("invalid height");
340
341        Ok(())
342    }
343}
344
345#[async_trait]
346impl Store for InMemoryStore {
347    async fn get_head(&self) -> Result<ExtendedHeader> {
348        self.get_head().await
349    }
350
351    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
352        self.get_by_hash(hash).await
353    }
354
355    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
356        self.get_by_height(height).await
357    }
358
359    async fn wait_new_head(&self) -> u64 {
360        let head = self.get_head_height().await.unwrap_or(0);
361        let mut notifier = pin!(self.header_added_notifier.notified());
362
363        loop {
364            let new_head = self.get_head_height().await.unwrap_or(0);
365
366            if head != new_head {
367                return new_head;
368            }
369
370            // Await for a notification
371            notifier.as_mut().await;
372
373            // Reset notifier
374            notifier.set(self.header_added_notifier.notified());
375        }
376    }
377
378    async fn wait_height(&self, height: u64) -> Result<()> {
379        let mut notifier = pin!(self.header_added_notifier.notified());
380
381        loop {
382            if self.contains_height(height).await {
383                return Ok(());
384            }
385
386            // Await for a notification
387            notifier.as_mut().await;
388
389            // Reset notifier
390            notifier.set(self.header_added_notifier.notified());
391        }
392    }
393
394    async fn head_height(&self) -> Result<u64> {
395        self.get_head_height().await
396    }
397
398    async fn has(&self, hash: &Hash) -> bool {
399        self.contains_hash(hash).await
400    }
401
402    async fn has_at(&self, height: u64) -> bool {
403        self.contains_height(height).await
404    }
405
406    async fn insert<R>(&self, header: R) -> Result<()>
407    where
408        R: TryInto<VerifiedExtendedHeaders> + Send,
409        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
410    {
411        self.insert(header).await
412    }
413
414    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
415        self.update_sampling_metadata(height, cids).await
416    }
417
418    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
419        self.mark_as_sampled(height).await
420    }
421
422    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
423        self.get_sampling_metadata(height).await
424    }
425
426    async fn get_stored_header_ranges(&self) -> Result<BlockRanges> {
427        Ok(self.get_stored_ranges().await)
428    }
429
430    async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
431        Ok(self.get_sampled_ranges().await)
432    }
433
434    async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
435        Ok(self.get_pruned_ranges().await)
436    }
437
438    async fn remove_height(&self, height: u64) -> Result<()> {
439        self.remove_height(height).await
440    }
441
442    async fn close(self) -> Result<()> {
443        Ok(())
444    }
445}
446
447impl Default for InMemoryStore {
448    fn default() -> Self {
449        Self::new()
450    }
451}