alloy-provider 2.1.0

Interface with an Ethereum blockchain
Documentation
use crate::{BlockLogs, Provider, ProviderBuilder};
use alloy_consensus::BlockHeader;
use alloy_eips::BlockNumberOrTag;
use alloy_network::BlockResponse as _;
use alloy_network_primitives::HeaderResponse;
use alloy_primitives::{B256, U64};
use alloy_rpc_client::RpcClient;
use alloy_rpc_types_eth::{Block, Filter, Log};
use alloy_transport::{
    layers::{RetryBackoffLayer, RetryPolicy},
    TransportError, TransportFut,
};
use std::{
    collections::HashMap,
    sync::{Arc, RwLock},
    task::Poll,
    time::Duration,
};

struct ChainState {
    blocks: HashMap<u64, Block>,
    logs: HashMap<B256, Vec<Log>>,
    head: u64,
    block_request_full: Vec<bool>,
    log_request_block_hash: Vec<bool>,
    range_logs_override: Option<Vec<Log>>,
    fail_logs: usize,
    reorg_after_log_success: Option<Vec<(Block, Vec<Log>)>>,
}

#[derive(Clone)]
pub(crate) struct MockChain {
    state: Arc<RwLock<ChainState>>,
}

impl MockChain {
    pub(crate) fn new() -> Self {
        Self {
            state: Arc::new(RwLock::new(ChainState {
                blocks: HashMap::new(),
                logs: HashMap::new(),
                head: 0,
                block_request_full: Vec::new(),
                log_request_block_hash: Vec::new(),
                range_logs_override: None,
                fail_logs: 0,
                reorg_after_log_success: None,
            })),
        }
    }

    pub(crate) fn extend(&self, blocks: &[(Block, Vec<Log>)]) {
        let mut state = self.state.write().unwrap();
        for (block, logs) in blocks {
            let number = block.header.inner.number;
            state.logs.insert(block.header.hash, logs.clone());
            state.blocks.insert(number, block.clone());
            if number > state.head {
                state.head = number;
            }
        }
    }

    pub(crate) fn reorg(&self, blocks: &[(Block, Vec<Log>)]) {
        let mut state = self.state.write().unwrap();
        Self::apply_reorg(&mut state, blocks);
    }

    fn apply_reorg(state: &mut ChainState, blocks: &[(Block, Vec<Log>)]) {
        let min_height =
            blocks.iter().map(|(b, _)| b.header.inner.number).min().expect("reorg needs blocks");
        let removed_hashes: Vec<_> = state
            .blocks
            .iter()
            .filter_map(|(&height, block)| (height >= min_height).then_some(block.header.hash))
            .collect();
        state.blocks.retain(|&height, _| height < min_height);
        for hash in removed_hashes {
            state.logs.remove(&hash);
        }

        let mut max = state.head;
        for (block, logs) in blocks {
            let number = block.header.inner.number;
            state.logs.insert(block.header.hash, logs.clone());
            state.blocks.insert(number, block.clone());
            if number > max {
                max = number;
            }
        }
        state.head = max;
    }

    pub(crate) fn fail_next_logs(&self, count: usize) {
        self.state.write().unwrap().fail_logs += count;
    }

    pub(crate) fn reorg_after_next_log_success(&self, blocks: Vec<(Block, Vec<Log>)>) {
        self.state.write().unwrap().reorg_after_log_success = Some(blocks);
    }

    pub(crate) fn override_next_range_logs(&self, logs: Vec<Log>) {
        self.state.write().unwrap().range_logs_override = Some(logs);
    }

    pub(crate) fn block_request_full_flags(&self) -> Vec<bool> {
        self.state.read().unwrap().block_request_full.clone()
    }

    pub(crate) fn log_request_block_hash_flags(&self) -> Vec<bool> {
        self.state.read().unwrap().log_request_block_hash.clone()
    }

    pub(crate) fn provider(&self) -> impl Provider {
        let transport = MockChainTransport { chain: self.clone() };
        ProviderBuilder::new().connect_client(RpcClient::new(transport, true))
    }

    pub(crate) fn provider_with_retry(&self) -> impl Provider {
        #[derive(Clone, Debug)]
        struct AlwaysRetryPolicy;

        impl RetryPolicy for AlwaysRetryPolicy {
            fn should_retry(&self, _error: &TransportError) -> bool {
                true
            }

            fn backoff_hint(&self, _error: &TransportError) -> Option<Duration> {
                None
            }
        }

        let retry_layer = RetryBackoffLayer::new_with_policy(1, 0, 10_000, AlwaysRetryPolicy);
        let transport = MockChainTransport { chain: self.clone() };
        let client = RpcClient::builder().layer(retry_layer).transport(transport, true);
        ProviderBuilder::new().connect_client(client)
    }

