use crate::{
error::rpc::RpcClientError,
strings::rpc::{methods, storage_keys},
};
use scale::{Decode, Encode};
use std::sync::Arc;
use subxt::{
Metadata, SubstrateConfig,
backend::{
legacy::{LegacyRpcMethods, rpc_methods::Block},
rpc::RpcClient,
},
config::substrate::H256,
};
use tokio::sync::{Mutex, RwLock, Semaphore};
use url::Url;
const MAX_CONCURRENT_UPSTREAM_CALLS: usize = 4;
const METADATA_V14: u32 = 14;
const METADATA_LATEST: u32 = 15;
#[derive(Clone)]
pub struct ForkRpcClient {
legacy: Arc<RwLock<LegacyRpcMethods<SubstrateConfig>>>,
endpoint: Url,
upstream_semaphore: Arc<Semaphore>,
reconnect_lock: Arc<Mutex<()>>,
}
impl std::fmt::Debug for ForkRpcClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ForkRpcClient").field("endpoint", &self.endpoint).finish()
}
}
impl ForkRpcClient {
pub async fn connect(endpoint: &Url) -> Result<Self, RpcClientError> {
let legacy = Self::create_connection(endpoint).await?;
Ok(Self {
legacy: Arc::new(RwLock::new(legacy)),
endpoint: endpoint.clone(),
upstream_semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_UPSTREAM_CALLS)),
reconnect_lock: Arc::new(Mutex::new(())),
})
}
async fn create_connection(
endpoint: &Url,
) -> Result<LegacyRpcMethods<SubstrateConfig>, RpcClientError> {
use jsonrpsee::ws_client::WsClientBuilder;
let client = WsClientBuilder::default()
.max_response_size(u32::MAX)
.build(endpoint.as_str())
.await
.map_err(|e| RpcClientError::ConnectionFailed {
endpoint: endpoint.to_string(),
message: e.to_string(),
})?;
let rpc_client = RpcClient::new(client);
Ok(LegacyRpcMethods::new(rpc_client))
}
pub async fn reconnect(&self) -> Result<(), RpcClientError> {
let _guard = self.reconnect_lock.lock().await;
if self.legacy.read().await.system_chain().await.is_ok() {
return Ok(());
}
let new_legacy = Self::create_connection(&self.endpoint).await?;
*self.legacy.write().await = new_legacy;
Ok(())
}
pub fn endpoint(&self) -> &Url {
&self.endpoint
}
pub async fn finalized_head(&self) -> Result<H256, RpcClientError> {
self.legacy.read().await.chain_get_finalized_head().await.map_err(|e| {
RpcClientError::RequestFailed {
method: methods::CHAIN_GET_FINALIZED_HEAD,
message: e.to_string(),
}
})
}
pub async fn header(
&self,
hash: H256,
) -> Result<<SubstrateConfig as subxt::Config>::Header, RpcClientError> {
self.legacy
.read()
.await
.chain_get_header(Some(hash))
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::CHAIN_GET_HEADER,
message: e.to_string(),
})?
.ok_or_else(|| RpcClientError::InvalidResponse(format!("No header found for {hash:?}")))
}
pub async fn block_hash_at(&self, block_number: u32) -> Result<Option<H256>, RpcClientError> {
self.legacy
.read()
.await
.chain_get_block_hash(Some(block_number.into()))
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::CHAIN_GET_BLOCK_HASH,
message: e.to_string(),
})
}
pub async fn block_by_number(
&self,
block_number: u32,
) -> Result<Option<(H256, Block<SubstrateConfig>)>, RpcClientError> {
let block_hash = self.block_hash_at(block_number).await?;
let block_hash = match block_hash {
Some(hash) => hash,
None => return Ok(None),
};
let block =
self.legacy.read().await.chain_get_block(Some(block_hash)).await.map_err(|e| {
RpcClientError::RequestFailed {
method: methods::CHAIN_GET_BLOCK,
message: e.to_string(),
}
})?;
Ok(block.map(|block| (block_hash, block.block)))
}
pub async fn block_by_hash(
&self,
block_hash: H256,
) -> Result<Option<Block<SubstrateConfig>>, RpcClientError> {
let block =
self.legacy.read().await.chain_get_block(Some(block_hash)).await.map_err(|e| {
RpcClientError::RequestFailed {
method: methods::CHAIN_GET_BLOCK,
message: e.to_string(),
}
})?;
Ok(block.map(|b| b.block))
}
pub async fn storage(&self, key: &[u8], at: H256) -> Result<Option<Vec<u8>>, RpcClientError> {
self.legacy.read().await.state_get_storage(key, Some(at)).await.map_err(|e| {
RpcClientError::RequestFailed {
method: methods::STATE_GET_STORAGE,
message: e.to_string(),
}
})
}
pub async fn storage_batch(
&self,
keys: &[&[u8]],
at: H256,
) -> Result<Vec<Option<Vec<u8>>>, RpcClientError> {
if keys.is_empty() {
return Ok(vec![]);
}
let _permit = self.upstream_semaphore.acquire().await.expect("semaphore closed");
let result = self
.legacy
.read()
.await
.state_query_storage_at(keys.iter().copied(), Some(at))
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::STATE_QUERY_STORAGE_AT,
message: e.to_string(),
})?;
let changes: std::collections::HashMap<Vec<u8>, Option<Vec<u8>>> = result
.into_iter()
.flat_map(|change_set| {
change_set.changes.into_iter().map(|(k, v)| {
let key_bytes = k.0.to_vec();
let value_bytes = v.map(|v| v.0.to_vec());
(key_bytes, value_bytes)
})
})
.collect();
let values = keys.iter().map(|key| changes.get::<[u8]>(key).cloned().flatten()).collect();
Ok(values)
}
pub async fn storage_keys_paged(
&self,
prefix: &[u8],
count: u32,
start_key: Option<&[u8]>,
at: H256,
) -> Result<Vec<Vec<u8>>, RpcClientError> {
let _permit = self.upstream_semaphore.acquire().await.expect("semaphore closed");
self.legacy
.read()
.await
.state_get_keys_paged(prefix, count, start_key, Some(at))
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::STATE_GET_KEYS_PAGED,
message: e.to_string(),
})
}
pub async fn metadata(&self, at: H256) -> Result<Metadata, RpcClientError> {
let raw = self.legacy.read().await.state_get_metadata(Some(at)).await.map_err(|e| {
RpcClientError::RequestFailed {
method: methods::STATE_GET_METADATA,
message: e.to_string(),
}
})?;
let raw_bytes = raw.into_raw();
match Metadata::decode(&mut raw_bytes.as_slice()) {
Ok(metadata) => Ok(metadata),
Err(default_err) => {
for version in (METADATA_V14..=METADATA_LATEST).rev() {
if let Some(bytes) = self.metadata_at_version(version, at).await? &&
let Ok(metadata) = Metadata::decode(&mut bytes.as_slice())
{
return Ok(metadata);
}
}
Err(RpcClientError::MetadataDecodingFailed(default_err.to_string()))
},
}
}
async fn metadata_at_version(
&self,
version: u32,
at: H256,
) -> Result<Option<Vec<u8>>, RpcClientError> {
let result = self
.legacy
.read()
.await
.state_call("Metadata_metadata_at_version", Some(&version.encode()), Some(at))
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::STATE_CALL,
message: e.to_string(),
})?;
let opaque: Option<Vec<u8>> = Decode::decode(&mut result.as_slice()).map_err(|e| {
RpcClientError::InvalidResponse(format!(
"Failed to decode metadata_at_version response: {e}"
))
})?;
Ok(opaque)
}
pub async fn runtime_code(&self, at: H256) -> Result<Vec<u8>, RpcClientError> {
let code_key = sp_core::storage::well_known_keys::CODE;
self.storage(code_key, at)
.await?
.ok_or_else(|| RpcClientError::StorageNotFound(storage_keys::CODE.to_string()))
}
pub async fn system_chain(&self) -> Result<String, RpcClientError> {
self.legacy
.read()
.await
.system_chain()
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::SYSTEM_CHAIN,
message: e.to_string(),
})
}
pub async fn state_call(
&self,
function: &str,
call_parameters: &[u8],
at: Option<H256>,
) -> Result<Vec<u8>, RpcClientError> {
self.legacy
.read()
.await
.state_call(function, Some(call_parameters), at)
.await
.map_err(|e| RpcClientError::RequestFailed {
method: methods::STATE_CALL,
message: e.to_string(),
})
}
pub async fn system_properties(
&self,
) -> Result<subxt::backend::legacy::rpc_methods::SystemProperties, RpcClientError> {
self.legacy.read().await.system_properties().await.map_err(|e| {
RpcClientError::RequestFailed {
method: methods::SYSTEM_PROPERTIES,
message: e.to_string(),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_display_connection_failed() {
let err = RpcClientError::ConnectionFailed {
endpoint: "wss://example.com".to_string(),
message: "connection refused".to_string(),
};
assert_eq!(err.to_string(), "Failed to connect to wss://example.com: connection refused");
}
#[test]
fn error_display_request_failed() {
let err = RpcClientError::RequestFailed {
method: methods::STATE_GET_STORAGE,
message: "connection reset".to_string(),
};
assert_eq!(
err.to_string(),
format!("RPC request `{}` failed: connection reset", methods::STATE_GET_STORAGE)
);
}
#[test]
fn error_display_timeout() {
let err = RpcClientError::Timeout { method: methods::STATE_GET_METADATA };
assert_eq!(
err.to_string(),
format!("RPC request `{}` timed out", methods::STATE_GET_METADATA)
);
}
#[test]
fn error_display_invalid_response() {
let err = RpcClientError::InvalidResponse("missing field".to_string());
assert_eq!(err.to_string(), "Invalid RPC response: missing field");
}
#[test]
fn error_display_storage_not_found() {
let err = RpcClientError::StorageNotFound(storage_keys::CODE.to_string());
assert_eq!(
err.to_string(),
format!("Required storage key not found: {}", storage_keys::CODE)
);
}
#[tokio::test]
async fn connect_to_invalid_endpoint_fails() {
let endpoint: Url = "ws://127.0.0.1:19999".parse().unwrap();
let result = ForkRpcClient::connect(&endpoint).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, RpcClientError::ConnectionFailed { .. }),
"Expected ConnectionFailed, got: {err:?}"
);
}
mod sequential {
use super::*;
use crate::testing::{
TestContext,
constants::{SYSTEM_NUMBER_KEY, SYSTEM_PALLET_PREFIX, SYSTEM_PARENT_HASH_KEY},
};
use std::time::Duration;
#[tokio::test]
async fn connect_to_node() {
let ctx = TestContext::for_rpc_client().await;
assert_eq!(ctx.rpc().endpoint(), &ctx.endpoint);
}
#[tokio::test]
async fn fetch_finalized_head() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
assert_eq!(hash.as_bytes().len(), 32);
}
#[tokio::test]
async fn fetch_header() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let header = ctx.rpc().header(hash).await.unwrap();
assert_eq!(header.state_root.as_bytes().len(), 32);
}
#[tokio::test]
async fn fetch_storage() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let key = hex::decode(SYSTEM_NUMBER_KEY).unwrap();
let value = ctx.rpc().storage(&key, hash).await.unwrap();
assert!(value.is_some());
}
#[tokio::test]
async fn fetch_metadata() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let metadata = ctx.rpc().metadata(hash).await.unwrap();
assert!(metadata.pallets().len() > 0);
}
#[tokio::test]
async fn fetch_runtime_code() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let code = ctx.rpc().runtime_code(hash).await.unwrap();
assert!(
code.len() > 10_000,
"Runtime code should be substantial, got {} bytes",
code.len()
);
}
#[tokio::test]
async fn fetch_storage_keys_paged() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let prefix = hex::decode(SYSTEM_PALLET_PREFIX).unwrap();
let keys = ctx.rpc().storage_keys_paged(&prefix, 10, None, hash).await.unwrap();
assert!(!keys.is_empty());
for key in &keys {
assert!(key.starts_with(&prefix));
}
}
#[tokio::test]
async fn fetch_storage_batch() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let keys = [
hex::decode(SYSTEM_NUMBER_KEY).unwrap(),
hex::decode(SYSTEM_PARENT_HASH_KEY).unwrap(),
];
let key_refs: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
let values = ctx.rpc().storage_batch(&key_refs, hash).await.unwrap();
assert_eq!(values.len(), 2);
assert!(values[0].is_some());
assert!(values[1].is_some());
}
#[tokio::test]
async fn fetch_system_chain() {
let ctx = TestContext::for_rpc_client().await;
let chain_name = ctx.rpc().system_chain().await.unwrap();
assert!(!chain_name.is_empty());
}
#[tokio::test]
async fn fetch_system_properties() {
let ctx = TestContext::for_rpc_client().await;
let _properties = ctx.rpc().system_properties().await.unwrap();
}
#[tokio::test]
async fn fetch_header_non_existent_block_fails() {
let ctx = TestContext::for_rpc_client().await;
let non_existent_hash = H256::from([0xde; 32]);
let result = ctx.rpc().header(non_existent_hash).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, RpcClientError::InvalidResponse(_)),
"Expected InvalidResponse for non-existent block, got: {err:?}"
);
}
#[tokio::test]
async fn fetch_storage_non_existent_key_returns_none() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let non_existent_key = vec![0xff; 32];
let result = ctx.rpc().storage(&non_existent_key, hash).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_storage_batch_with_mixed_keys() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let keys = [
hex::decode(SYSTEM_NUMBER_KEY).unwrap(), vec![0xff; 32], ];
let key_refs: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
let values = ctx.rpc().storage_batch(&key_refs, hash).await.unwrap();
assert_eq!(values.len(), 2);
assert!(values[0].is_some(), "System::Number should exist");
assert!(values[1].is_none(), "Fabricated key should not exist");
}
#[tokio::test]
async fn fetch_storage_batch_empty_keys() {
let ctx = TestContext::for_rpc_client().await;
let hash = ctx.rpc().finalized_head().await.unwrap();
let values = ctx.rpc().storage_batch(&[], hash).await.unwrap();
assert!(values.is_empty());
}
#[tokio::test]
async fn fetch_block_by_number_returns_block() {
let ctx = TestContext::for_rpc_client().await;
let finalized_hash = ctx.rpc().finalized_head().await.unwrap();
let finalized_header = ctx.rpc().header(finalized_hash).await.unwrap();
let finalized_number = finalized_header.number;
let result = ctx.rpc().block_by_number(finalized_number).await.unwrap();
assert!(result.is_some(), "Finalized block should exist");
let (hash, block) = result.unwrap();
assert_eq!(hash, finalized_hash, "Block hash should match finalized head");
assert_eq!(
block.header.number, finalized_number,
"Block header number should match requested number"
);
}
#[tokio::test]
async fn fetch_block_by_number_non_existent_returns_none() {
let ctx = TestContext::for_rpc_client().await;
let non_existent_number = u32::MAX;
let result = ctx.rpc().block_by_number(non_existent_number).await.unwrap();
assert!(result.is_none(), "Non-existent block should return None");
}
#[tokio::test]
async fn fetch_block_by_number_multiple_blocks() {
let ctx = TestContext::for_rpc_client().await;
std::thread::sleep(Duration::from_secs(30));
let finalized_hash = ctx.rpc().finalized_head().await.unwrap();
let finalized_header = ctx.rpc().header(finalized_hash).await.unwrap();
let finalized_number = finalized_header.number;
let mut previous_hash = None;
for block_num in 0..=finalized_number.min(5) {
let result = ctx.rpc().block_by_number(block_num).await.unwrap();
assert!(
result.is_some(),
"Block {} should exist (finalized is {})",
block_num,
finalized_number
);
let (hash, block) = result.unwrap();
assert_eq!(block.header.number, block_num);
if let Some(prev) = previous_hash {
assert_eq!(
block.header.parent_hash, prev,
"Block {} parent hash should match previous block hash",
block_num
);
}
previous_hash = Some(hash);
}
}
}
}