ethers_utils/
middleware.rs

1use crate::chain::*;
2use async_trait::async_trait;
3use ethers::{providers::Middleware, types::Chain};
4
5#[async_trait]
6pub trait MiddlewareExt: Middleware {
7    async fn get_chain_variant(&self) -> Result<ChainVariant, Self::Error> {
8        let chain_id = self.get_chainid().await?;
9
10        match Chain::try_from(chain_id) {
11            Ok(chain) => Ok(ChainVariant::Chain(chain)),
12            Err(_) => Ok(ChainVariant::UnknownChain(chain_id)),
13        }
14    }
15}
16
17impl<M: Middleware> MiddlewareExt for M {}
18
19#[cfg(test)]
20mod tests {
21    use super::*;
22    use ethers::types::Chain;
23    use ethers::{providers::Provider, utils::Anvil};
24    use url::Url;
25
26    const ETHEREUM_URL: &'static str = "https://ethereum-mainnet-rpc.allthatnode.com";
27
28    const FAKE_CHAIN_ID: u64 = 9999999999;
29
30    #[tokio::test]
31    async fn test_get_chain_variant() {
32        let provider = ethers::providers::Http::new(Url::parse(ETHEREUM_URL).unwrap());
33        let provider = Provider::new(provider);
34        let chain = provider.get_chain_variant().await.unwrap();
35        assert_eq!(chain, ChainVariant::Chain(Chain::Mainnet));
36
37        let anvil = Anvil::new().chain_id(FAKE_CHAIN_ID).spawn();
38        let provider = ethers::providers::Http::new(Url::parse(&anvil.endpoint()).unwrap());
39        let provider = Provider::new(provider);
40        let chain = provider.get_chain_variant().await.unwrap();
41        assert_eq!(chain, ChainVariant::UnknownChain(FAKE_CHAIN_ID.into()));
42    }
43}