use derive_more::From;
use std::{collections::VecDeque, fmt, num::ParseIntError, str::FromStr};
use thiserror::Error;
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub struct ChildIndexHardened(u32);
impl ChildIndexHardened {
pub fn from_31_bit(i: u32) -> Result<Self, ChildIndexError> {
if i & (1 << 31) == 0 {
Ok(ChildIndexHardened(i))
} else {
Err(ChildIndexError::NumberTooLarge(i))
}
}
pub fn next(&self) -> Result<Self, ChildIndexError> {
ChildIndexHardened::from_31_bit(self.0 + 1)
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub struct ChildIndexNormal(u32);
impl ChildIndexNormal {
pub fn normal(i: u32) -> Result<Self, ChildIndexError> {
if i & (1 << 31) == 0 {
Ok(ChildIndexNormal(i))
} else {
Err(ChildIndexError::NumberTooLarge(i))
}
}
pub fn next(&self) -> ChildIndexNormal {
ChildIndexNormal(self.0 + 1)
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug, From)]
pub enum ChildIndex {
Hardened(ChildIndexHardened),
Normal(ChildIndexNormal),
}
impl fmt::Display for ChildIndex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ChildIndex::Hardened(i) => write!(f, "{}'", i.0),
ChildIndex::Normal(i) => write!(f, "{}", i.0),
}
}
}
impl FromStr for ChildIndex {
type Err = ChildIndexError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.contains('\'') {
let idx = s.replace('\'', "");
Ok(ChildIndex::Hardened(ChildIndexHardened::from_31_bit(
idx.parse()?,
)?))
} else {
Ok(ChildIndex::Normal(ChildIndexNormal::normal(s.parse()?)?))
}
}
}
const PURPOSE: ChildIndex = ChildIndex::Hardened(ChildIndexHardened(44));
const ERG: ChildIndex = ChildIndex::Hardened(ChildIndexHardened(429));
const CHANGE: ChildIndex = ChildIndex::Normal(ChildIndexNormal(0));
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ChildIndexError {
#[error("number too large: {0}")]
NumberTooLarge(u32),
#[error("failed to parse index: {0}")]
BadIndex(#[from] ParseIntError),
}
impl ChildIndex {
pub fn normal(i: u32) -> Result<Self, ChildIndexError> {
Ok(ChildIndex::Normal(ChildIndexNormal::normal(i)?))
}
pub fn hardened(i: u32) -> Result<Self, ChildIndexError> {
Ok(ChildIndex::Hardened(ChildIndexHardened::from_31_bit(i)?))
}
pub fn to_bits(&self) -> u32 {
match self {
ChildIndex::Hardened(index) => (1 << 31) | index.0,
ChildIndex::Normal(index) => index.0,
}
}
pub fn next(&self) -> Result<Self, ChildIndexError> {
match self {
ChildIndex::Hardened(i) => Ok(ChildIndex::Hardened(i.next()?)),
ChildIndex::Normal(i) => Ok(ChildIndex::Normal(i.next())),
}
}
}
#[derive(PartialEq, Eq, Debug, Clone, From)]
pub struct DerivationPath(pub(super) Box<[ChildIndex]>);
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum DerivationPathError {
#[error("derivation path is empty")]
EmptyPath,
#[error("invalid derivation path format")]
InvalidFormat(String),
#[error("child error: {0}")]
ChildIndex(#[from] ChildIndexError),
}
impl DerivationPath {
pub fn new(acc: ChildIndexHardened, address_indices: Vec<ChildIndexNormal>) -> Self {
let mut res = vec![PURPOSE, ERG, ChildIndex::Hardened(acc), CHANGE];
res.append(
address_indices
.into_iter()
.map(ChildIndex::Normal)
.collect::<Vec<ChildIndex>>()
.as_mut(),
);
Self(res.into_boxed_slice())
}
pub fn master_path() -> Self {
Self(Box::new([]))
}
pub fn depth(&self) -> usize {
self.0.len()
}
pub fn extend(&self, index: ChildIndex) -> DerivationPath {
let mut res = self.0.to_vec();
res.push(index);
DerivationPath(res.into_boxed_slice())
}
pub fn next(&self) -> Result<DerivationPath, DerivationPathError> {
#[allow(clippy::unwrap_used)]
if self.0.len() > 0 {
let mut new_path = self.0.to_vec();
let last_idx = new_path.len() - 1;
new_path[last_idx] = new_path.last().unwrap().next()?;
Ok(DerivationPath(new_path.into_boxed_slice()))
} else {
Err(DerivationPathError::EmptyPath)
}
}
pub fn ledger_bytes(&self) -> Vec<u8> {
let mut res = vec![self.0.len() as u8];
self.0
.iter()
.for_each(|i| res.append(&mut i.to_bits().to_be_bytes().to_vec()));
res
}
}
impl fmt::Display for DerivationPath {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "m/")?;
let children = self
.0
.iter()
.map(ChildIndex::to_string)
.collect::<Vec<_>>()
.join("/");
write!(f, "{}", children)?;
Ok(())
}
}
impl FromStr for DerivationPath {
type Err = DerivationPathError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let cleaned_parts = s.split_whitespace().collect::<String>();
let mut parts = cleaned_parts.split('/').collect::<VecDeque<_>>();
let master_key_id = parts.pop_front().ok_or(DerivationPathError::EmptyPath)?;
if master_key_id != "m" && master_key_id != "M" {
return Err(DerivationPathError::InvalidFormat(format!(
"Master node must be either 'm' or 'M', got {}",
master_key_id
)));
}
let path = parts
.into_iter()
.flat_map(ChildIndex::from_str)
.collect::<Vec<_>>();
Ok(path.into_boxed_slice().into())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_derivation_path_to_string() {
let path = DerivationPath::new(ChildIndexHardened(1), vec![ChildIndexNormal(3)]);
let expected = "m/44'/429'/1'/0/3";
assert_eq!(expected, path.to_string())
}
#[test]
fn test_derivation_path_to_string_no_addr() {
let path = DerivationPath::new(ChildIndexHardened(0), vec![]);
let expected = "m/44'/429'/0'/0";
assert_eq!(expected, path.to_string())
}
#[test]
fn test_string_to_derivation_path() {
let path = "m/44'/429'/0'/0/1";
let expected = DerivationPath::new(ChildIndexHardened(0), vec![ChildIndexNormal(1)]);
assert_eq!(expected, path.parse::<DerivationPath>().unwrap())
}
#[test]
fn test_derivation_path_next() {
let path = DerivationPath::new(ChildIndexHardened(1), vec![ChildIndexNormal(3)]);
let new_path = path.next().unwrap();
let expected = "m/44'/429'/1'/0/4";
assert_eq!(expected, new_path.to_string());
}
#[test]
fn test_derivation_path_next_returns_err_if_emtpy() {
let path = DerivationPath(Box::new([]));
assert_eq!(path.next(), Err(DerivationPathError::EmptyPath))
}
}