use ethers_core::{
types::{
Address, BlockNumber, Bytes, NameOrAddress, Signature, Transaction, TransactionRequest,
U256,
},
utils::keccak256,
};
use ethers_providers::{FromErr, Middleware, PendingTransaction};
use ethers_signers::Signer;
use async_trait::async_trait;
use futures_util::{future::ok, join};
use std::future::Future;
use thiserror::Error;
#[derive(Clone, Debug)]
pub struct SignerMiddleware<M, S> {
pub(crate) inner: M,
pub(crate) signer: S,
pub(crate) address: Address,
}
impl<M: Middleware, S: Signer> FromErr<M::Error> for SignerMiddlewareError<M, S> {
fn from(src: M::Error) -> SignerMiddlewareError<M, S> {
SignerMiddlewareError::MiddlewareError(src)
}
}
#[derive(Error, Debug)]
pub enum SignerMiddlewareError<M: Middleware, S: Signer> {
#[error("{0}")]
SignerError(S::Error),
#[error("{0}")]
MiddlewareError(M::Error),
#[error("no nonce was specified")]
NonceMissing,
#[error("no gas price was specified")]
GasPriceMissing,
#[error("no gas was specified")]
GasMissing,
}
impl<M, S> SignerMiddleware<M, S>
where
M: Middleware,
S: Signer,
{
pub fn new(inner: M, signer: S) -> Self {
let address = signer.address();
SignerMiddleware {
inner,
signer,
address,
}
}
async fn sign_transaction(
&self,
tx: TransactionRequest,
) -> Result<Transaction, SignerMiddlewareError<M, S>> {
let nonce = tx.nonce.ok_or(SignerMiddlewareError::NonceMissing)?;
let gas_price = tx.gas_price.ok_or(SignerMiddlewareError::GasPriceMissing)?;
let gas = tx.gas.ok_or(SignerMiddlewareError::GasMissing)?;
let signature = self
.signer
.sign_transaction(&tx)
.await
.map_err(SignerMiddlewareError::SignerError)?;
let rlp = tx.rlp_signed(&signature);
let hash = keccak256(&rlp.as_ref());
let to = tx.to.map(|to| match to {
NameOrAddress::Address(inner) => inner,
NameOrAddress::Name(_) => {
panic!("Expected `to` to be an Ethereum Address, not an ENS name")
}
});
Ok(Transaction {
hash: hash.into(),
nonce,
from: self.address(),
to,
value: tx.value.unwrap_or_default(),
gas_price,
gas,
input: tx.data.unwrap_or_default(),
v: signature.v.into(),
r: U256::from_big_endian(signature.r.as_bytes()),
s: U256::from_big_endian(signature.s.as_bytes()),
block_hash: None,
block_number: None,
transaction_index: None,
#[cfg(feature = "celo")]
fee_currency: tx.fee_currency,
#[cfg(feature = "celo")]
gateway_fee: tx.gateway_fee,
#[cfg(feature = "celo")]
gateway_fee_recipient: tx.gateway_fee_recipient,
})
}
async fn fill_transaction(
&self,
tx: &mut TransactionRequest,
block: Option<BlockNumber>,
) -> Result<(), SignerMiddlewareError<M, S>> {
if tx.from.is_none() {
tx.from = Some(self.address());
}
let (gas_price, gas, nonce) = join!(
maybe(tx.gas_price, self.inner.get_gas_price()),
maybe(tx.gas, self.inner.estimate_gas(&tx)),
maybe(
tx.nonce,
self.inner.get_transaction_count(self.address(), block)
),
);
tx.gas_price = Some(gas_price.map_err(SignerMiddlewareError::MiddlewareError)?);
tx.gas = Some(gas.map_err(SignerMiddlewareError::MiddlewareError)?);
tx.nonce = Some(nonce.map_err(SignerMiddlewareError::MiddlewareError)?);
Ok(())
}
pub fn address(&self) -> Address {
self.address
}
pub fn signer(&self) -> &S {
&self.signer
}
pub fn with_signer(&self, signer: S) -> Self
where
S: Clone,
M: Clone,
{
let mut this = self.clone();
this.address = signer.address();
this.signer = signer;
this
}
}
#[async_trait]
impl<M, S> Middleware for SignerMiddleware<M, S>
where
M: Middleware,
S: Signer,
{
type Error = SignerMiddlewareError<M, S>;
type Provider = M::Provider;
type Inner = M;
fn inner(&self) -> &M {
&self.inner
}
async fn send_transaction(
&self,
mut tx: TransactionRequest,
block: Option<BlockNumber>,
) -> Result<PendingTransaction<'_, Self::Provider>, Self::Error> {
if let Some(ref to) = tx.to {
if let NameOrAddress::Name(ens_name) = to {
let addr = self
.inner
.resolve_name(&ens_name)
.await
.map_err(SignerMiddlewareError::MiddlewareError)?;
tx.to = Some(addr.into())
}
}
self.fill_transaction(&mut tx, block).await?;
let signed_tx = self.sign_transaction(tx).await?;
self.inner
.send_raw_transaction(&signed_tx)
.await
.map_err(SignerMiddlewareError::MiddlewareError)
}
async fn sign<T: Into<Bytes> + Send + Sync>(
&self,
data: T,
_: &Address,
) -> Result<Signature, Self::Error> {
self.signer
.sign_message(data.into())
.await
.map_err(SignerMiddlewareError::SignerError)
}
}
async fn maybe<F, T, E>(item: Option<T>, f: F) -> Result<T, E>
where
F: Future<Output = Result<T, E>>,
{
if let Some(item) = item {
ok(item).await
} else {
f.await
}
}
#[cfg(all(test, not(feature = "celo")))]
mod tests {
use super::*;
use ethers::{providers::Provider, signers::LocalWallet};
use std::convert::TryFrom;
#[tokio::test]
async fn signs_tx() {
let tx = TransactionRequest {
from: None,
to: Some(
"F0109fC8DF283027b6285cc889F5aA624EaC1F55"
.parse::<Address>()
.unwrap()
.into(),
),
value: Some(1_000_000_000.into()),
gas: Some(2_000_000.into()),
nonce: Some(0.into()),
gas_price: Some(21_000_000_000u128.into()),
data: None,
};
let chain_id = 1u64;
let provider = Provider::try_from("http://localhost:8545").unwrap();
let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
.parse::<LocalWallet>()
.unwrap()
.set_chain_id(chain_id);
let client = SignerMiddleware::new(provider, key);
let tx = client.sign_transaction(tx).await.unwrap();
assert_eq!(
tx.hash,
"de8db924885b0803d2edc335f745b2b8750c8848744905684c20b987443a9593"
.parse()
.unwrap()
);
let expected_rlp = Bytes::from(hex::decode("f869808504e3b29200831e848094f0109fc8df283027b6285cc889f5aa624eac1f55843b9aca008025a0c9cf86333bcb065d140032ecaab5d9281bde80f21b9687b3e94161de42d51895a0727a108a0b8d101465414033c3f705a9c7b826e596766046ee1183dbc8aeaa68").unwrap());
assert_eq!(tx.rlp(), expected_rlp);
}
}