bitbox_api/
keypath.rs

1use crate::error::Error;
2
3pub const HARDENED: u32 = 0x80000000;
4
5#[derive(Debug, Clone, PartialEq)]
6pub struct Keypath(Vec<u32>);
7
8impl Keypath {
9    pub fn to_vec(&self) -> Vec<u32> {
10        self.0.clone()
11    }
12
13    pub(crate) fn hardened_prefix(&self) -> Keypath {
14        Keypath(
15            self.0
16                .iter()
17                .cloned()
18                .take_while(|&el| el >= HARDENED)
19                .collect(),
20        )
21    }
22}
23
24fn parse_bip32_keypath(keypath: &str) -> Option<Vec<u32>> {
25    let keypath = keypath.strip_prefix("m/")?;
26    if keypath.is_empty() {
27        return Some(vec![]);
28    }
29    let parts: Vec<&str> = keypath.split('/').collect();
30    let mut res = Vec::new();
31
32    for part in parts {
33        let mut add_prime = 0;
34        let number = if part.ends_with('\'') {
35            add_prime = HARDENED;
36            part[0..part.len() - 1].parse::<u32>()
37        } else {
38            part.parse::<u32>()
39        };
40
41        match number {
42            Ok(n) if n < HARDENED => {
43                res.push(n + add_prime);
44            }
45            _ => return None,
46        }
47    }
48
49    Some(res)
50}
51
52impl TryFrom<&str> for Keypath {
53    type Error = Error;
54    fn try_from(value: &str) -> Result<Self, Self::Error> {
55        Ok(Keypath(
56            parse_bip32_keypath(value).ok_or(Error::KeypathParse(value.into()))?,
57        ))
58    }
59}
60
61impl From<&bitcoin::bip32::DerivationPath> for Keypath {
62    fn from(value: &bitcoin::bip32::DerivationPath) -> Self {
63        Keypath(value.to_u32_vec())
64    }
65}
66
67impl From<&Keypath> for crate::pb::Keypath {
68    fn from(value: &Keypath) -> Self {
69        crate::pb::Keypath {
70            keypath: value.to_vec(),
71        }
72    }
73}
74
75#[cfg(feature = "wasm")]
76impl<'de> serde::Deserialize<'de> for Keypath {
77    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
78    where
79        D: serde::Deserializer<'de>,
80    {
81        struct KeypathVisitor;
82
83        impl<'de> serde::de::Visitor<'de> for KeypathVisitor {
84            type Value = Keypath;
85
86            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
87                formatter.write_str("a string or a number sequence")
88            }
89
90            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
91            where
92                E: serde::de::Error,
93            {
94                value.try_into().map_err(serde::de::Error::custom)
95            }
96
97            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
98            where
99                A: serde::de::SeqAccess<'de>,
100            {
101                let mut vec = Vec::<u32>::new();
102                while let Some(elem) = seq.next_element()? {
103                    vec.push(elem);
104                }
105                Ok(Keypath(vec))
106            }
107        }
108
109        deserializer.deserialize_any(KeypathVisitor)
110    }
111}
112
113#[cfg(feature = "wasm")]
114pub fn serde_deserialize<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
115where
116    D: serde::Deserializer<'de>,
117{
118    use serde::Deserialize;
119    Ok(Keypath::deserialize(deserializer)?.to_vec())
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_parse_bip32_keypath() {
128        // Test regular cases
129        assert_eq!(parse_bip32_keypath("m/44/0/0/0"), Some(vec![44, 0, 0, 0]));
130        assert_eq!(
131            parse_bip32_keypath("m/44'/0'/0'/0'"),
132            Some(vec![HARDENED + 44, HARDENED, HARDENED, HARDENED])
133        );
134
135        // Test edge cases
136        assert_eq!(parse_bip32_keypath("m/0/0/0"), Some(vec![0, 0, 0]));
137        assert_eq!(
138            parse_bip32_keypath("m/0'/0'/0'"),
139            Some(vec![HARDENED, HARDENED, HARDENED])
140        );
141        assert_eq!(
142            parse_bip32_keypath("m/2147483647/2147483647/2147483647"),
143            Some(vec![2147483647, 2147483647, 2147483647])
144        );
145        assert_eq!(
146            parse_bip32_keypath("m/2147483647'/2147483647'/2147483647'"),
147            Some(vec![
148                HARDENED + 2147483647,
149                HARDENED + 2147483647,
150                HARDENED + 2147483647
151            ])
152        );
153        assert_eq!(parse_bip32_keypath("m/"), Some(vec![]));
154
155        // Test failure cases
156        assert_eq!(parse_bip32_keypath("m/2147483648/0/0"), None);
157        assert_eq!(parse_bip32_keypath("m/0/2147483648/0"), None);
158        assert_eq!(parse_bip32_keypath("m/0/0/2147483648"), None);
159        assert_eq!(parse_bip32_keypath("m/2147483648'/0/0"), None);
160        assert_eq!(parse_bip32_keypath("m/0/2147483648'/0"), None);
161        assert_eq!(parse_bip32_keypath("m/0/0/2147483648'"), None);
162        assert_eq!(parse_bip32_keypath("m/abcd/0/0"), None);
163        assert_eq!(parse_bip32_keypath("m/0'/abcd'/0'"), None);
164        assert_eq!(parse_bip32_keypath("m/0/0'/abcd'"), None);
165        assert_eq!(parse_bip32_keypath("m//0/0"), None);
166        assert_eq!(parse_bip32_keypath("m/0//0"), None);
167        assert_eq!(parse_bip32_keypath("m/0/0//"), None);
168        assert_eq!(parse_bip32_keypath("/0/0/0"), None);
169        assert_eq!(parse_bip32_keypath("44/0/0/0"), None);
170    }
171
172    #[test]
173    fn test_from_derivation_path() {
174        let derivation_path: bitcoin::bip32::DerivationPath =
175            std::str::FromStr::from_str("m/84'/0'/0'/0/1").unwrap();
176        let keypath = Keypath::from(&derivation_path);
177        assert_eq!(
178            keypath.to_vec().as_slice(),
179            &[84 + HARDENED, HARDENED, HARDENED, 0, 1]
180        );
181    }
182}