bitbox_api/
keypath.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::error::Error;
4
5pub const HARDENED: u32 = 0x80000000;
6
7#[derive(Debug, Clone, PartialEq)]
8pub struct Keypath(Vec<u32>);
9
10impl Keypath {
11    pub fn to_vec(&self) -> Vec<u32> {
12        self.0.clone()
13    }
14
15    pub(crate) fn hardened_prefix(&self) -> Keypath {
16        Keypath(
17            self.0
18                .iter()
19                .cloned()
20                .take_while(|&el| el >= HARDENED)
21                .collect(),
22        )
23    }
24}
25
26fn parse_bip32_keypath(keypath: &str) -> Option<Vec<u32>> {
27    let keypath = keypath.strip_prefix("m/")?;
28    if keypath.is_empty() {
29        return Some(vec![]);
30    }
31    let parts: Vec<&str> = keypath.split('/').collect();
32    let mut res = Vec::new();
33
34    for part in parts {
35        let mut add_prime = 0;
36        let number = if part.ends_with('\'') {
37            add_prime = HARDENED;
38            part[0..part.len() - 1].parse::<u32>()
39        } else {
40            part.parse::<u32>()
41        };
42
43        match number {
44            Ok(n) if n < HARDENED => {
45                res.push(n + add_prime);
46            }
47            _ => return None,
48        }
49    }
50
51    Some(res)
52}
53
54impl TryFrom<&str> for Keypath {
55    type Error = Error;
56    fn try_from(value: &str) -> Result<Self, Self::Error> {
57        Ok(Keypath(
58            parse_bip32_keypath(value).ok_or(Error::KeypathParse(value.into()))?,
59        ))
60    }
61}
62
63impl From<&bitcoin::bip32::DerivationPath> for Keypath {
64    fn from(value: &bitcoin::bip32::DerivationPath) -> Self {
65        Keypath(value.to_u32_vec())
66    }
67}
68
69impl From<&Keypath> for crate::pb::Keypath {
70    fn from(value: &Keypath) -> Self {
71        crate::pb::Keypath {
72            keypath: value.to_vec(),
73        }
74    }
75}
76
77#[cfg(feature = "wasm")]
78impl<'de> serde::Deserialize<'de> for Keypath {
79    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
80    where
81        D: serde::Deserializer<'de>,
82    {
83        struct KeypathVisitor;
84
85        impl<'de> serde::de::Visitor<'de> for KeypathVisitor {
86            type Value = Keypath;
87
88            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89                formatter.write_str("a string or a number sequence")
90            }
91
92            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
93            where
94                E: serde::de::Error,
95            {
96                value.try_into().map_err(serde::de::Error::custom)
97            }
98
99            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
100            where
101                A: serde::de::SeqAccess<'de>,
102            {
103                let mut vec = Vec::<u32>::new();
104                while let Some(elem) = seq.next_element()? {
105                    vec.push(elem);
106                }
107                Ok(Keypath(vec))
108            }
109        }
110
111        deserializer.deserialize_any(KeypathVisitor)
112    }
113}
114
115#[cfg(feature = "wasm")]
116pub fn serde_deserialize<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
117where
118    D: serde::Deserializer<'de>,
119{
120    use serde::Deserialize;
121    Ok(Keypath::deserialize(deserializer)?.to_vec())
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_parse_bip32_keypath() {
130        // Test regular cases
131        assert_eq!(parse_bip32_keypath("m/44/0/0/0"), Some(vec![44, 0, 0, 0]));
132        assert_eq!(
133            parse_bip32_keypath("m/44'/0'/0'/0'"),
134            Some(vec![HARDENED + 44, HARDENED, HARDENED, HARDENED])
135        );
136
137        // Test edge cases
138        assert_eq!(parse_bip32_keypath("m/0/0/0"), Some(vec![0, 0, 0]));
139        assert_eq!(
140            parse_bip32_keypath("m/0'/0'/0'"),
141            Some(vec![HARDENED, HARDENED, HARDENED])
142        );
143        assert_eq!(
144            parse_bip32_keypath("m/2147483647/2147483647/2147483647"),
145            Some(vec![2147483647, 2147483647, 2147483647])
146        );
147        assert_eq!(
148            parse_bip32_keypath("m/2147483647'/2147483647'/2147483647'"),
149            Some(vec![
150                HARDENED + 2147483647,
151                HARDENED + 2147483647,
152                HARDENED + 2147483647
153            ])
154        );
155        assert_eq!(parse_bip32_keypath("m/"), Some(vec![]));
156
157        // Test failure cases
158        assert_eq!(parse_bip32_keypath("m/2147483648/0/0"), None);
159        assert_eq!(parse_bip32_keypath("m/0/2147483648/0"), None);
160        assert_eq!(parse_bip32_keypath("m/0/0/2147483648"), None);
161        assert_eq!(parse_bip32_keypath("m/2147483648'/0/0"), None);
162        assert_eq!(parse_bip32_keypath("m/0/2147483648'/0"), None);
163        assert_eq!(parse_bip32_keypath("m/0/0/2147483648'"), None);
164        assert_eq!(parse_bip32_keypath("m/abcd/0/0"), None);
165        assert_eq!(parse_bip32_keypath("m/0'/abcd'/0'"), None);
166        assert_eq!(parse_bip32_keypath("m/0/0'/abcd'"), None);
167        assert_eq!(parse_bip32_keypath("m//0/0"), None);
168        assert_eq!(parse_bip32_keypath("m/0//0"), None);
169        assert_eq!(parse_bip32_keypath("m/0/0//"), None);
170        assert_eq!(parse_bip32_keypath("/0/0/0"), None);
171        assert_eq!(parse_bip32_keypath("44/0/0/0"), None);
172    }
173
174    #[test]
175    fn test_from_derivation_path() {
176        let derivation_path: bitcoin::bip32::DerivationPath =
177            std::str::FromStr::from_str("m/84'/0'/0'/0/1").unwrap();
178        let keypath = Keypath::from(&derivation_path);
179        assert_eq!(
180            keypath.to_vec().as_slice(),
181            &[84 + HARDENED, HARDENED, HARDENED, 0, 1]
182        );
183    }
184}