use alloc::format;
use alloc::string::String;
use core::str::FromStr;
use ethers::{types::Address as Addr, utils::to_checksum};
use generic_array::{
sequence::Split,
typenum::{U12, U20},
GenericArray,
};
use serde::{Deserialize, Serialize};
use sha3::{digest::Update, Digest, Keccak256};
use umbral_pre::{serde_bytes, PublicKey};
#[derive(PartialEq, Debug, Serialize, Deserialize, Copy, Clone, PartialOrd, Eq, Ord)]
pub struct Address(#[serde(with = "serde_bytes::as_hex")] [u8; Address::SIZE]);
impl Address {
pub const SIZE: usize = 20;
pub fn new(bytes: &[u8; Self::SIZE]) -> Self {
Self(*bytes)
}
pub(crate) fn from_public_key(pk: &PublicKey) -> Self {
let pk_bytes = pk.to_uncompressed_bytes();
let digest = Keccak256::new().chain(&pk_bytes[1..]).finalize();
let (_prefix, address): (GenericArray<u8, U12>, GenericArray<u8, U20>) = digest.split();
Self(address.into())
}
pub fn to_checksum_address(&self) -> String {
to_checksum(&Addr::from(self.0), None)
}
}
impl FromStr for Address {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.strip_prefix("0x").unwrap_or(s);
let bytes = hex::decode(s).map_err(|e| format!("Invalid hex string: {}", e))?;
if bytes.len() != Self::SIZE {
return Err(format!(
"Invalid address length: expected {} bytes, got {} bytes",
Self::SIZE,
bytes.len()
));
}
let mut array = [0u8; Self::SIZE];
array.copy_from_slice(&bytes);
Ok(Self(array))
}
}
impl AsRef<[u8]> for Address {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use umbral_pre::SecretKey;
#[test]
fn test_checksum_address() {
let address_bytes = hex::decode("5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed").unwrap();
let mut array = [0u8; 20];
array.copy_from_slice(&address_bytes);
let address = Address::new(&array);
assert_eq!(
address.to_checksum_address(),
"0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed"
);
let address_bytes2 = hex::decode("fb6916095ca1df60bb79ce92ce3ea74c37c5d359").unwrap();
let mut array2 = [0u8; 20];
array2.copy_from_slice(&address_bytes2);
let address2 = Address::new(&array2);
assert_eq!(
address2.to_checksum_address(),
"0xfB6916095ca1df60bB79Ce92cE3Ea74c37c5d359"
);
}
#[test]
fn test_from_str() {
let address_str = "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed";
let address = Address::from_str(address_str).unwrap();
assert_eq!(address.to_checksum_address(), address_str);
}
#[test]
fn test_from_str_invalid_hex() {
let address_str = "0x5aAzzzzzzz"; let result = Address::from_str(address_str);
assert!(result.unwrap_err().contains("Invalid hex string"));
}
#[test]
fn test_from_str_invalid_length() {
let address_str = "0x5aAeb6053F3E94C9b9A09f3366"; let result = Address::from_str(address_str);
assert!(result.unwrap_err().contains("Invalid address length"));
}
#[test]
fn test_from_public_key() {
let public_key = SecretKey::random().public_key();
let address_from_public_key = Address::from_public_key(&public_key);
let address_from_str =
Address::from_str(&address_from_public_key.to_checksum_address()).unwrap();
assert_eq!(
address_from_str.to_checksum_address(),
address_from_public_key.to_checksum_address()
);
}
}