1use alloc::format;
2use alloc::string::String;
3use core::str::FromStr;
4
5use ethers::{types::Address as Addr, utils::to_checksum};
6use generic_array::{
7 sequence::Split,
8 typenum::{U12, U20},
9 GenericArray,
10};
11use serde::{Deserialize, Serialize};
12use sha3::{digest::Update, Digest, Keccak256};
13use umbral_pre::{serde_bytes, PublicKey};
14
15#[derive(PartialEq, Debug, Serialize, Deserialize, Copy, Clone, PartialOrd, Eq, Ord)]
22pub struct Address(#[serde(with = "serde_bytes::as_hex")] [u8; Address::SIZE]);
23
24impl Address {
25 pub const SIZE: usize = 20;
27
28 pub fn new(bytes: &[u8; Self::SIZE]) -> Self {
30 Self(*bytes)
31 }
32
33 pub(crate) fn from_public_key(pk: &PublicKey) -> Self {
34 let pk_bytes = pk.to_uncompressed_bytes();
37 let digest = Keccak256::new().chain(&pk_bytes[1..]).finalize();
38
39 let (_prefix, address): (GenericArray<u8, U12>, GenericArray<u8, U20>) = digest.split();
40
41 Self(address.into())
42 }
43
44 pub fn to_checksum_address(&self) -> String {
46 to_checksum(&Addr::from(self.0), None)
47 }
48}
49
50impl FromStr for Address {
51 type Err = String;
52
53 fn from_str(s: &str) -> Result<Self, Self::Err> {
54 let s = s.strip_prefix("0x").unwrap_or(s);
55 let bytes = hex::decode(s).map_err(|e| format!("Invalid hex string: {}", e))?;
56 if bytes.len() != Self::SIZE {
57 return Err(format!(
58 "Invalid address length: expected {} bytes, got {} bytes",
59 Self::SIZE,
60 bytes.len()
61 ));
62 }
63 let mut array = [0u8; Self::SIZE];
64 array.copy_from_slice(&bytes);
65 Ok(Self(array))
66 }
67}
68
69impl AsRef<[u8]> for Address {
70 fn as_ref(&self) -> &[u8] {
71 &self.0
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use umbral_pre::SecretKey;
79
80 #[test]
81 fn test_checksum_address() {
82 let address_bytes = hex::decode("5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed").unwrap();
84 let mut array = [0u8; 20];
85 array.copy_from_slice(&address_bytes);
86 let address = Address::new(&array);
87
88 assert_eq!(
89 address.to_checksum_address(),
90 "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed"
91 );
92
93 let address_bytes2 = hex::decode("fb6916095ca1df60bb79ce92ce3ea74c37c5d359").unwrap();
95 let mut array2 = [0u8; 20];
96 array2.copy_from_slice(&address_bytes2);
97 let address2 = Address::new(&array2);
98
99 assert_eq!(
100 address2.to_checksum_address(),
101 "0xfB6916095ca1df60bB79Ce92cE3Ea74c37c5d359"
102 );
103 }
104
105 #[test]
106 fn test_from_str() {
107 let address_str = "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed";
108 let address = Address::from_str(address_str).unwrap();
109 assert_eq!(address.to_checksum_address(), address_str);
110 }
111
112 #[test]
113 fn test_from_str_invalid_hex() {
114 let address_str = "0x5aAzzzzzzz"; let result = Address::from_str(address_str);
116 assert!(result.unwrap_err().contains("Invalid hex string"));
117 }
118
119 #[test]
120 fn test_from_str_invalid_length() {
121 let address_str = "0x5aAeb6053F3E94C9b9A09f3366"; let result = Address::from_str(address_str);
123 assert!(result.unwrap_err().contains("Invalid address length"));
124 }
125
126 #[test]
127 fn test_from_public_key() {
128 let public_key = SecretKey::random().public_key();
129 let address_from_public_key = Address::from_public_key(&public_key);
130 let address_from_str =
131 Address::from_str(&address_from_public_key.to_checksum_address()).unwrap();
132 assert_eq!(
133 address_from_str.to_checksum_address(),
134 address_from_public_key.to_checksum_address()
135 );
136 }
137}