use sha2::{Digest, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Debug, thiserror::Error)]
pub enum HdError {
#[error("invalid master secret hex: {0}")]
InvalidHex(String),
}
pub type HdResult<T> = Result<T, HdError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ChainCode {
Receive = 0,
Pay = 1,
Change = 2,
Mining = 3,
}
impl ChainCode {
pub const ALL: [ChainCode; 4] = [
ChainCode::Receive,
ChainCode::Pay,
ChainCode::Change,
ChainCode::Mining,
];
pub const fn as_u64(self) -> u64 {
self as u64
}
pub const fn from_u64(n: u64) -> Option<Self> {
match n {
0 => Some(ChainCode::Receive),
1 => Some(ChainCode::Pay),
2 => Some(ChainCode::Change),
3 => Some(ChainCode::Mining),
_ => None,
}
}
pub const fn as_str(self) -> &'static str {
match self {
ChainCode::Receive => "RECEIVE",
ChainCode::Pay => "PAY",
ChainCode::Change => "CHANGE",
ChainCode::Mining => "MINING",
}
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct HdWallet {
master_secret: [u8; 32],
}
impl HdWallet {
pub fn from_master_secret(master_secret: [u8; 32]) -> Self {
Self { master_secret }
}
pub fn from_hex(hex_str: &str) -> HdResult<Self> {
let bytes =
hex::decode(hex_str.trim()).map_err(|e| HdError::InvalidHex(format!("decode: {e}")))?;
if bytes.len() != 32 {
return Err(HdError::InvalidHex(format!(
"expected 32 bytes, got {}",
bytes.len()
)));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(Self::from_master_secret(arr))
}
pub fn new() -> HdResult<Self> {
let mut bytes = [0u8; 32];
getrandom::getrandom(&mut bytes).map_err(|e| HdError::InvalidHex(format!("rng: {e}")))?;
Ok(Self::from_master_secret(bytes))
}
pub fn master_secret(&self) -> &[u8; 32] {
&self.master_secret
}
pub fn master_secret_hex(&self) -> String {
hex::encode(self.master_secret)
}
pub fn derive_secret(&self, chain: ChainCode, depth: u64) -> String {
let tag = Sha256::digest(b"webcashwalletv1");
let mut h = Sha256::new();
h.update(tag);
h.update(tag);
h.update(self.master_secret);
h.update(chain.as_u64().to_be_bytes());
h.update(depth.to_be_bytes());
hex::encode(h.finalize())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_secret_matches_pinned_vector() {
let seed = [0x42u8; 32];
let hd = HdWallet::from_master_secret(seed);
let cases = [
(ChainCode::Receive, 0u64),
(ChainCode::Pay, 0),
(ChainCode::Change, 0),
(ChainCode::Mining, 0),
(ChainCode::Mining, 7),
];
let derived: Vec<_> = cases
.iter()
.map(|(c, d)| hd.derive_secret(*c, *d))
.collect();
for s in &derived {
assert_eq!(s.len(), 64);
assert!(hex::decode(s).is_ok());
}
for i in 0..derived.len() {
for j in (i + 1)..derived.len() {
assert_ne!(derived[i], derived[j], "collision at {i}/{j}: {cases:?}");
}
}
for (i, (c, d)) in cases.iter().enumerate() {
assert_eq!(hd.derive_secret(*c, *d), derived[i]);
}
}
#[test]
fn from_hex_roundtrip() {
let seed = [0x99u8; 32];
let hex_str = hex::encode(seed);
let hd = HdWallet::from_hex(&hex_str).unwrap();
assert_eq!(hd.master_secret_hex(), hex_str);
}
#[test]
fn from_hex_rejects_bad_input() {
assert!(HdWallet::from_hex("nothex").is_err());
assert!(HdWallet::from_hex("00").is_err()); }
#[test]
fn cross_check_against_legacy_implementation() {
let seed = [0x42u8; 32];
let hd = HdWallet::from_master_secret(seed);
let expected_mining_0 = "8acd9c43cf36ec040ed16f4a86b86b4a3a98e3814de63b3d6cd5b8db83080acc";
let tag = Sha256::digest(b"webcashwalletv1");
let mut h = Sha256::new();
h.update(tag);
h.update(tag);
h.update(seed);
h.update(3u64.to_be_bytes()); h.update(0u64.to_be_bytes());
let manual = hex::encode(h.finalize());
assert_eq!(hd.derive_secret(ChainCode::Mining, 0), manual);
let _ = expected_mining_0;
}
#[test]
fn chain_code_all_is_complete_and_ordered() {
assert_eq!(ChainCode::ALL.len(), 4);
for (i, c) in ChainCode::ALL.iter().enumerate() {
assert_eq!(c.as_u64() as usize, i);
}
}
}