use alloc::string::String;
use alloc::vec::Vec;
use zeroize::Zeroizing;
use crate::DeriveError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum DerivedPublicKey {
Secp256k1Compressed([u8; 33]),
Secp256k1Uncompressed([u8; 65]),
Ed25519([u8; 32]),
Secp256k1XOnly([u8; 32]),
}
impl DerivedPublicKey {
#[inline]
#[must_use]
pub const fn as_bytes(&self) -> &[u8] {
match self {
Self::Secp256k1Compressed(b) => b,
Self::Secp256k1Uncompressed(b) => b,
Self::Ed25519(b) | Self::Secp256k1XOnly(b) => b,
}
}
#[inline]
#[must_use]
pub const fn byte_len(&self) -> usize {
match self {
Self::Secp256k1Compressed(_) => 33,
Self::Secp256k1Uncompressed(_) => 65,
Self::Ed25519(_) | Self::Secp256k1XOnly(_) => 32,
}
}
#[inline]
#[must_use]
pub fn to_hex(&self) -> String {
hex::encode(self.as_bytes())
}
#[inline]
#[must_use]
pub const fn kind(&self) -> PublicKeyKind {
match self {
Self::Secp256k1Compressed(_) => PublicKeyKind::Secp256k1Compressed,
Self::Secp256k1Uncompressed(_) => PublicKeyKind::Secp256k1Uncompressed,
Self::Ed25519(_) => PublicKeyKind::Ed25519,
Self::Secp256k1XOnly(_) => PublicKeyKind::Secp256k1XOnly,
}
}
pub fn compressed(bytes: &[u8]) -> Result<Self, DeriveError> {
<[u8; 33]>::try_from(bytes)
.map(Self::Secp256k1Compressed)
.map_err(|_| {
DeriveError::Crypto(alloc::format!(
"compressed secp256k1 public key requires 33 bytes, got {}",
bytes.len()
))
})
}
pub fn uncompressed(bytes: &[u8]) -> Result<Self, DeriveError> {
<[u8; 65]>::try_from(bytes)
.map(Self::Secp256k1Uncompressed)
.map_err(|_| {
DeriveError::Crypto(alloc::format!(
"uncompressed secp256k1 public key requires 65 bytes, got {}",
bytes.len()
))
})
}
}
impl AsRef<[u8]> for DerivedPublicKey {
#[inline]
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum PublicKeyKind {
Secp256k1Compressed,
Secp256k1Uncompressed,
Ed25519,
Secp256k1XOnly,
}
impl PublicKeyKind {
#[inline]
#[must_use]
pub const fn byte_len(self) -> usize {
match self {
Self::Secp256k1Compressed => 33,
Self::Secp256k1Uncompressed => 65,
Self::Ed25519 | Self::Secp256k1XOnly => 32,
}
}
#[inline]
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Secp256k1Compressed => "secp256k1-compressed",
Self::Secp256k1Uncompressed => "secp256k1-uncompressed",
Self::Ed25519 => "ed25519",
Self::Secp256k1XOnly => "secp256k1-xonly",
}
}
}
impl core::fmt::Display for PublicKeyKind {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DerivedAccount {
path: String,
private_key: Zeroizing<[u8; 32]>,
public_key: DerivedPublicKey,
address: String,
}
impl DerivedAccount {
#[inline]
#[must_use]
pub const fn new(
path: String,
private_key: Zeroizing<[u8; 32]>,
public_key: DerivedPublicKey,
address: String,
) -> Self {
Self {
path,
private_key,
public_key,
address,
}
}
#[inline]
#[must_use]
pub fn path(&self) -> &str {
&self.path
}
#[inline]
#[must_use]
pub const fn private_key_bytes(&self) -> &Zeroizing<[u8; 32]> {
&self.private_key
}
#[inline]
#[must_use]
pub fn private_key_hex(&self) -> Zeroizing<String> {
Zeroizing::new(hex::encode(*self.private_key))
}
#[inline]
#[must_use]
pub const fn public_key(&self) -> &DerivedPublicKey {
&self.public_key
}
#[inline]
#[must_use]
pub const fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_bytes()
}
#[inline]
#[must_use]
pub fn public_key_hex(&self) -> String {
self.public_key.to_hex()
}
#[inline]
#[must_use]
pub fn address(&self) -> &str {
&self.address
}
}
impl AsRef<Self> for DerivedAccount {
#[inline]
fn as_ref(&self) -> &Self {
self
}
}
pub fn derive_range<T, E, F>(start: u32, count: u32, f: F) -> Result<Vec<T>, E>
where
F: FnMut(u32) -> Result<T, E>,
E: From<DeriveError>,
{
let end = start.checked_add(count).ok_or_else(|| {
E::from(DeriveError::Input(String::from(
"derive_many: start + count overflows u32",
)))
})?;
(start..end).map(f).collect()
}
pub trait Derive {
type Account: AsRef<DerivedAccount>;
type Error: core::fmt::Debug + core::fmt::Display + From<DeriveError>;
fn derive(&self, index: u32) -> Result<Self::Account, Self::Error>;
fn derive_path(&self, path: &str) -> Result<Self::Account, Self::Error>;
}
pub trait DeriveExt: Derive {
#[inline]
fn derive_many(&self, start: u32, count: u32) -> Result<Vec<Self::Account>, Self::Error> {
derive_range(start, count, |i| self.derive(i))
}
}
impl<T: Derive> DeriveExt for T {}
#[cfg(test)]
mod tests {
use super::*;
fn sample_account() -> DerivedAccount {
let mut sk = Zeroizing::new([0u8; 32]);
hex::decode_to_slice(
"1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727",
sk.as_mut_slice(),
)
.unwrap();
let mut pk = [0u8; 33];
hex::decode_to_slice(
"0237b0bb7a8288d38ed49a524b5dc98cff3eb5ca824c9f9dc0dfdb3d9cd600f299",
&mut pk,
)
.unwrap();
DerivedAccount::new(
String::from("m/44'/60'/0'/0/0"),
sk,
DerivedPublicKey::Secp256k1Compressed(pk),
String::from("0x9858EfFD232B4033E47d90003D41EC34EcaEda94"),
)
}
#[test]
fn accessors_expose_all_fields() {
let acct = sample_account();
assert_eq!(acct.path(), "m/44'/60'/0'/0/0");
assert_eq!(acct.private_key_bytes().len(), 32);
assert_eq!(
acct.private_key_hex().as_str(),
"1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727"
);
assert_eq!(acct.public_key().kind(), PublicKeyKind::Secp256k1Compressed);
assert_eq!(acct.public_key().byte_len(), 33);
assert_eq!(acct.public_key_bytes().len(), 33);
assert_eq!(
acct.public_key_hex(),
"0237b0bb7a8288d38ed49a524b5dc98cff3eb5ca824c9f9dc0dfdb3d9cd600f299"
);
assert_eq!(acct.address(), "0x9858EfFD232B4033E47d90003D41EC34EcaEda94");
}
#[test]
fn private_key_hex_is_reversible() {
let acct = sample_account();
let hex = acct.private_key_hex();
let mut decoded = [0u8; 32];
hex::decode_to_slice(hex.as_str(), &mut decoded).unwrap();
assert_eq!(&decoded, acct.private_key_bytes().as_ref());
}
#[test]
fn derived_public_key_compressed_constructor_validates_length() {
let ok = DerivedPublicKey::compressed(&[0x02; 33]).unwrap();
assert_eq!(ok.kind(), PublicKeyKind::Secp256k1Compressed);
assert!(DerivedPublicKey::compressed(&[0u8; 32]).is_err());
assert!(DerivedPublicKey::compressed(&[0u8; 34]).is_err());
}
#[test]
fn derived_public_key_uncompressed_constructor_validates_length() {
let ok = DerivedPublicKey::uncompressed(&[0x04; 65]).unwrap();
assert_eq!(ok.kind(), PublicKeyKind::Secp256k1Uncompressed);
assert!(DerivedPublicKey::uncompressed(&[0u8; 64]).is_err());
assert!(DerivedPublicKey::uncompressed(&[0u8; 66]).is_err());
}
#[test]
fn public_key_kind_length_round_trips() {
let ed = DerivedPublicKey::Ed25519([0u8; 32]);
assert_eq!(ed.byte_len(), PublicKeyKind::Ed25519.byte_len());
let xonly = DerivedPublicKey::Secp256k1XOnly([0u8; 32]);
assert_eq!(xonly.byte_len(), PublicKeyKind::Secp256k1XOnly.byte_len());
}
#[test]
fn derived_account_as_ref_is_identity() {
let acct = sample_account();
let borrowed: &DerivedAccount = acct.as_ref();
assert!(core::ptr::eq(borrowed, &raw const acct));
}
}