use alloc::string::String;
use alloc::vec::Vec;
use zeroize::Zeroizing;
use crate::DeriveError;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DerivedAccount {
pub path: String,
pub private_key: Zeroizing<String>,
pub public_key: String,
pub address: String,
}
impl DerivedAccount {
#[must_use]
pub const fn new(
path: String,
private_key: Zeroizing<String>,
public_key: String,
address: String,
) -> Self {
Self {
path,
private_key,
public_key,
address,
}
}
pub fn private_key_bytes(&self) -> Result<Zeroizing<[u8; 32]>, DeriveError> {
let mut buf = Zeroizing::new([0u8; 32]);
hex::decode_to_slice(self.private_key.as_str(), buf.as_mut_slice())
.map_err(|e| DeriveError::InvalidHex(alloc::format!("private_key: {e}")))?;
Ok(buf)
}
pub fn public_key_bytes(&self) -> Result<Vec<u8>, DeriveError> {
hex::decode(&self.public_key)
.map_err(|e| DeriveError::InvalidHex(alloc::format!("public_key: {e}")))
}
}
pub trait Derive {
type Error: core::fmt::Debug + core::fmt::Display + From<DeriveError>;
fn derive(&self, index: u32) -> Result<DerivedAccount, Self::Error>;
fn derive_path(&self, path: &str) -> Result<DerivedAccount, Self::Error>;
}
pub trait DeriveExt: Derive {
fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAccount>, Self::Error> {
let end = start.checked_add(count).ok_or(DeriveError::IndexOverflow)?;
(start..end).map(|i| self.derive(i)).collect()
}
}
impl<T: Derive> DeriveExt for T {}
#[cfg(test)]
mod tests {
use super::*;
fn sample_account() -> DerivedAccount {
DerivedAccount::new(
String::from("m/44'/60'/0'/0/0"),
Zeroizing::new(String::from(
"1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727",
)),
String::from("0237b0bb7a8288d38ed49a524b5dc98cff3eb5ca824c9f9dc0dfdb3d9cd600f299"),
String::from("0x9858EfFD232B4033E47d90003D41EC34EcaEda94"),
)
}
#[test]
fn private_key_bytes_roundtrip() {
let acct = sample_account();
let bytes = acct.private_key_bytes().unwrap();
assert_eq!(bytes.len(), 32);
assert_eq!(hex::encode(*bytes), acct.private_key.as_str());
}
#[test]
fn public_key_bytes_roundtrip() {
let acct = sample_account();
let bytes = acct.public_key_bytes().unwrap();
assert_eq!(bytes.len(), 33);
assert_eq!(hex::encode(&bytes), acct.public_key);
}
#[test]
fn private_key_bytes_rejects_short_hex() {
let bad = DerivedAccount::new(
String::from("m/0"),
Zeroizing::new(String::from("deadbeef")),
String::new(),
String::new(),
);
assert!(matches!(
bad.private_key_bytes(),
Err(DeriveError::InvalidHex(_))
));
}
#[test]
fn public_key_bytes_rejects_non_hex() {
let bad = DerivedAccount::new(
String::from("m/0"),
Zeroizing::new(String::new()),
String::from("not-hex!"),
String::new(),
);
assert!(matches!(
bad.public_key_bytes(),
Err(DeriveError::InvalidHex(_))
));
}
}