use std::{
mem,
sync::{Arc, RwLock},
time::Duration,
};
use anyhow::Context as _;
use backon::{BlockingRetryable, ConstantBuilder};
use tokio::{
runtime::Handle,
sync::{
mpsc::{self, UnboundedReceiver},
watch,
},
};
use zksync_dal::{Connection, ConnectionPool, Core, CoreDal};
use zksync_types::{L1BatchNumber, L2BlockNumber, StorageKey, StorageValue, H256};
use self::metrics::{Method, ValuesUpdateStage, CACHE_METRICS, STORAGE_METRICS};
use crate::{
cache::{lru_cache::LruCache, CacheValue},
ReadStorage,
};
mod metrics;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, PartialEq, Eq)]
struct TimestampedFactoryDep {
bytecode: Vec<u8>,
inserted_at: L2BlockNumber,
}
type FactoryDepsCache = LruCache<H256, TimestampedFactoryDep>;
impl CacheValue<H256> for TimestampedFactoryDep {
fn cache_weight(&self) -> u32 {
(self.bytecode.len() + mem::size_of::<L2BlockNumber>())
.try_into()
.expect("Cached bytes are too large")
}
}
type InitialWritesCache = LruCache<H256, L1BatchNumber>;
impl CacheValue<H256> for L1BatchNumber {
#[allow(clippy::cast_possible_truncation)] fn cache_weight(&self) -> u32 {
const WEIGHT: usize = mem::size_of::<L1BatchNumber>() + mem::size_of::<H256>();
WEIGHT as u32
}
}
#[derive(Debug, Clone, Copy)]
struct TimestampedStorageValue {
value: StorageValue,
loaded_at: L2BlockNumber,
}
impl CacheValue<H256> for TimestampedStorageValue {
#[allow(clippy::cast_possible_truncation)] fn cache_weight(&self) -> u32 {
const WEIGHT: usize = mem::size_of::<TimestampedStorageValue>() + mem::size_of::<H256>();
WEIGHT as u32
}
}
#[derive(Debug)]
struct ValuesCacheInner {
valid_for: L2BlockNumber,
values: LruCache<H256, TimestampedStorageValue>,
}
#[derive(Debug, Clone)]
struct ValuesCache(Arc<RwLock<ValuesCacheInner>>);
impl ValuesCache {
fn new(capacity: u64) -> Self {
let inner = ValuesCacheInner {
valid_for: L2BlockNumber(0),
values: LruCache::new("values_cache", capacity),
};
Self(Arc::new(RwLock::new(inner)))
}
fn valid_for(&self) -> L2BlockNumber {
self.0.read().expect("values cache is poisoned").valid_for
}
fn get(&self, l2_block_number: L2BlockNumber, hashed_key: H256) -> Option<StorageValue> {
let lock = self.0.read().expect("values cache is poisoned");
if lock.valid_for < l2_block_number {
return None;
}
let timestamped_value = lock.values.get(&hashed_key)?;
if timestamped_value.loaded_at <= l2_block_number {
Some(timestamped_value.value)
} else {
None }
}
fn insert(&self, l2_block_number: L2BlockNumber, hashed_key: H256, value: StorageValue) {
let lock = self.0.read().expect("values cache is poisoned");
if lock.valid_for == l2_block_number {
lock.values.insert(
hashed_key,
TimestampedStorageValue {
value,
loaded_at: l2_block_number,
},
);
} else {
CACHE_METRICS.stale_values.inc();
}
}
async fn update(
&self,
from_l2_block: L2BlockNumber,
to_l2_block: L2BlockNumber,
connection: &mut Connection<'_, Core>,
) -> anyhow::Result<()> {
const MAX_L2_BLOCKS_LAG: u32 = 5;
tracing::debug!(
"Updating storage values cache from L2 block {from_l2_block} to {to_l2_block}"
);
if to_l2_block.0 - from_l2_block.0 > MAX_L2_BLOCKS_LAG {
tracing::info!(
"Storage values cache is too far behind (current L2 block is {from_l2_block}; \
requested update to {to_l2_block}); resetting the cache"
);
let mut lock = self
.0
.write()
.map_err(|_| anyhow::anyhow!("values cache is poisoned"))?;
anyhow::ensure!(
lock.valid_for == from_l2_block,
"sanity check failed: values cache was expected to be valid for L2 block #{from_l2_block}, but it's actually \
valid for L2 block #{}",
lock.valid_for
);
lock.valid_for = to_l2_block;
lock.values.clear();
CACHE_METRICS.values_emptied.inc();
} else {
let update_latency = CACHE_METRICS.values_update[&ValuesUpdateStage::LoadKeys].start();
let l2_blocks = (from_l2_block + 1)..=to_l2_block;
let modified_keys = connection
.storage_logs_dal()
.modified_keys_in_l2_blocks(l2_blocks.clone())
.await?;
let elapsed = update_latency.observe();
CACHE_METRICS
.values_update_modified_keys
.observe(modified_keys.len());
tracing::debug!(
"Loaded {modified_keys_len} modified storage keys from L2 blocks {l2_blocks:?}; \
took {elapsed:?}",
modified_keys_len = modified_keys.len()
);
let update_latency =
CACHE_METRICS.values_update[&ValuesUpdateStage::RemoveStaleKeys].start();
let mut lock = self
.0
.write()
.map_err(|_| anyhow::anyhow!("values cache is poisoned"))?;
anyhow::ensure!(
lock.valid_for == from_l2_block,
"sanity check failed: values cache was expected to be valid for L2 block #{from_l2_block}, but it's actually \
valid for L2 block #{}",
lock.valid_for
);
lock.valid_for = to_l2_block;
for modified_key in &modified_keys {
lock.values.remove(modified_key);
}
lock.values.report_size();
drop(lock);
update_latency.observe();
}
CACHE_METRICS
.values_valid_for_miniblock
.set(u64::from(to_l2_block.0));
Ok(())
}
}
#[derive(Debug, Clone)]
struct ValuesCacheAndUpdater {
cache: ValuesCache,
command_sender: mpsc::UnboundedSender<L2BlockNumber>,
}
#[derive(Debug, Clone)]
pub struct PostgresStorageCaches {
factory_deps: FactoryDepsCache,
initial_writes: InitialWritesCache,
negative_initial_writes: InitialWritesCache,
values: Option<ValuesCacheAndUpdater>,
}
impl PostgresStorageCaches {
pub fn new(factory_deps_capacity: u64, initial_writes_capacity: u64) -> Self {
tracing::debug!(
"Initialized VM execution cache with {factory_deps_capacity}B capacity for factory deps, \
{initial_writes_capacity}B capacity for initial writes"
);
Self {
factory_deps: FactoryDepsCache::new("factory_deps_cache", factory_deps_capacity),
initial_writes: InitialWritesCache::new(
"initial_writes_cache",
initial_writes_capacity / 2,
),
negative_initial_writes: InitialWritesCache::new(
"negative_initial_writes_cache",
initial_writes_capacity / 2,
),
values: None,
}
}
pub fn configure_storage_values_cache(
&mut self,
capacity: u64,
connection_pool: ConnectionPool<Core>,
) -> PostgresStorageCachesTask {
assert!(
capacity > 0,
"Storage values cache capacity must be positive"
);
tracing::debug!("Initializing VM storage values cache with {capacity}B capacity");
let (command_sender, command_receiver) = mpsc::unbounded_channel();
let values_cache = ValuesCache::new(capacity);
self.values = Some(ValuesCacheAndUpdater {
cache: values_cache.clone(),
command_sender,
});
PostgresStorageCachesTask {
connection_pool,
values_cache,
command_receiver,
}
}
pub fn schedule_values_update(&self, to_l2_block: L2BlockNumber) {
let Some(values) = &self.values else {
return;
};
if values.cache.valid_for() < to_l2_block {
values.command_sender.send(to_l2_block).ok();
}
}
}
#[derive(Debug)]
pub struct PostgresStorageCachesTask {
connection_pool: ConnectionPool<Core>,
values_cache: ValuesCache,
command_receiver: UnboundedReceiver<L2BlockNumber>,
}
impl PostgresStorageCachesTask {
pub async fn run(mut self, mut stop_receiver: watch::Receiver<bool>) -> anyhow::Result<()> {
let mut current_l2_block = self.values_cache.valid_for();
loop {
tokio::select! {
_ = stop_receiver.changed() => {
break;
}
Some(to_l2_block) = self.command_receiver.recv() => {
if to_l2_block <= current_l2_block {
continue;
}
let mut connection = self
.connection_pool
.connection_tagged("values_cache_updater")
.await?;
self.values_cache
.update(current_l2_block, to_l2_block, &mut connection)
.await?;
current_l2_block = to_l2_block;
}
else => {
stop_receiver.changed().await?;
break;
}
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct PostgresStorage<'a> {
rt_handle: Handle,
connection: Connection<'a, Core>,
l2_block_number: L2BlockNumber,
l1_batch_number_for_l2_block: L1BatchNumber,
pending_l1_batch_number: L1BatchNumber,
consider_new_l1_batch: bool,
caches: Option<PostgresStorageCaches>,
}
impl<'a> PostgresStorage<'a> {
pub fn new(
rt_handle: Handle,
connection: Connection<'a, Core>,
block_number: L2BlockNumber,
consider_new_l1_batch: bool,
) -> Self {
rt_handle
.clone()
.block_on(Self::new_async(
rt_handle,
connection,
block_number,
consider_new_l1_batch,
))
.unwrap()
}
pub async fn new_async(
rt_handle: Handle,
mut connection: Connection<'a, Core>,
block_number: L2BlockNumber,
consider_new_l1_batch: bool,
) -> anyhow::Result<PostgresStorage<'a>> {
let resolved = connection
.storage_web3_dal()
.resolve_l1_batch_number_of_l2_block(block_number)
.await
.with_context(|| {
format!("failed resolving L1 batch number for L2 block #{block_number}")
})?;
Ok(Self {
rt_handle,
connection,
l2_block_number: block_number,
l1_batch_number_for_l2_block: resolved.expected_l1_batch(),
pending_l1_batch_number: resolved.pending_l1_batch,
consider_new_l1_batch,
caches: None,
})
}
#[must_use]
pub fn with_caches(self, caches: PostgresStorageCaches) -> Self {
Self {
caches: Some(caches),
..self
}
}
fn write_counts(&self, write_l1_batch_number: L1BatchNumber) -> bool {
if self.consider_new_l1_batch {
self.l1_batch_number_for_l2_block >= write_l1_batch_number
} else {
self.l1_batch_number_for_l2_block > write_l1_batch_number
}
}
fn values_cache(&self) -> Option<&ValuesCache> {
Some(&self.caches.as_ref()?.values.as_ref()?.cache)
}
}
impl ReadStorage for PostgresStorage<'_> {
fn read_value(&mut self, key: &StorageKey) -> StorageValue {
let hashed_key = key.hashed_key();
let latency = STORAGE_METRICS.storage[&Method::ReadValue].start();
let values_cache = self.values_cache();
let cached_value =
values_cache.and_then(|cache| cache.get(self.l2_block_number, hashed_key));
let value = cached_value.unwrap_or_else(|| {
const RETRY_INTERVAL: Duration = Duration::from_millis(500);
const MAX_TRIES: usize = 20;
let mut dal = self.connection.storage_web3_dal();
let value = (|| {
self.rt_handle
.block_on(dal.get_historical_value_unchecked(hashed_key, self.l2_block_number))
})
.retry(
&ConstantBuilder::default()
.with_delay(RETRY_INTERVAL)
.with_max_times(MAX_TRIES),
)
.when(|e| {
e.inner()
.as_database_error()
.is_some_and(|e| e.message() == "canceling statement due to statement timeout")
})
.call()
.expect("Failed executing `read_value`");
if let Some(cache) = self.values_cache() {
cache.insert(self.l2_block_number, hashed_key, value);
}
value
});
latency.observe();
value
}
fn is_write_initial(&mut self, key: &StorageKey) -> bool {
let hashed_key = key.hashed_key();
let latency = STORAGE_METRICS.storage[&Method::IsWriteInitial].start();
let caches = self.caches.as_ref();
let cached_value = caches.and_then(|caches| caches.initial_writes.get(&hashed_key));
if cached_value.is_none() {
let cached_value =
caches.and_then(|caches| caches.negative_initial_writes.get(&hashed_key));
if let Some(min_l1_batch_for_initial_write) = cached_value {
if !self.write_counts(min_l1_batch_for_initial_write) {
CACHE_METRICS.effective_values.inc();
return true;
}
}
}
let l1_batch_number = cached_value.or_else(|| {
let mut dal = self.connection.storage_web3_dal();
let value = self
.rt_handle
.block_on(dal.get_l1_batch_number_for_initial_write(hashed_key))
.expect("Failed executing `is_write_initial`");
if let Some(caches) = &self.caches {
if let Some(l1_batch_number) = value {
caches.negative_initial_writes.remove(&hashed_key);
caches.initial_writes.insert(hashed_key, l1_batch_number);
} else {
caches
.negative_initial_writes
.insert(hashed_key, self.pending_l1_batch_number);
}
}
value
});
latency.observe();
let contains_key = l1_batch_number.map_or(false, |initial_write_l1_batch_number| {
self.write_counts(initial_write_l1_batch_number)
});
!contains_key
}
fn load_factory_dep(&mut self, hash: H256) -> Option<Vec<u8>> {
let latency = STORAGE_METRICS.storage[&Method::LoadFactoryDep].start();
let cached_value = self
.caches
.as_ref()
.and_then(|caches| caches.factory_deps.get(&hash));
let value = cached_value.or_else(|| {
let mut dal = self.connection.storage_web3_dal();
let value = self
.rt_handle
.block_on(dal.get_factory_dep(hash))
.expect("Failed executing `load_factory_dep`")
.map(|(bytecode, inserted_at)| TimestampedFactoryDep {
bytecode,
inserted_at,
});
if let Some(caches) = &self.caches {
if let Some(value) = value.clone() {
caches.factory_deps.insert(hash, value);
}
};
value
});
latency.observe();
Some(
value
.filter(|dep| dep.inserted_at <= self.l2_block_number)?
.bytecode,
)
}
fn get_enumeration_index(&mut self, key: &StorageKey) -> Option<u64> {
let hashed_key = key.hashed_key();
let mut dal = self.connection.storage_logs_dedup_dal();
let value = self.rt_handle.block_on(
dal.get_enumeration_index_in_l1_batch(hashed_key, self.l1_batch_number_for_l2_block),
);
value.expect("failed getting enumeration index for key")
}
}