use alloc::string::String;
use alloc::vec::Vec;
use zeroize::Zeroizing;
use crate::DeriveError;
#[derive(Debug, Clone)]
pub struct DerivedAccount {
path: String,
private_key: Zeroizing<[u8; 32]>,
public_key: Vec<u8>,
address: String,
}
impl DerivedAccount {
#[inline]
#[must_use]
pub const fn new(
path: String,
private_key: Zeroizing<[u8; 32]>,
public_key: Vec<u8>,
address: String,
) -> Self {
Self {
path,
private_key,
public_key,
address,
}
}
#[inline]
#[must_use]
pub fn path(&self) -> &str {
&self.path
}
#[inline]
#[must_use]
pub const fn private_key_bytes(&self) -> &Zeroizing<[u8; 32]> {
&self.private_key
}
#[inline]
#[must_use]
pub fn private_key_hex(&self) -> Zeroizing<String> {
Zeroizing::new(hex::encode(*self.private_key))
}
#[inline]
#[must_use]
pub fn public_key_bytes(&self) -> &[u8] {
&self.public_key
}
#[inline]
#[must_use]
pub fn public_key_hex(&self) -> String {
hex::encode(&self.public_key)
}
#[inline]
#[must_use]
pub fn address(&self) -> &str {
&self.address
}
}
pub fn derive_range<T, E, F>(start: u32, count: u32, f: F) -> Result<Vec<T>, E>
where
F: FnMut(u32) -> Result<T, E>,
E: From<DeriveError>,
{
let end = start.checked_add(count).ok_or_else(|| {
E::from(DeriveError::Input(String::from(
"derive_many: start + count overflows u32",
)))
})?;
(start..end).map(f).collect()
}
pub trait Derive {
type Error: core::fmt::Debug + core::fmt::Display + From<DeriveError>;
fn derive(&self, index: u32) -> Result<DerivedAccount, Self::Error>;
fn derive_path(&self, path: &str) -> Result<DerivedAccount, Self::Error>;
}
pub trait DeriveExt: Derive {
#[inline]
fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAccount>, Self::Error> {
derive_range(start, count, |i| self.derive(i))
}
}
impl<T: Derive> DeriveExt for T {}
#[cfg(test)]
mod tests {
use super::*;
fn sample_account() -> DerivedAccount {
let mut sk = Zeroizing::new([0u8; 32]);
hex::decode_to_slice(
"1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727",
sk.as_mut_slice(),
)
.unwrap();
DerivedAccount::new(
String::from("m/44'/60'/0'/0/0"),
sk,
hex::decode("0237b0bb7a8288d38ed49a524b5dc98cff3eb5ca824c9f9dc0dfdb3d9cd600f299")
.unwrap(),
String::from("0x9858EfFD232B4033E47d90003D41EC34EcaEda94"),
)
}
#[test]
fn accessors_expose_all_fields() {
let acct = sample_account();
assert_eq!(acct.path(), "m/44'/60'/0'/0/0");
assert_eq!(acct.private_key_bytes().len(), 32);
assert_eq!(
acct.private_key_hex().as_str(),
"1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727"
);
assert_eq!(acct.public_key_bytes().len(), 33);
assert_eq!(
acct.public_key_hex(),
"0237b0bb7a8288d38ed49a524b5dc98cff3eb5ca824c9f9dc0dfdb3d9cd600f299"
);
assert_eq!(acct.address(), "0x9858EfFD232B4033E47d90003D41EC34EcaEda94");
}
#[test]
fn private_key_hex_is_reversible() {
let acct = sample_account();
let hex = acct.private_key_hex();
let mut decoded = [0u8; 32];
hex::decode_to_slice(hex.as_str(), &mut decoded).unwrap();
assert_eq!(&decoded, acct.private_key_bytes().as_ref());
}
}