init4_bin_base/utils/
signer.rs1use crate::utils::from_env::FromEnv;
2use alloy::{
3 consensus::SignableTransaction,
4 primitives::{Address, ChainId, B256},
5 signers::{
6 aws::{AwsSigner, AwsSignerError},
7 local::{LocalSignerError, PrivateKeySigner},
8 Signature,
9 },
10};
11use aws_config::{load_defaults, BehaviorVersion};
12use aws_sdk_kms::Client;
13use std::borrow::Cow;
14
15#[derive(FromEnv, Debug, Clone)]
26#[from_env(crate)]
27pub struct LocalOrAwsConfig {
28 #[from_env(var = "SIGNER_KEY", desc = "AWS KMS key ID or local private key")]
30 key_info: Cow<'static, str>,
31 #[from_env(var = "SIGNER_CHAIN_ID", desc = "Chain ID for AWS signer", optional)]
33 chain_id: Option<u64>,
34}
35
36impl LocalOrAwsConfig {
37 pub async fn connect_remote(&self) -> Result<LocalOrAws, SignerError> {
39 let signer = LocalOrAws::aws_signer(&self.key_info, self.chain_id).await?;
40 Ok(LocalOrAws::Aws(signer))
41 }
42
43 pub fn connect_local(&self) -> Result<LocalOrAws, SignerError> {
45 Ok(LocalOrAws::Local(LocalOrAws::wallet(&self.key_info)?))
46 }
47
48 pub async fn connect(&self) -> Result<LocalOrAws, SignerError> {
50 if let Ok(local) = self.connect_local() {
51 Ok(local)
52 } else {
53 self.connect_remote().await
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub enum LocalOrAws {
61 Local(PrivateKeySigner),
63 Aws(AwsSigner),
65}
66
67#[derive(Debug, thiserror::Error)]
69pub enum SignerError {
70 #[error("failed to connect AWS signer: {0}")]
72 AwsSigner(#[from] Box<AwsSignerError>),
73 #[error("failed to load private key: {0}")]
75 Wallet(#[from] LocalSignerError),
76 #[error("failed to parse hex: {0}")]
78 Hex(#[from] alloy::hex::FromHexError),
79}
80
81impl From<AwsSignerError> for SignerError {
82 fn from(err: AwsSignerError) -> Self {
83 SignerError::AwsSigner(Box::new(err))
84 }
85}
86
87impl LocalOrAws {
88 pub async fn load(key: &str, chain_id: Option<u64>) -> Result<Self, SignerError> {
90 if let Ok(wallet) = LocalOrAws::wallet(key) {
91 Ok(LocalOrAws::Local(wallet))
92 } else {
93 let signer = LocalOrAws::aws_signer(key, chain_id).await?;
94 Ok(LocalOrAws::Aws(signer))
95 }
96 }
97
98 fn wallet(private_key: &str) -> Result<PrivateKeySigner, SignerError> {
104 let bytes = alloy::hex::decode(private_key.strip_prefix("0x").unwrap_or(private_key))?;
105 Ok(PrivateKeySigner::from_slice(&bytes).unwrap())
106 }
107
108 async fn aws_signer(key_id: &str, chain_id: Option<u64>) -> Result<AwsSigner, SignerError> {
110 let config = load_defaults(BehaviorVersion::latest()).await;
111 let client = Client::new(&config);
112 AwsSigner::new(client, key_id.to_string(), chain_id)
113 .await
114 .map_err(Into::into)
115 }
116}
117
118#[async_trait::async_trait]
119impl alloy::network::TxSigner<Signature> for LocalOrAws {
120 fn address(&self) -> Address {
121 match self {
122 LocalOrAws::Local(signer) => signer.address(),
123 LocalOrAws::Aws(signer) => signer.address(),
124 }
125 }
126
127 async fn sign_transaction(
128 &self,
129 tx: &mut dyn SignableTransaction<Signature>,
130 ) -> alloy::signers::Result<Signature> {
131 match self {
132 LocalOrAws::Local(signer) => signer.sign_transaction(tx).await,
133 LocalOrAws::Aws(signer) => signer.sign_transaction(tx).await,
134 }
135 }
136}
137
138#[async_trait::async_trait]
139impl alloy::signers::Signer<Signature> for LocalOrAws {
140 async fn sign_hash(&self, hash: &B256) -> alloy::signers::Result<Signature> {
142 match self {
143 LocalOrAws::Local(signer) => signer.sign_hash(hash).await,
144 LocalOrAws::Aws(signer) => signer.sign_hash(hash).await,
145 }
146 }
147
148 fn address(&self) -> Address {
150 match self {
151 LocalOrAws::Local(signer) => signer.address(),
152 LocalOrAws::Aws(signer) => signer.address(),
153 }
154 }
155
156 fn chain_id(&self) -> Option<ChainId> {
158 match self {
159 LocalOrAws::Local(signer) => signer.chain_id(),
160 LocalOrAws::Aws(signer) => signer.chain_id(),
161 }
162 }
163
164 fn set_chain_id(&mut self, chain_id: Option<ChainId>) {
166 match self {
167 LocalOrAws::Local(signer) => signer.set_chain_id(chain_id),
168 LocalOrAws::Aws(signer) => signer.set_chain_id(chain_id),
169 }
170 }
171}