1use async_trait::async_trait;
2use ethers_core::types::{
3 transaction::{eip2718::TypedTransaction, eip2930::AccessListWithGasUsed},
4 Address, BlockId, Bytes, Chain, Signature, TransactionRequest, U256,
5};
6use ethers_providers::{maybe, Middleware, MiddlewareError, PendingTransaction};
7use ethers_signers::Signer;
8use thiserror::Error;
9
10#[derive(Clone, Debug)]
11pub struct SignerMiddleware<M, S> {
63 pub(crate) inner: M,
64 pub(crate) signer: S,
65 pub(crate) address: Address,
66}
67
68#[derive(Error, Debug)]
69pub enum SignerMiddlewareError<M: Middleware, S: Signer> {
71 #[error("{0}")]
72 SignerError(S::Error),
74
75 #[error("{0}")]
76 MiddlewareError(M::Error),
78
79 #[error("no nonce was specified")]
81 NonceMissing,
82 #[error("no gas price was specified")]
84 GasPriceMissing,
85 #[error("no gas was specified")]
87 GasMissing,
88 #[error("specified from address is not signer")]
90 WrongSigner,
91 #[error("specified chain_id is different than the signer's chain_id")]
93 DifferentChainID,
94}
95
96impl<M: Middleware, S: Signer> MiddlewareError for SignerMiddlewareError<M, S> {
97 type Inner = M::Error;
98
99 fn from_err(src: M::Error) -> Self {
100 SignerMiddlewareError::MiddlewareError(src)
101 }
102
103 fn as_inner(&self) -> Option<&Self::Inner> {
104 match self {
105 SignerMiddlewareError::MiddlewareError(e) => Some(e),
106 _ => None,
107 }
108 }
109}
110
111impl<M, S> SignerMiddleware<M, S>
113where
114 M: Middleware,
115 S: Signer,
116{
117 pub fn new(inner: M, signer: S) -> Self {
127 let address = signer.address();
128 SignerMiddleware { inner, signer, address }
129 }
130
131 async fn sign_transaction(
136 &self,
137 mut tx: TypedTransaction,
138 ) -> Result<Bytes, SignerMiddlewareError<M, S>> {
139 let chain_id = self.signer.chain_id();
142 match tx.chain_id() {
143 Some(id) if id.as_u64() != chain_id => {
144 return Err(SignerMiddlewareError::DifferentChainID)
145 }
146 None => {
147 tx.set_chain_id(chain_id);
148 }
149 _ => {}
150 }
151
152 let signature =
153 self.signer.sign_transaction(&tx).await.map_err(SignerMiddlewareError::SignerError)?;
154
155 Ok(tx.rlp_signed(&signature))
157 }
158
159 pub fn address(&self) -> Address {
161 self.address
162 }
163
164 pub fn signer(&self) -> &S {
166 &self.signer
167 }
168
169 #[must_use]
171 pub fn with_signer(&self, signer: S) -> Self
172 where
173 S: Clone,
174 M: Clone,
175 {
176 let mut this = self.clone();
177 this.address = signer.address();
178 this.signer = signer;
179 this
180 }
181
182 pub async fn new_with_provider_chain(
190 inner: M,
191 signer: S,
192 ) -> Result<Self, SignerMiddlewareError<M, S>> {
193 let address = signer.address();
194 let chain_id =
195 inner.get_chainid().await.map_err(|e| SignerMiddlewareError::MiddlewareError(e))?;
196 let signer = signer.with_chain_id(chain_id.as_u64());
197 Ok(SignerMiddleware { inner, signer, address })
198 }
199
200 fn set_tx_from_if_none(&self, tx: &TypedTransaction) -> TypedTransaction {
201 let mut tx = tx.clone();
202 if tx.from().is_none() {
203 tx.set_from(self.address);
204 }
205 tx
206 }
207}
208
209#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
210#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
211impl<M, S> Middleware for SignerMiddleware<M, S>
212where
213 M: Middleware,
214 S: Signer,
215{
216 type Error = SignerMiddlewareError<M, S>;
217 type Provider = M::Provider;
218 type Inner = M;
219
220 fn inner(&self) -> &M {
221 &self.inner
222 }
223
224 fn default_sender(&self) -> Option<Address> {
226 Some(self.address)
227 }
228
229 async fn is_signer(&self) -> bool {
231 true
232 }
233
234 async fn sign_transaction(
235 &self,
236 tx: &TypedTransaction,
237 _: Address,
238 ) -> Result<Signature, Self::Error> {
239 Ok(self.signer.sign_transaction(tx).await.map_err(SignerMiddlewareError::SignerError)?)
240 }
241
242 async fn fill_transaction(
244 &self,
245 tx: &mut TypedTransaction,
246 block: Option<BlockId>,
247 ) -> Result<(), Self::Error> {
248 let from = if tx.from().is_some() && tx.from() != Some(&self.address()) {
250 *tx.from().unwrap()
251 } else {
252 self.address
253 };
254 tx.set_from(from);
255
256 let chain_id = self.signer.chain_id();
258 if tx.chain_id().is_none() {
259 tx.set_chain_id(chain_id);
260 }
261
262 if let Some(chain_id) = tx.chain_id() {
265 let chain = Chain::try_from(chain_id.as_u64());
266 if chain.unwrap_or_default().is_legacy() {
267 if let TypedTransaction::Eip1559(inner) = tx {
268 let tx_req: TransactionRequest = inner.clone().into();
269 *tx = TypedTransaction::Legacy(tx_req);
270 }
271 }
272 }
273
274 let nonce = maybe(tx.nonce().cloned(), self.get_transaction_count(from, block)).await?;
275 tx.set_nonce(nonce);
276 self.inner()
277 .fill_transaction(tx, block)
278 .await
279 .map_err(SignerMiddlewareError::MiddlewareError)?;
280 Ok(())
281 }
282
283 async fn send_transaction<T: Into<TypedTransaction> + Send + Sync>(
287 &self,
288 tx: T,
289 block: Option<BlockId>,
290 ) -> Result<PendingTransaction<'_, Self::Provider>, Self::Error> {
291 let mut tx = tx.into();
292
293 self.fill_transaction(&mut tx, block).await?;
295
296 if tx.from().is_some() && tx.from() != Some(&self.address()) {
298 return self
299 .inner
300 .send_transaction(tx, block)
301 .await
302 .map_err(SignerMiddlewareError::MiddlewareError)
303 }
304
305 let signed_tx = self.sign_transaction(tx).await?;
308
309 self.inner
311 .send_raw_transaction(signed_tx)
312 .await
313 .map_err(SignerMiddlewareError::MiddlewareError)
314 }
315
316 async fn sign<T: Into<Bytes> + Send + Sync>(
319 &self,
320 data: T,
321 _: &Address,
322 ) -> Result<Signature, Self::Error> {
323 self.signer.sign_message(data.into()).await.map_err(SignerMiddlewareError::SignerError)
324 }
325
326 async fn estimate_gas(
327 &self,
328 tx: &TypedTransaction,
329 block: Option<BlockId>,
330 ) -> Result<U256, Self::Error> {
331 let tx = self.set_tx_from_if_none(tx);
332 self.inner.estimate_gas(&tx, block).await.map_err(SignerMiddlewareError::MiddlewareError)
333 }
334
335 async fn create_access_list(
336 &self,
337 tx: &TypedTransaction,
338 block: Option<BlockId>,
339 ) -> Result<AccessListWithGasUsed, Self::Error> {
340 let tx = self.set_tx_from_if_none(tx);
341 self.inner
342 .create_access_list(&tx, block)
343 .await
344 .map_err(SignerMiddlewareError::MiddlewareError)
345 }
346
347 async fn call(
348 &self,
349 tx: &TypedTransaction,
350 block: Option<BlockId>,
351 ) -> Result<Bytes, Self::Error> {
352 let tx = self.set_tx_from_if_none(tx);
353 self.inner().call(&tx, block).await.map_err(SignerMiddlewareError::MiddlewareError)
354 }
355}
356
357#[cfg(all(test, not(feature = "celo")))]
358mod tests {
359 use super::*;
360 use ethers_core::{
361 types::{Eip1559TransactionRequest, TransactionRequest},
362 utils::{self, keccak256, Anvil},
363 };
364 use ethers_providers::Provider;
365 use ethers_signers::LocalWallet;
366 use std::convert::TryFrom;
367
368 #[tokio::test]
369 async fn signs_tx() {
370 let tx = TransactionRequest {
373 from: None,
374 to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
375 value: Some(1_000_000_000.into()),
376 gas: Some(2_000_000.into()),
377 nonce: Some(0.into()),
378 gas_price: Some(21_000_000_000u128.into()),
379 data: None,
380 chain_id: None,
381 }
382 .into();
383 let chain_id = 1u64;
384
385 let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
389 let provider = Provider::try_from(anvil.endpoint()).unwrap();
390 let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
391 .parse::<LocalWallet>()
392 .unwrap()
393 .with_chain_id(chain_id);
394 let client = SignerMiddleware::new(provider, key);
395
396 let tx = client.sign_transaction(tx).await.unwrap();
397
398 assert_eq!(
399 keccak256(&tx)[..],
400 hex::decode("de8db924885b0803d2edc335f745b2b8750c8848744905684c20b987443a9593")
401 .unwrap()
402 );
403
404 let expected_rlp = Bytes::from(hex::decode("f869808504e3b29200831e848094f0109fc8df283027b6285cc889f5aa624eac1f55843b9aca008025a0c9cf86333bcb065d140032ecaab5d9281bde80f21b9687b3e94161de42d51895a0727a108a0b8d101465414033c3f705a9c7b826e596766046ee1183dbc8aeaa68").unwrap());
405 assert_eq!(tx, expected_rlp);
406 }
407
408 #[tokio::test]
409 async fn signs_tx_none_chainid() {
410 let tx = TransactionRequest {
415 from: None,
416 to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
417 value: Some(1_000_000_000.into()),
418 gas: Some(2_000_000.into()),
419 nonce: Some(U256::zero()),
420 gas_price: Some(21_000_000_000u128.into()),
421 data: None,
422 chain_id: None,
423 }
424 .into();
425 let chain_id = 1337u64;
426
427 let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
431 let provider = Provider::try_from(anvil.endpoint()).unwrap();
432 let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
433 .parse::<LocalWallet>()
434 .unwrap()
435 .with_chain_id(chain_id);
436 let client = SignerMiddleware::new(provider, key);
437
438 let tx = client.sign_transaction(tx).await.unwrap();
439
440 let expected_rlp = Bytes::from(hex::decode("f86b808504e3b29200831e848094f0109fc8df283027b6285cc889f5aa624eac1f55843b9aca0080820a95a08290324bae25ca0490077e0d1f4098730333088f6a500793fa420243f35c6b23a06aca42876cd28fdf614a4641e64222fee586391bb3f4061ed5dfefac006be850").unwrap());
441 assert_eq!(tx, expected_rlp);
442 }
443
444 #[tokio::test]
445 async fn anvil_consistent_chainid() {
446 let anvil = Anvil::new().spawn();
447 let provider = Provider::try_from(anvil.endpoint()).unwrap();
448 let chain_id = provider.get_chainid().await.unwrap();
449 assert_eq!(chain_id, U256::from(31337));
450
451 let key = LocalWallet::new(&mut rand::thread_rng());
454
455 let client = SignerMiddleware::new_with_provider_chain(provider, key).await.unwrap();
458 let middleware_chainid = client.get_chainid().await.unwrap();
459 assert_eq!(chain_id, middleware_chainid);
460
461 let signer = client.signer();
462 let signer_chainid = signer.chain_id();
463 assert_eq!(chain_id.as_u64(), signer_chainid);
464 }
465
466 #[tokio::test]
467 async fn anvil_consistent_chainid_not_default() {
468 let anvil = Anvil::new().args(vec!["--chain-id", "13371337"]).spawn();
469 let provider = Provider::try_from(anvil.endpoint()).unwrap();
470 let chain_id = provider.get_chainid().await.unwrap();
471 assert_eq!(chain_id, U256::from(13371337));
472
473 let key = LocalWallet::new(&mut rand::thread_rng());
476
477 let client = SignerMiddleware::new_with_provider_chain(provider, key).await.unwrap();
480 let middleware_chainid = client.get_chainid().await.unwrap();
481 assert_eq!(chain_id, middleware_chainid);
482
483 let signer = client.signer();
484 let signer_chainid = signer.chain_id();
485 assert_eq!(chain_id.as_u64(), signer_chainid);
486 }
487
488 #[tokio::test]
489 async fn handles_tx_from_field() {
490 let anvil = Anvil::new().spawn();
491 let acc = anvil.addresses()[0];
492 let provider = Provider::try_from(anvil.endpoint()).unwrap();
493 let key = LocalWallet::new(&mut rand::thread_rng()).with_chain_id(1u32);
494 provider
495 .send_transaction(
496 TransactionRequest::pay(key.address(), utils::parse_ether(1u64).unwrap()).from(acc),
497 None,
498 )
499 .await
500 .unwrap()
501 .await
502 .unwrap()
503 .unwrap();
504 let client = SignerMiddleware::new_with_provider_chain(provider, key).await.unwrap();
505
506 let request = TransactionRequest::new();
507
508 let request_from_none = request.clone();
511 let hash = *client.send_transaction(request_from_none, None).await.unwrap();
512 let tx = client.get_transaction(hash).await.unwrap().unwrap();
513 assert_eq!(tx.from, client.address());
514
515 let request_from_signer = request.clone().from(client.address());
518 let hash = *client.send_transaction(request_from_signer, None).await.unwrap();
519 let tx = client.get_transaction(hash).await.unwrap().unwrap();
520 assert_eq!(tx.from, client.address());
521
522 let request_from_other = request.from(acc);
525 let hash = *client.send_transaction(request_from_other, None).await.unwrap();
526 let tx = client.get_transaction(hash).await.unwrap().unwrap();
527 assert_eq!(tx.from, acc);
528 }
529
530 #[tokio::test]
531 async fn converts_tx_to_legacy_to_match_chain() {
532 let eip1559 = Eip1559TransactionRequest {
533 from: None,
534 to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
535 value: Some(1_000_000_000.into()),
536 gas: Some(2_000_000.into()),
537 nonce: Some(U256::zero()),
538 access_list: Default::default(),
539 max_priority_fee_per_gas: None,
540 data: None,
541 chain_id: None,
542 max_fee_per_gas: None,
543 };
544 let mut tx = TypedTransaction::Eip1559(eip1559);
545
546 let chain_id = 324u64; let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
552 let provider = Provider::try_from(anvil.endpoint()).unwrap();
553 let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
554 .parse::<LocalWallet>()
555 .unwrap()
556 .with_chain_id(chain_id);
557 let client = SignerMiddleware::new(provider, key);
558 client.fill_transaction(&mut tx, None).await.unwrap();
559
560 assert!(tx.as_eip1559_ref().is_none());
561 assert_eq!(tx, TypedTransaction::Legacy(tx.as_legacy_ref().unwrap().clone()));
562 }
563
564 #[tokio::test]
565 async fn does_not_convert_to_legacy_for_eip1559_chain() {
566 let eip1559 = Eip1559TransactionRequest {
567 from: None,
568 to: Some("F0109fC8DF283027b6285cc889F5aA624EaC1F55".parse::<Address>().unwrap().into()),
569 value: Some(1_000_000_000.into()),
570 gas: Some(2_000_000.into()),
571 nonce: Some(U256::zero()),
572 access_list: Default::default(),
573 max_priority_fee_per_gas: None,
574 data: None,
575 chain_id: None,
576 max_fee_per_gas: None,
577 };
578 let mut tx = TypedTransaction::Eip1559(eip1559);
579
580 let chain_id = 1u64; let anvil = Anvil::new().args(vec!["--chain-id".to_string(), chain_id.to_string()]).spawn();
586 let provider = Provider::try_from(anvil.endpoint()).unwrap();
587 let key = "4c0883a69102937d6231471b5dbb6204fe5129617082792ae468d01a3f362318"
588 .parse::<LocalWallet>()
589 .unwrap()
590 .with_chain_id(chain_id);
591 let client = SignerMiddleware::new(provider, key);
592 client.fill_transaction(&mut tx, None).await.unwrap();
593
594 assert!(tx.as_legacy_ref().is_none());
595 assert_eq!(tx, TypedTransaction::Eip1559(tx.as_eip1559_ref().unwrap().clone()));
596 }
597}