wagyu_model/
derivation_path.rs1use std::{
2 fmt,
3 fmt::{Debug, Display},
4 str::FromStr,
5};
6
7pub trait DerivationPath: Clone + Debug + Display + FromStr + Send + Sync + 'static + Eq + Sized {
9 fn to_vec(&self) -> Result<Vec<ChildIndex>, DerivationPathError>;
11
12 fn from_vec(path: &Vec<ChildIndex>) -> Result<Self, DerivationPathError>;
14}
15
16#[derive(Debug, Fail, PartialEq, Eq)]
17pub enum DerivationPathError {
18 #[fail(display = "expected BIP32 path")]
19 ExpectedBIP32Path,
20
21 #[fail(display = "expected BIP44 path")]
22 ExpectedBIP44Path,
23
24 #[fail(display = "expected BIP49 path")]
25 ExpectedBIP49Path,
26
27 #[fail(display = "expected valid Ethereum derivation path")]
28 ExpectedValidEthereumDerivationPath,
29
30 #[fail(display = "expected ZIP32 path")]
31 ExpectedZIP32Path,
32
33 #[fail(display = "expected hardened path")]
34 ExpectedHardenedPath,
35
36 #[fail(display = "expected normal path")]
37 ExpectedNormalPath,
38
39 #[fail(display = "invalid child number: {}", _0)]
40 InvalidChildNumber(u32),
41
42 #[fail(display = "invalid child number format")]
43 InvalidChildNumberFormat,
44
45 #[fail(display = "invalid derivation path: {}", _0)]
46 InvalidDerivationPath(String),
47}
48
49#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
51pub enum ChildIndex {
52 Normal(u32),
54 Hardened(u32),
56}
57
58impl ChildIndex {
59 pub fn normal(index: u32) -> Result<Self, DerivationPathError> {
61 if index & (1 << 31) == 0 {
62 Ok(ChildIndex::Normal(index))
63 } else {
64 Err(DerivationPathError::InvalidChildNumber(index))
65 }
66 }
67
68 pub fn hardened(index: u32) -> Result<Self, DerivationPathError> {
70 if index & (1 << 31) == 0 {
71 Ok(ChildIndex::Hardened(index))
72 } else {
73 Err(DerivationPathError::InvalidChildNumber(index))
74 }
75 }
76
77 pub fn is_normal(&self) -> bool {
79 !self.is_hardened()
80 }
81
82 pub fn is_hardened(&self) -> bool {
84 match *self {
85 ChildIndex::Hardened(_) => true,
86 ChildIndex::Normal(_) => false,
87 }
88 }
89
90 pub fn to_index(&self) -> u32 {
92 match self {
93 &ChildIndex::Hardened(i) => i + (1 << 31),
94 &ChildIndex::Normal(i) => i,
95 }
96 }
97}
98
99impl From<u32> for ChildIndex {
100 fn from(number: u32) -> Self {
101 if number & (1 << 31) != 0 {
102 ChildIndex::Hardened(number ^ (1 << 31))
103 } else {
104 ChildIndex::Normal(number)
105 }
106 }
107}
108
109impl From<ChildIndex> for u32 {
110 fn from(index: ChildIndex) -> Self {
111 match index {
112 ChildIndex::Normal(number) => number,
113 ChildIndex::Hardened(number) => number | (1 << 31),
114 }
115 }
116}
117
118impl FromStr for ChildIndex {
119 type Err = DerivationPathError;
120
121 fn from_str(inp: &str) -> Result<Self, Self::Err> {
122 Ok(match inp.chars().last().map_or(false, |l| l == '\'' || l == 'h') {
123 true => Self::hardened(
124 inp[0..inp.len() - 1]
125 .parse()
126 .map_err(|_| DerivationPathError::InvalidChildNumberFormat)?,
127 )?,
128 false => Self::normal(inp.parse().map_err(|_| DerivationPathError::InvalidChildNumberFormat)?)?,
129 })
130 }
131}
132
133impl fmt::Display for ChildIndex {
134 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135 match *self {
136 ChildIndex::Hardened(number) => write!(f, "{}'", number),
137 ChildIndex::Normal(number) => write!(f, "{}", number),
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 mod child_index {
147 use super::*;
148
149 #[test]
150 fn normal() {
151 for i in 0..1 << 31 {
152 assert_eq!(ChildIndex::Normal(i), ChildIndex::normal(i).unwrap());
153 }
154 for i in 1 << 31..std::u32::MAX {
155 assert_eq!(Err(DerivationPathError::InvalidChildNumber(i)), ChildIndex::normal(i));
156 }
157 }
158
159 #[test]
160 fn hardened() {
161 for i in 0..1 << 31 {
162 assert_eq!(ChildIndex::Hardened(i), ChildIndex::hardened(i).unwrap());
163 }
164 for i in 1 << 31..std::u32::MAX {
165 assert_eq!(Err(DerivationPathError::InvalidChildNumber(i)), ChildIndex::hardened(i));
166 }
167 }
168
169 #[test]
170 fn is_normal() {
171 for i in 0..1 << 31 {
172 assert!(ChildIndex::Normal(i).is_normal());
173 assert!(!ChildIndex::Hardened(i).is_normal());
174 }
175 }
176
177 #[test]
178 fn is_hardened() {
179 for i in 0..1 << 31 {
180 assert!(!ChildIndex::Normal(i).is_hardened());
181 assert!(ChildIndex::Hardened(i).is_hardened());
182 }
183 }
184
185 #[test]
186 fn to_index() {
187 for i in 0..1 << 31 {
188 assert_eq!(i, ChildIndex::Normal(i).to_index());
189 assert_eq!(i | (1 << 31), ChildIndex::Hardened(i).to_index());
190 }
191 }
192
193 #[test]
194 fn from() {
195 const THRESHOLD: u32 = 1 << 31;
196 for i in 0..std::u32::MAX {
197 match i < THRESHOLD {
198 true => assert_eq!(ChildIndex::Normal(i), ChildIndex::from(i)),
199 false => assert_eq!(ChildIndex::Hardened(i ^ 1 << 31), ChildIndex::from(i)),
200 }
201 }
202 }
203
204 #[test]
205 fn from_str() {
206 for i in (0..1 << 31).step_by(1 << 10) {
207 assert_eq!(ChildIndex::Normal(i), ChildIndex::from_str(&format!("{}", i)).unwrap());
208 assert_eq!(
209 ChildIndex::Hardened(i),
210 ChildIndex::from_str(&format!("{}\'", i)).unwrap()
211 );
212 assert_eq!(
213 ChildIndex::Hardened(i),
214 ChildIndex::from_str(&format!("{}h", i)).unwrap()
215 );
216 }
217 }
218
219 #[test]
220 fn to_string() {
221 for i in (0..1 << 31).step_by(1 << 10) {
222 assert_eq!(format!("{}", i), ChildIndex::Normal(i).to_string());
223 assert_eq!(format!("{}\'", i), ChildIndex::Hardened(i).to_string());
224 }
225 }
226 }
227}