lumina_node/store/
redb_store.rs

1use std::fmt::Display;
2use std::ops::RangeInclusive;
3use std::path::Path;
4use std::pin::pin;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use celestia_types::ExtendedHeader;
9use celestia_types::hash::Hash;
10use cid::Cid;
11use libp2p::identity::Keypair;
12use redb::{
13    CommitError, Database, ReadTransaction, ReadableTable, ReadableTableMetadata, StorageError,
14    Table, TableDefinition, TableError, TransactionError, WriteTransaction,
15};
16use tendermint_proto::Protobuf;
17use tokio::sync::Notify;
18use tokio::task::spawn_blocking;
19use tracing::{debug, trace, warn};
20
21use crate::block_ranges::BlockRanges;
22use crate::store::utils::VerifiedExtendedHeaders;
23use crate::store::{Result, SamplingMetadata, Store, StoreError, StoreInsertionError};
24use crate::utils::Counter;
25
26use super::utils::{deserialize_extended_header, deserialize_sampling_metadata};
27
28const SCHEMA_VERSION: u64 = 3;
29
30const HEIGHTS_TABLE: TableDefinition<'static, &[u8], u64> = TableDefinition::new("STORE.HEIGHTS");
31const HEADERS_TABLE: TableDefinition<'static, u64, &[u8]> = TableDefinition::new("STORE.HEADERS");
32const SAMPLING_METADATA_TABLE: TableDefinition<'static, u64, &[u8]> =
33    TableDefinition::new("STORE.SAMPLING_METADATA");
34const SCHEMA_VERSION_TABLE: TableDefinition<'static, (), u64> =
35    TableDefinition::new("STORE.SCHEMA_VERSION");
36const RANGES_TABLE: TableDefinition<'static, &str, Vec<(u64, u64)>> =
37    TableDefinition::new("STORE.RANGES");
38const LIBP2P_IDENTITY_TABLE: TableDefinition<'static, (), &[u8]> =
39    TableDefinition::new("LIBP2P.IDENTITY");
40
41const SAMPLED_RANGES_KEY: &str = "KEY.SAMPLED_RANGES";
42const HEADER_RANGES_KEY: &str = "KEY.HEADER_RANGES";
43const PRUNED_RANGES_KEY: &str = "KEY.PRUNED_RANGES";
44
45/// A [`Store`] implementation based on a [`redb`] database.
46#[derive(Debug)]
47pub struct RedbStore {
48    inner: Arc<Inner>,
49    task_counter: Counter,
50}
51
52#[derive(Debug)]
53struct Inner {
54    /// Reference to the entire redb database
55    db: Arc<Database>,
56    /// Notify when a new header is added
57    header_added_notifier: Notify,
58}
59
60impl RedbStore {
61    /// Open a persistent [`redb`] store.
62    pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
63        let path = path.as_ref().to_owned();
64
65        let db = spawn_blocking(|| Database::create(path))
66            .await?
67            .map_err(|e| StoreError::OpenFailed(e.to_string()))?;
68
69        RedbStore::new(Arc::new(db)).await
70    }
71
72    /// Open an in-memory [`redb`] store.
73    pub async fn in_memory() -> Result<Self> {
74        let db = Database::builder()
75            .create_with_backend(redb::backends::InMemoryBackend::new())
76            .map_err(|e| StoreError::OpenFailed(e.to_string()))?;
77
78        RedbStore::new(Arc::new(db)).await
79    }
80
81    /// Create new `RedbStore` with an already opened [`redb::Database`].
82    pub async fn new(db: Arc<Database>) -> Result<Self> {
83        let store = RedbStore {
84            inner: Arc::new(Inner {
85                db,
86                header_added_notifier: Notify::new(),
87            }),
88            task_counter: Counter::new(),
89        };
90
91        store
92            .write_tx(|tx| {
93                let mut schema_version_table = tx.open_table(SCHEMA_VERSION_TABLE)?;
94                let schema_version = schema_version_table.get(())?.map(|guard| guard.value());
95
96                match schema_version {
97                    Some(schema_version) => {
98                        if schema_version > SCHEMA_VERSION {
99                            let e = format!(
100                                "Incompatible database schema; found {schema_version}, expected {SCHEMA_VERSION}."
101                            );
102                            return Err(StoreError::OpenFailed(e));
103                        }
104
105                        // Do migrations
106                        migrate_v1_to_v2(tx, &mut schema_version_table)?;
107                        migrate_v2_to_v3(tx, &mut schema_version_table)?;
108                    }
109                    None => {
110                        // New database
111                        schema_version_table.insert((), SCHEMA_VERSION)?;
112                    }
113                }
114
115                // Force us to write migrations!
116                debug_assert_eq!(
117                    schema_version_table.get(())?.map(|guard| guard.value()),
118                    Some(SCHEMA_VERSION),
119                    "Some migrations are missing"
120                );
121
122                // create tables, so that reads later don't complain
123                let _heights_table = tx.open_table(HEIGHTS_TABLE)?;
124                let _headers_table = tx.open_table(HEADERS_TABLE)?;
125                let _ranges_table = tx.open_table(RANGES_TABLE)?;
126                let _sampling_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
127                let mut identity_table = tx.open_table(LIBP2P_IDENTITY_TABLE)?;
128
129                if identity_table.is_empty()? {
130                    let keypair = Keypair::generate_ed25519();
131
132                    let peer_id = keypair.public().to_peer_id();
133                    let keypair_bytes = keypair.to_protobuf_encoding()?;
134                    debug!("Initialised new identity: {peer_id}");
135                    identity_table.insert((), &*keypair_bytes)?;
136                }
137
138                Ok(())
139            })
140            .await
141            .map_err(|e| match e {
142                e @ StoreError::OpenFailed(_) => e,
143                e => StoreError::OpenFailed(e.to_string()),
144            })?;
145
146        Ok(store)
147    }
148
149    /// Returns the raw [`redb::Database`].
150    ///
151    /// This is useful if you want to pass the database handle to any other
152    /// stores (e.g. [`blockstore`]).
153    pub fn raw_db(&self) -> Arc<Database> {
154        self.inner.db.clone()
155    }
156
157    /// Execute a read transaction.
158    async fn read_tx<F, T>(&self, f: F) -> Result<T>
159    where
160        F: FnOnce(&mut ReadTransaction) -> Result<T> + Send + 'static,
161        T: Send + 'static,
162    {
163        let inner = self.inner.clone();
164        let guard = self.task_counter.guard();
165
166        spawn_blocking(move || {
167            let _guard = guard;
168
169            {
170                let mut tx = inner.db.begin_read()?;
171                f(&mut tx)
172            }
173        })
174        .await?
175    }
176
177    /// Execute a write transaction.
178    ///
179    /// If closure returns an error the transaction is aborted, otherwise committed.
180    async fn write_tx<F, T>(&self, f: F) -> Result<T>
181    where
182        F: FnOnce(&mut WriteTransaction) -> Result<T> + Send + 'static,
183        T: Send + 'static,
184    {
185        let inner = self.inner.clone();
186        let guard = self.task_counter.guard();
187
188        spawn_blocking(move || {
189            let _guard = guard;
190
191            {
192                let mut tx = inner.db.begin_write()?;
193                let res = f(&mut tx);
194
195                if res.is_ok() {
196                    tx.commit()?;
197                } else {
198                    tx.abort()?;
199                }
200
201                res
202            }
203        })
204        .await?
205    }
206
207    async fn head_height(&self) -> Result<u64> {
208        self.read_tx(|tx| {
209            let table = tx.open_table(RANGES_TABLE)?;
210            let header_ranges = get_ranges(&table, HEADER_RANGES_KEY)?;
211
212            header_ranges.head().ok_or(StoreError::NotFound)
213        })
214        .await
215    }
216
217    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
218        let hash = *hash;
219
220        self.read_tx(move |tx| {
221            let heights_table = tx.open_table(HEIGHTS_TABLE)?;
222            let headers_table = tx.open_table(HEADERS_TABLE)?;
223
224            let height = get_height(&heights_table, hash.as_bytes())?;
225            get_header(&headers_table, height)
226        })
227        .await
228    }
229
230    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
231        self.read_tx(move |tx| {
232            let table = tx.open_table(HEADERS_TABLE)?;
233            get_header(&table, height)
234        })
235        .await
236    }
237
238    async fn get_head(&self) -> Result<ExtendedHeader> {
239        self.read_tx(|tx| {
240            let ranges_table = tx.open_table(RANGES_TABLE)?;
241            let headers_table = tx.open_table(HEADERS_TABLE)?;
242
243            let header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
244            let head = header_ranges.head().ok_or(StoreError::NotFound)?;
245
246            get_header(&headers_table, head)
247        })
248        .await
249    }
250
251    async fn contains_hash(&self, hash: &Hash) -> bool {
252        let hash = *hash;
253
254        self.read_tx(move |tx| {
255            let heights_table = tx.open_table(HEIGHTS_TABLE)?;
256            let headers_table = tx.open_table(HEADERS_TABLE)?;
257
258            let height = get_height(&heights_table, hash.as_bytes())?;
259            Ok(headers_table.get(height)?.is_some())
260        })
261        .await
262        .unwrap_or(false)
263    }
264
265    async fn contains_height(&self, height: u64) -> bool {
266        self.read_tx(move |tx| {
267            let headers_table = tx.open_table(HEADERS_TABLE)?;
268            Ok(headers_table.get(height)?.is_some())
269        })
270        .await
271        .unwrap_or(false)
272    }
273
274    async fn insert<R>(&self, headers: R) -> Result<()>
275    where
276        R: TryInto<VerifiedExtendedHeaders> + Send,
277        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
278    {
279        let headers = headers
280            .try_into()
281            .map_err(|e| StoreInsertionError::HeadersVerificationFailed(e.to_string()))?;
282
283        self.write_tx(move |tx| {
284            let (Some(head), Some(tail)) = (headers.as_ref().first(), headers.as_ref().last())
285            else {
286                return Ok(());
287            };
288            let mut heights_table = tx.open_table(HEIGHTS_TABLE)?;
289            let mut headers_table = tx.open_table(HEADERS_TABLE)?;
290            let mut ranges_table = tx.open_table(RANGES_TABLE)?;
291
292            let mut header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
293            let mut sampled_ranges = get_ranges(&ranges_table, SAMPLED_RANGES_KEY)?;
294            let mut pruned_ranges = get_ranges(&ranges_table, PRUNED_RANGES_KEY)?;
295
296            let headers_range = head.height()..=tail.height();
297
298            let (prev_exists, next_exists) = header_ranges
299                .check_insertion_constraints(&headers_range)
300                .map_err(StoreInsertionError::ConstraintsNotMet)?;
301
302            verify_against_neighbours(
303                &headers_table,
304                prev_exists.then_some(head),
305                next_exists.then_some(tail),
306            )?;
307
308            for header in headers {
309                let height = header.height();
310                let hash = header.hash();
311                let serialized_header = header.encode_vec();
312
313                if headers_table
314                    .insert(height, &serialized_header[..])?
315                    .is_some()
316                {
317                    return Err(StoreError::StoredDataError(
318                        "inconsistency between headers table and ranges table".into(),
319                    ));
320                }
321
322                if heights_table.insert(hash.as_bytes(), height)?.is_some() {
323                    // TODO: Replace this with `StoredDataError` when we implement
324                    // type-safe validation on insertion.
325                    return Err(StoreInsertionError::HashExists(hash).into());
326                }
327
328                trace!("Inserted header {hash} with height {height}");
329            }
330
331            header_ranges
332                .insert_relaxed(&headers_range)
333                .expect("invalid range");
334            sampled_ranges
335                .remove_relaxed(&headers_range)
336                .expect("invalid range");
337            pruned_ranges
338                .remove_relaxed(&headers_range)
339                .expect("invalid range");
340
341            set_ranges(&mut ranges_table, HEADER_RANGES_KEY, &header_ranges)?;
342            set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
343            set_ranges(&mut ranges_table, PRUNED_RANGES_KEY, &pruned_ranges)?;
344
345            debug!("Inserted header range {headers_range:?}",);
346
347            Ok(())
348        })
349        .await?;
350
351        self.inner.header_added_notifier.notify_waiters();
352
353        Ok(())
354    }
355
356    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
357        self.write_tx(move |tx| {
358            let mut sampling_metadata_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
359            let ranges_table = tx.open_table(RANGES_TABLE)?;
360
361            let header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
362
363            if !header_ranges.contains(height) {
364                return Err(StoreError::NotFound);
365            }
366
367            let previous = get_sampling_metadata(&sampling_metadata_table, height)?;
368
369            let entry = match previous {
370                Some(mut previous) => {
371                    for cid in cids {
372                        if !previous.cids.contains(&cid) {
373                            previous.cids.push(cid);
374                        }
375                    }
376
377                    previous
378                }
379                None => SamplingMetadata { cids },
380            };
381
382            let serialized = entry.encode_vec();
383            sampling_metadata_table.insert(height, &serialized[..])?;
384
385            Ok(())
386        })
387        .await
388    }
389
390    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
391        self.write_tx(move |tx| {
392            let mut ranges_table = tx.open_table(RANGES_TABLE)?;
393            let header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
394            let mut sampled_ranges = get_ranges(&ranges_table, SAMPLED_RANGES_KEY)?;
395
396            if !header_ranges.contains(height) {
397                return Err(StoreError::NotFound);
398            }
399
400            sampled_ranges
401                .insert_relaxed(height..=height)
402                .expect("invalid height");
403
404            set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
405
406            Ok(())
407        })
408        .await
409    }
410
411    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
412        self.read_tx(move |tx| {
413            let headers_table = tx.open_table(HEADERS_TABLE)?;
414            let sampling_metadata_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
415
416            if headers_table.get(height)?.is_none() {
417                return Err(StoreError::NotFound);
418            }
419
420            get_sampling_metadata(&sampling_metadata_table, height)
421        })
422        .await
423    }
424
425    async fn get_stored_ranges(&self) -> Result<BlockRanges> {
426        self.read_tx(|tx| {
427            let table = tx.open_table(RANGES_TABLE)?;
428            get_ranges(&table, HEADER_RANGES_KEY)
429        })
430        .await
431    }
432
433    async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
434        self.read_tx(|tx| {
435            let table = tx.open_table(RANGES_TABLE)?;
436            get_ranges(&table, SAMPLED_RANGES_KEY)
437        })
438        .await
439    }
440
441    async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
442        self.read_tx(|tx| {
443            let table = tx.open_table(RANGES_TABLE)?;
444            get_ranges(&table, PRUNED_RANGES_KEY)
445        })
446        .await
447    }
448
449    async fn remove_height(&self, height: u64) -> Result<()> {
450        self.write_tx(move |tx| {
451            let mut heights_table = tx.open_table(HEIGHTS_TABLE)?;
452            let mut headers_table = tx.open_table(HEADERS_TABLE)?;
453            let mut ranges_table = tx.open_table(RANGES_TABLE)?;
454            let mut sampling_metadata_table = tx.open_table(SAMPLING_METADATA_TABLE)?;
455
456            let mut header_ranges = get_ranges(&ranges_table, HEADER_RANGES_KEY)?;
457            let mut sampled_ranges = get_ranges(&ranges_table, SAMPLED_RANGES_KEY)?;
458            let mut pruned_ranges = get_ranges(&ranges_table, PRUNED_RANGES_KEY)?;
459
460            if !header_ranges.contains(height) {
461                return Err(StoreError::NotFound);
462            }
463
464            let Some(header) = headers_table.remove(height)? else {
465                return Err(StoreError::StoredDataError(format!(
466                    "inconsistency between ranges and height_to_hash tables, height {height}"
467                )));
468            };
469
470            let hash = ExtendedHeader::decode(header.value())
471                .map_err(|e| StoreError::StoredDataError(e.to_string()))?
472                .hash();
473
474            if heights_table.remove(hash.as_bytes())?.is_none() {
475                return Err(StoreError::StoredDataError(format!(
476                    "inconsistency between header and height_to_hash tables, hash {hash}"
477                )));
478            }
479
480            sampling_metadata_table.remove(height)?;
481
482            header_ranges
483                .remove_relaxed(height..=height)
484                .expect("valid range never fails");
485            sampled_ranges
486                .remove_relaxed(height..=height)
487                .expect("valid range never fails");
488            pruned_ranges
489                .insert_relaxed(height..=height)
490                .expect("valid range never fails");
491
492            set_ranges(&mut ranges_table, HEADER_RANGES_KEY, &header_ranges)?;
493            set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
494            set_ranges(&mut ranges_table, PRUNED_RANGES_KEY, &pruned_ranges)?;
495
496            Ok(())
497        })
498        .await
499    }
500
501    async fn get_identity(&self) -> Result<Keypair> {
502        self.read_tx(move |tx| {
503            let identity_table = tx.open_table(LIBP2P_IDENTITY_TABLE)?;
504
505            let (_, key_bytes) = identity_table
506                .first()?
507                .expect("identity_table should be non empty");
508
509            Ok(Keypair::from_protobuf_encoding(key_bytes.value())?)
510        })
511        .await
512    }
513}
514
515#[async_trait]
516impl Store for RedbStore {
517    async fn get_head(&self) -> Result<ExtendedHeader> {
518        self.get_head().await
519    }
520
521    async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
522        self.get_by_hash(hash).await
523    }
524
525    async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
526        self.get_by_height(height).await
527    }
528
529    async fn wait_new_head(&self) -> u64 {
530        let head = self.head_height().await.unwrap_or(0);
531        let mut notifier = pin!(self.inner.header_added_notifier.notified());
532
533        loop {
534            let new_head = self.head_height().await.unwrap_or(0);
535
536            if head != new_head {
537                return new_head;
538            }
539
540            // Await for a notification
541            notifier.as_mut().await;
542
543            // Reset notifier
544            notifier.set(self.inner.header_added_notifier.notified());
545        }
546    }
547
548    async fn wait_height(&self, height: u64) -> Result<()> {
549        let mut notifier = pin!(self.inner.header_added_notifier.notified());
550
551        loop {
552            if self.contains_height(height).await {
553                return Ok(());
554            }
555
556            // Await for a notification
557            notifier.as_mut().await;
558
559            // Reset notifier
560            notifier.set(self.inner.header_added_notifier.notified());
561        }
562    }
563
564    async fn head_height(&self) -> Result<u64> {
565        self.head_height().await
566    }
567
568    async fn has(&self, hash: &Hash) -> bool {
569        self.contains_hash(hash).await
570    }
571
572    async fn has_at(&self, height: u64) -> bool {
573        self.contains_height(height).await
574    }
575
576    async fn insert<R>(&self, headers: R) -> Result<()>
577    where
578        R: TryInto<VerifiedExtendedHeaders> + Send,
579        <R as TryInto<VerifiedExtendedHeaders>>::Error: Display,
580    {
581        self.insert(headers).await
582    }
583
584    async fn update_sampling_metadata(&self, height: u64, cids: Vec<Cid>) -> Result<()> {
585        self.update_sampling_metadata(height, cids).await
586    }
587
588    async fn mark_as_sampled(&self, height: u64) -> Result<()> {
589        self.mark_as_sampled(height).await
590    }
591
592    async fn get_sampling_metadata(&self, height: u64) -> Result<Option<SamplingMetadata>> {
593        self.get_sampling_metadata(height).await
594    }
595
596    async fn get_stored_header_ranges(&self) -> Result<BlockRanges> {
597        Ok(self.get_stored_ranges().await?)
598    }
599
600    async fn get_sampled_ranges(&self) -> Result<BlockRanges> {
601        self.get_sampled_ranges().await
602    }
603
604    async fn get_pruned_ranges(&self) -> Result<BlockRanges> {
605        self.get_pruned_ranges().await
606    }
607
608    async fn remove_height(&self, height: u64) -> Result<()> {
609        self.remove_height(height).await
610    }
611
612    async fn get_identity(&self) -> Result<Keypair> {
613        self.get_identity().await
614    }
615
616    async fn close(mut self) -> Result<()> {
617        // Wait all ongoing `spawn_blocking` tasks to finish.
618        self.task_counter.wait_guards().await;
619        Ok(())
620    }
621}
622
623fn verify_against_neighbours<R>(
624    headers_table: &R,
625    lowest_header: Option<&ExtendedHeader>,
626    highest_header: Option<&ExtendedHeader>,
627) -> Result<()>
628where
629    R: ReadableTable<u64, &'static [u8]>,
630{
631    if let Some(lowest_header) = lowest_header {
632        let prev = get_header(headers_table, lowest_header.height() - 1).map_err(|e| {
633            if let StoreError::NotFound = e {
634                StoreError::StoredDataError("inconsistency between headers and ranges table".into())
635            } else {
636                e
637            }
638        })?;
639
640        prev.verify(lowest_header)
641            .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
642    }
643
644    if let Some(highest_header) = highest_header {
645        let next = get_header(headers_table, highest_header.height() + 1).map_err(|e| {
646            if let StoreError::NotFound = e {
647                StoreError::StoredDataError("inconsistency between headers and ranges table".into())
648            } else {
649                e
650            }
651        })?;
652
653        highest_header
654            .verify(&next)
655            .map_err(|e| StoreInsertionError::NeighborsVerificationFailed(e.to_string()))?;
656    }
657
658    Ok(())
659}
660
661fn get_ranges<R>(ranges_table: &R, name: &str) -> Result<BlockRanges>
662where
663    R: ReadableTable<&'static str, Vec<(u64, u64)>>,
664{
665    let raw_ranges = ranges_table
666        .get(name)?
667        .map(|guard| {
668            guard
669                .value()
670                .iter()
671                .map(|(start, end)| *start..=*end)
672                .collect()
673        })
674        .unwrap_or_default();
675
676    BlockRanges::from_vec(raw_ranges).map_err(|e| {
677        let s = format!("Stored BlockRanges for {name} are invalid: {e}");
678        StoreError::StoredDataError(s)
679    })
680}
681
682fn set_ranges(
683    ranges_table: &mut Table<&str, Vec<(u64, u64)>>,
684    name: &str,
685    ranges: &BlockRanges,
686) -> Result<()> {
687    let raw_ranges: &[RangeInclusive<u64>] = ranges.as_ref();
688    let raw_ranges = raw_ranges
689        .iter()
690        .map(|range| (*range.start(), *range.end()))
691        .collect::<Vec<_>>();
692
693    ranges_table.insert(name, raw_ranges)?;
694
695    Ok(())
696}
697
698#[inline]
699fn get_height<R>(heights_table: &R, key: &[u8]) -> Result<u64>
700where
701    R: ReadableTable<&'static [u8], u64>,
702{
703    heights_table
704        .get(key)?
705        .map(|guard| guard.value())
706        .ok_or(StoreError::NotFound)
707}
708
709#[inline]
710fn get_header<R>(headers_table: &R, key: u64) -> Result<ExtendedHeader>
711where
712    R: ReadableTable<u64, &'static [u8]>,
713{
714    let serialized = headers_table.get(key)?.ok_or(StoreError::NotFound)?;
715    deserialize_extended_header(serialized.value())
716}
717
718#[inline]
719fn get_sampling_metadata<R>(
720    sampling_metadata_table: &R,
721    key: u64,
722) -> Result<Option<SamplingMetadata>>
723where
724    R: ReadableTable<u64, &'static [u8]>,
725{
726    sampling_metadata_table
727        .get(key)?
728        .map(|guard| deserialize_sampling_metadata(guard.value()))
729        .transpose()
730}
731
732impl From<TransactionError> for StoreError {
733    fn from(e: TransactionError) -> Self {
734        match e {
735            TransactionError::ReadTransactionStillInUse(_) => {
736                unreachable!("redb::ReadTransaction::close is never used")
737            }
738            e => StoreError::FatalDatabaseError(format!("TransactionError: {e}")),
739        }
740    }
741}
742
743impl From<TableError> for StoreError {
744    fn from(e: TableError) -> Self {
745        match e {
746            TableError::Storage(e) => e.into(),
747            TableError::TableAlreadyOpen(table, location) => {
748                panic!("Table {table} already opened from: {location}")
749            }
750            TableError::TableDoesNotExist(table) => {
751                panic!("Table {table} was not created on initialization")
752            }
753            e => StoreError::StoredDataError(format!("TableError: {e}")),
754        }
755    }
756}
757
758impl From<StorageError> for StoreError {
759    fn from(e: StorageError) -> Self {
760        match e {
761            StorageError::ValueTooLarge(_) => {
762                unreachable!("redb::Table::insert_reserve is never used")
763            }
764            e => StoreError::FatalDatabaseError(format!("StorageError: {e}")),
765        }
766    }
767}
768
769impl From<CommitError> for StoreError {
770    fn from(e: CommitError) -> Self {
771        StoreError::FatalDatabaseError(format!("CommitError: {e}"))
772    }
773}
774
775fn migrate_v1_to_v2(
776    tx: &WriteTransaction,
777    schema_version_table: &mut Table<(), u64>,
778) -> Result<()> {
779    const HEADER_HEIGHT_RANGES: TableDefinition<'static, u64, (u64, u64)> =
780        TableDefinition::new("STORE.HEIGHT_RANGES");
781
782    let version = schema_version_table
783        .get(())?
784        .map(|guard| guard.value())
785        .expect("migrations never run on new db");
786
787    if version >= 2 {
788        // Nothing to migrate.
789        return Ok(());
790    }
791
792    debug_assert_eq!(version, 1);
793    warn!("Migrating DB schema from v1 to v2");
794
795    let header_ranges_table = tx.open_table(HEADER_HEIGHT_RANGES)?;
796    let mut ranges_table = tx.open_table(RANGES_TABLE)?;
797
798    let raw_ranges = header_ranges_table
799        .iter()?
800        .map(|range_guard| {
801            let range = range_guard?.1.value();
802            Ok((range.0, range.1))
803        })
804        .collect::<Result<Vec<_>>>()?;
805
806    tx.delete_table(header_ranges_table)?;
807    ranges_table.insert(HEADER_RANGES_KEY, raw_ranges)?;
808
809    // Migrated to v2
810    schema_version_table.insert((), 2)?;
811
812    Ok(())
813}
814
815fn migrate_v2_to_v3(
816    tx: &WriteTransaction,
817    schema_version_table: &mut Table<(), u64>,
818) -> Result<()> {
819    let version = schema_version_table
820        .get(())?
821        .map(|guard| guard.value())
822        .expect("migrations never run on new db");
823
824    if version >= 3 {
825        // Nothing to migrate.
826        return Ok(());
827    }
828
829    debug_assert_eq!(version, 2);
830    warn!("Migrating DB schema from v2 to v3");
831
832    // There are two changes in v3:
833    //
834    // * Removal of `SamplingStatus` in `SamplingMetadata`
835    // * Rename of sampled ranges database key
836    //
837    // For the first one we don't need to take any actions because it will
838    // be ignored by the deserializer.
839    let mut ranges_table = tx.open_table(RANGES_TABLE)?;
840    let sampled_ranges = get_ranges(&ranges_table, v2::SAMPLED_RANGES_KEY)?;
841    set_ranges(&mut ranges_table, SAMPLED_RANGES_KEY, &sampled_ranges)?;
842    ranges_table.remove(v2::SAMPLED_RANGES_KEY)?;
843
844    // Migrated to v3
845    schema_version_table.insert((), 3)?;
846
847    Ok(())
848}
849
850mod v2 {
851    pub(super) const SAMPLED_RANGES_KEY: &str = "KEY.ACCEPTED_SAMPING_RANGES";
852}
853
854#[cfg(test)]
855pub mod tests {
856    use super::*;
857    use crate::test_utils::ExtendedHeaderGeneratorExt;
858    use celestia_types::test_utils::ExtendedHeaderGenerator;
859    use std::path::Path;
860    use tempfile::TempDir;
861
862    #[tokio::test]
863    async fn test_store_persistence() {
864        let db_dir = TempDir::with_prefix("lumina.store.test").unwrap();
865        let db = db_dir.path().join("db");
866
867        let (original_store, mut generator) = gen_filled_store(0, Some(&db)).await;
868        let mut original_headers = generator.next_many(20);
869
870        original_store
871            .insert(original_headers.clone())
872            .await
873            .expect("inserting test data failed");
874        drop(original_store);
875
876        let reopened_store = create_store(Some(&db)).await;
877
878        assert_eq!(
879            original_headers.last().unwrap().height(),
880            reopened_store.head_height().await.unwrap()
881        );
882        for original_header in &original_headers {
883            let stored_header = reopened_store
884                .get_by_height(original_header.height())
885                .await
886                .unwrap();
887            assert_eq!(original_header, &stored_header);
888        }
889
890        let mut new_headers = generator.next_many(10);
891        reopened_store
892            .insert(new_headers.clone())
893            .await
894            .expect("failed to insert data");
895        drop(reopened_store);
896
897        original_headers.append(&mut new_headers);
898
899        let reopened_store = create_store(Some(&db)).await;
900        assert_eq!(
901            original_headers.last().unwrap().height(),
902            reopened_store.head_height().await.unwrap()
903        );
904        for original_header in &original_headers {
905            let stored_header = reopened_store
906                .get_by_height(original_header.height())
907                .await
908                .unwrap();
909            assert_eq!(original_header, &stored_header);
910        }
911    }
912
913    #[tokio::test]
914    async fn test_separate_stores() {
915        let (store0, mut generator0) = gen_filled_store(0, None).await;
916        let store1 = create_store(None).await;
917
918        let headers = generator0.next_many(10);
919        store0.insert(headers.clone()).await.unwrap();
920        store1.insert(headers).await.unwrap();
921
922        let mut generator1 = generator0.fork();
923
924        store0
925            .insert(generator0.next_many_verified(5))
926            .await
927            .unwrap();
928        store1
929            .insert(generator1.next_many_verified(6))
930            .await
931            .unwrap();
932
933        assert_eq!(
934            store0.get_by_height(10).await.unwrap(),
935            store1.get_by_height(10).await.unwrap()
936        );
937        assert_ne!(
938            store0.get_by_height(11).await.unwrap(),
939            store1.get_by_height(11).await.unwrap()
940        );
941
942        assert_eq!(store0.head_height().await.unwrap(), 15);
943        assert_eq!(store1.head_height().await.unwrap(), 16);
944    }
945
946    #[tokio::test]
947    async fn test_identity_persistance() {
948        let db_dir = TempDir::with_prefix("lumina.store.test").unwrap();
949        let db = db_dir.path().join("db");
950
951        let original_store = create_store(Some(&db)).await;
952        let original_identity = original_store.get_identity().await.unwrap().public();
953        drop(original_store);
954
955        let reopened_store = create_store(Some(&db)).await;
956        let reopened_identity = reopened_store.get_identity().await.unwrap().public();
957
958        assert_eq!(original_identity, reopened_identity);
959    }
960
961    #[tokio::test]
962    async fn migration_from_v2() {
963        const SCHEMA_VERSION_TABLE: TableDefinition<'static, (), u64> =
964            TableDefinition::new("STORE.SCHEMA_VERSION");
965        const RANGES_TABLE: TableDefinition<'static, &str, Vec<(u64, u64)>> =
966            TableDefinition::new("STORE.RANGES");
967
968        let db = Database::builder()
969            .create_with_backend(redb::backends::InMemoryBackend::new())
970            .unwrap();
971        let db = Arc::new(db);
972
973        // Prepare a v2 db
974        tokio::task::spawn_blocking({
975            let db = db.clone();
976            move || {
977                let tx = db.begin_write().unwrap();
978
979                {
980                    let mut schema_version_table = tx.open_table(SCHEMA_VERSION_TABLE).unwrap();
981                    schema_version_table.insert((), 2).unwrap();
982
983                    let mut ranges_table = tx.open_table(RANGES_TABLE).unwrap();
984                    ranges_table
985                        .insert(v2::SAMPLED_RANGES_KEY, vec![(123, 124)])
986                        .unwrap();
987                }
988
989                tx.commit().unwrap();
990            }
991        })
992        .await
993        .unwrap();
994
995        // Migrate and check store
996        let store = RedbStore::new(db.clone()).await.unwrap();
997        let ranges = store.get_sampled_ranges().await.unwrap();
998        assert_eq!(ranges, BlockRanges::try_from([123..=124]).unwrap());
999        store.close().await.unwrap();
1000
1001        // Check that old ranges were deleted
1002        tokio::task::spawn_blocking({
1003            let db = db.clone();
1004            move || {
1005                let tx = db.begin_read().unwrap();
1006                let ranges_table = tx.open_table(RANGES_TABLE).unwrap();
1007                assert!(ranges_table.get(v2::SAMPLED_RANGES_KEY).unwrap().is_none());
1008            }
1009        })
1010        .await
1011        .unwrap();
1012    }
1013
1014    pub async fn create_store(path: Option<&Path>) -> RedbStore {
1015        match path {
1016            Some(path) => RedbStore::open(path).await.unwrap(),
1017            None => RedbStore::in_memory().await.unwrap(),
1018        }
1019    }
1020
1021    pub async fn gen_filled_store(
1022        amount: u64,
1023        path: Option<&Path>,
1024    ) -> (RedbStore, ExtendedHeaderGenerator) {
1025        let s = create_store(path).await;
1026        let mut generator = ExtendedHeaderGenerator::new();
1027        let headers = generator.next_many(amount);
1028
1029        s.insert(headers).await.expect("inserting test data failed");
1030
1031        (s, generator)
1032    }
1033}