use async_trait::async_trait;
use ethers_core::types::{
transaction::{eip2718::TypedTransaction, eip2930::AccessListWithGasUsed},
Address, BlockId, Bytes, Chain, Signature, TransactionRequest, U256,
};
use ethers_providers::{maybe, Middleware, MiddlewareError, PendingTransaction};
use ethers_signers::Signer;
use thiserror::Error;
#[derive(Clone, Debug)]
pub struct SignerMiddleware<M, S> {
pub(crate) inner: M,
pub(crate) signer: S,
pub(crate) address: Address,
}
#[derive(Error, Debug)]
pub enum SignerMiddlewareError<M: Middleware, S: Signer> {
#[error("{0}")]
SignerError(S::Error),
#[error("{0}")]
MiddlewareError(M::Error),
#[error("no nonce was specified")]
NonceMissing,
#[error("no gas price was specified")]
GasPriceMissing,
#[error("no gas was specified")]
GasMissing,
#[error("specified from address is not signer")]
WrongSigner,
#[error("specified chain_id is different than the signer's chain_id")]
DifferentChainID,
}
impl<M: Middleware, S: Signer> MiddlewareError for SignerMiddlewareError<M, S> {
type Inner = M::Error;
fn from_err(src: M::Error) -> Self {
SignerMiddlewareError::MiddlewareError(src)
}
fn as_inner(&self) -> Option<&Self::Inner> {
match self {
SignerMiddlewareError::MiddlewareError(e) => Some(e),
_ => None,
}
}
}
impl<M, S> SignerMiddleware<M, S>
where
M: Middleware,
S: Signer,
{
pub fn new(inner: M, signer: S) -> Self {
let address = signer.address();
SignerMiddleware { inner, signer, address }
}
async fn sign_transaction(
&self,
mut tx: TypedTransaction,
) -> Result<Bytes, SignerMiddlewareError<M, S>> {
let chain_id = self.signer.chain_id();
match tx.chain_id() {
Some(id) if id.as_u64() != chain_id => {
return Err(SignerMiddlewareError::DifferentChainID)
}
None => {
tx.set_chain_id(chain_id);
}
_ => {}
}
let signature =
self.signer.sign_transaction(&tx).await.map_err(SignerMiddlewareError::SignerError)?;
Ok(tx.rlp_signed(&signature))
}
pub fn address(&self) -> Address {
self.address
}
pub fn signer(&self) -> &S {
&self.signer
}
#[must_use]
pub fn with_signer(&self, signer: S) -> Self
where
S: Clone,
M: Clone,
{
let mut this = self.clone();
this.address = signer.address();
this.signer = signer;
this
}
pub async fn new_with_provider_chain(
inner: M,
signer: S,
) -> Result<Self, SignerMiddlewareError<M, S>> {
let address = signer.address();
let chain_id =
inner.get_chainid().await.map_err(|e| SignerMiddlewareError::MiddlewareError(e))?;
let signer = signer.with_chain_id(chain_id.as_u64());
Ok(SignerMiddleware { inner, signer, address })
}
fn set_tx_from_if_none(&self, tx: &TypedTransaction) -> TypedTransaction {
let mut tx = tx.clone();
if tx.from().is_none() {
tx.set_from(self.address);
}
tx
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<M, S> Middleware for SignerMiddleware<M, S>
where
M: Middleware,
S: Signer,
{
type Error = SignerMiddlewareError<M, S>;
type Provider = M::Provider;
type Inner = M;
fn inner(&self) -> &M {
&self.inner
}
fn default_sender(&self) -> Option<Address> {
Some(self.address)
}
async fn is_signer(&self) -> bool {
true
}
async fn sign_transaction(
&self,
tx: &TypedTransaction,
_: Address,
) -> Result<Signature, Self::Error> {
Ok(self.signer.sign_transaction(tx).await.map_err(SignerMiddlewareError::SignerError)?)
}
async fn fill_transaction(
&self,
tx: &mut TypedTransaction,
block: Option<BlockId>,
) -> Result<(), Self::Error> {
let from = if tx.from().is_some() && tx.from() != Some(&self.address()) {
*tx.from().unwrap()
} else {
self.address
};
tx.set_from(from);
let chain_id = self.signer.chain_id();
if tx.chain_id().is_none() {
tx.set_chain_id(chain_id);
}
if let Some(chain_id) = tx.chain_id() {
let chain = Chain::try_from(chain_id.as_u64());
if chain.unwrap_or_default().is_legacy() {
if let TypedTransaction::Eip1559(inner) = tx {
let tx_req: TransactionRequest = inner.clone().into();
*tx = TypedTransaction::Legacy(tx_req);
}
}
}
let nonce = maybe(tx.nonce().cloned(), self.get_transaction_count(from, block)).await?;
tx.set_nonce(nonce);
self.inner()
.fill_transaction(tx, block)
.await
.map_err(SignerMiddlewareError::MiddlewareError)?;
Ok(())
}
async fn send_transaction<T: Into<TypedTransaction> + Send + Sync>(
&self,
tx: T,
block: Option<BlockId>,
) -> Result<PendingTransaction<'_, Self::Provider>, Self::Error> {
let mut tx = tx.into();
self.fill_transaction(&mut tx, block).await?;
if tx.from().is_some() && tx.from() != Some(&self.address()) {
return self
.inner
.send_transaction(tx, block)
.await
.map_err(SignerMiddlewareError::MiddlewareError)
}
let signed_tx = self.sign_transaction(tx).await?;
self.inner
.send_raw_transaction(signed_tx)
.await
.map_err(SignerMiddlewareError::MiddlewareError)
}
async fn sign<T: Into<Bytes> + Send + Sync>(
&self,
data: T,
_: &Address,
) -> Result<Signature, Self::Error> {
self.signer.sign_message(data.into()).await.map_err(SignerMiddlewareError::SignerError)
}
async fn estimate_gas(
&self,
tx: &TypedTransaction,
block: Option<BlockId>,
) -> Result<U256, Self::Error> {
let tx = self.set_tx_from_if_none(tx);
self.inner.estimate_gas(&tx, block).await.map_err(SignerMiddlewareError::MiddlewareError)
}
async fn create_access_list(
&self,
tx: &TypedTransaction,
block: Option<BlockId>,
) -> Result<AccessListWithGasUsed, Self::Error> {
let tx = self.set_tx_from_if_none(tx);
self.inner
.create_access_list(&tx, block)
.await
.map_err(SignerMiddlewareError::MiddlewareError)
}
async fn call(
&self,
tx: &TypedTransaction,
block: Option<BlockId>,
) -> Result<Bytes, Self::Error> {
let tx = self.set_tx_from_if_none(tx);
self.inner().call(&tx, block).await.map_err(SignerMiddlewareError::MiddlewareError)
}
}
#[cfg(all(test, not(feature = "celo")))]
mod tests {
use super::*;
use ethers_core::{
types::{Eip1559TransactionRequest, TransactionRequest},
utils::{self, keccak256, Anvil},
};
use ethers_providers::Provider;
use ethers_signers::LocalWallet;
use std::convert::TryFrom;
#[tokio::test]
async fn signs_tx() {
let tx = TransactionRequest {
from: None,
to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
value: Some(1_000_000_000.into()),
gas: Some(2_000_000.into()),
nonce: Some(0.into()),
gas_price: Some(21_000_000_000u128.into()),
data: None,
chain_id: None,
}
.into();
let chain_id = 1u64;
let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
.parse::<LocalWallet>()
.unwrap()
.with_chain_id(chain_id);
let client = SignerMiddleware::new(provider, key);
let tx = client.sign_transaction(tx).await.unwrap();
assert_eq!(
keccak256(&tx)[..],
hex::decode("de8db924885b0803d2edc335f745b2b8750c8848744905684c20b987443a9593")
.unwrap()
);
let expected_rlp = Bytes::from(hex::decode("f869808504e3b29200831e848094f0109fc8df283027b6285cc889f5aa624eac1f55843b9aca008025a0c9cf86333bcb065d140032ecaab5d9281bde80f21b9687b3e94161de42d51895a0727a108a0b8d101465414033c3f705a9c7b826e596766046ee1183dbc8aeaa68").unwrap());
assert_eq!(tx, expected_rlp);
}
#[tokio::test]
async fn signs_tx_none_chainid() {
let tx = TransactionRequest {
from: None,
to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
value: Some(1_000_000_000.into()),
gas: Some(2_000_000.into()),
nonce: Some(U256::zero()),
gas_price: Some(21_000_000_000u128.into()),
data: None,
chain_id: None,
}
.into();
let chain_id = 1337u64;
let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
.parse::<LocalWallet>()
.unwrap()
.with_chain_id(chain_id);
let client = SignerMiddleware::new(provider, key);
let tx = client.sign_transaction(tx).await.unwrap();
let expected_rlp = Bytes::from(hex::decode("f86b808504e3b29200831e848094f0109fc8df283027b6285cc889f5aa624eac1f55843b9aca0080820a95a08290324bae25ca0490077e0d1f4098730333088f6a500793fa420243f35c6b23a06aca42876cd28fdf614a4641e64222fee586391bb3f4061ed5dfefac006be850").unwrap());
assert_eq!(tx, expected_rlp);
}
#[tokio::test]
async fn anvil_consistent_chainid() {
let anvil = Anvil::new().spawn();
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let chain_id = provider.get_chainid().await.unwrap();
assert_eq!(chain_id, U256::from(31337));
let key = LocalWallet::new(&mut rand::thread_rng());
let client = SignerMiddleware::new_with_provider_chain(provider, key).await.unwrap();
let middleware_chainid = client.get_chainid().await.unwrap();
assert_eq!(chain_id, middleware_chainid);
let signer = client.signer();
let signer_chainid = signer.chain_id();
assert_eq!(chain_id.as_u64(), signer_chainid);
}
#[tokio::test]
async fn anvil_consistent_chainid_not_default() {
let anvil = Anvil::new().args(vec!["--chain-id", "13371337"]).spawn();
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let chain_id = provider.get_chainid().await.unwrap();
assert_eq!(chain_id, U256::from(13371337));
let key = LocalWallet::new(&mut rand::thread_rng());
let client = SignerMiddleware::new_with_provider_chain(provider, key).await.unwrap();
let middleware_chainid = client.get_chainid().await.unwrap();
assert_eq!(chain_id, middleware_chainid);
let signer = client.signer();
let signer_chainid = signer.chain_id();
assert_eq!(chain_id.as_u64(), signer_chainid);
}
#[tokio::test]
async fn handles_tx_from_field() {
let anvil = Anvil::new().spawn();
let acc = anvil.addresses()[0];
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let key = LocalWallet::new(&mut rand::thread_rng()).with_chain_id(1u32);
provider
.send_transaction(
TransactionRequest::pay(key.address(), utils::parse_ether(1u64).unwrap()).from(acc),
None,
)
.await
.unwrap()
.await
.unwrap()
.unwrap();
let client = SignerMiddleware::new_with_provider_chain(provider, key).await.unwrap();
let request = TransactionRequest::new();
let request_from_none = request.clone();
let hash = *client.send_transaction(request_from_none, None).await.unwrap();
let tx = client.get_transaction(hash).await.unwrap().unwrap();
assert_eq!(tx.from, client.address());
let request_from_signer = request.clone().from(client.address());
let hash = *client.send_transaction(request_from_signer, None).await.unwrap();
let tx = client.get_transaction(hash).await.unwrap().unwrap();
assert_eq!(tx.from, client.address());
let request_from_other = request.from(acc);
let hash = *client.send_transaction(request_from_other, None).await.unwrap();
let tx = client.get_transaction(hash).await.unwrap().unwrap();
assert_eq!(tx.from, acc);
}
#[tokio::test]
async fn converts_tx_to_legacy_to_match_chain() {
let eip1559 = Eip1559TransactionRequest {
from: None,
to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
value: Some(1_000_000_000.into()),
gas: Some(2_000_000.into()),
nonce: Some(U256::zero()),
access_list: Default::default(),
max_priority_fee_per_gas: None,
data: None,
chain_id: None,
max_fee_per_gas: None,
};
let mut tx = TypedTransaction::Eip1559(eip1559);
let chain_id = 324u64;
let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
.parse::<LocalWallet>()
.unwrap()
.with_chain_id(chain_id);
let client = SignerMiddleware::new(provider, key);
client.fill_transaction(&mut tx, None).await.unwrap();
assert!(tx.as_eip1559_ref().is_none());
assert_eq!(tx, TypedTransaction::Legacy(tx.as_legacy_ref().unwrap().clone()));
}
#[tokio::test]
async fn does_not_convert_to_legacy_for_eip1559_chain() {
let eip1559 = Eip1559TransactionRequest {
from: None,
to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
value: Some(1_000_000_000.into()),
gas: Some(2_000_000.into()),
nonce: Some(U256::zero()),
access_list: Default::default(),
max_priority_fee_per_gas: None,
data: None,
chain_id: None,
max_fee_per_gas: None,
};
let mut tx = TypedTransaction::Eip1559(eip1559);
let chain_id = 1u64;
let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
let provider = Provider::try_from(anvil.endpoint()).unwrap();
let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
.parse::<LocalWallet>()
.unwrap()
.with_chain_id(chain_id);
let client = SignerMiddleware::new(provider, key);
client.fill_transaction(&mut tx, None).await.unwrap();
assert!(tx.as_legacy_ref().is_none());
assert_eq!(tx, TypedTransaction::Eip1559(tx.as_eip1559_ref().unwrap().clone()));
}
}