    fn handle_request(&self, req: &alloy_json_rpc::SerializedRequest) -> alloy_json_rpc::Response {
        let mut state = self.state.write().unwrap();
        let payload = match req.method() {
            "eth_blockNumber" => {
                let raw = serde_json::to_string(&U64::from(state.head)).unwrap();
                alloy_json_rpc::ResponsePayload::Success(
                    serde_json::value::RawValue::from_string(raw).unwrap(),
                )
            }
            "eth_getBlockByNumber" => {
                let params = req.params().expect("eth_getBlockByNumber requires params");
                let (tag, full): (BlockNumberOrTag, bool) =
                    serde_json::from_str(params.get()).unwrap();
                state.block_request_full.push(full);
                let number = match tag {
                    BlockNumberOrTag::Number(n) => n,
                    BlockNumberOrTag::Latest => state.head,
                    _ => unimplemented!("unsupported block tag in MockChain: {tag:?}"),
                };
                let block = state.blocks.get(&number).cloned();
                let raw = serde_json::to_string(&block).unwrap();
                alloy_json_rpc::ResponsePayload::Success(
                    serde_json::value::RawValue::from_string(raw).unwrap(),
                )
            }
            "eth_getLogs" => {
                let params = req.params().expect("eth_getLogs requires params");
                let (filter,): (Filter,) = serde_json::from_str(params.get()).unwrap();
                let block_hash = filter.get_block_hash();
                state.log_request_block_hash.push(block_hash.is_some());
                if state.fail_logs > 0 {
                    state.fail_logs -= 1;
                    alloy_json_rpc::ResponsePayload::internal_error_message(
                        "temporary log error".into(),
                    )
                } else {
                    let logs = if let Some(hash) = block_hash {
                        state.logs.get(&hash).cloned()
                    } else if let Some(logs) = state.range_logs_override.take() {
                        Some(logs)
                    } else {
                        let (from_block, to_block) = filter.extract_block_range();
                        let number = match (from_block, to_block) {
                            (Some(from), Some(to)) if from == to => from,
                            other => panic!("logs are queried by exact block, got {other:?}"),
                        };
                        state
                            .blocks
                            .get(&number)
                            .and_then(|block| state.logs.get(&block.header.hash))
                            .cloned()
                            .or(Some(Vec::new()))
                    };

                    if let Some(logs) = logs {
                        if let Some(blocks) = state.reorg_after_log_success.take() {
                            Self::apply_reorg(&mut state, &blocks);
                        }
                        let raw = serde_json::to_string(&logs).unwrap();
                        alloy_json_rpc::ResponsePayload::Success(
                            serde_json::value::RawValue::from_string(raw).unwrap(),
                        )
                    } else {
                        alloy_json_rpc::ResponsePayload::internal_error_message(
                            "block not found".into(),
                        )
                    }
                }
            }
            other => panic!("MockChain: unexpected RPC method `{other}`"),
        };
        alloy_json_rpc::Response { id: req.id().clone(), payload }
    }
}

#[derive(Clone)]
struct MockChainTransport {
    chain: MockChain,
}

impl tower::Service<alloy_json_rpc::RequestPacket> for MockChainTransport {
    type Response = alloy_json_rpc::ResponsePacket;
    type Error = TransportError;
    type Future = TransportFut<'static>;

    fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: alloy_json_rpc::RequestPacket) -> Self::Future {
        let chain = self.chain.clone();
        Box::pin(async move {
            Ok(match req {
                alloy_json_rpc::RequestPacket::Single(req) => {
                    alloy_json_rpc::ResponsePacket::Single(chain.handle_request(&req))
                }
                alloy_json_rpc::RequestPacket::Batch(reqs) => {
                    alloy_json_rpc::ResponsePacket::Batch(
                        reqs.iter().map(|r| chain.handle_request(r)).collect(),
                    )
                }
            })
        })
    }
}

pub(crate) fn block(number: u64, hash_last_byte: u8, parent_hash_last_byte: u8) -> Block {
    let mut block: Block = Block::default();
    block.header.inner.number = number;
    block.header.hash = B256::with_last_byte(hash_last_byte);
    block.header.inner.parent_hash = B256::with_last_byte(parent_hash_last_byte);
    block
}

pub(crate) fn log(number: u64, hash_last_byte: u8, index: u64) -> Log {
    Log {
        block_hash: Some(B256::with_last_byte(hash_last_byte)),
        block_number: Some(number),
        log_index: Some(index),
        ..Default::default()
    }
}

pub(crate) fn assert_batch(
    block_logs: &BlockLogs<alloy_network::Ethereum>,
    number: u64,
    hash_last_byte: u8,
    removed: bool,
    log_count: usize,
) {
    let block_hash = B256::with_last_byte(hash_last_byte);
    assert_eq!(block_logs.block.header().number(), number);
    assert_eq!(block_logs.block.header().hash(), block_hash);
    assert_eq!(block_logs.logs.len(), log_count);

    for log in &block_logs.logs {
        assert_eq!(log.block_number, Some(number));
        assert_eq!(log.block_hash, Some(block_hash));
        assert_eq!(log.removed, removed);
    }
}