use std::fmt::Debug;
use std::io;
use std::ops::{Bound, RangeBounds, RangeInclusive};
use async_trait::async_trait;
use celestia_types::hash::Hash;
use celestia_types::ExtendedHeader;
use thiserror::Error;
pub use in_memory_store::InMemoryStore;
#[cfg(target_arch = "wasm32")]
pub use indexed_db_store::IndexedDbStore;
#[cfg(not(target_arch = "wasm32"))]
pub use sled_store::SledStore;
mod in_memory_store;
#[cfg(target_arch = "wasm32")]
mod indexed_db_store;
#[cfg(not(target_arch = "wasm32"))]
mod sled_store;
use crate::utils::validate_headers;
type Result<T, E = StoreError> = std::result::Result<T, E>;
#[async_trait]
pub trait Store: Send + Sync + Debug {
async fn get_head(&self) -> Result<ExtendedHeader>;
async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader>;
async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader>;
async fn get_range<R>(&self, range: R) -> Result<Vec<ExtendedHeader>>
where
R: RangeBounds<u64> + Send,
{
let head_height = self.head_height().await?;
let range = to_headers_range(range, head_height)?;
let amount = if range.is_empty() {
0
} else {
range.end() - range.start() + 1 };
let mut headers = Vec::with_capacity(
amount
.try_into()
.map_err(|_| StoreError::InvalidHeadersRange)?,
);
for height in range {
let header = self.get_by_height(height).await?;
headers.push(header);
}
Ok(headers)
}
async fn head_height(&self) -> Result<u64>;
async fn has(&self, hash: &Hash) -> bool;
async fn has_at(&self, height: u64) -> bool;
async fn append_single_unchecked(&self, header: ExtendedHeader) -> Result<()>;
async fn append_unchecked(&self, headers: Vec<ExtendedHeader>) -> Result<()> {
for header in headers.into_iter() {
self.append_single_unchecked(header).await?;
}
Ok(())
}
async fn append_single(&self, header: ExtendedHeader) -> Result<()> {
header.validate()?;
match self.get_head().await {
Ok(head) => {
head.verify(&header)?;
}
Err(StoreError::NotFound) => {}
Err(e) => return Err(e),
}
self.append_single_unchecked(header).await
}
async fn append(&self, headers: Vec<ExtendedHeader>) -> Result<()> {
validate_headers(&headers).await?;
match self.get_head().await {
Ok(head) => {
head.verify_adjacent_range(&headers)?;
}
Err(StoreError::NotFound) => {}
Err(e) => return Err(e),
}
self.append_unchecked(headers).await
}
}
#[derive(Error, Debug)]
pub enum StoreError {
#[error("Hash {0} already exists in store")]
HashExists(Hash),
#[error("Height {0} already exists in store")]
HeightExists(u64),
#[error("Failed to append header at height {1}, current head {0}")]
NonContinuousAppend(u64, u64),
#[error("Failed to validate header at height {0}")]
HeaderChecksError(u64),
#[error("Header not found in store")]
NotFound,
#[error("Store in inconsistent state; height {0} within known range, but missing header")]
LostHeight(u64),
#[error("Store in inconsistent state; height->hash mapping exists, {0} missing")]
LostHash(Hash),
#[error(transparent)]
CelestiaTypes(#[from] celestia_types::Error),
#[error("Stored data in inconsistent state, try reseting the store: {0}")]
StoredDataError(String),
#[error("Persistent storage reported unrecoverable error: {0}")]
BackingStoreError(String),
#[error("Received error from executor: {0}")]
ExecutorError(String),
#[error("Received io error from persistent storage: {0}")]
IoError(#[from] io::Error),
#[error("Error opening store: {0}")]
OpenFailed(String),
#[error("Invalid headers range")]
InvalidHeadersRange,
}
fn to_headers_range(bounds: impl RangeBounds<u64>, last_index: u64) -> Result<RangeInclusive<u64>> {
let start = match bounds.start_bound() {
Bound::Unbounded => 1,
Bound::Included(&x) if x > last_index || x == 0 => return Err(StoreError::NotFound),
Bound::Excluded(&x) if x >= last_index => return Err(StoreError::NotFound),
Bound::Included(&x) => x,
Bound::Excluded(&x) => x + 1, };
let end = match bounds.end_bound() {
Bound::Unbounded => last_index,
Bound::Included(&x) if x > last_index => return Err(StoreError::NotFound),
Bound::Excluded(&x) if x > last_index + 1 => return Err(StoreError::NotFound),
Bound::Excluded(&0) => 0,
Bound::Included(&x) => x,
Bound::Excluded(&x) => x - 1,
};
Ok(start..=end)
}
#[cfg(test)]
mod tests {
use std::ops::Bound;
use super::to_headers_range;
#[test]
fn converts_bounded_ranges() {
assert_eq!(1..=15, to_headers_range(1..16, 100).unwrap());
assert_eq!(1..=15, to_headers_range(1..=15, 100).unwrap());
assert_eq!(300..=400, to_headers_range(300..401, 500).unwrap());
assert_eq!(300..=400, to_headers_range(300..=400, 500).unwrap());
}
#[test]
fn starts_from_one_when_unbounded_start() {
assert_eq!(&1, to_headers_range(..=10, 100).unwrap().start());
assert_eq!(&1, to_headers_range(..10, 100).unwrap().start());
assert_eq!(&1, to_headers_range(.., 100).unwrap().start());
}
#[test]
fn ends_on_last_index_when_unbounded_end() {
assert_eq!(&10, to_headers_range(1.., 10).unwrap().end());
assert_eq!(&11, to_headers_range(1.., 11).unwrap().end());
assert_eq!(&10, to_headers_range(.., 10).unwrap().end());
}
#[test]
fn handle_ranges_ending_precisely_at_last_index() {
let last_index = 10;
let bounds_ending_at_last_index = [
(Bound::Unbounded, Bound::Included(last_index)),
(Bound::Unbounded, Bound::Excluded(last_index + 1)),
];
for bound in bounds_ending_at_last_index {
let range = to_headers_range(bound, last_index).unwrap();
assert_eq!(*range.end(), last_index);
}
}
#[test]
fn handle_ranges_ending_after_last_index() {
let last_index = 10;
let bounds_ending_after_last_index = [
(Bound::Unbounded, Bound::Included(last_index + 1)),
(Bound::Unbounded, Bound::Excluded(last_index + 2)),
];
for bound in bounds_ending_after_last_index {
to_headers_range(bound, last_index).unwrap_err();
}
}
#[test]
fn errors_if_zero_heigth_is_included() {
let includes_zero_height = 0..5;
to_headers_range(includes_zero_height, 10).unwrap_err();
}
#[test]
fn handle_ranges_starting_precisely_at_last_index() {
let last_index = 10;
let bounds_starting_at_last_index = [
(Bound::Included(last_index), Bound::Unbounded),
(Bound::Excluded(last_index - 1), Bound::Unbounded),
];
for bound in bounds_starting_at_last_index {
let range = to_headers_range(bound, last_index).unwrap();
assert_eq!(*range.start(), last_index);
}
}
#[test]
fn handle_ranges_starting_after_last_index() {
let last_index = 10;
let bounds_starting_after_last_index = [
(Bound::Included(last_index + 1), Bound::Unbounded),
(Bound::Excluded(last_index), Bound::Unbounded),
];
for bound in bounds_starting_after_last_index {
to_headers_range(bound, last_index).unwrap_err();
}
}
#[test]
fn handle_ranges_that_lead_to_empty_ranges() {
let last_index = 10;
let bounds_leading_to_empty_range = [
(Bound::Unbounded, Bound::Excluded(0)),
(Bound::Included(3), Bound::Excluded(3)),
(Bound::Included(3), Bound::Included(2)),
(Bound::Excluded(2), Bound::Included(2)),
];
for bound in bounds_leading_to_empty_range {
assert!(to_headers_range(bound, last_index).unwrap().is_empty());
}
}
}