use crate::eth::EthError;
use hex;
use rand::{thread_rng, RngCore};
use serde::{Deserialize, Serialize};
use sha3::{Digest, Sha3_256};
use thiserror::Error;
use alloy::{
consensus::{SignableTransaction, TxEip1559, TxEnvelope},
network::eip2718::Encodable2718,
network::TxSignerSync,
primitives::TxKind,
signers::{local::PrivateKeySigner, SignerSync},
};
use alloy_primitives::{Address as EthAddress, B256, U256};
use std::str::FromStr;
const SALT_SIZE: usize = 16;
const NONCE_SIZE: usize = 12;
const KEY_SIZE: usize = 32;
const TAG_SIZE: usize = 16;
#[derive(Debug, Clone)]
pub struct TransactionData {
pub to: EthAddress,
pub value: U256,
pub data: Option<Vec<u8>>,
pub nonce: u64,
pub gas_limit: u64,
pub gas_price: u128,
pub max_priority_fee: Option<u128>,
pub chain_id: u64,
}
#[derive(Debug, Error)]
pub enum SignerError {
#[error("failed to generate random bytes: {0}")]
RandomGenerationError(String),
#[error("invalid private key format: {0}")]
InvalidPrivateKey(String),
#[error("chain ID mismatch: expected {expected}, got {actual}")]
ChainIdMismatch { expected: u64, actual: u64 },
#[error("failed to sign transaction or message: {0}")]
SigningError(String),
#[error("ethereum error: {0}")]
EthError(#[from] EthError),
#[error("encryption error: {0}")]
EncryptionError(String),
#[error("decryption error: {0}")]
DecryptionError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedSignerData {
pub encrypted_data: Vec<u8>,
pub address: String,
pub chain_id: u64,
}
pub trait Signer {
fn address(&self) -> EthAddress;
fn chain_id(&self) -> u64;
fn sign_transaction(&self, tx_data: &TransactionData) -> Result<Vec<u8>, SignerError>;
fn sign_message(&self, message: &[u8]) -> Result<Vec<u8>, SignerError>;
fn sign_hash(&self, hash: &[u8]) -> Result<Vec<u8>, SignerError>;
}
#[derive(Debug, Clone)]
pub struct LocalSigner {
pub inner: PrivateKeySigner,
pub address: EthAddress,
pub chain_id: u64,
pub private_key_hex: String,
}
impl Serialize for LocalSigner {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("LocalSigner", 3)?;
state.serialize_field("address", &self.address)?;
state.serialize_field("chain_id", &self.chain_id)?;
state.serialize_field("private_key_hex", &self.private_key_hex)?;
state.end()
}
}
impl<'de> Deserialize<'de> for LocalSigner {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct LocalSignerData {
#[allow(dead_code)]
address: EthAddress,
chain_id: u64,
private_key_hex: String,
}
let data = LocalSignerData::deserialize(deserializer)?;
match LocalSigner::from_private_key(&data.private_key_hex, data.chain_id) {
Ok(signer) => Ok(signer),
Err(e) => Err(serde::de::Error::custom(format!(
"Failed to reconstruct signer: {}",
e
))),
}
}
}
impl LocalSigner {
pub fn new_random(chain_id: u64) -> Result<Self, SignerError> {
let mut rng = thread_rng();
let mut private_key_bytes = [0u8; 32];
rng.fill_bytes(&mut private_key_bytes);
let max_scalar =
hex::decode("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364140")
.map_err(|_| {
SignerError::RandomGenerationError("Failed to decode max scalar".to_string())
})?;
if private_key_bytes.as_slice().cmp(max_scalar.as_slice()) != std::cmp::Ordering::Less {
rng.fill_bytes(&mut private_key_bytes);
}
let key = B256::from_slice(&private_key_bytes);
let private_key_hex = format!("0x{}", hex::encode(private_key_bytes));
let inner = match PrivateKeySigner::from_bytes(&key) {
Ok(signer) => signer,
Err(e) => return Err(SignerError::InvalidPrivateKey(e.to_string())),
};
let address = inner.address();
Ok(Self {
inner,
address,
chain_id,
private_key_hex,
})
}
pub fn from_private_key(private_key: &str, chain_id: u64) -> Result<Self, SignerError> {
let clean_key = private_key.trim_start_matches("0x");
if clean_key.len() != 64 {
return Err(SignerError::InvalidPrivateKey(
"Private key must be 32 bytes (64 hex characters)".to_string(),
));
}
let key_bytes =
hex::decode(clean_key).map_err(|e| SignerError::InvalidPrivateKey(e.to_string()))?;
Self::from_bytes(&key_bytes, chain_id, format!("0x{}", clean_key))
}
fn from_bytes(
bytes: &[u8],
chain_id: u64,
private_key_hex: String,
) -> Result<Self, SignerError> {
if bytes.len() != 32 {
return Err(SignerError::InvalidPrivateKey(
"Private key must be exactly 32 bytes".to_string(),
));
}
let key = B256::from_slice(bytes);
let inner = match PrivateKeySigner::from_bytes(&key) {
Ok(wallet) => wallet,
Err(e) => return Err(SignerError::InvalidPrivateKey(e.to_string())),
};
let address = inner.address();
Ok(Self {
inner,
address,
chain_id,
private_key_hex,
})
}
pub fn encrypt(&self, password: &str) -> Result<EncryptedSignerData, SignerError> {
let clean_key = self.private_key_hex.trim_start_matches("0x");
let key_bytes =
hex::decode(clean_key).map_err(|e| SignerError::EncryptionError(e.to_string()))?;
let encrypted_data =
encrypt_data(&key_bytes, password).map_err(|e| SignerError::EncryptionError(e))?;
Ok(EncryptedSignerData {
encrypted_data,
address: self.address.to_string(),
chain_id: self.chain_id,
})
}
pub fn decrypt(encrypted: &EncryptedSignerData, password: &str) -> Result<Self, SignerError> {
let decrypted_bytes = decrypt_data(&encrypted.encrypted_data, password)
.map_err(|e| SignerError::DecryptionError(e))?;
let private_key_hex = format!("0x{}", hex::encode(&decrypted_bytes));
Self::from_bytes(&decrypted_bytes, encrypted.chain_id, private_key_hex)
}
pub fn export_private_key(&self) -> String {
self.private_key_hex.clone()
}
}
impl Signer for LocalSigner {
fn address(&self) -> EthAddress {
self.address
}
fn chain_id(&self) -> u64 {
self.chain_id
}
fn sign_transaction(&self, tx_data: &TransactionData) -> Result<Vec<u8>, SignerError> {
if tx_data.chain_id != self.chain_id {
return Err(SignerError::ChainIdMismatch {
expected: self.chain_id,
actual: tx_data.chain_id,
});
}
let to_str = tx_data.to.to_string();
let to = alloy_primitives::Address::from_str(&to_str)
.map_err(|e| SignerError::SigningError(format!("Invalid contract address: {}", e)))?;
let mut tx = TxEip1559 {
chain_id: tx_data.chain_id,
nonce: tx_data.nonce,
to: TxKind::Call(to),
gas_limit: tx_data.gas_limit,
max_fee_per_gas: tx_data.gas_price,
max_priority_fee_per_gas: tx_data.max_priority_fee.unwrap_or_else(|| {
match tx_data.chain_id {
1 => tx_data.gas_price / 10,
8453 => tx_data.gas_price / 5,
_ => tx_data.gas_price / 10,
}
}),
input: tx_data.data.clone().unwrap_or_default().into(),
value: tx_data.value,
..Default::default()
};
let sig = match self.inner.sign_transaction_sync(&mut tx) {
Ok(sig) => sig,
Err(e) => return Err(SignerError::SigningError(e.to_string())),
};
let signed = TxEnvelope::from(tx.into_signed(sig));
let mut buf = vec![];
signed.encode_2718(&mut buf);
Ok(buf)
}
fn sign_message(&self, message: &[u8]) -> Result<Vec<u8>, SignerError> {
let prefix = format!("\x19Ethereum Signed Message:\n{}", message.len());
let prefixed_message = [prefix.as_bytes(), message].concat();
let hash = sha3::Keccak256::digest(&prefixed_message);
let hash_bytes = B256::from_slice(hash.as_slice());
match self.inner.sign_hash_sync(&hash_bytes) {
Ok(signature) => Ok(signature.as_bytes().to_vec()),
Err(e) => Err(SignerError::SigningError(e.to_string())),
}
}
fn sign_hash(&self, hash: &[u8]) -> Result<Vec<u8>, SignerError> {
if hash.len() != 32 {
return Err(SignerError::SigningError(
"Hash must be exactly 32 bytes".to_string(),
));
}
let hash_bytes = B256::from_slice(hash);
match self.inner.sign_hash_sync(&hash_bytes) {
Ok(signature) => Ok(signature.as_bytes().to_vec()),
Err(e) => Err(SignerError::SigningError(e.to_string())),
}
}
}
pub fn encrypt_data(data: &[u8], password: &str) -> Result<Vec<u8>, String> {
let mut rng = thread_rng();
let mut salt = [0u8; SALT_SIZE];
rng.fill_bytes(&mut salt);
let mut nonce = [0u8; NONCE_SIZE];
rng.fill_bytes(&mut nonce);
let key = derive_key(password.as_bytes(), &salt);
let encrypted_data = encrypt_with_key(data, &key, &nonce);
let tag = compute_tag(&salt, &nonce, &encrypted_data, &key);
Ok([
salt.as_ref(),
nonce.as_ref(),
encrypted_data.as_ref(),
tag.as_ref(),
]
.concat())
}
pub fn decrypt_data(encrypted_data: &[u8], password: &str) -> Result<Vec<u8>, String> {
if encrypted_data.len() < SALT_SIZE + NONCE_SIZE + TAG_SIZE {
return Err("Encrypted data is too short".into());
}
let salt = &encrypted_data[..SALT_SIZE];
let nonce = &encrypted_data[SALT_SIZE..SALT_SIZE + NONCE_SIZE];
let tag = &encrypted_data[encrypted_data.len() - TAG_SIZE..];
let ciphertext = &encrypted_data[SALT_SIZE + NONCE_SIZE..encrypted_data.len() - TAG_SIZE];
let key = derive_key(password.as_bytes(), salt);
let expected_tag = compute_tag(salt, nonce, ciphertext, &key);
if tag != expected_tag {
return Err("Decryption failed: Authentication tag mismatch".into());
}
let plaintext = decrypt_with_key(ciphertext, &key, nonce);
Ok(plaintext)
}
fn derive_key(password: &[u8], salt: &[u8]) -> [u8; KEY_SIZE] {
let mut hasher = Sha3_256::new();
hasher.update(salt);
hasher.update(password);
let mut key = hasher.finalize().into();
for _ in 0..10000 {
let mut hasher = Sha3_256::new();
hasher.update(key);
hasher.update(salt);
key = hasher.finalize().into();
}
key
}
fn encrypt_with_key(data: &[u8], key: &[u8; KEY_SIZE], nonce: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len());
for (i, &byte) in data.iter().enumerate() {
let key_byte = key[i % key.len()];
let nonce_byte = nonce[i % nonce.len()];
let keystream = key_byte ^ nonce_byte ^ (i as u8);
result.push(byte ^ keystream);
}
result
}
fn decrypt_with_key(data: &[u8], key: &[u8; KEY_SIZE], nonce: &[u8]) -> Vec<u8> {
encrypt_with_key(data, key, nonce)
}
fn compute_tag(salt: &[u8], nonce: &[u8], data: &[u8], key: &[u8]) -> Vec<u8> {
let mut hasher = Sha3_256::new();
hasher.update(salt);
hasher.update(nonce);
hasher.update(data);
hasher.update(key);
let hash = hasher.finalize();
hash[..TAG_SIZE].to_vec()
}