use std::{fmt, future::Future, time::Duration};
use anyhow::Context as _;
use async_trait::async_trait;
use tokio::sync::watch;
use zksync_dal::{ConnectionPool, Core, CoreDal, DalError};
use zksync_health_check::{Health, HealthStatus, HealthUpdater, ReactiveHealthCheck};
use zksync_shared_metrics::{CheckerComponent, EN_METRICS};
use zksync_types::{L1BatchNumber, L2BlockNumber, H256};
use zksync_web3_decl::{
client::{DynClient, L2},
error::{ClientRpcContext, EnrichedClientError, EnrichedClientResult},
namespaces::{EthNamespaceClient, ZksNamespaceClient},
};
#[cfg(test)]
mod tests;
#[derive(Debug, thiserror::Error)]
#[cfg_attr(test, derive(Clone, Copy))] pub enum MissingData {
#[error("no requested L2 block")]
L2Block,
#[error("no requested L1 batch")]
Batch,
#[error("no root hash for L1 batch")]
RootHash,
}
#[derive(Debug, thiserror::Error)]
pub enum HashMatchError {
#[error("RPC error calling main node")]
Rpc(#[from] EnrichedClientError),
#[error("missing data on main node")]
MissingData(#[from] MissingData),
#[error(transparent)]
Internal(#[from] anyhow::Error),
}
impl From<DalError> for HashMatchError {
fn from(err: DalError) -> Self {
Self::Internal(err.generalize())
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
HashMatch(#[from] HashMatchError),
#[error(
"Unrecoverable error: the earliest L1 batch #{0} in the local DB \
has mismatched hash with the main node. Make sure you're connected to the right network; \
if you've recovered from a snapshot, re-check snapshot authenticity. \
Using an earlier snapshot could help."
)]
EarliestL1BatchMismatch(L1BatchNumber),
#[error(
"Unrecoverable error: the earliest L1 batch #{0} in the local DB \
is truncated on the main node. Make sure you're connected to the right network; \
if you've recovered from a snapshot, re-check snapshot authenticity. \
Using an earlier snapshot could help."
)]
EarliestL1BatchTruncated(L1BatchNumber),
#[error("reorg detected, restart the node to revert to the last correct L1 batch #{0}.")]
ReorgDetected(L1BatchNumber),
}
impl HashMatchError {
pub fn is_transient(&self) -> bool {
match self {
Self::Rpc(err) => err.is_transient(),
Self::MissingData(_) => true,
Self::Internal(_) => false,
}
}
}
impl Error {
pub fn is_transient(&self) -> bool {
matches!(self, Self::HashMatch(err) if err.is_transient())
}
}
impl From<anyhow::Error> for Error {
fn from(err: anyhow::Error) -> Self {
Self::HashMatch(HashMatchError::Internal(err))
}
}
impl From<DalError> for Error {
fn from(err: DalError) -> Self {
Self::HashMatch(HashMatchError::Internal(err.generalize()))
}
}
impl From<EnrichedClientError> for Error {
fn from(err: EnrichedClientError) -> Self {
Self::HashMatch(HashMatchError::Rpc(err))
}
}
#[async_trait]
trait MainNodeClient: fmt::Debug + Send + Sync {
async fn sealed_l2_block_number(&self) -> EnrichedClientResult<L2BlockNumber>;
async fn sealed_l1_batch_number(&self) -> EnrichedClientResult<L1BatchNumber>;
async fn l2_block_hash(&self, number: L2BlockNumber) -> EnrichedClientResult<Option<H256>>;
async fn l1_batch_root_hash(
&self,
number: L1BatchNumber,
) -> EnrichedClientResult<Result<H256, MissingData>>;
}
#[async_trait]
impl MainNodeClient for Box<DynClient<L2>> {
async fn sealed_l2_block_number(&self) -> EnrichedClientResult<L2BlockNumber> {
let number = self
.get_block_number()
.rpc_context("sealed_l2_block_number")
.await?;
let number = u32::try_from(number).map_err(|err| {
EnrichedClientError::custom(err, "u32::try_from").with_arg("number", &number)
})?;
Ok(L2BlockNumber(number))
}
async fn sealed_l1_batch_number(&self) -> EnrichedClientResult<L1BatchNumber> {
let number = self
.get_l1_batch_number()
.rpc_context("sealed_l1_batch_number")
.await?;
let number = u32::try_from(number).map_err(|err| {
EnrichedClientError::custom(err, "u32::try_from").with_arg("number", &number)
})?;
Ok(L1BatchNumber(number))
}
async fn l2_block_hash(&self, number: L2BlockNumber) -> EnrichedClientResult<Option<H256>> {
Ok(self
.get_block_by_number(number.0.into(), false)
.rpc_context("l2_block_hash")
.with_arg("number", &number)
.await?
.map(|block| block.hash))
}
async fn l1_batch_root_hash(
&self,
number: L1BatchNumber,
) -> EnrichedClientResult<Result<H256, MissingData>> {
let Some(batch) = self
.get_l1_batch_details(number)
.rpc_context("l1_batch_root_hash")
.with_arg("number", &number)
.await?
else {
return Ok(Err(MissingData::Batch));
};
Ok(batch.base.root_hash.ok_or(MissingData::RootHash))
}
}
trait HandleReorgDetectorEvent: fmt::Debug + Send + Sync {
fn initialize(&mut self);
fn update_correct_block(
&mut self,
last_correct_l2_block: L2BlockNumber,
last_correct_l1_batch: L1BatchNumber,
);
fn report_divergence(&mut self, diverged_l1_batch: L1BatchNumber);
fn start_shutting_down(&mut self);
}
impl HandleReorgDetectorEvent for HealthUpdater {
fn initialize(&mut self) {
self.update(Health::from(HealthStatus::Ready));
}
fn update_correct_block(
&mut self,
last_correct_l2_block: L2BlockNumber,
last_correct_l1_batch: L1BatchNumber,
) {
let last_correct_l2_block = last_correct_l2_block.0.into();
let prev_checked_l2_block = EN_METRICS.last_correct_l2_block
[&CheckerComponent::ReorgDetector]
.set(last_correct_l2_block);
if prev_checked_l2_block != last_correct_l2_block {
tracing::debug!("No reorg at L2 block #{last_correct_l2_block}");
}
let last_correct_l1_batch = last_correct_l1_batch.0.into();
let prev_checked_l1_batch = EN_METRICS.last_correct_batch[&CheckerComponent::ReorgDetector]
.set(last_correct_l1_batch);
if prev_checked_l1_batch != last_correct_l1_batch {
tracing::debug!("No reorg at L1 batch #{last_correct_l1_batch}");
}
let health_details = serde_json::json!({
"last_correct_l2_block": last_correct_l2_block,
"last_correct_l1_batch": last_correct_l1_batch,
});
self.update(Health::from(HealthStatus::Ready).with_details(health_details));
}
fn report_divergence(&mut self, diverged_l1_batch: L1BatchNumber) {
let health_details = serde_json::json!({
"diverged_l1_batch": diverged_l1_batch,
});
self.update(Health::from(HealthStatus::Affected).with_details(health_details));
}
fn start_shutting_down(&mut self) {
self.update(HealthStatus::ShuttingDown.into());
}
}
#[derive(Debug)]
pub struct ReorgDetector {
client: Box<dyn MainNodeClient>,
event_handler: Box<dyn HandleReorgDetectorEvent>,
pool: ConnectionPool<Core>,
sleep_interval: Duration,
health_check: ReactiveHealthCheck,
}
impl ReorgDetector {
const DEFAULT_SLEEP_INTERVAL: Duration = Duration::from_secs(5);
pub fn new(client: Box<DynClient<L2>>, pool: ConnectionPool<Core>) -> Self {
let (health_check, health_updater) = ReactiveHealthCheck::new("reorg_detector");
Self {
client: Box::new(client.for_component("reorg_detector")),
event_handler: Box::new(health_updater),
pool,
sleep_interval: Self::DEFAULT_SLEEP_INTERVAL,
health_check,
}
}
pub fn health_check(&self) -> &ReactiveHealthCheck {
&self.health_check
}
async fn check_consistency(&mut self) -> Result<(), Error> {
let mut storage = self.pool.connection().await?;
let Some(local_l1_batch) = storage
.blocks_dal()
.get_last_l1_batch_number_with_tree_data()
.await?
else {
return Ok(());
};
let Some(local_l2_block) = storage.blocks_dal().get_sealed_l2_block_number().await? else {
return Ok(());
};
drop(storage);
let remote_l1_batch = self.client.sealed_l1_batch_number().await?;
let remote_l2_block = self.client.sealed_l2_block_number().await?;
let checked_l1_batch = local_l1_batch.min(remote_l1_batch);
let checked_l2_block = local_l2_block.min(remote_l2_block);
let root_hashes_match = self.root_hashes_match(checked_l1_batch).await?;
let l2_block_hashes_match = self.l2_block_hashes_match(checked_l2_block).await?;
if root_hashes_match && l2_block_hashes_match {
self.event_handler
.update_correct_block(checked_l2_block, checked_l1_batch);
return Ok(());
}
let diverged_l1_batch = checked_l1_batch + (root_hashes_match as u32);
self.event_handler.report_divergence(diverged_l1_batch);
let mut storage = self.pool.connection().await?;
let first_l1_batch = storage
.blocks_dal()
.get_earliest_l1_batch_number_with_metadata()
.await?
.context("all L1 batches disappeared")?;
drop(storage);
match self.root_hashes_match(first_l1_batch).await {
Ok(true) => {}
Ok(false) => return Err(Error::EarliestL1BatchMismatch(first_l1_batch)),
Err(HashMatchError::MissingData(_)) => {
return Err(Error::EarliestL1BatchTruncated(first_l1_batch));
}
Err(err) => return Err(err.into()),
}
tracing::info!("Searching for the first diverged L1 batch");
let last_correct_l1_batch = self.detect_reorg(first_l1_batch, diverged_l1_batch).await?;
tracing::info!("Reorg localized: last correct L1 batch is #{last_correct_l1_batch}");
Err(Error::ReorgDetected(last_correct_l1_batch))
}
async fn l2_block_hashes_match(&self, l2_block: L2BlockNumber) -> Result<bool, HashMatchError> {
let mut storage = self.pool.connection().await?;
let local_hash = storage
.blocks_dal()
.get_l2_block_header(l2_block)
.await?
.with_context(|| format!("Header does not exist for local L2 block #{l2_block}"))?
.hash;
drop(storage);
let Some(remote_hash) = self.client.l2_block_hash(l2_block).await? else {
tracing::info!("Remote L2 block #{l2_block} is missing");
return Err(MissingData::L2Block.into());
};
if remote_hash != local_hash {
tracing::warn!(
"Reorg detected: local hash {local_hash:?} doesn't match the hash from \
main node {remote_hash:?} (L2 block #{l2_block})"
);
}
Ok(remote_hash == local_hash)
}
async fn root_hashes_match(&self, l1_batch: L1BatchNumber) -> Result<bool, HashMatchError> {
let mut storage = self.pool.connection().await?;
let local_hash = storage
.blocks_dal()
.get_l1_batch_state_root(l1_batch)
.await?
.with_context(|| format!("Root hash does not exist for local batch #{l1_batch}"))?;
drop(storage);
let remote_hash = self.client.l1_batch_root_hash(l1_batch).await??;
if remote_hash != local_hash {
tracing::warn!(
"Reorg detected: local root hash {local_hash:?} doesn't match the state hash from \
main node {remote_hash:?} (L1 batch #{l1_batch})"
);
}
Ok(remote_hash == local_hash)
}
async fn root_hashes_and_contents_match(
&self,
l1_batch: L1BatchNumber,
) -> Result<bool, HashMatchError> {
let root_hashes_match = self.root_hashes_match(l1_batch).await?;
if !root_hashes_match {
return Ok(false);
}
let mut storage = self.pool.connection().await?;
let (_, last_l2_block_in_batch) = storage
.blocks_dal()
.get_l2_block_range_of_l1_batch(l1_batch)
.await?
.with_context(|| format!("L1 batch #{l1_batch} does not have L2 blocks"))?;
drop(storage);
self.l2_block_hashes_match(last_l2_block_in_batch).await
}
async fn detect_reorg(
&self,
known_valid_l1_batch: L1BatchNumber,
diverged_l1_batch: L1BatchNumber,
) -> Result<L1BatchNumber, HashMatchError> {
binary_search_with(
known_valid_l1_batch.0,
diverged_l1_batch.0,
|number| async move {
match self
.root_hashes_and_contents_match(L1BatchNumber(number))
.await
{
Err(HashMatchError::MissingData(_)) => Ok(true),
res => res,
}
},
)
.await
.map(L1BatchNumber)
}
pub async fn run_once(&mut self, stop_receiver: watch::Receiver<bool>) -> Result<(), Error> {
self.run_inner(true, stop_receiver).await
}
pub async fn run(mut self, stop_receiver: watch::Receiver<bool>) -> Result<(), Error> {
self.event_handler.initialize();
self.run_inner(false, stop_receiver).await?;
self.event_handler.start_shutting_down();
tracing::info!("Shutting down reorg detector");
Ok(())
}
async fn run_inner(
&mut self,
stop_after_success: bool,
mut stop_receiver: watch::Receiver<bool>,
) -> Result<(), Error> {
while !*stop_receiver.borrow_and_update() {
let sleep_interval = match self.check_consistency().await {
Err(Error::HashMatch(HashMatchError::MissingData(MissingData::RootHash))) => {
tracing::debug!("Last L1 batch on the main node doesn't have a state root hash; waiting until it is computed");
self.sleep_interval / 10
}
Err(err) if err.is_transient() => {
tracing::warn!("Following transient error occurred: {err}");
tracing::info!("Trying again after a delay");
self.sleep_interval
}
Err(err) => return Err(err),
Ok(()) if stop_after_success => return Ok(()),
Ok(()) => self.sleep_interval,
};
if tokio::time::timeout(sleep_interval, stop_receiver.changed())
.await
.is_ok()
{
break;
}
}
Ok(())
}
}
#[async_trait]
trait BinarySearchPredicate: Send {
type Error;
async fn eval(&mut self, argument: u32) -> Result<bool, Self::Error>;
}
#[async_trait]
impl<F, Fut, E> BinarySearchPredicate for F
where
F: Send + FnMut(u32) -> Fut,
Fut: Send + Future<Output = Result<bool, E>>,
{
type Error = E;
async fn eval(&mut self, argument: u32) -> Result<bool, Self::Error> {
self(argument).await
}
}
async fn binary_search_with<P: BinarySearchPredicate>(
mut left: u32,
mut right: u32,
mut predicate: P,
) -> Result<u32, P::Error> {
while left + 1 < right {
let middle = (left + right) / 2;
if predicate.eval(middle).await? {
left = middle;
} else {
right = middle;
}
}
Ok(left)
}