use crate::{
ForkRpcClient, StorageCache,
error::{RemoteStorageError, RpcClientError},
models::BlockRow,
};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use subxt::{Metadata, config::substrate::H256, ext::codec::Encode};
const DEFAULT_PREFETCH_PAGE_SIZE: u32 = 1000;
const MIN_STORAGE_KEY_PREFIX_LEN: usize = 32;
#[derive(Debug, Default)]
pub struct StorageStats {
pub cache_hits: AtomicUsize,
pub prefetch_hits: AtomicUsize,
pub rpc_misses: AtomicUsize,
pub next_key_cache: AtomicUsize,
pub next_key_rpc: AtomicUsize,
}
#[derive(Debug, Clone, Default)]
pub struct StorageStatsSnapshot {
pub cache_hits: usize,
pub prefetch_hits: usize,
pub rpc_misses: usize,
pub next_key_cache: usize,
pub next_key_rpc: usize,
}
impl std::fmt::Display for StorageStatsSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let total_get = self.cache_hits + self.prefetch_hits + self.rpc_misses;
let total_next = self.next_key_cache + self.next_key_rpc;
write!(
f,
"get: {} total ({} cache, {} prefetch, {} rpc) | next_key: {} total ({} cache, {} rpc)",
total_get,
self.cache_hits,
self.prefetch_hits,
self.rpc_misses,
total_next,
self.next_key_cache,
self.next_key_rpc,
)
}
}
#[derive(Clone, Debug)]
pub struct RemoteStorageLayer {
rpc: ForkRpcClient,
cache: StorageCache,
stats: Arc<StorageStats>,
}
impl RemoteStorageLayer {
pub fn new(rpc: ForkRpcClient, cache: StorageCache) -> Self {
Self { rpc, cache, stats: Arc::new(StorageStats::default()) }
}
pub fn rpc(&self) -> &ForkRpcClient {
&self.rpc
}
pub fn cache(&self) -> &StorageCache {
&self.cache
}
pub fn endpoint(&self) -> &url::Url {
self.rpc.endpoint()
}
pub fn stats(&self) -> StorageStatsSnapshot {
StorageStatsSnapshot {
cache_hits: self.stats.cache_hits.load(Ordering::Relaxed),
prefetch_hits: self.stats.prefetch_hits.load(Ordering::Relaxed),
rpc_misses: self.stats.rpc_misses.load(Ordering::Relaxed),
next_key_cache: self.stats.next_key_cache.load(Ordering::Relaxed),
next_key_rpc: self.stats.next_key_rpc.load(Ordering::Relaxed),
}
}
pub fn reset_stats(&self) {
self.stats.cache_hits.store(0, Ordering::Relaxed);
self.stats.prefetch_hits.store(0, Ordering::Relaxed);
self.stats.rpc_misses.store(0, Ordering::Relaxed);
self.stats.next_key_cache.store(0, Ordering::Relaxed);
self.stats.next_key_rpc.store(0, Ordering::Relaxed);
}
pub async fn get(
&self,
block_hash: H256,
key: &[u8],
) -> Result<Option<Vec<u8>>, RemoteStorageError> {
if let Some(cached) = self.cache.get_storage(block_hash, key).await? {
self.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
return Ok(cached);
}
if key.len() >= MIN_STORAGE_KEY_PREFIX_LEN {
let prefix = &key[..MIN_STORAGE_KEY_PREFIX_LEN];
let progress = self.cache.get_prefix_scan_progress(block_hash, prefix).await?;
if progress.is_none() {
match self
.prefetch_prefix_single_page(block_hash, prefix, DEFAULT_PREFETCH_PAGE_SIZE)
.await
{
Ok(_) => {
if let Some(cached) = self.cache.get_storage(block_hash, key).await? {
self.stats.prefetch_hits.fetch_add(1, Ordering::Relaxed);
return Ok(cached);
}
},
Err(e) => {
log::debug!(
"Speculative prefetch failed (non-fatal), falling through to individual fetch: {e}"
);
},
}
}
}
self.stats.rpc_misses.fetch_add(1, Ordering::Relaxed);
let value = match self.rpc.storage(key, block_hash).await {
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc.storage(key, block_hash).await?
},
};
self.cache.set_storage(block_hash, key, value.as_deref()).await?;
Ok(value)
}
pub async fn get_batch(
&self,
block_hash: H256,
keys: &[&[u8]],
) -> Result<Vec<Option<Vec<u8>>>, RemoteStorageError> {
if keys.is_empty() {
return Ok(vec![]);
}
let cached_results = self.cache.get_storage_batch(block_hash, keys).await?;
let mut uncached_indices: Vec<usize> = Vec::new();
let mut uncached_keys: Vec<&[u8]> = Vec::new();
for (i, cached) in cached_results.iter().enumerate() {
if cached.is_none() {
uncached_indices.push(i);
uncached_keys.push(keys[i]);
}
}
if uncached_keys.is_empty() {
return Ok(cached_results.into_iter().map(|c| c.flatten()).collect());
}
let fetched_values = match self.rpc.storage_batch(&uncached_keys, block_hash).await {
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc.storage_batch(&uncached_keys, block_hash).await?
},
};
let cache_entries: Vec<(&[u8], Option<&[u8]>)> = uncached_keys
.iter()
.zip(fetched_values.iter())
.map(|(k, v)| (*k, v.as_deref()))
.collect();
if !cache_entries.is_empty() {
self.cache.set_storage_batch(block_hash, &cache_entries).await?;
}
let mut results: Vec<Option<Vec<u8>>> =
cached_results.into_iter().map(|c| c.flatten()).collect();
for (i, idx) in uncached_indices.into_iter().enumerate() {
results[idx] = fetched_values[i].clone();
}
Ok(results)
}
pub async fn prefetch_prefix(
&self,
block_hash: H256,
prefix: &[u8],
page_size: u32,
) -> Result<usize, RemoteStorageError> {
let progress = self.cache.get_prefix_scan_progress(block_hash, prefix).await?;
if let Some(ref p) = progress &&
p.is_complete
{
return Ok(self.cache.count_keys_by_prefix(block_hash, prefix).await?);
}
let mut start_key = progress.and_then(|p| p.last_scanned_key);
loop {
let keys = match self
.rpc
.storage_keys_paged(prefix, page_size, start_key.as_deref(), block_hash)
.await
{
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc
.storage_keys_paged(prefix, page_size, start_key.as_deref(), block_hash)
.await?
},
};
if keys.is_empty() {
if start_key.is_none() {
self.cache.update_prefix_scan(block_hash, prefix, prefix, true).await?;
}
break;
}
let is_last_page = keys.len() < page_size as usize;
let key_refs: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
let values = match self.rpc.storage_batch(&key_refs, block_hash).await {
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc.storage_batch(&key_refs, block_hash).await?
},
};
let cache_entries: Vec<(&[u8], Option<&[u8]>)> =
key_refs.iter().zip(values.iter()).map(|(k, v)| (*k, v.as_deref())).collect();
self.cache.set_storage_batch(block_hash, &cache_entries).await?;
let last_key = keys.into_iter().last();
if let Some(ref key) = last_key {
self.cache.update_prefix_scan(block_hash, prefix, key, is_last_page).await?;
}
if is_last_page {
break;
}
start_key = last_key;
}
Ok(self.cache.count_keys_by_prefix(block_hash, prefix).await?)
}
pub async fn prefetch_prefix_single_page(
&self,
block_hash: H256,
prefix: &[u8],
page_size: u32,
) -> Result<usize, RemoteStorageError> {
let progress = self.cache.get_prefix_scan_progress(block_hash, prefix).await?;
if let Some(ref p) = progress {
if p.is_complete {
return Ok(self.cache.count_keys_by_prefix(block_hash, prefix).await?);
}
return Ok(0);
}
let keys = match self.rpc.storage_keys_paged(prefix, page_size, None, block_hash).await {
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc.storage_keys_paged(prefix, page_size, None, block_hash).await?
},
};
if keys.is_empty() {
self.cache.update_prefix_scan(block_hash, prefix, prefix, true).await?;
return Ok(0);
}
let is_last_page = keys.len() < page_size as usize;
let key_refs: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
let values = match self.rpc.storage_batch(&key_refs, block_hash).await {
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc.storage_batch(&key_refs, block_hash).await?
},
};
let cache_entries: Vec<(&[u8], Option<&[u8]>)> =
key_refs.iter().zip(values.iter()).map(|(k, v)| (*k, v.as_deref())).collect();
self.cache.set_storage_batch(block_hash, &cache_entries).await?;
let count = keys.len();
if let Some(last_key) = keys.into_iter().last() {
self.cache
.update_prefix_scan(block_hash, prefix, &last_key, is_last_page)
.await?;
}
Ok(count)
}
pub async fn get_keys(
&self,
block_hash: H256,
prefix: &[u8],
) -> Result<Vec<Vec<u8>>, RemoteStorageError> {
self.prefetch_prefix(block_hash, prefix, DEFAULT_PREFETCH_PAGE_SIZE).await?;
Ok(self.cache.get_keys_by_prefix(block_hash, prefix).await?)
}
pub async fn fetch_and_cache_block_by_number(
&self,
block_number: u32,
) -> Result<Option<BlockRow>, RemoteStorageError> {
let (block_hash, block) = match self.rpc.block_by_number(block_number).await? {
Some((hash, block)) => (hash, block),
None => return Ok(None),
};
let header = block.header;
let parent_hash = header.parent_hash;
let header_encoded = header.encode();
self.cache
.cache_block(block_hash, block_number, parent_hash, &header_encoded)
.await?;
Ok(Some(BlockRow {
hash: block_hash.as_bytes().to_vec(),
number: block_number as i64,
parent_hash: parent_hash.as_bytes().to_vec(),
header: header_encoded,
}))
}
pub async fn next_key(
&self,
block_hash: H256,
prefix: &[u8],
key: &[u8],
) -> Result<Option<Vec<u8>>, RemoteStorageError> {
let candidate_lengths: &[usize] = &[prefix.len(), 32, 16];
for &len in candidate_lengths {
if len > prefix.len() {
continue;
}
let candidate = &prefix[..len];
if let Some(progress) =
self.cache.get_prefix_scan_progress(block_hash, candidate).await? &&
progress.is_complete
{
self.stats.next_key_cache.fetch_add(1, Ordering::Relaxed);
return Ok(self.cache.next_key_from_cache(block_hash, prefix, key).await?);
}
}
self.stats.next_key_rpc.fetch_add(1, Ordering::Relaxed);
let keys = match self.rpc.storage_keys_paged(prefix, 1, Some(key), block_hash).await {
Ok(v) => v,
Err(_) => {
self.rpc.reconnect().await?;
self.rpc.storage_keys_paged(prefix, 1, Some(key), block_hash).await?
},
};
Ok(keys.into_iter().next())
}
pub async fn block_body(&self, hash: H256) -> Result<Option<Vec<Vec<u8>>>, RemoteStorageError> {
match self.rpc.block_by_hash(hash).await? {
Some(block) => {
let extrinsics = block.extrinsics.into_iter().map(|ext| ext.0.to_vec()).collect();
Ok(Some(extrinsics))
},
None => Ok(None),
}
}
pub async fn block_header(&self, hash: H256) -> Result<Option<Vec<u8>>, RemoteStorageError> {
match self.rpc.header(hash).await {
Ok(header) => Ok(Some(header.encode())),
Err(RpcClientError::InvalidResponse(_)) => Ok(None),
Err(e) => Err(e.into()),
}
}
pub async fn block_hash_by_number(
&self,
block_number: u32,
) -> Result<Option<H256>, RemoteStorageError> {
Ok(self.rpc.block_hash_at(block_number).await?)
}
pub async fn block_number_by_hash(
&self,
hash: H256,
) -> Result<Option<u32>, RemoteStorageError> {
if let Some(block) = self.cache.get_block(hash).await? {
return Ok(Some(block.number as u32));
}
match self.rpc.block_by_hash(hash).await? {
Some(block) => {
let number = block.header.number;
let parent_hash = block.header.parent_hash;
let header_encoded = block.header.encode();
self.cache.cache_block(hash, number, parent_hash, &header_encoded).await?;
Ok(Some(number))
},
None => Ok(None),
}
}
pub async fn parent_hash(&self, hash: H256) -> Result<Option<H256>, RemoteStorageError> {
if let Some(block) = self.cache.get_block(hash).await? {
let parent_hash = H256::from_slice(&block.parent_hash);
return Ok(Some(parent_hash));
}
match self.rpc.block_by_hash(hash).await? {
Some(block) => {
let number = block.header.number;
let parent_hash = block.header.parent_hash;
let header_encoded = block.header.encode();
self.cache.cache_block(hash, number, parent_hash, &header_encoded).await?;
Ok(Some(parent_hash))
},
None => Ok(None),
}
}
pub async fn block_by_number(
&self,
block_number: u32,
) -> Result<
Option<(H256, subxt::backend::legacy::rpc_methods::Block<subxt::SubstrateConfig>)>,
RemoteStorageError,
> {
Ok(self.rpc.block_by_number(block_number).await?)
}
pub async fn finalized_head(&self) -> Result<H256, RemoteStorageError> {
Ok(self.rpc.finalized_head().await?)
}
pub async fn metadata(&self, block_hash: H256) -> Result<Metadata, RemoteStorageError> {
Ok(self.rpc.metadata(block_hash).await?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_display_rpc() {
use crate::error::RpcClientError;
let inner = RpcClientError::InvalidResponse("test".to_string());
let err = RemoteStorageError::Rpc(inner);
assert!(err.to_string().contains("RPC error"));
}
#[test]
fn error_display_cache() {
use crate::error::CacheError;
let inner = CacheError::DataCorruption("test".to_string());
let err = RemoteStorageError::Cache(inner);
assert!(err.to_string().contains("Cache error"));
}
mod sequential {
use crate::testing::{
TestContext,
constants::{SYSTEM_NUMBER_KEY, SYSTEM_PALLET_PREFIX, SYSTEM_PARENT_HASH_KEY},
};
use std::time::Duration;
#[tokio::test(flavor = "multi_thread")]
async fn get_fetches_and_caches() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
let key = hex::decode(SYSTEM_NUMBER_KEY).unwrap();
let value1 = layer.get(block_hash, &key).await.unwrap();
assert!(value1.is_some(), "System::Number should exist");
let cached = layer.cache().get_storage(block_hash, &key).await.unwrap();
assert!(cached.is_some(), "Value should be cached after first get");
assert_eq!(cached.unwrap(), value1);
let value2 = layer.get(block_hash, &key).await.unwrap();
assert_eq!(value1, value2);
}
#[tokio::test(flavor = "multi_thread")]
async fn get_caches_empty_values() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
let nonexistent_key = b"this_key_definitely_does_not_exist_12345";
let value = layer.get(block_hash, nonexistent_key).await.unwrap();
assert!(value.is_none(), "Nonexistent key should return None");
let cached = layer.cache().get_storage(block_hash, nonexistent_key).await.unwrap();
assert_eq!(cached, Some(None), "Empty value should be cached as Some(None)");
}
#[tokio::test(flavor = "multi_thread")]
async fn get_batch_fetches_mixed() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
let key1 = hex::decode(SYSTEM_NUMBER_KEY).unwrap();
let key2 = hex::decode(SYSTEM_PARENT_HASH_KEY).unwrap();
let key3 = b"nonexistent_key".to_vec();
let keys: Vec<&[u8]> = vec![key1.as_slice(), key2.as_slice(), key3.as_slice()];
let results = layer.get_batch(block_hash, &keys).await.unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].is_some(), "System::Number should exist");
assert!(results[1].is_some(), "System::ParentHash should exist");
assert!(results[2].is_none(), "Nonexistent key should be None");
for (i, key) in keys.iter().enumerate() {
let cached = layer.cache().get_storage(block_hash, key).await.unwrap();
assert!(cached.is_some(), "Key {} should be cached", i);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn get_batch_uses_cache() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
let key1 = hex::decode(SYSTEM_NUMBER_KEY).unwrap();
let key2 = hex::decode(SYSTEM_PARENT_HASH_KEY).unwrap();
let value1 = layer.get(block_hash, &key1).await.unwrap();
let keys: Vec<&[u8]> = vec![key1.as_slice(), key2.as_slice()];
let results = layer.get_batch(block_hash, &keys).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0], value1, "Cached value should match");
assert!(results[1].is_some(), "Uncached value should be fetched");
}
#[tokio::test(flavor = "multi_thread")]
async fn prefetch_prefix() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
let prefix = hex::decode(SYSTEM_PALLET_PREFIX).unwrap();
let count = layer.prefetch_prefix(block_hash, &prefix, 5).await.unwrap();
assert!(count > 0, "Should have prefetched some keys");
let key = hex::decode(SYSTEM_NUMBER_KEY).unwrap();
let cached = layer.cache().get_storage(block_hash, &key).await.unwrap();
assert!(cached.is_some(), "Prefetched key should be cached");
}
#[tokio::test(flavor = "multi_thread")]
async fn layer_is_cloneable() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
let layer2 = layer.clone();
let key = hex::decode(SYSTEM_NUMBER_KEY).unwrap();
let value1 = layer.get(block_hash, &key).await.unwrap();
let value2 = layer2.get(block_hash, &key).await.unwrap();
assert_eq!(value1, value2);
}
#[tokio::test(flavor = "multi_thread")]
async fn accessor_methods() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_hash = ctx.block_hash();
assert!(!block_hash.is_zero());
assert!(layer.rpc().endpoint().as_str().starts_with("ws://"));
}
#[tokio::test(flavor = "multi_thread")]
async fn fetch_and_cache_block_by_number_caches_block() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let finalized_hash = layer.rpc().finalized_head().await.unwrap();
let finalized_header = layer.rpc().header(finalized_hash).await.unwrap();
let finalized_number = finalized_header.number;
let cached = layer.cache().get_block_by_number(finalized_number).await.unwrap();
assert!(cached.is_none());
let result = layer.fetch_and_cache_block_by_number(finalized_number).await.unwrap();
assert!(result.is_some());
let block_row = result.unwrap();
assert_eq!(block_row.number, finalized_number as i64);
assert_eq!(block_row.hash.len(), 32);
assert_eq!(block_row.parent_hash.len(), 32);
assert!(!block_row.header.is_empty());
let cached = layer.cache().get_block_by_number(finalized_number).await.unwrap();
assert!(cached.is_some());
let cached_block = cached.unwrap();
assert_eq!(cached_block.number, block_row.number);
assert_eq!(cached_block.hash, block_row.hash);
assert_eq!(cached_block.parent_hash, block_row.parent_hash);
assert_eq!(cached_block.header, block_row.header);
}
#[tokio::test(flavor = "multi_thread")]
async fn fetch_and_cache_block_by_number_non_existent() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let non_existent_number = u32::MAX;
let result = layer.fetch_and_cache_block_by_number(non_existent_number).await.unwrap();
assert!(result.is_none(), "Non-existent block should return None");
let cached = layer.cache().get_block_by_number(non_existent_number).await.unwrap();
assert!(cached.is_none(), "Non-existent block should not be cached");
}
#[tokio::test(flavor = "multi_thread")]
async fn fetch_and_cache_block_by_number_multiple_blocks() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
std::thread::sleep(Duration::from_secs(30));
let finalized_hash = layer.rpc().finalized_head().await.unwrap();
let finalized_header = layer.rpc().header(finalized_hash).await.unwrap();
let finalized_number = finalized_header.number;
let max_blocks = finalized_number.min(3);
for block_num in 0..=max_blocks {
let result =
layer.fetch_and_cache_block_by_number(block_num).await.unwrap().unwrap();
assert_eq!(result.number, block_num as i64);
let cached = layer.cache().get_block_by_number(block_num).await.unwrap().unwrap();
assert_eq!(cached.number, result.number);
assert_eq!(cached.hash, result.hash);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn fetch_and_cache_block_by_number_idempotent() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
let block_number = 0u32;
let result1 =
layer.fetch_and_cache_block_by_number(block_number).await.unwrap().unwrap();
let result2 =
layer.fetch_and_cache_block_by_number(block_number).await.unwrap().unwrap();
assert_eq!(result1.number, result2.number);
assert_eq!(result1.hash, result2.hash);
assert_eq!(result1.parent_hash, result2.parent_hash);
assert_eq!(result1.header, result2.header);
}
#[tokio::test(flavor = "multi_thread")]
async fn fetch_and_cache_block_by_number_verifies_parent_chain() {
let ctx = TestContext::for_remote().await;
let layer = ctx.remote();
std::thread::sleep(Duration::from_secs(30));
let finalized_hash = layer.rpc().finalized_head().await.unwrap();
let finalized_header = layer.rpc().header(finalized_hash).await.unwrap();
let finalized_number = finalized_header.number;
let max_blocks = finalized_number.min(3);
let mut previous_hash: Option<Vec<u8>> = None;
for block_num in 0..=max_blocks {
let block_row =
layer.fetch_and_cache_block_by_number(block_num).await.unwrap().unwrap();
if let Some(prev_hash) = previous_hash {
assert_eq!(
block_row.parent_hash,
prev_hash,
"Block {} parent hash should match block {} hash",
block_num,
block_num - 1
);
}
previous_hash = Some(block_row.hash.clone());
}
}
}
}