use std::sync::atomic::{AtomicU64, Ordering};
use async_trait::async_trait;
use celestia_types::hash::Hash;
use celestia_types::ExtendedHeader;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use tracing::debug;
use crate::store::{Result, Store, StoreError};
#[derive(Debug)]
pub struct InMemoryStore {
headers: DashMap<Hash, ExtendedHeader>,
height_to_hash: DashMap<u64, Hash>,
head_height: AtomicU64,
}
impl InMemoryStore {
pub fn new() -> Self {
InMemoryStore {
headers: DashMap::new(),
height_to_hash: DashMap::new(),
head_height: AtomicU64::new(0),
}
}
#[inline]
fn get_head_height(&self) -> Result<u64> {
let height = self.head_height.load(Ordering::Acquire);
if height == 0 {
Err(StoreError::NotFound)
} else {
Ok(height)
}
}
pub(crate) fn append_single_unchecked(&self, header: ExtendedHeader) -> Result<()> {
let hash = header.hash();
let height = header.height().value();
let head_height = self.get_head_height().unwrap_or(0);
if head_height > 0 && height <= head_height {
return Err(StoreError::HeightExists(height));
}
if head_height + 1 != height {
return Err(StoreError::NonContinuousAppend(head_height, height));
}
let hash_entry = self.headers.entry(hash);
let height_entry = self.height_to_hash.entry(height);
if matches!(hash_entry, Entry::Occupied(_)) {
return Err(StoreError::HashExists(hash));
}
if matches!(height_entry, Entry::Occupied(_)) {
return Err(StoreError::HeightExists(height));
}
debug!("Inserting header {hash} with height {height}");
hash_entry.insert(header);
height_entry.insert(hash);
self.head_height.store(height, Ordering::Release);
Ok(())
}
fn get_head(&self) -> Result<ExtendedHeader> {
let head_height = self.get_head_height()?;
self.get_by_height(head_height)
}
fn contains_hash(&self, hash: &Hash) -> bool {
self.headers.contains_key(hash)
}
fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
self.headers
.get(hash)
.as_deref()
.cloned()
.ok_or(StoreError::NotFound)
}
fn contains_height(&self, height: u64) -> bool {
let Ok(head_height) = self.get_head_height() else {
return false;
};
height <= head_height
}
fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
if !self.contains_height(height) {
return Err(StoreError::NotFound);
}
let Some(hash) = self.height_to_hash.get(&height).as_deref().copied() else {
return Err(StoreError::LostHeight(height));
};
self.headers
.get(&hash)
.as_deref()
.cloned()
.ok_or(StoreError::LostHash(hash))
}
}
#[async_trait]
impl Store for InMemoryStore {
async fn get_head(&self) -> Result<ExtendedHeader> {
self.get_head()
}
async fn get_by_hash(&self, hash: &Hash) -> Result<ExtendedHeader> {
self.get_by_hash(hash)
}
async fn get_by_height(&self, height: u64) -> Result<ExtendedHeader> {
self.get_by_height(height)
}
async fn head_height(&self) -> Result<u64> {
self.get_head_height()
}
async fn has(&self, hash: &Hash) -> bool {
self.contains_hash(hash)
}
async fn has_at(&self, height: u64) -> bool {
self.contains_height(height)
}
async fn append_single_unchecked(&self, header: ExtendedHeader) -> Result<()> {
self.append_single_unchecked(header)
}
}
impl Default for InMemoryStore {
fn default() -> Self {
Self::new()
}
}
impl Clone for InMemoryStore {
fn clone(&self) -> Self {
InMemoryStore {
headers: self.headers.clone(),
height_to_hash: self.height_to_hash.clone(),
head_height: AtomicU64::new(self.head_height.load(Ordering::Acquire)),
}
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use celestia_types::test_utils::ExtendedHeaderGenerator;
use celestia_types::Height;
#[cfg(not(target_arch = "wasm32"))]
use tokio::test as async_test;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as async_test;
#[test]
fn test_empty_store() {
let s = InMemoryStore::new();
assert!(matches!(s.get_head_height(), Err(StoreError::NotFound)));
assert!(matches!(s.get_head(), Err(StoreError::NotFound)));
assert!(matches!(s.get_by_height(1), Err(StoreError::NotFound)));
assert!(matches!(
s.get_by_hash(&Hash::Sha256([0; 32])),
Err(StoreError::NotFound)
));
}
#[test]
fn test_read_write() {
let s = InMemoryStore::new();
let mut gen = ExtendedHeaderGenerator::new();
let header = gen.next();
s.append_single_unchecked(header.clone()).unwrap();
assert_eq!(s.get_head_height().unwrap(), 1);
assert_eq!(s.get_head().unwrap(), header);
assert_eq!(s.get_by_height(1).unwrap(), header);
assert_eq!(s.get_by_hash(&header.hash()).unwrap(), header);
}
#[test]
fn test_pregenerated_data() {
let (s, _) = gen_filled_store(100);
assert_eq!(s.get_head_height().unwrap(), 100);
let head = s.get_head().unwrap();
assert_eq!(s.get_by_height(100).unwrap(), head);
assert!(matches!(s.get_by_height(101), Err(StoreError::NotFound)));
let header = s.get_by_height(54).unwrap();
assert_eq!(s.get_by_hash(&header.hash()).unwrap(), header);
}
#[test]
fn test_duplicate_insert() {
let (s, mut gen) = gen_filled_store(100);
let header101 = gen.next();
s.append_single_unchecked(header101.clone()).unwrap();
assert!(matches!(
s.append_single_unchecked(header101.clone()),
Err(StoreError::HeightExists(101))
));
}
#[test]
fn test_overwrite_height() {
let (s, gen) = gen_filled_store(100);
let header29 = s.get_by_height(29).unwrap();
let header30 = gen.next_of(&header29);
let insert_existing_result = s.append_single_unchecked(header30);
assert!(matches!(
insert_existing_result,
Err(StoreError::HeightExists(30))
));
}
#[test]
fn test_overwrite_hash() {
let (s, _) = gen_filled_store(100);
let mut dup_header = s.get_by_height(33).unwrap();
dup_header.header.height = Height::from(101u32);
let insert_existing_result = s.append_single_unchecked(dup_header.clone());
assert!(matches!(
insert_existing_result,
Err(StoreError::HashExists(_))
));
}
#[async_test]
async fn test_append_range() {
let (s, mut gen) = gen_filled_store(10);
let hs = gen.next_many(4);
s.append_unchecked(hs).await.unwrap();
s.get_by_height(14).unwrap();
}
#[async_test]
async fn test_append_gap_between_head() {
let (s, mut gen) = gen_filled_store(10);
gen.next();
let upcoming_head = gen.next();
let insert_with_gap_result = s.append_single_unchecked(upcoming_head);
assert!(matches!(
insert_with_gap_result,
Err(StoreError::NonContinuousAppend(10, 12))
));
}
#[async_test]
async fn test_non_continuous_append() {
let (s, mut gen) = gen_filled_store(10);
let mut hs = gen.next_many(6);
hs.remove(3);
let insert_existing_result = s.append_unchecked(hs).await;
assert!(matches!(
insert_existing_result,
Err(StoreError::NonContinuousAppend(13, 15))
));
}
#[test]
fn test_genesis_with_height() {
let mut gen = ExtendedHeaderGenerator::new_from_height(5);
let header5 = gen.next();
let s = InMemoryStore::new();
assert!(matches!(
s.append_single_unchecked(header5),
Err(StoreError::NonContinuousAppend(0, 5))
));
}
pub fn gen_filled_store(amount: u64) -> (InMemoryStore, ExtendedHeaderGenerator) {
let s = InMemoryStore::new();
let mut gen = ExtendedHeaderGenerator::new();
let headers = gen.next_many(amount);
for header in headers {
s.append_single_unchecked(header)
.expect("inserting test data failed");
}
(s, gen)
}
}