use std::time::Duration;
use futures::FutureExt;
use tokio::sync::watch;
use zksync_eth_client::EthInterface;
use zksync_types::{L1ChainId, L2ChainId};
use zksync_web3_decl::{
client::{DynClient, L1, L2},
error::ClientRpcContext,
namespaces::{EthNamespaceClient, ZksNamespaceClient},
};
#[derive(Debug)]
pub struct ValidateChainIdsTask {
l1_chain_id: L1ChainId,
l2_chain_id: L2ChainId,
eth_client: Box<DynClient<L1>>,
main_node_client: Box<DynClient<L2>>,
}
impl ValidateChainIdsTask {
const BACKOFF_INTERVAL: Duration = Duration::from_secs(5);
pub fn new(
l1_chain_id: L1ChainId,
l2_chain_id: L2ChainId,
eth_client: Box<DynClient<L1>>,
main_node_client: Box<DynClient<L2>>,
) -> Self {
Self {
l1_chain_id,
l2_chain_id,
eth_client: eth_client.for_component("chain_ids_validation"),
main_node_client: main_node_client.for_component("chain_ids_validation"),
}
}
async fn check_eth_client(
eth_client: Box<DynClient<L1>>,
expected: L1ChainId,
) -> anyhow::Result<()> {
loop {
match eth_client.fetch_chain_id().await {
Ok(chain_id) => {
anyhow::ensure!(
expected == chain_id,
"Configured L1 chain ID doesn't match the one from Ethereum node. \
Make sure your configuration is correct and you are corrected to the right Ethereum node. \
Eth node chain ID: {chain_id}. Local config value: {expected}"
);
tracing::info!(
"Checked that L1 chain ID {chain_id} is returned by Ethereum client"
);
return Ok(());
}
Err(err) => {
tracing::warn!("Error getting L1 chain ID from Ethereum client: {err}");
tokio::time::sleep(Self::BACKOFF_INTERVAL).await;
}
}
}
}
async fn check_l1_chain_using_main_node(
main_node_client: Box<DynClient<L2>>,
expected: L1ChainId,
) -> anyhow::Result<()> {
loop {
match main_node_client
.l1_chain_id()
.rpc_context("l1_chain_id")
.await
{
Ok(chain_id) => {
let chain_id = L1ChainId(chain_id.as_u64());
anyhow::ensure!(
expected == chain_id,
"Configured L1 chain ID doesn't match the one from main node. \
Make sure your configuration is correct and you are corrected to the right main node. \
Main node L1 chain ID: {chain_id}. Local config value: {expected}"
);
tracing::info!(
"Checked that L1 chain ID {chain_id} is returned by main node client"
);
return Ok(());
}
Err(err) if err.is_transient() => {
tracing::warn!(
"Transient error getting L1 chain ID from main node client, will retry in {:?}: {err}",
Self::BACKOFF_INTERVAL
);
tokio::time::sleep(Self::BACKOFF_INTERVAL).await;
}
Err(err) => {
tracing::error!("Error getting L1 chain ID from main node client: {err}");
return Err(err.into());
}
}
}
}
async fn check_l2_chain_using_main_node(
main_node_client: Box<DynClient<L2>>,
expected: L2ChainId,
) -> anyhow::Result<()> {
loop {
match main_node_client.chain_id().rpc_context("chain_id").await {
Ok(chain_id) => {
let chain_id = L2ChainId::try_from(chain_id.as_u64()).map_err(|err| {
anyhow::anyhow!("invalid chain ID supplied by main node: {err}")
})?;
anyhow::ensure!(
expected == chain_id,
"Configured L2 chain ID doesn't match the one from main node. \
Make sure your configuration is correct and you are corrected to the right main node. \
Main node L2 chain ID: {chain_id:?}. Local config value: {expected:?}"
);
tracing::info!(
"Checked that L2 chain ID {chain_id:?} is returned by main node client"
);
return Ok(());
}
Err(err) if err.is_transient() => {
tracing::warn!(
"Transient error getting L2 chain ID from main node client, will retry in {:?}: {err}",
Self::BACKOFF_INTERVAL
);
tokio::time::sleep(Self::BACKOFF_INTERVAL).await;
}
Err(err) => {
tracing::error!("Error getting L2 chain ID from main node client: {err}");
return Err(err.into());
}
}
}
}
pub async fn run_once(self, mut stop_receiver: watch::Receiver<bool>) -> anyhow::Result<()> {
let eth_client_check = Self::check_eth_client(self.eth_client, self.l1_chain_id);
let main_node_l1_check =
Self::check_l1_chain_using_main_node(self.main_node_client.clone(), self.l1_chain_id);
let main_node_l2_check =
Self::check_l2_chain_using_main_node(self.main_node_client, self.l2_chain_id);
let joined_futures =
futures::future::try_join3(eth_client_check, main_node_l1_check, main_node_l2_check)
.fuse();
tokio::select! {
res = joined_futures => res.map(drop),
_ = stop_receiver.changed() => Ok(()),
}
}
pub async fn run(self, mut stop_receiver: watch::Receiver<bool>) -> anyhow::Result<()> {
let eth_client_check = Self::check_eth_client(self.eth_client, self.l1_chain_id).fuse();
let main_node_l1_check =
Self::check_l1_chain_using_main_node(self.main_node_client.clone(), self.l1_chain_id)
.fuse();
let main_node_l2_check =
Self::check_l2_chain_using_main_node(self.main_node_client, self.l2_chain_id).fuse();
tokio::select! {
Err(err) = eth_client_check => Err(err),
Err(err) = main_node_l1_check => Err(err),
Err(err) = main_node_l2_check => Err(err),
_ = stop_receiver.changed() => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use zksync_types::U64;
use zksync_web3_decl::client::{MockClient, L1};
use super::*;
#[tokio::test]
async fn validating_chain_ids_errors() {
let eth_client = MockClient::builder(L1::default())
.method("eth_chainId", || Ok(U64::from(9)))
.build();
let main_node_client = MockClient::builder(L2::default())
.method("eth_chainId", || Ok(U64::from(270)))
.method("zks_L1ChainId", || Ok(U64::from(3)))
.build();
let validation_task = ValidateChainIdsTask::new(
L1ChainId(3), L2ChainId::default(),
Box::new(eth_client.clone()),
Box::new(main_node_client.clone()),
);
let (_stop_sender, stop_receiver) = watch::channel(false);
let err = validation_task
.run(stop_receiver.clone())
.await
.unwrap_err()
.to_string();
assert!(
err.contains("L1 chain ID") && err.contains("Ethereum node"),
"{err}"
);
let validation_task = ValidateChainIdsTask::new(
L1ChainId(9), L2ChainId::from(270),
Box::new(eth_client.clone()),
Box::new(main_node_client),
);
let err = validation_task
.run(stop_receiver.clone())
.await
.unwrap_err()
.to_string();
assert!(
err.contains("L1 chain ID") && err.contains("main node"),
"{err}"
);
let main_node_client = MockClient::builder(L2::default())
.method("eth_chainId", || Ok(U64::from(270)))
.method("zks_L1ChainId", || Ok(U64::from(9)))
.build();
let validation_task = ValidateChainIdsTask::new(
L1ChainId(9),
L2ChainId::from(271), Box::new(eth_client),
Box::new(main_node_client),
);
let err = validation_task
.run(stop_receiver)
.await
.unwrap_err()
.to_string();
assert!(
err.contains("L2 chain ID") && err.contains("main node"),
"{err}"
);
}
#[tokio::test]
async fn validating_chain_ids_success() {
let eth_client = MockClient::builder(L1::default())
.method("eth_chainId", || Ok(U64::from(9)))
.build();
let main_node_client = MockClient::builder(L2::default())
.method("eth_chainId", || Ok(U64::from(270)))
.method("zks_L1ChainId", || Ok(U64::from(9)))
.build();
let validation_task = ValidateChainIdsTask::new(
L1ChainId(9),
L2ChainId::default(),
Box::new(eth_client),
Box::new(main_node_client),
);
let (stop_sender, stop_receiver) = watch::channel(false);
let task = tokio::spawn(validation_task.run(stop_receiver));
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(!task.is_finished());
stop_sender.send_replace(true);
task.await.unwrap().unwrap();
}
}