ethers_utils/
middleware.rs1use crate::chain::*;
2use async_trait::async_trait;
3use ethers::{providers::Middleware, types::Chain};
4
5#[async_trait]
6pub trait MiddlewareExt: Middleware {
7 async fn get_chain_variant(&self) -> Result<ChainVariant, Self::Error> {
8 let chain_id = self.get_chainid().await?;
9
10 match Chain::try_from(chain_id) {
11 Ok(chain) => Ok(ChainVariant::Chain(chain)),
12 Err(_) => Ok(ChainVariant::UnknownChain(chain_id)),
13 }
14 }
15}
16
17impl<M: Middleware> MiddlewareExt for M {}
18
19#[cfg(test)]
20mod tests {
21 use super::*;
22 use ethers::types::Chain;
23 use ethers::{providers::Provider, utils::Anvil};
24 use url::Url;
25
26 const ETHEREUM_URL: &'static str = "https://ethereum-mainnet-rpc.allthatnode.com";
27
28 const FAKE_CHAIN_ID: u64 = 9999999999;
29
30 #[tokio::test]
31 async fn test_get_chain_variant() {
32 let provider = ethers::providers::Http::new(Url::parse(ETHEREUM_URL).unwrap());
33 let provider = Provider::new(provider);
34 let chain = provider.get_chain_variant().await.unwrap();
35 assert_eq!(chain, ChainVariant::Chain(Chain::Mainnet));
36
37 let anvil = Anvil::new().chain_id(FAKE_CHAIN_ID).spawn();
38 let provider = ethers::providers::Http::new(Url::parse(&anvil.endpoint()).unwrap());
39 let provider = Provider::new(provider);
40 let chain = provider.get_chain_variant().await.unwrap();
41 assert_eq!(chain, ChainVariant::UnknownChain(FAKE_CHAIN_ID.into()));
42 }
43}