ergo_lib/wallet/
derivation_path.rs1use derive_more::From;
6use std::{collections::VecDeque, fmt, num::ParseIntError, str::FromStr};
7use thiserror::Error;
8
9#[derive(PartialEq, Eq, Clone, Copy, Debug)]
11pub struct ChildIndexHardened(u32);
12
13impl ChildIndexHardened {
14 pub fn from_31_bit(i: u32) -> Result<Self, ChildIndexError> {
16 if i & (1 << 31) == 0 {
17 Ok(ChildIndexHardened(i))
18 } else {
19 Err(ChildIndexError::NumberTooLarge(i))
20 }
21 }
22
23 pub fn next(&self) -> Result<Self, ChildIndexError> {
25 ChildIndexHardened::from_31_bit(self.0 + 1)
26 }
27}
28
29#[derive(PartialEq, Eq, Clone, Copy, Debug)]
31pub struct ChildIndexNormal(u32);
32
33impl ChildIndexNormal {
34 pub fn normal(i: u32) -> Result<Self, ChildIndexError> {
36 if i & (1 << 31) == 0 {
37 Ok(ChildIndexNormal(i))
38 } else {
39 Err(ChildIndexError::NumberTooLarge(i))
40 }
41 }
42
43 pub fn next(&self) -> ChildIndexNormal {
45 ChildIndexNormal(self.0 + 1)
46 }
47}
48
49#[derive(PartialEq, Eq, Clone, Copy, Debug, From)]
51pub enum ChildIndex {
52 Hardened(ChildIndexHardened),
54 Normal(ChildIndexNormal),
56}
57
58impl fmt::Display for ChildIndex {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 ChildIndex::Hardened(i) => write!(f, "{}'", i.0),
62 ChildIndex::Normal(i) => write!(f, "{}", i.0),
63 }
64 }
65}
66
67impl FromStr for ChildIndex {
68 type Err = ChildIndexError;
69
70 fn from_str(s: &str) -> Result<Self, Self::Err> {
71 if s.contains('\'') {
72 let idx = s.replace('\'', "");
73 Ok(ChildIndex::Hardened(ChildIndexHardened::from_31_bit(
74 idx.parse()?,
75 )?))
76 } else {
77 Ok(ChildIndex::Normal(ChildIndexNormal::normal(s.parse()?)?))
78 }
79 }
80}
81
82const PURPOSE: ChildIndex = ChildIndex::Hardened(ChildIndexHardened(44));
83const ERG: ChildIndex = ChildIndex::Hardened(ChildIndexHardened(429));
84const CHANGE: ChildIndex = ChildIndex::Normal(ChildIndexNormal(0));
86
87#[derive(Error, Debug, Clone, PartialEq, Eq)]
89pub enum ChildIndexError {
90 #[error("number too large: {0}")]
92 NumberTooLarge(u32),
93 #[error("failed to parse index: {0}")]
95 BadIndex(#[from] ParseIntError),
96}
97
98impl ChildIndex {
99 pub fn normal(i: u32) -> Result<Self, ChildIndexError> {
101 Ok(ChildIndex::Normal(ChildIndexNormal::normal(i)?))
102 }
103
104 pub fn hardened(i: u32) -> Result<Self, ChildIndexError> {
106 Ok(ChildIndex::Hardened(ChildIndexHardened::from_31_bit(i)?))
107 }
108
109 pub fn to_bits(&self) -> u32 {
112 match self {
113 ChildIndex::Hardened(index) => (1 << 31) | index.0,
114 ChildIndex::Normal(index) => index.0,
115 }
116 }
117
118 pub fn next(&self) -> Result<Self, ChildIndexError> {
120 match self {
121 ChildIndex::Hardened(i) => Ok(ChildIndex::Hardened(i.next()?)),
122 ChildIndex::Normal(i) => Ok(ChildIndex::Normal(i.next())),
123 }
124 }
125}
126
127#[derive(PartialEq, Eq, Debug, Clone, From)]
131pub struct DerivationPath(pub(super) Box<[ChildIndex]>);
132
133#[derive(Error, Debug, Clone, PartialEq, Eq)]
135pub enum DerivationPathError {
136 #[error("derivation path is empty")]
139 EmptyPath,
140 #[error("invalid derivation path format")]
143 InvalidFormat(String),
144 #[error("child error: {0}")]
146 ChildIndex(#[from] ChildIndexError),
147}
148
149impl DerivationPath {
150 pub fn new(acc: ChildIndexHardened, address_indices: Vec<ChildIndexNormal>) -> Self {
155 let mut res = vec![PURPOSE, ERG, ChildIndex::Hardened(acc), CHANGE];
156 res.append(
157 address_indices
158 .into_iter()
159 .map(ChildIndex::Normal)
160 .collect::<Vec<ChildIndex>>()
161 .as_mut(),
162 );
163 Self(res.into_boxed_slice())
164 }
165
166 pub fn master_path() -> Self {
168 Self(Box::new([]))
169 }
170
171 pub fn depth(&self) -> usize {
173 self.0.len()
174 }
175
176 pub fn extend(&self, index: ChildIndex) -> DerivationPath {
179 let mut res = self.0.to_vec();
180 res.push(index);
181 DerivationPath(res.into_boxed_slice())
182 }
183
184 pub fn next(&self) -> Result<DerivationPath, DerivationPathError> {
187 #[allow(clippy::unwrap_used)]
188 if self.0.len() > 0 {
189 let mut new_path = self.0.to_vec();
190 let last_idx = new_path.len() - 1;
191 new_path[last_idx] = new_path.last().unwrap().next()?;
193
194 Ok(DerivationPath(new_path.into_boxed_slice()))
195 } else {
196 Err(DerivationPathError::EmptyPath)
197 }
198 }
199
200 pub fn ledger_bytes(&self) -> Vec<u8> {
230 let mut res = vec![self.0.len() as u8];
231 self.0
232 .iter()
233 .for_each(|i| res.append(&mut i.to_bits().to_be_bytes().to_vec()));
234 res
235 }
236}
237
238impl fmt::Display for DerivationPath {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 write!(f, "m/")?;
241 let children = self
242 .0
243 .iter()
244 .map(ChildIndex::to_string)
245 .collect::<Vec<_>>()
246 .join("/");
247 write!(f, "{}", children)?;
248
249 Ok(())
250 }
251}
252
253impl FromStr for DerivationPath {
254 type Err = DerivationPathError;
255
256 fn from_str(s: &str) -> Result<Self, Self::Err> {
257 let cleaned_parts = s.split_whitespace().collect::<String>();
258 let mut parts = cleaned_parts.split('/').collect::<VecDeque<_>>();
259 let master_key_id = parts.pop_front().ok_or(DerivationPathError::EmptyPath)?;
260 if master_key_id != "m" && master_key_id != "M" {
261 return Err(DerivationPathError::InvalidFormat(format!(
262 "Master node must be either 'm' or 'M', got {}",
263 master_key_id
264 )));
265 }
266 let path = parts
267 .into_iter()
268 .flat_map(ChildIndex::from_str)
269 .collect::<Vec<_>>();
270 Ok(path.into_boxed_slice().into())
271 }
272}
273
274#[cfg(test)]
275#[allow(clippy::unwrap_used, clippy::panic)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_derivation_path_to_string() {
281 let path = DerivationPath::new(ChildIndexHardened(1), vec![ChildIndexNormal(3)]);
282 let expected = "m/44'/429'/1'/0/3";
283
284 assert_eq!(expected, path.to_string())
285 }
286
287 #[test]
288 fn test_derivation_path_to_string_no_addr() {
289 let path = DerivationPath::new(ChildIndexHardened(0), vec![]);
290 let expected = "m/44'/429'/0'/0";
291
292 assert_eq!(expected, path.to_string())
293 }
294
295 #[test]
296 fn test_string_to_derivation_path() {
297 let path = "m/44'/429'/0'/0/1";
298 let expected = DerivationPath::new(ChildIndexHardened(0), vec![ChildIndexNormal(1)]);
299
300 assert_eq!(expected, path.parse::<DerivationPath>().unwrap())
301 }
302
303 #[test]
304 fn test_derivation_path_next() {
305 let path = DerivationPath::new(ChildIndexHardened(1), vec![ChildIndexNormal(3)]);
307 let new_path = path.next().unwrap();
308 let expected = "m/44'/429'/1'/0/4";
309
310 assert_eq!(expected, new_path.to_string());
311 }
312
313 #[test]
315 fn test_derivation_path_next_returns_err_if_emtpy() {
316 let path = DerivationPath(Box::new([]));
317
318 assert_eq!(path.next(), Err(DerivationPathError::EmptyPath))
319 }
320}