use bitcoin::bip32::{ChildNumber, DerivationPath};
use crate::consts::MAX_PATH_COMPONENTS;
use crate::error::{Error, Result};
pub const EXPLICIT_PATH_INDICATOR: u8 = 0xFE;
pub const STANDARD_PATHS: &[(u8, &str)] = &[
(0x01, "m/44'/0'/0'"), (0x02, "m/49'/0'/0'"), (0x03, "m/84'/0'/0'"), (0x04, "m/86'/0'/0'"), (0x05, "m/48'/0'/0'/2'"), (0x06, "m/48'/0'/0'/1'"), (0x07, "m/87'/0'/0'"), (0x11, "m/44'/1'/0'"),
(0x12, "m/49'/1'/0'"),
(0x13, "m/84'/1'/0'"),
(0x14, "m/86'/1'/0'"),
(0x15, "m/48'/1'/0'/2'"),
(0x16, "m/48'/1'/0'/1'"), (0x17, "m/87'/1'/0'"),
];
pub fn lookup_indicator(indicator: u8) -> Option<DerivationPath> {
STANDARD_PATHS
.iter()
.find(|(b, _)| *b == indicator)
.and_then(|(_, p)| p.parse().ok())
}
pub fn lookup_path(path: &DerivationPath) -> Option<u8> {
STANDARD_PATHS
.iter()
.find(|(_, p)| {
p.parse::<DerivationPath>()
.map(|table_path| &table_path == path)
.unwrap_or(false)
})
.map(|(b, _)| *b)
}
pub fn encode_path(path: &DerivationPath) -> Vec<u8> {
if let Some(indicator) = lookup_path(path) {
return vec![indicator];
}
let mut out = Vec::with_capacity(2 + 5 * MAX_PATH_COMPONENTS as usize);
out.push(EXPLICIT_PATH_INDICATOR);
let components: Vec<ChildNumber> = path.into_iter().copied().collect();
out.push(components.len() as u8);
for cn in components {
let raw: u32 = u32::from(cn);
leb128_encode(raw, &mut out);
}
out
}
pub fn decode_path(cursor: &mut &[u8]) -> Result<DerivationPath> {
let indicator = read_u8(cursor)?;
if indicator == EXPLICIT_PATH_INDICATOR {
return decode_explicit_path(cursor);
}
if let Some(path) = lookup_indicator(indicator) {
return Ok(path);
}
Err(Error::InvalidPathIndicator(indicator))
}
fn decode_explicit_path(cursor: &mut &[u8]) -> Result<DerivationPath> {
let count = read_u8(cursor)?;
if count == 0 || count > MAX_PATH_COMPONENTS {
return Err(Error::PathTooDeep(count));
}
let mut components: Vec<ChildNumber> = Vec::with_capacity(count as usize);
for _ in 0..count {
let raw = leb128_decode_u32(cursor)?;
let cn = if raw & 0x8000_0000 != 0 {
ChildNumber::from_hardened_idx(raw & 0x7FFF_FFFF)
.map_err(|e| Error::InvalidPathComponent(format!("{e}")))?
} else {
ChildNumber::from_normal_idx(raw)
.map_err(|e| Error::InvalidPathComponent(format!("{e}")))?
};
components.push(cn);
}
Ok(DerivationPath::from(components))
}
fn leb128_encode(mut value: u32, out: &mut Vec<u8>) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
out.push(byte);
} else {
out.push(byte);
break;
}
}
}
fn leb128_decode_u32(cursor: &mut &[u8]) -> Result<u32> {
let mut result: u64 = 0;
let mut shift: u32 = 0;
loop {
let byte = read_u8(cursor)?;
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
if shift >= 35 {
return Err(Error::InvalidPathComponent(format!(
"LEB128 overflow at shift {shift}"
)));
}
}
if result > u32::MAX as u64 {
return Err(Error::InvalidPathComponent(format!(
"LEB128 value {result} > u32::MAX"
)));
}
Ok(result as u32)
}
fn read_u8(cursor: &mut &[u8]) -> Result<u8> {
if cursor.is_empty() {
return Err(Error::UnexpectedEnd);
}
let b = cursor[0];
*cursor = &cursor[1..];
Ok(b)
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn round_trip_all_standard_paths() {
for (indicator, path_str) in STANDARD_PATHS {
let path = DerivationPath::from_str(path_str).unwrap();
let encoded = encode_path(&path);
assert_eq!(encoded, vec![*indicator], "round-trip {path_str}");
let mut cursor: &[u8] = &encoded;
let decoded = decode_path(&mut cursor).unwrap();
assert_eq!(decoded, path, "round-trip parsed {path_str}");
assert!(cursor.is_empty());
}
}
#[test]
fn round_trip_explicit_path_simple() {
let path = DerivationPath::from_str("m/0/1/2").unwrap();
let encoded = encode_path(&path);
assert_eq!(encoded[0], 0xFE);
assert_eq!(encoded[1], 3);
let mut cursor: &[u8] = &encoded;
let decoded = decode_path(&mut cursor).unwrap();
assert_eq!(decoded, path);
}
#[test]
fn round_trip_explicit_path_all_hardened() {
let path = DerivationPath::from_str("m/9999'/1234'/56'/7'").unwrap();
let encoded = encode_path(&path);
assert_eq!(encoded[0], 0xFE);
assert_eq!(encoded[1], 4);
assert_eq!(encoded.len(), 1 + 1 + 4 * 5);
let mut cursor: &[u8] = &encoded;
let decoded = decode_path(&mut cursor).unwrap();
assert_eq!(decoded, path);
}
#[test]
fn round_trip_explicit_path_at_cap() {
let path = DerivationPath::from_str("m/0'/1'/2'/3'/4'/5'/6'/7'/8'/9'").unwrap();
let encoded = encode_path(&path);
let mut cursor: &[u8] = &encoded;
let decoded = decode_path(&mut cursor).unwrap();
assert_eq!(decoded, path);
}
#[test]
fn rejects_path_too_deep() {
let mut bytes = vec![0xFE, 11u8];
for i in 0..11 {
bytes.push(i); }
let mut cursor: &[u8] = &bytes;
assert!(matches!(
decode_path(&mut cursor),
Err(Error::PathTooDeep(11)),
));
}
#[test]
fn rejects_path_count_zero() {
let bytes = vec![0xFE, 0u8];
let mut cursor: &[u8] = &bytes;
assert!(matches!(
decode_path(&mut cursor),
Err(Error::PathTooDeep(0)),
));
}
#[test]
fn rejects_reserved_indicator_zero() {
let bytes = vec![0x00];
let mut cursor: &[u8] = &bytes;
assert!(matches!(
decode_path(&mut cursor),
Err(Error::InvalidPathIndicator(0x00)),
));
}
#[test]
fn round_trip_indicator_0x16_added_in_v0_2() {
let path = DerivationPath::from_str("m/48'/1'/0'/1'").unwrap();
let encoded = encode_path(&path);
assert_eq!(encoded, vec![0x16]);
let mut cursor: &[u8] = &encoded;
let decoded = decode_path(&mut cursor).unwrap();
assert_eq!(decoded, path);
assert!(cursor.is_empty());
}
#[test]
fn rejects_reserved_indicator_high_range() {
let bytes = vec![0xFD];
let mut cursor: &[u8] = &bytes;
assert!(matches!(
decode_path(&mut cursor),
Err(Error::InvalidPathIndicator(0xFD)),
));
let bytes = vec![0xFF];
let mut cursor: &[u8] = &bytes;
assert!(matches!(
decode_path(&mut cursor),
Err(Error::InvalidPathIndicator(0xFF)),
));
}
#[test]
fn rejects_truncated_explicit_path() {
let bytes = vec![0xFE, 2u8, 0u8];
let mut cursor: &[u8] = &bytes;
assert!(matches!(
decode_path(&mut cursor),
Err(Error::UnexpectedEnd),
));
}
#[test]
fn leb128_encode_examples() {
let mut out = Vec::new();
leb128_encode(0, &mut out);
assert_eq!(out, vec![0]);
let mut out = Vec::new();
leb128_encode(127, &mut out);
assert_eq!(out, vec![0x7F]);
let mut out = Vec::new();
leb128_encode(128, &mut out);
assert_eq!(out, vec![0x80, 0x01]);
let mut out = Vec::new();
leb128_encode(0x8000_0000, &mut out);
assert_eq!(out, vec![0x80, 0x80, 0x80, 0x80, 0x08]);
}
}