lumina_node/store/
in_memory_store.rs

1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::pin::pin;
5
6use async_trait::async_trait;
7use celestia_types::hash::Hash;
8use celestia_types::ExtendedHeader;
9use cid::Cid;
10use tokio::sync::{Notify, RwLock};
11use tracing::debug;
12
13use crate::block_ranges::BlockRanges;
14use crate::store::utils::VerifiedExtendedHeaders;
15use crate::store::{
16    Result, SamplingMetadata, SamplingStatus, Store, StoreError, StoreInsertionError,
17};
18
19/// A non-persistent in memory [`Store`] implementation.
20#[derive(Debug)]
21pub struct InMemoryStore {
22    /// Mutable part
23    inner: RwLock<InMemoryStoreInner>,
24    /// Notify when a new header is added
25    header_added_notifier: Notify,
26}
27
28#[derive(Debug, Clone)]
29struct InMemoryStoreInner {
30    /// Maps header Hash to the header itself, responsible for actually storing the header data
31    headers: HashMap<Hash, ExtendedHeader>,
32    /// Maps header height to its hash, in case we need to do lookup by height
33    height_to_hash: HashMap<u64, Hash>,
34    /// Source of truth about headers present in the db, used to synchronise inserts
35    header_ranges: BlockRanges,
36    /// Maps header height to the header sampling metadata
37    sampling_data: HashMap<u64, SamplingMetadata>,
38    /// Source of truth about accepted sampling ranges present in the db.
39    accepted_sampling_ranges: BlockRanges,
40}
41
42impl InMemoryStoreInner {
43    fn new() -> Self {
44        Self {
45            headers: HashMap::new(),
46            height_to_hash: HashMap::new(),
47            header_ranges: BlockRanges::default(),
48            sampling_data: HashMap::new(),
49            accepted_sampling_ranges: BlockRanges::default(),
50        }
51    }
52}
53
54impl InMemoryStore {
55    /// Create a new store.
56    pub fn new() -> Self {
57        InMemoryStore {
58            inner: RwLock::new(InMemoryStoreInner::new()),
59            header_added_notifier: Notify::new(),
60        }
61    }
62
63    #[inline]
64    async fn get_head_height(&self) -> Result<u64> {
65        self.inner.read().await.get_head_height()
66    }
67
68    async fn get_head(&self) -> Result<ExtendedHeader> {
69        let head_height = self.get_head_height().await?;
70        self.get_by_height(head_height).await
71    }
72
73    async fn contains_hash(&self, hash: &Hash) -> bool {
74        self.inner.read().await.contains_hash(hash)
75    }
76
77    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
78        self.inner.read().await.get_by_hash(hash)
79    }
80
81    async fn contains_height(&self, height: u64) -> bool {
82        self.inner.read().await.contains_height(height)
83    }
84
85    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
86        self.inner.read().await.get_by_height(height)
87    }
88
89    pub(crate) async fn insert<R>(&self, headers: R) -> Result<()>
90    where
91        R: TryInto<VerifiedExtendedHeaders> + Send,
92        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
93    {
94        let headers = headers
95            .try_into()
96            .map_err(|e| StoreInsertionError::HeadersVerificationFailed(e.to_string()))?;
97
98        self.inner.write().await.insert(headers).await?;
99        self.header_added_notifier.notify_waiters();
100
101        Ok(())
102    }
103
104    async fn update_sampling_metadata(
105        &self,
106        height: u64,
107        status: SamplingStatus,
108        cids: Vec<Cid>,
109    ) -> Result<()> {
110        self.inner
111            .write()
112            .await
113            .update_sampling_metadata(height, status, cids)
114            .await
115    }
116
117    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
118        self.inner.read().await.get_sampling_metadata(height).await
119    }
120
121    async fn get_stored_ranges(&self) -> BlockRanges {
122        self.inner.read().await.get_stored_ranges()
123    }
124
125    async fn get_accepted_sampling_ranges(&self) -> BlockRanges {
126        self.inner.read().await.get_accepted_sampling_ranges()
127    }
128
129    /// Clone the store and all its contents. Async fn due to internal use of async mutex.
130    pub async fn async_clone(&self) -> Self {
131        InMemoryStore {
132            inner: RwLock::new(self.inner.read().await.clone()),
133            header_added_notifier: Notify::new(),
134        }
135    }
136
137    async fn remove_height(&self, height: u64) -> Result<()> {
138        let mut inner = self.inner.write().await;
139        inner.remove_height(height)
140    }
141}
142
143impl InMemoryStoreInner {
144    fn get_stored_ranges(&self) -> BlockRanges {
145        self.header_ranges.clone()
146    }
147
148    fn get_accepted_sampling_ranges(&self) -> BlockRanges {
149        self.accepted_sampling_ranges.clone()
150    }
151
152    #[inline]
153    fn get_head_height(&self) -> Result<u64> {
154        self.header_ranges.head().ok_or(StoreError::NotFound)
155    }
156
157    fn contains_hash(&self, hash: &Hash) -> bool {
158        self.headers.contains_key(hash)
159    }
160
161    fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
162        self.headers.get(hash).cloned().ok_or(StoreError::NotFound)
163    }
164
165    fn contains_height(&self, height: u64) -> bool {
166        self.header_ranges.contains(height)
167    }
168
169    fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
170        let Some(hash) = self.height_to_hash.get(&height).copied() else {
171            return Err(StoreError::NotFound);
172        };
173
174        Ok(self
175            .headers
176            .get(&hash)
177            .expect("inconsistent between header hash and header heights")
178            .to_owned())
179    }
180
181    async fn insert(&mut self, headers: VerifiedExtendedHeaders) -> Result<()> {
182        let (Some(head), Some(tail)) = (headers.as_ref().first(), headers.as_ref().last()) else {
183            return Ok(());
184        };
185
186        let headers_range = head.height().value()..=tail.height().value();
187        let (prev_exists, next_exists) = self
188            .header_ranges
189            .check_insertion_constraints(&headers_range)
190            .map_err(StoreInsertionError::ContraintsNotMet)?;
191
192        // header range is already internally verified against itself in `P2p::get_unverified_header_ranges`
193        self.verify_against_neighbours(prev_exists.then_some(head), next_exists.then_some(tail))?;
194
195        for header in headers.into_iter() {
196            let hash = header.hash();
197            let height = header.height().value();
198
199            debug_assert!(
200                !self.height_to_hash.contains_key(&height),
201                "inconsistency between headers table and ranges table"
202            );
203
204            let Entry::Vacant(headers_entry) = self.headers.entry(hash) else {
205                // TODO: Remove this when we implement type-safe validation on insertion.
206                return Err(StoreInsertionError::HashExists(hash).into());
207            };
208
209            debug!("Inserting header {hash} with height {height}");
210            headers_entry.insert(header);
211            self.height_to_hash.insert(height, hash);
212        }
213
214        self.header_ranges
215            .insert_relaxed(headers_range)
216            .expect("invalid range");
217
218        Ok(())
219    }
220
221    fn verify_against_neighbours(
222        &self,
223        lowest_header: Option<&ExtendedHeader>,
224        highest_header: Option<&ExtendedHeader>,
225    ) -> Result<()> {
226        if let Some(lowest_header) = lowest_header {
227            let prev = self
228                .get_by_height(lowest_header.height().value() - 1)
229                .map_err(|e| match e {
230                    StoreError::NotFound => {
231                        panic!("inconsistency between headers and ranges table")
232                    }
233                    e => e,
234                })?;
235
236            prev.verify(lowest_header)
237                .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
238        }
239
240        if let Some(highest_header) = highest_header {
241            let next = self
242                .get_by_height(highest_header.height().value() + 1)
243                .map_err(|e| match e {
244                    StoreError::NotFound => {
245                        panic!("inconsistency between headers and ranges table")
246                    }
247                    e => e,
248                })?;
249
250            highest_header
251                .verify(&next)
252                .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
253        }
254
255        Ok(())
256    }
257
258    async fn update_sampling_metadata(
259        &mut self,
260        height: u64,
261        status: SamplingStatus,
262        cids: Vec<Cid>,
263    ) -> Result<()> {
264        if !self.contains_height(height) {
265            return Err(StoreError::NotFound);
266        }
267
268        match self.sampling_data.entry(height) {
269            Entry::Vacant(entry) => {
270                entry.insert(SamplingMetadata { status, cids });
271            }
272            Entry::Occupied(mut entry) => {
273                let metadata = entry.get_mut();
274                metadata.status = status;
275
276                for cid in cids {
277                    if !metadata.cids.contains(&cid) {
278                        metadata.cids.push(cid);
279                    }
280                }
281            }
282        }
283
284        match status {
285            SamplingStatus::Accepted => self
286                .accepted_sampling_ranges
287                .insert_relaxed(height..=height)
288                .expect("invalid height"),
289            _ => self
290                .accepted_sampling_ranges
291                .remove_relaxed(height..=height)
292                .expect("invalid height"),
293        }
294
295        Ok(())
296    }
297
298    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
299        if !self.contains_height(height) {
300            return Err(StoreError::NotFound);
301        }
302
303        let Some(metadata) = self.sampling_data.get(&height) else {
304            return Ok(None);
305        };
306
307        Ok(Some(metadata.clone()))
308    }
309
310    fn remove_height(&mut self, height: u64) -> Result<()> {
311        if !self.header_ranges.contains(height) {
312            return Err(StoreError::NotFound);
313        }
314
315        let Entry::Occupied(height_to_hash) = self.height_to_hash.entry(height) else {
316            return Err(StoreError::StoredDataError(format!(
317                "inconsistency between ranges and height_to_hash tables, height {height}"
318            )));
319        };
320
321        let hash = height_to_hash.get();
322        let Entry::Occupied(header) = self.headers.entry(*hash) else {
323            return Err(StoreError::StoredDataError(format!(
324                "inconsistency between header and height_to_hash tables, hash {hash}"
325            )));
326        };
327
328        // sampling data may or may not be there
329        self.sampling_data.remove(&height);
330
331        height_to_hash.remove_entry();
332        header.remove_entry();
333
334        self.header_ranges
335            .remove_relaxed(height..=height)
336            .expect("valid range never fails");
337
338        Ok(())
339    }
340}
341
342#[async_trait]
343impl Store for InMemoryStore {
344    async fn get_head(&self) -> Result<ExtendedHeader> {
345        self.get_head().await
346    }
347
348    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
349        self.get_by_hash(hash).await
350    }
351
352    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
353        self.get_by_height(height).await
354    }
355
356    async fn wait_new_head(&self) -> u64 {
357        let head = self.get_head_height().await.unwrap_or(0);
358        let mut notifier = pin!(self.header_added_notifier.notified());
359
360        loop {
361            let new_head = self.get_head_height().await.unwrap_or(0);
362
363            if head != new_head {
364                return new_head;
365            }
366
367            // Await for a notification
368            notifier.as_mut().await;
369
370            // Reset notifier
371            notifier.set(self.header_added_notifier.notified());
372        }
373    }
374
375    async fn wait_height(&self, height: u64) -> Result<()> {
376        let mut notifier = pin!(self.header_added_notifier.notified());
377
378        loop {
379            if self.contains_height(height).await {
380                return Ok(());
381            }
382
383            // Await for a notification
384            notifier.as_mut().await;
385
386            // Reset notifier
387            notifier.set(self.header_added_notifier.notified());
388        }
389    }
390
391    async fn head_height(&self) -> Result<u64> {
392        self.get_head_height().await
393    }
394
395    async fn has(&self, hash: &Hash) -> bool {
396        self.contains_hash(hash).await
397    }
398
399    async fn has_at(&self, height: u64) -> bool {
400        self.contains_height(height).await
401    }
402
403    async fn insert<R>(&self, header: R) -> Result<()>
404    where
405        R: TryInto<VerifiedExtendedHeaders> + Send,
406        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
407    {
408        self.insert(header).await
409    }
410
411    async fn update_sampling_metadata(
412        &self,
413        height: u64,
414        status: SamplingStatus,
415        cids: Vec<Cid>,
416    ) -> Result<()> {
417        self.update_sampling_metadata(height, status, cids).await
418    }
419
420    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
421        self.get_sampling_metadata(height).await
422    }
423
424    async fn get_stored_header_ranges(&self) -> Result<BlockRanges> {
425        Ok(self.get_stored_ranges().await)
426    }
427
428    async fn get_accepted_sampling_ranges(&self) -> Result<BlockRanges> {
429        Ok(self.get_accepted_sampling_ranges().await)
430    }
431
432    async fn remove_height(&self, height: u64) -> Result<()> {
433        self.remove_height(height).await
434    }
435
436    async fn close(self) -> Result<()> {
437        Ok(())
438    }
439}
440
441impl Default for InMemoryStore {
442    fn default() -> Self {
443        Self::new()
444    }
445}