use crate::{
BlockData, ColdReceipt, ColdResult, ColdStorageBackend, ColdStorageError, Confirmed, Filter,
HeaderSpecifier, LogStream, ReceiptSpecifier, RpcLog, SignetEventsSpecifier, StreamParams,
TransactionSpecifier, ZenithHeaderSpecifier, cache::ColdCache, metrics,
};
use alloy::primitives::{B256, BlockNumber};
use parking_lot::Mutex;
use signet_storage_types::{DbSignetEvent, DbZenithHeader, RecoveredTx, SealedHeader};
use std::{
sync::{Arc, Weak},
time::Duration,
};
use tokio::{
sync::{Semaphore, mpsc},
time::Instant,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::{
sync::{CancellationToken, DropGuard},
task::TaskTracker,
};
use tracing::Instrument;
const DEFAULT_MAX_STREAM_DEADLINE: Duration = Duration::from_secs(60);
const DEFAULT_STREAM_SETUP_TIMEOUT: Duration = Duration::from_millis(500);
fn warn_on_write_overrun(
op: &'static str,
elapsed: Duration,
threshold: Option<Duration>,
is_ok: bool,
) {
let Some(threshold) = threshold else { return };
if is_ok && elapsed > threshold {
tracing::warn!(
op,
elapsed_ms = elapsed.as_millis() as u64,
threshold_ms = threshold.as_millis() as u64,
"cold write exceeded end-to-end write timeout (queue + drain + commit)",
);
}
}
fn log_join_error(op: &'static str, e: &tokio::task::JoinError) {
if e.is_panic() {
tracing::error!(op, error = %e, "cold storage spawned task panicked");
} else if e.is_cancelled() {
tracing::debug!(op, "cold storage spawned task cancelled");
}
}
const MAX_CONCURRENT_READERS: usize = 64;
const MAX_CONCURRENT_WRITES: usize = 1;
const MAX_CONCURRENT_STREAMS: usize = 8;
const STREAM_CHANNEL_BUFFER: usize = 256;
pub(crate) struct Inner<B> {
pub(crate) backend: B,
pub(crate) cache: Mutex<ColdCache>,
pub(crate) max_stream_deadline: Duration,
pub(crate) read_sem: Arc<Semaphore>,
pub(crate) write_sem: Arc<Semaphore>,
pub(crate) stream_sem: Arc<Semaphore>,
pub(crate) tracker: TaskTracker,
_shutdown_guard: DropGuard,
}
pub struct ColdStorage<B: ColdStorageBackend> {
inner: Arc<Inner<B>>,
}
impl<B: ColdStorageBackend> Clone for ColdStorage<B> {
fn clone(&self) -> Self {
Self { inner: Arc::clone(&self.inner) }
}
}
impl<B: ColdStorageBackend> std::fmt::Debug for ColdStorage<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ColdStorage").finish_non_exhaustive()
}
}
impl<B: ColdStorageBackend> ColdStorage<B> {
pub fn new(backend: B, cancel: CancellationToken) -> Self {
let shutdown = cancel.child_token();
let shutdown_guard = shutdown.clone().drop_guard();
let inner = Arc::new(Inner {
backend,
cache: Mutex::new(ColdCache::new()),
max_stream_deadline: DEFAULT_MAX_STREAM_DEADLINE,
read_sem: Arc::new(Semaphore::new(MAX_CONCURRENT_READERS)),
write_sem: Arc::new(Semaphore::new(MAX_CONCURRENT_WRITES)),
stream_sem: Arc::new(Semaphore::new(MAX_CONCURRENT_STREAMS)),
tracker: TaskTracker::new(),
_shutdown_guard: shutdown_guard,
});
let weak: Weak<Inner<B>> = Arc::downgrade(&inner);
tokio::spawn(async move {
shutdown.cancelled().await;
let Some(inner) = weak.upgrade() else { return };
inner.read_sem.close();
inner.write_sem.close();
inner.stream_sem.close();
inner.tracker.close();
});
Self { inner }
}
pub async fn wait_shutdown(&self) {
self.inner.tracker.close();
self.inner.tracker.wait().await;
}
async fn spawn_read<T, F, Fut>(&self, op: &'static str, f: F) -> ColdResult<T>
where
T: Send + 'static,
F: FnOnce(Arc<Inner<B>>) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ColdResult<T>> + Send,
{
let wait = Instant::now();
let permit = self
.inner
.read_sem
.clone()
.acquire_owned()
.await
.map_err(|_| ColdStorageError::TaskTerminated)?;
metrics::record_permit_wait("read", wait.elapsed());
let inner = Arc::clone(&self.inner);
self.inner
.tracker
.spawn(
async move {
let _p = permit;
let _guard = metrics::InFlightGuard::new("read");
let start = Instant::now();
let result = f(inner).await;
metrics::record_op_duration(op, start.elapsed());
if let Err(ref e) = result {
metrics::record_op_error(op, e.kind());
}
result
}
.in_current_span(),
)
.await
.map_err(|e| {
log_join_error(op, &e);
ColdStorageError::TaskTerminated
})?
}
async fn spawn_write<T, F, Fut>(&self, op: &'static str, f: F) -> ColdResult<T>
where
T: Send + 'static,
F: FnOnce(Arc<Inner<B>>) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ColdResult<T>> + Send,
{
let e2e_start = Instant::now();
let threshold = self.inner.backend.write_timeout();
let write_permit = self
.inner
.write_sem
.clone()
.acquire_owned()
.await
.map_err(|_| ColdStorageError::TaskTerminated)?;
metrics::record_permit_wait("write", e2e_start.elapsed());
let drain_wait = Instant::now();
let drain = self
.inner
.read_sem
.clone()
.acquire_many_owned(MAX_CONCURRENT_READERS as u32)
.await
.map_err(|_| ColdStorageError::TaskTerminated)?;
metrics::record_permit_wait("drain", drain_wait.elapsed());
let inner = Arc::clone(&self.inner);
self.inner
.tracker
.spawn(
async move {
let _w = write_permit;
let _d = drain;
let _guard = metrics::InFlightGuard::new("write");
let start = Instant::now();
let result = f(inner).await;
metrics::record_op_duration(op, start.elapsed());
if let Err(ref e) = result {
metrics::record_op_error(op, e.kind());
}
warn_on_write_overrun(op, e2e_start.elapsed(), threshold, result.is_ok());
result
}
.in_current_span(),
)
.await
.map_err(|e| {
log_join_error(op, &e);
ColdStorageError::TaskTerminated
})?
}
#[tracing::instrument(skip(self, spec), fields(op = "get_header"))]
pub async fn get_header(&self, spec: HeaderSpecifier) -> ColdResult<Option<SealedHeader>> {
let op_start = Instant::now();
if let HeaderSpecifier::Number(n) = &spec
&& let Some(hit) = self.inner.cache.lock().get_header(n)
{
metrics::record_op_duration("get_header", op_start.elapsed());
return Ok(Some(hit));
}
self.spawn_read("get_header", move |inner| async move {
let result = inner.backend.get_header(spec).await;
if let Ok(Some(ref h)) = result {
inner.cache.lock().put_header(h.number, h.clone());
}
result
})
.await
}
pub async fn get_header_by_number(
&self,
block: BlockNumber,
) -> ColdResult<Option<SealedHeader>> {
self.get_header(HeaderSpecifier::Number(block)).await
}
pub async fn get_header_by_hash(&self, hash: B256) -> ColdResult<Option<SealedHeader>> {
self.get_header(HeaderSpecifier::Hash(hash)).await
}
#[tracing::instrument(skip(self, specs), fields(op = "get_headers"))]
pub async fn get_headers(
&self,
specs: Vec<HeaderSpecifier>,
) -> ColdResult<Vec<Option<SealedHeader>>> {
self.spawn_read("get_headers", move |inner| async move {
inner.backend.get_headers(specs).await
})
.await
}
#[tracing::instrument(skip(self, spec), fields(op = "get_transaction"))]
pub async fn get_transaction(
&self,
spec: TransactionSpecifier,
) -> ColdResult<Option<Confirmed<RecoveredTx>>> {
let op_start = Instant::now();
if let TransactionSpecifier::BlockAndIndex { block, index } = &spec
&& let Some(hit) = self.inner.cache.lock().get_tx(&(*block, *index))
{
metrics::record_op_duration("get_transaction", op_start.elapsed());
return Ok(Some(hit));
}
self.spawn_read("get_transaction", move |inner| async move {
let result = inner.backend.get_transaction(spec).await;
if let Ok(Some(ref c)) = result {
let meta = c.meta();
inner
.cache
.lock()
.put_tx((meta.block_number(), meta.transaction_index()), c.clone());
}
result
})
.await
}
pub async fn get_tx_by_hash(&self, hash: B256) -> ColdResult<Option<Confirmed<RecoveredTx>>> {
self.get_transaction(TransactionSpecifier::Hash(hash)).await
}
pub async fn get_tx_by_block_and_index(
&self,
block: BlockNumber,
index: u64,
) -> ColdResult<Option<Confirmed<RecoveredTx>>> {
self.get_transaction(TransactionSpecifier::BlockAndIndex { block, index }).await
}
pub async fn get_tx_by_block_hash_and_index(
&self,
block_hash: B256,
index: u64,
) -> ColdResult<Option<Confirmed<RecoveredTx>>> {
self.get_transaction(TransactionSpecifier::BlockHashAndIndex { block_hash, index }).await
}
#[tracing::instrument(skip(self), fields(op = "get_transactions_in_block"))]
pub async fn get_transactions_in_block(
&self,
block: BlockNumber,
) -> ColdResult<Vec<RecoveredTx>> {
self.spawn_read("get_transactions_in_block", move |inner| async move {
inner.backend.get_transactions_in_block(block).await
})
.await
}
#[tracing::instrument(skip(self), fields(op = "get_transaction_count"))]
pub async fn get_transaction_count(&self, block: BlockNumber) -> ColdResult<u64> {
self.spawn_read("get_transaction_count", move |inner| async move {
inner.backend.get_transaction_count(block).await
})
.await
}
#[tracing::instrument(skip(self, spec), fields(op = "get_receipt"))]
pub async fn get_receipt(&self, spec: ReceiptSpecifier) -> ColdResult<Option<ColdReceipt>> {
let op_start = Instant::now();
if let ReceiptSpecifier::BlockAndIndex { block, index } = &spec
&& let Some(hit) = self.inner.cache.lock().get_receipt(&(*block, *index))
{
metrics::record_op_duration("get_receipt", op_start.elapsed());
return Ok(Some(hit));
}
self.spawn_read("get_receipt", move |inner| async move {
let result = inner.backend.get_receipt(spec).await;
if let Ok(Some(ref c)) = result {
inner.cache.lock().put_receipt((c.block_number, c.transaction_index), c.clone());
}
result
})
.await
}
pub async fn get_receipt_by_tx_hash(&self, hash: B256) -> ColdResult<Option<ColdReceipt>> {
self.get_receipt(ReceiptSpecifier::TxHash(hash)).await
}
pub async fn get_receipt_by_block_and_index(
&self,
block: BlockNumber,
index: u64,
) -> ColdResult<Option<ColdReceipt>> {
self.get_receipt(ReceiptSpecifier::BlockAndIndex { block, index }).await
}
#[tracing::instrument(skip(self), fields(op = "get_receipts_in_block"))]
pub async fn get_receipts_in_block(&self, block: BlockNumber) -> ColdResult<Vec<ColdReceipt>> {
self.spawn_read("get_receipts_in_block", move |inner| async move {
inner.backend.get_receipts_in_block(block).await
})
.await
}
#[tracing::instrument(skip(self, spec), fields(op = "get_signet_events"))]
pub async fn get_signet_events(
&self,
spec: SignetEventsSpecifier,
) -> ColdResult<Vec<DbSignetEvent>> {
self.spawn_read("get_signet_events", move |inner| async move {
inner.backend.get_signet_events(spec).await
})
.await
}
pub async fn get_signet_events_in_block(
&self,
block: BlockNumber,
) -> ColdResult<Vec<DbSignetEvent>> {
self.get_signet_events(SignetEventsSpecifier::Block(block)).await
}
pub async fn get_signet_events_in_range(
&self,
start: BlockNumber,
end: BlockNumber,
) -> ColdResult<Vec<DbSignetEvent>> {
self.get_signet_events(SignetEventsSpecifier::BlockRange { start, end }).await
}
pub async fn get_zenith_header(
&self,
block: BlockNumber,
) -> ColdResult<Option<DbZenithHeader>> {
self.get_zenith_header_by_spec(ZenithHeaderSpecifier::Number(block)).await
}
#[tracing::instrument(skip(self, spec), fields(op = "get_zenith_header_by_spec"))]
async fn get_zenith_header_by_spec(
&self,
spec: ZenithHeaderSpecifier,
) -> ColdResult<Option<DbZenithHeader>> {
self.spawn_read("get_zenith_header_by_spec", move |inner| async move {
inner.backend.get_zenith_header(spec).await
})
.await
}
#[tracing::instrument(skip(self, spec), fields(op = "get_zenith_headers"))]
pub async fn get_zenith_headers(
&self,
spec: ZenithHeaderSpecifier,
) -> ColdResult<Vec<DbZenithHeader>> {
self.spawn_read("get_zenith_headers", move |inner| async move {
inner.backend.get_zenith_headers(spec).await
})
.await
}
pub async fn get_zenith_headers_in_range(
&self,
start: BlockNumber,
end: BlockNumber,
) -> ColdResult<Vec<DbZenithHeader>> {
self.get_zenith_headers(ZenithHeaderSpecifier::Range { start, end }).await
}
#[tracing::instrument(skip(self, filter), fields(op = "get_logs"))]
pub async fn get_logs(&self, filter: Filter, max_logs: usize) -> ColdResult<Vec<RpcLog>> {
self.spawn_read("get_logs", move |inner| async move {
inner.backend.get_logs(&filter, max_logs).await
})
.await
}
#[tracing::instrument(skip(self, filter), fields(op = "stream_logs"))]
pub async fn stream_logs(
&self,
filter: Filter,
max_logs: usize,
deadline: Duration,
) -> ColdResult<LogStream> {
let from = filter.get_from_block().unwrap_or(0);
let to = match filter.get_to_block() {
Some(to) => to,
None => {
let setup_to =
self.inner.backend.read_timeout().unwrap_or(DEFAULT_STREAM_SETUP_TIMEOUT);
let latest = tokio::time::timeout(setup_to, self.inner.backend.get_latest_block())
.await
.map_err(|_| ColdStorageError::DeadlineExceeded(setup_to))??;
match latest {
Some(latest) => latest,
None => {
let (_tx, rx) = mpsc::channel(1);
return Ok(ReceiverStream::new(rx));
}
}
}
};
let wait = Instant::now();
let permit = self
.inner
.stream_sem
.clone()
.acquire_owned()
.await
.map_err(|_| ColdStorageError::TaskTerminated)?;
metrics::record_permit_wait("stream", wait.elapsed());
let effective = deadline.min(self.inner.max_stream_deadline);
let deadline_instant = Instant::now() + effective;
let (sender, rx) = mpsc::channel(STREAM_CHANNEL_BUFFER);
let inner = Arc::clone(&self.inner);
let started = Instant::now();
self.inner.tracker.spawn(
async move {
let _p = permit;
let _guard = metrics::InFlightGuard::new("stream");
let params =
StreamParams { from, to, max_logs, sender, deadline: deadline_instant };
inner.backend.produce_log_stream(&filter, params).await;
metrics::record_stream_lifetime(started.elapsed());
}
.in_current_span(),
);
Ok(ReceiverStream::new(rx))
}
#[tracing::instrument(skip(self), fields(op = "get_latest_block"))]
pub async fn get_latest_block(&self) -> ColdResult<Option<BlockNumber>> {
self.spawn_read("get_latest_block", move |inner| async move {
inner.backend.get_latest_block().await
})
.await
}
#[tracing::instrument(skip(self, data), fields(op = "append_block"))]
pub async fn append_block(&self, data: BlockData) -> ColdResult<()> {
self.spawn_write("append_block", move |inner| async move {
inner.backend.append_block(data).await
})
.await
}
#[tracing::instrument(skip(self, data), fields(op = "append_blocks"))]
pub async fn append_blocks(&self, data: Vec<BlockData>) -> ColdResult<()> {
self.spawn_write("append_blocks", move |inner| async move {
inner.backend.append_blocks(data).await
})
.await
}
#[tracing::instrument(skip(self), fields(op = "truncate_above"))]
pub async fn truncate_above(&self, block: BlockNumber) -> ColdResult<()> {
self.spawn_write("truncate_above", move |inner| async move {
let result = inner.backend.truncate_above(block).await;
if result.is_ok() {
inner.cache.lock().invalidate_above(block);
}
result
})
.await
}
#[tracing::instrument(skip(self), fields(op = "drain_above"))]
pub async fn drain_above(&self, block: BlockNumber) -> ColdResult<Vec<Vec<ColdReceipt>>> {
self.spawn_write("drain_above", move |inner| async move {
let result = inner.backend.drain_above(block).await;
if result.is_ok() {
inner.cache.lock().invalidate_above(block);
}
result
})
.await
}
}