lumina_node/store/
redb_store.rs

1use std::fmt::Display;
2use std::ops::RangeInclusive;
3use std::path::Path;
4use std::pin::pin;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use celestia_types::ExtendedHeader;
9use celestia_types::hash::Hash;
10use cid::Cid;
11use libp2p::identity::Keypair;
12use redb::{
13    CommitError, Database, ReadTransaction, ReadableTable, ReadableTableMetadata, StorageError,
14    Table, TableDefinition, TableError, TransactionError, WriteTransaction,
15};
16use tendermint_proto::Protobuf;
17use tokio::sync::Notify;
18use tokio::task::spawn_blocking;
19use tracing::{debug, trace, warn};
20
21use crate::block_ranges::BlockRanges;
22use crate::store::utils::VerifiedExtendedHeaders;
23use crate::store::{Result, SamplingMetadata, Store, StoreError, StoreInsertionError};
24use crate::utils::Counter;
25
26use super::utils::{deserialize_extended_header, deserialize_sampling_metadata};
27
28const SCHEMA_VERSION: u64 = 3;
29
30const HEIGHTS_TABLE: TableDefinition<'static, &[u8], u64> = TableDefinition::new("STORE.HEIGHTS");
31const HEADERS_TABLE: TableDefinition<'static, u64, &[u8]> = TableDefinition::new("STORE.HEADERS");
32const SAMPLING_METADATA_TABLE: TableDefinition<'static, u64, &[u8]> =
33    TableDefinition::new("STORE.SAMPLING_METADATA");
34const SCHEMA_VERSION_TABLE: TableDefinition<'static, (), u64> =
35    TableDefinition::new("STORE.SCHEMA_VERSION");
36const RANGES_TABLE: TableDefinition<'static, &str, Vec<(u64, u64)>> =
37    TableDefinition::new("STORE.RANGES");
38const LIBP2P_IDENTITY_TABLE: TableDefinition<'static, (), &[u8]> =
39    TableDefinition::new("LIBP2P.IDENTITY");
40
41const SAMPLED_RANGES_KEY: &str = "KEY.SAMPLED_RANGES";
42const HEADER_RANGES_KEY: &str = "KEY.HEADER_RANGES";
43const PRUNED_RANGES_KEY: &str = "KEY.PRUNED_RANGES";
44
45/// A [`Store`] implementation based on a [`redb`] database.
46#[derive(Debug)]
47pub struct RedbStore {
48    inner: Arc<Inner>,
49    task_counter: Counter,
50}
51
52#[derive(Debug)]
53struct Inner {
54    /// Reference to the entire redb database
55    db: Arc<Database>,
56    /// Notify when a new header is added
57    header_added_notifier: Notify,
58}
59
60impl RedbStore {
61    /// Open a persistent [`redb`] store.
62    pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
63        let path = path.as_ref().to_owned();
64
65        let db = spawn_blocking(|| Database::create(path))
66            .await?
67            .map_err(|e| StoreError::OpenFailed(e.to_string()))?;
68
69        RedbStore::new(Arc::new(db)).await
70    }
71
72    /// Open an in memory [`redb`] store.
73    pub async fn in_memory() -> Result<Self> {
74        let db = Database::builder()
75            .create_with_backend(redb::backends::InMemoryBackend::new())
76            .map_err(|e| StoreError::OpenFailed(e.to_string()))?;
77
78        RedbStore::new(Arc::new(db)).await
79    }
80
81    /// Create new `RedbStore` with an already opened [`redb::Database`].
82    pub async fn new(db: Arc<Database>) -> Result<Self> {
83        let store = RedbStore {
84            inner: Arc::new(Inner {
85                db,
86                header_added_notifier: Notify::new(),
87            }),
88            task_counter: Counter::new(),
89        };
90
91        store
92            .write_tx(|tx| {
93                let mut schema_version_table = tx.open_table(SCHEMA_VERSION_TABLE)?;
94                let schema_version = schema_version_table.get(())?.map(|guard| guard.value());
95
96                match schema_version {
97                    Some(schema_version) => {
98                        if schema_version > SCHEMA_VERSION {
99                            let e = format!(
100                                "Incompatible database schema; found {schema_version}, expected {SCHEMA_VERSION}."
101                            );
102                            return Err(StoreError::OpenFailed(e));
103                        }
104
105                        // Do migrations
106                        migrate_v1_to_v2(tx, &mut schema_version_table)?;
107                        migrate_v2_to_v3(tx, &mut schema_version_table)?;
108                    }
109                    None => {
110                        // New database
111                        schema_version_table.insert((), SCHEMA_VERSION)?;
112                    }
113                }
114
115                // Force us to write migrations!
116                debug_assert_eq!(
117                    schema_version_table.get(())?.map(|guard| guard.value()),
118                    Some(SCHEMA_VERSION),
119                    "Some migrations are missing"
120                );
121
122                // create tables, so that reads later don't complain
123                let _heights_table = tx.open_table(HEIGHTS_TABLE)?;
124                let _headers_table = tx.open_table(HEADERS_TABLE)?;
125                let _ranges_table = tx.open_table(RANGES_TABLE)?;
126                let _sampling_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
127                let mut identity_table = tx.open_table(LIBP2P_IDENTITY_TABLE)?;
128
129                if identity_table.is_empty()? {
130                    let keypair = Keypair::generate_ed25519();
131
132                    let peer_id = keypair.public().to_peer_id();
133                    let keypair_bytes = keypair.to_protobuf_encoding()?;
134                    debug!("Initialised new identity: {peer_id}");
135                    identity_table.insert((), &*keypair_bytes)?;
136                }
137
138                Ok(())
139            })
140            .await
141            .map_err(|e| match e {
142                e @ StoreError::OpenFailed(_) => e,
143                e => StoreError::OpenFailed(e.to_string()),
144            })?;
145
146        Ok(store)
147    }
148
149    /// Returns the raw [`redb::Database`].
150    ///
151    /// This is useful if you want to pass the database handle to any other
152    /// stores (e.g. [`blockstore`]).
153    pub fn raw_db(&self) -> Arc<Database> {
154        self.inner.db.clone()
155    }
156
157    /// Execute a read transaction.
158    async fn read_tx<F, T>(&self, f: F) -> Result<T>
159    where
160        F: FnOnce(&mut ReadTransaction) -> Result<T> + Send + 'static,
161        T: Send + 'static,
162    {
163        let inner = self.inner.clone();
164        let guard = self.task_counter.guard();
165
166        spawn_blocking(move || {
167            let _guard = guard;
168
169            {
170                let mut tx = inner.db.begin_read()?;
171                f(&mut tx)
172            }
173        })
174        .await?
175    }
176
177    /// Execute a write transaction.
178    ///
179    /// If closure returns an error the transaction is aborted, otherwise commited.
180    async fn write_tx<F, T>(&self, f: F) -> Result<T>
181    where
182        F: FnOnce(&mut WriteTransaction) -> Result<T> + Send + 'static,
183        T: Send + 'static,
184    {
185        let inner = self.inner.clone();
186        let guard = self.task_counter.guard();
187
188        spawn_blocking(move || {
189            let _guard = guard;
190
191            {
192                let mut tx = inner.db.begin_write()?;
193                let res = f(&mut tx);
194
195                if res.is_ok() {
196                    tx.commit()?;
197                } else {
198                    tx.abort()?;
199                }
200
201                res
202            }
203        })
204        .await?
205    }
206
207    async fn head_height(&self) -> Result<u64> {
208        self.read_tx(|tx| {
209            let table = tx.open_table(RANGES_TABLE)?;
210            let header_ranges = get_ranges(&table, HEADER_RANGES_KEY)?;
211
212            header_ranges.head().ok_or(StoreError::NotFound)
213        })
214        .await
215    }
216
217    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
218        let hash = *hash;
219
220        self.read_tx(move |tx| {
221            let heights_table = tx.open_table(HEIGHTS_TABLE)?;
222            let headers_table = tx.open_table(HEADERS_TABLE)?;
223
224            let height = get_height(&heights_table, hash.as_bytes())?;
225            get_header(&headers_table, height)
226        })
227        .await
228    }
229
230    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
231        self.read_tx(move |tx| {
232            let table = tx.open_table(HEADERS_TABLE)?;
233            get_header(&table, height)
234        })
235        .await
236    }
237
238    async fn get_head(&self) -> Result<ExtendedHeader> {
239        self.read_tx(|tx| {
240            let ranges_table = tx.open_table(RANGES_TABLE)?;
241            let headers_table = tx.open_table(HEADERS_TABLE)?;
242
243            let header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
244            let head = header_ranges.head().ok_or(StoreError::NotFound)?;
245
246            get_header(&headers_table, head)
247        })
248        .await
249    }
250
251    async fn contains_hash(&self, hash: &Hash) -> bool {
252        let hash = *hash;
253
254        self.read_tx(move |tx| {
255            let heights_table = tx.open_table(HEIGHTS_TABLE)?;
256            let headers_table = tx.open_table(HEADERS_TABLE)?;
257
258            let height = get_height(&heights_table, hash.as_bytes())?;
259            Ok(headers_table.get(height)?.is_some())
260        })
261        .await
262        .unwrap_or(false)
263    }
264
265    async fn contains_height(&self, height: u64) -> bool {
266        self.read_tx(move |tx| {
267            let headers_table = tx.open_table(HEADERS_TABLE)?;
268            Ok(headers_table.get(height)?.is_some())
269        })
270        .await
271        .unwrap_or(false)
272    }
273
274    async fn insert<R>(&self, headers: R) -> Result<()>
275    where
276        R: TryInto<VerifiedExtendedHeaders> + Send,
277        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
278    {
279        let headers = headers
280            .try_into()
281            .map_err(|e| StoreInsertionError::HeadersVerificationFailed(e.to_string()))?;
282
283        self.write_tx(move |tx| {
284            let (Some(head), Some(tail)) = (headers.as_ref().first(), headers.as_ref().last())
285            else {
286                return Ok(());
287            };
288
289            let mut heights_table = tx.open_table(HEIGHTS_TABLE)?;
290            let mut headers_table = tx.open_table(HEADERS_TABLE)?;
291            let mut ranges_table = tx.open_table(RANGES_TABLE)?;
292
293            let mut header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
294            let mut sampled_ranges = get_ranges(&ranges_table, SAMPLED_RANGES_KEY)?;
295            let mut pruned_ranges = get_ranges(&ranges_table, PRUNED_RANGES_KEY)?;
296
297            let headers_range = head.height().value()..=tail.height().value();
298
299            let (prev_exists, next_exists) = header_ranges
300                .check_insertion_constraints(&headers_range)
301                .map_err(StoreInsertionError::ContraintsNotMet)?;
302
303            verify_against_neighbours(
304                &headers_table,
305                prev_exists.then_some(head),
306                next_exists.then_some(tail),
307            )?;
308
309            for header in headers {
310                let height = header.height().value();
311                let hash = header.hash();
312                let serialized_header = header.encode_vec();
313
314                if headers_table
315                    .insert(height, &serialized_header[..])?
316                    .is_some()
317                {
318                    return Err(StoreError::StoredDataError(
319                        "inconsistency between headers table and ranges table".into(),
320                    ));
321                }
322
323                if heights_table.insert(hash.as_bytes(), height)?.is_some() {
324                    // TODO: Replace this with `StoredDataError` when we implement
325                    // type-safe validation on insertion.
326                    return Err(StoreInsertionError::HashExists(hash).into());
327                }
328
329                trace!("Inserted header {hash} with height {height}");
330            }
331
332            header_ranges
333                .insert_relaxed(&headers_range)
334                .expect("invalid range");
335            sampled_ranges
336                .remove_relaxed(&headers_range)
337                .expect("invalid range");
338            pruned_ranges
339                .remove_relaxed(&headers_range)
340                .expect("invalid range");
341
342            set_ranges(&mut ranges_table, HEADER_RANGES_KEY, &header_ranges)?;
343            set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
344            set_ranges(&mut ranges_table, PRUNED_RANGES_KEY, &pruned_ranges)?;
345
346            debug!("Inserted header range {headers_range:?}",);
347
348            Ok(())
349        })
350        .await?;
351
352        self.inner.header_added_notifier.notify_waiters();
353
354        Ok(())
355    }
356
357    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
358        self.write_tx(move |tx| {
359            let mut sampling_metadata_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
360            let ranges_table = tx.open_table(RANGES_TABLE)?;
361
362            let header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
363
364            if !header_ranges.contains(height) {
365                return Err(StoreError::NotFound);
366            }
367
368            let previous = get_sampling_metadata(&sampling_metadata_table, height)?;
369
370            let entry = match previous {
371                Some(mut previous) => {
372                    for cid in cids {
373                        if !previous.cids.contains(&cid) {
374                            previous.cids.push(cid);
375                        }
376                    }
377
378                    previous
379                }
380                None => SamplingMetadata { cids },
381            };
382
383            let serialized = entry.encode_vec();
384            sampling_metadata_table.insert(height, &serialized[..])?;
385
386            Ok(())
387        })
388        .await
389    }
390
391    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
392        self.write_tx(move |tx| {
393            let mut ranges_table = tx.open_table(RANGES_TABLE)?;
394            let header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
395            let mut sampled_ranges = get_ranges(&ranges_table, SAMPLED_RANGES_KEY)?;
396
397            if !header_ranges.contains(height) {
398                return Err(StoreError::NotFound);
399            }
400
401            sampled_ranges
402                .insert_relaxed(height..=height)
403                .expect("invalid height");
404
405            set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
406
407            Ok(())
408        })
409        .await
410    }
411
412    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
413        self.read_tx(move |tx| {
414            let headers_table = tx.open_table(HEADERS_TABLE)?;
415            let sampling_metadata_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
416
417            if headers_table.get(height)?.is_none() {
418                return Err(StoreError::NotFound);
419            }
420
421            get_sampling_metadata(&sampling_metadata_table, height)
422        })
423        .await
424    }
425
426    async fn get_stored_ranges(&self) -> Result<BlockRanges> {
427        self.read_tx(|tx| {
428            let table = tx.open_table(RANGES_TABLE)?;
429            get_ranges(&table, HEADER_RANGES_KEY)
430        })
431        .await
432    }
433
434    async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
435        self.read_tx(|tx| {
436            let table = tx.open_table(RANGES_TABLE)?;
437            get_ranges(&table, SAMPLED_RANGES_KEY)
438        })
439        .await
440    }
441
442    async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
443        self.read_tx(|tx| {
444            let table = tx.open_table(RANGES_TABLE)?;
445            get_ranges(&table, PRUNED_RANGES_KEY)
446        })
447        .await
448    }
449
450    async fn remove_height(&self, height: u64) -> Result<()> {
451        self.write_tx(move |tx| {
452            let mut heights_table = tx.open_table(HEIGHTS_TABLE)?;
453            let mut headers_table = tx.open_table(HEADERS_TABLE)?;
454            let mut ranges_table = tx.open_table(RANGES_TABLE)?;
455            let mut sampling_metadata_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
456
457            let mut header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
458            let mut sampled_ranges = get_ranges(&ranges_table, SAMPLED_RANGES_KEY)?;
459            let mut pruned_ranges = get_ranges(&ranges_table, PRUNED_RANGES_KEY)?;
460
461            if !header_ranges.contains(height) {
462                return Err(StoreError::NotFound);
463            }
464
465            let Some(header) = headers_table.remove(height)? else {
466                return Err(StoreError::StoredDataError(format!(
467                    "inconsistency between ranges and height_to_hash tables, height {height}"
468                )));
469            };
470
471            let hash = ExtendedHeader::decode(header.value())
472                .map_err(|e| StoreError::StoredDataError(e.to_string()))?
473                .hash();
474
475            if heights_table.remove(hash.as_bytes())?.is_none() {
476                return Err(StoreError::StoredDataError(format!(
477                    "inconsistency between header and height_to_hash tables, hash {hash}"
478                )));
479            }
480
481            sampling_metadata_table.remove(height)?;
482
483            header_ranges
484                .remove_relaxed(height..=height)
485                .expect("valid range never fails");
486            sampled_ranges
487                .remove_relaxed(height..=height)
488                .expect("valid range never fails");
489            pruned_ranges
490                .insert_relaxed(height..=height)
491                .expect("valid range never fails");
492
493            set_ranges(&mut ranges_table, HEADER_RANGES_KEY, &header_ranges)?;
494            set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
495            set_ranges(&mut ranges_table, PRUNED_RANGES_KEY, &pruned_ranges)?;
496
497            Ok(())
498        })
499        .await
500    }
501
502    async fn get_identity(&self) -> Result<Keypair> {
503        self.read_tx(move |tx| {
504            let identity_table = tx.open_table(LIBP2P_IDENTITY_TABLE)?;
505
506            let (_, key_bytes) = identity_table
507                .first()?
508                .expect("identity_table should be non empty");
509
510            Ok(Keypair::from_protobuf_encoding(key_bytes.value())?)
511        })
512        .await
513    }
514}
515
516#[async_trait]
517impl Store for RedbStore {
518    async fn get_head(&self) -> Result<ExtendedHeader> {
519        self.get_head().await
520    }
521
522    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
523        self.get_by_hash(hash).await
524    }
525
526    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
527        self.get_by_height(height).await
528    }
529
530    async fn wait_new_head(&self) -> u64 {
531        let head = self.head_height().await.unwrap_or(0);
532        let mut notifier = pin!(self.inner.header_added_notifier.notified());
533
534        loop {
535            let new_head = self.head_height().await.unwrap_or(0);
536
537            if head != new_head {
538                return new_head;
539            }
540
541            // Await for a notification
542            notifier.as_mut().await;
543
544            // Reset notifier
545            notifier.set(self.inner.header_added_notifier.notified());
546        }
547    }
548
549    async fn wait_height(&self, height: u64) -> Result<()> {
550        let mut notifier = pin!(self.inner.header_added_notifier.notified());
551
552        loop {
553            if self.contains_height(height).await {
554                return Ok(());
555            }
556
557            // Await for a notification
558            notifier.as_mut().await;
559
560            // Reset notifier
561            notifier.set(self.inner.header_added_notifier.notified());
562        }
563    }
564
565    async fn head_height(&self) -> Result<u64> {
566        self.head_height().await
567    }
568
569    async fn has(&self, hash: &Hash) -> bool {
570        self.contains_hash(hash).await
571    }
572
573    async fn has_at(&self, height: u64) -> bool {
574        self.contains_height(height).await
575    }
576
577    async fn insert<R>(&self, headers: R) -> Result<()>
578    where
579        R: TryInto<VerifiedExtendedHeaders> + Send,
580        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
581    {
582        self.insert(headers).await
583    }
584
585    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
586        self.update_sampling_metadata(height, cids).await
587    }
588
589    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
590        self.mark_as_sampled(height).await
591    }
592
593    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
594        self.get_sampling_metadata(height).await
595    }
596
597    async fn get_stored_header_ranges(&self) -> Result<BlockRanges> {
598        Ok(self.get_stored_ranges().await?)
599    }
600
601    async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
602        self.get_sampled_ranges().await
603    }
604
605    async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
606        self.get_pruned_ranges().await
607    }
608
609    async fn remove_height(&self, height: u64) -> Result<()> {
610        self.remove_height(height).await
611    }
612
613    async fn get_identity(&self) -> Result<Keypair> {
614        self.get_identity().await
615    }
616
617    async fn close(mut self) -> Result<()> {
618        // Wait all ongoing `spawn_blocking` tasks to finish.
619        self.task_counter.wait_guards().await;
620        Ok(())
621    }
622}
623
624fn verify_against_neighbours<R>(
625    headers_table: &R,
626    lowest_header: Option<&ExtendedHeader>,
627    highest_header: Option<&ExtendedHeader>,
628) -> Result<()>
629where
630    R: ReadableTable<u64, &'static [u8]>,
631{
632    if let Some(lowest_header) = lowest_header {
633        let prev = get_header(headers_table, lowest_header.height().value() - 1).map_err(|e| {
634            if let StoreError::NotFound = e {
635                StoreError::StoredDataError("inconsistency between headers and ranges table".into())
636            } else {
637                e
638            }
639        })?;
640
641        prev.verify(lowest_header)
642            .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
643    }
644
645    if let Some(highest_header) = highest_header {
646        let next = get_header(headers_table, highest_header.height().value() + 1).map_err(|e| {
647            if let StoreError::NotFound = e {
648                StoreError::StoredDataError("inconsistency between headers and ranges table".into())
649            } else {
650                e
651            }
652        })?;
653
654        highest_header
655            .verify(&next)
656            .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
657    }
658
659    Ok(())
660}
661
662fn get_ranges<R>(ranges_table: &R, name: &str) -> Result<BlockRanges>
663where
664    R: ReadableTable<&'static str, Vec<(u64, u64)>>,
665{
666    let raw_ranges = ranges_table
667        .get(name)?
668        .map(|guard| {
669            guard
670                .value()
671                .iter()
672                .map(|(start, end)| *start..=*end)
673                .collect()
674        })
675        .unwrap_or_default();
676
677    BlockRanges::from_vec(raw_ranges).map_err(|e| {
678        let s = format!("Stored BlockRanges for {name} are invalid: {e}");
679        StoreError::StoredDataError(s)
680    })
681}
682
683fn set_ranges(
684    ranges_table: &mut Table<&str, Vec<(u64, u64)>>,
685    name: &str,
686    ranges: &BlockRanges,
687) -> Result<()> {
688    let raw_ranges: &[RangeInclusive<u64>] = ranges.as_ref();
689    let raw_ranges = raw_ranges
690        .iter()
691        .map(|range| (*range.start(), *range.end()))
692        .collect::<Vec<_>>();
693
694    ranges_table.insert(name, raw_ranges)?;
695
696    Ok(())
697}
698
699#[inline]
700fn get_height<R>(heights_table: &R, key: &[u8]) -> Result<u64>
701where
702    R: ReadableTable<&'static [u8], u64>,
703{
704    heights_table
705        .get(key)?
706        .map(|guard| guard.value())
707        .ok_or(StoreError::NotFound)
708}
709
710#[inline]
711fn get_header<R>(headers_table: &R, key: u64) -> Result<ExtendedHeader>
712where
713    R: ReadableTable<u64, &'static [u8]>,
714{
715    let serialized = headers_table.get(key)?.ok_or(StoreError::NotFound)?;
716    deserialize_extended_header(serialized.value())
717}
718
719#[inline]
720fn get_sampling_metadata<R>(
721    sampling_metadata_table: &R,
722    key: u64,
723) -> Result<Option<SamplingMetadata>>
724where
725    R: ReadableTable<u64, &'static [u8]>,
726{
727    sampling_metadata_table
728        .get(key)?
729        .map(|guard| deserialize_sampling_metadata(guard.value()))
730        .transpose()
731}
732
733impl From<TransactionError> for StoreError {
734    fn from(e: TransactionError) -> Self {
735        match e {
736            TransactionError::ReadTransactionStillInUse(_) => {
737                unreachable!("redb::ReadTransaction::close is never used")
738            }
739            e => StoreError::FatalDatabaseError(format!("TransactionError: {e}")),
740        }
741    }
742}
743
744impl From<TableError> for StoreError {
745    fn from(e: TableError) -> Self {
746        match e {
747            TableError::Storage(e) => e.into(),
748            TableError::TableAlreadyOpen(table, location) => {
749                panic!("Table {table} already opened from: {location}")
750            }
751            TableError::TableDoesNotExist(table) => {
752                panic!("Table {table} was not created on initialization")
753            }
754            e => StoreError::StoredDataError(format!("TableError: {e}")),
755        }
756    }
757}
758
759impl From<StorageError> for StoreError {
760    fn from(e: StorageError) -> Self {
761        match e {
762            StorageError::ValueTooLarge(_) => {
763                unreachable!("redb::Table::insert_reserve is never used")
764            }
765            e => StoreError::FatalDatabaseError(format!("StorageError: {e}")),
766        }
767    }
768}
769
770impl From<CommitError> for StoreError {
771    fn from(e: CommitError) -> Self {
772        StoreError::FatalDatabaseError(format!("CommitError: {e}"))
773    }
774}
775
776fn migrate_v1_to_v2(
777    tx: &WriteTransaction,
778    schema_version_table: &mut Table<(), u64>,
779) -> Result<()> {
780    const HEADER_HEIGHT_RANGES: TableDefinition<'static, u64, (u64, u64)> =
781        TableDefinition::new("STORE.HEIGHT_RANGES");
782
783    let version = schema_version_table
784        .get(())?
785        .map(|guard| guard.value())
786        .expect("migrations never run on new db");
787
788    if version >= 2 {
789        // Nothing to migrate.
790        return Ok(());
791    }
792
793    debug_assert_eq!(version, 1);
794    warn!("Migrating DB schema from v1 to v2");
795
796    let header_ranges_table = tx.open_table(HEADER_HEIGHT_RANGES)?;
797    let mut ranges_table = tx.open_table(RANGES_TABLE)?;
798
799    let raw_ranges = header_ranges_table
800        .iter()?
801        .map(|range_guard| {
802            let range = range_guard?.1.value();
803            Ok((range.0, range.1))
804        })
805        .collect::<Result<Vec<_>>>()?;
806
807    tx.delete_table(header_ranges_table)?;
808    ranges_table.insert(HEADER_RANGES_KEY, raw_ranges)?;
809
810    // Migrated to v2
811    schema_version_table.insert((), 2)?;
812
813    Ok(())
814}
815
816fn migrate_v2_to_v3(
817    tx: &WriteTransaction,
818    schema_version_table: &mut Table<(), u64>,
819) -> Result<()> {
820    let version = schema_version_table
821        .get(())?
822        .map(|guard| guard.value())
823        .expect("migrations never run on new db");
824
825    if version >= 3 {
826        // Nothing to migrate.
827        return Ok(());
828    }
829
830    debug_assert_eq!(version, 2);
831    warn!("Migrating DB schema from v2 to v3");
832
833    // There are two chages in v3:
834    //
835    // * Removal of `SamplingStatus` in `SamplingMetadata`
836    // * Rename of sampled ranges database key
837    //
838    // For the first one we don't need to take any actions because it will
839    // be ingored by the deserializer.
840    let mut ranges_table = tx.open_table(RANGES_TABLE)?;
841    let sampled_ranges = get_ranges(&ranges_table, v2::SAMPLED_RANGES_KEY)?;
842    set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
843    ranges_table.remove(v2::SAMPLED_RANGES_KEY)?;
844
845    // Migrated to v3
846    schema_version_table.insert((), 3)?;
847
848    Ok(())
849}
850
851mod v2 {
852    pub(super) const SAMPLED_RANGES_KEY: &str = "KEY.ACCEPTED_SAMPING_RANGES";
853}
854
855#[cfg(test)]
856pub mod tests {
857    use super::*;
858    use crate::test_utils::ExtendedHeaderGeneratorExt;
859    use celestia_types::test_utils::ExtendedHeaderGenerator;
860    use std::path::Path;
861    use tempfile::TempDir;
862
863    #[tokio::test]
864    async fn test_store_persistence() {
865        let db_dir = TempDir::with_prefix("lumina.store.test").unwrap();
866        let db = db_dir.path().join("db");
867
868        let (original_store, mut generator) = gen_filled_store(0, Some(&db)).await;
869        let mut original_headers = generator.next_many(20);
870
871        original_store
872            .insert(original_headers.clone())
873            .await
874            .expect("inserting test data failed");
875        drop(original_store);
876
877        let reopened_store = create_store(Some(&db)).await;
878
879        assert_eq!(
880            original_headers.last().unwrap().height().value(),
881            reopened_store.head_height().await.unwrap()
882        );
883        for original_header in &original_headers {
884            let stored_header = reopened_store
885                .get_by_height(original_header.height().value())
886                .await
887                .unwrap();
888            assert_eq!(original_header, &stored_header);
889        }
890
891        let mut new_headers = generator.next_many(10);
892        reopened_store
893            .insert(new_headers.clone())
894            .await
895            .expect("failed to insert data");
896        drop(reopened_store);
897
898        original_headers.append(&mut new_headers);
899
900        let reopened_store = create_store(Some(&db)).await;
901        assert_eq!(
902            original_headers.last().unwrap().height().value(),
903            reopened_store.head_height().await.unwrap()
904        );
905        for original_header in &original_headers {
906            let stored_header = reopened_store
907                .get_by_height(original_header.height().value())
908                .await
909                .unwrap();
910            assert_eq!(original_header, &stored_header);
911        }
912    }
913
914    #[tokio::test]
915    async fn test_separate_stores() {
916        let (store0, mut generator0) = gen_filled_store(0, None).await;
917        let store1 = create_store(None).await;
918
919        let headers = generator0.next_many(10);
920        store0.insert(headers.clone()).await.unwrap();
921        store1.insert(headers).await.unwrap();
922
923        let mut generator1 = generator0.fork();
924
925        store0
926            .insert(generator0.next_many_verified(5))
927            .await
928            .unwrap();
929        store1
930            .insert(generator1.next_many_verified(6))
931            .await
932            .unwrap();
933
934        assert_eq!(
935            store0.get_by_height(10).await.unwrap(),
936            store1.get_by_height(10).await.unwrap()
937        );
938        assert_ne!(
939            store0.get_by_height(11).await.unwrap(),
940            store1.get_by_height(11).await.unwrap()
941        );
942
943        assert_eq!(store0.head_height().await.unwrap(), 15);
944        assert_eq!(store1.head_height().await.unwrap(), 16);
945    }
946
947    #[tokio::test]
948    async fn test_identity_persistance() {
949        let db_dir = TempDir::with_prefix("lumina.store.test").unwrap();
950        let db = db_dir.path().join("db");
951
952        let original_store = create_store(Some(&db)).await;
953        let original_identity = original_store.get_identity().await.unwrap().public();
954        drop(original_store);
955
956        let reopened_store = create_store(Some(&db)).await;
957        let reopened_identity = reopened_store.get_identity().await.unwrap().public();
958
959        assert_eq!(original_identity, reopened_identity);
960    }
961
962    #[tokio::test]
963    async fn migration_from_v2() {
964        const SCHEMA_VERSION_TABLE: TableDefinition<'static, (), u64> =
965            TableDefinition::new("STORE.SCHEMA_VERSION");
966        const RANGES_TABLE: TableDefinition<'static, &str, Vec<(u64, u64)>> =
967            TableDefinition::new("STORE.RANGES");
968
969        let db = Database::builder()
970            .create_with_backend(redb::backends::InMemoryBackend::new())
971            .unwrap();
972        let db = Arc::new(db);
973
974        // Prepare a v2 db
975        tokio::task::spawn_blocking({
976            let db = db.clone();
977            move || {
978                let tx = db.begin_write().unwrap();
979
980                {
981                    let mut schema_version_table = tx.open_table(SCHEMA_VERSION_TABLE).unwrap();
982                    schema_version_table.insert((), 2).unwrap();
983
984                    let mut ranges_table = tx.open_table(RANGES_TABLE).unwrap();
985                    ranges_table
986                        .insert(v2::SAMPLED_RANGES_KEY, vec![(123, 124)])
987                        .unwrap();
988                }
989
990                tx.commit().unwrap();
991            }
992        })
993        .await
994        .unwrap();
995
996        // Migrate and check store
997        let store = RedbStore::new(db.clone()).await.unwrap();
998        let ranges = store.get_sampled_ranges().await.unwrap();
999        assert_eq!(ranges, BlockRanges::try_from([123..=124]).unwrap());
1000        store.close().await.unwrap();
1001
1002        // Check that old ranges were deleted
1003        tokio::task::spawn_blocking({
1004            let db = db.clone();
1005            move || {
1006                let tx = db.begin_read().unwrap();
1007                let ranges_table = tx.open_table(RANGES_TABLE).unwrap();
1008                assert!(ranges_table.get(v2::SAMPLED_RANGES_KEY).unwrap().is_none());
1009            }
1010        })
1011        .await
1012        .unwrap();
1013    }
1014
1015    pub async fn create_store(path: Option<&Path>) -> RedbStore {
1016        match path {
1017            Some(path) => RedbStore::open(path).await.unwrap(),
1018            None => RedbStore::in_memory().await.unwrap(),
1019        }
1020    }
1021
1022    pub async fn gen_filled_store(
1023        amount: u64,
1024        path: Option<&Path>,
1025    ) -> (RedbStore, ExtendedHeaderGenerator) {
1026        let s = create_store(path).await;
1027        let mut generator = ExtendedHeaderGenerator::new();
1028        let headers = generator.next_many(amount);
1029
1030        s.insert(headers).await.expect("inserting test data failed");
1031
1032        (s, generator)
1033    }
1034}