foundation_urtypes/registry/
keypath.rs

1// SPDX-FileCopyrightText: © 2023 Foundation Devices, Inc. <hello@foundationdevices.com>
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6use core::{num::NonZeroU32, ops::Range};
7
8use minicbor::{data::Type, decode::Error, encode::Write, Decode, Decoder, Encode, Encoder};
9
10/// Metadata for the complete or partial derivation path of a key
11/// (non-owned, zero copy).
12#[doc(alias("crypto-keypath"))]
13#[derive(Debug, Clone, PartialEq)]
14pub struct KeypathRef<'a> {
15    /// Path component.
16    pub components: PathComponents<'a>,
17    /// Fingerprint from the ancestor key.
18    pub source_fingerprint: Option<NonZeroU32>,
19    /// How many derivations this key is from the master (which is 0).
20    pub depth: Option<u8>,
21}
22
23impl<'a> KeypathRef<'a> {
24    /// Create a new key path for a master extended public key.
25    ///
26    /// The `source_fingerprint` parameter is the fingerprint of the master key.
27    pub fn new_master(source_fingerprint: NonZeroU32) -> Self {
28        Self {
29            components: PathComponents {
30                storage: PathStorage::RawDerivationPath(&[]),
31            },
32            source_fingerprint: Some(source_fingerprint),
33            depth: Some(0),
34        }
35    }
36}
37
38impl<'b, C> Decode<'b, C> for KeypathRef<'b> {
39    fn decode(d: &mut Decoder<'b>, ctx: &mut C) -> Result<Self, Error> {
40        let mut components = None;
41        let mut source_fingerprint = None;
42        let mut depth = None;
43
44        let mut len = d.map()?;
45        loop {
46            match len {
47                Some(0) => break,
48                Some(n) => len = Some(n - 1),
49                None => {
50                    if d.datatype()? == Type::Break {
51                        break;
52                    }
53                }
54            }
55
56            match d.u32()? {
57                1 => components = Some(PathComponents::decode(d, ctx)?),
58                2 => {
59                    source_fingerprint = Some(
60                        NonZeroU32::new(d.u32()?)
61                            .ok_or_else(|| Error::message("source-fingerprint is zero"))?,
62                    )
63                }
64                3 => depth = Some(d.u8()?),
65                _ => return Err(Error::message("unknown map entry")),
66            }
67        }
68
69        Ok(Self {
70            components: components.ok_or_else(|| Error::message("components is missing"))?,
71            source_fingerprint,
72            depth,
73        })
74    }
75}
76
77impl<'a, C> Encode<C> for KeypathRef<'a> {
78    fn encode<W: Write>(
79        &self,
80        e: &mut Encoder<W>,
81        ctx: &mut C,
82    ) -> Result<(), minicbor::encode::Error<W::Error>> {
83        let len =
84            1 + u64::from(self.source_fingerprint.is_some()) + u64::from(self.depth.is_some());
85        e.map(len)?;
86
87        e.u8(1)?;
88        self.components.encode(e, ctx)?;
89
90        if let Some(source_fingerprint) = self.source_fingerprint {
91            e.u8(2)?.u32(source_fingerprint.get())?;
92        }
93
94        if let Some(depth) = self.depth {
95            e.u8(3)?.u8(depth)?;
96        }
97
98        Ok(())
99    }
100}
101
102#[cfg(feature = "bitcoin")]
103impl<'a> From<&'a bitcoin::bip32::DerivationPath> for KeypathRef<'a> {
104    fn from(derivation_path: &'a bitcoin::bip32::DerivationPath) -> Self {
105        Self {
106            components: PathComponents {
107                storage: PathStorage::DerivationPath(derivation_path.as_ref()),
108            },
109            source_fingerprint: None,
110            depth: None,
111        }
112    }
113}
114
115/// Metadata for the complete or partial derivation path of a key.
116#[doc(alias("crypto-keypath"))]
117#[cfg(feature = "alloc")]
118#[derive(Debug, Clone, PartialEq)]
119pub struct Keypath {
120    /// Path component.
121    pub components: Vec<PathComponent>,
122    /// Fingerprint from the ancestor key.
123    pub source_fingerprint: Option<NonZeroU32>,
124    /// How many derivations this key is from the master (which is 0).
125    pub depth: Option<u8>,
126}
127
128#[cfg(feature = "alloc")]
129impl<'a> From<KeypathRef<'a>> for Keypath {
130    fn from(keypath: KeypathRef<'a>) -> Self {
131        Self {
132            components: keypath.components.iter().collect(),
133            source_fingerprint: keypath.source_fingerprint,
134            depth: keypath.depth,
135        }
136    }
137}
138
139/// Collection of [`PathComponents`].
140#[derive(Debug, Clone)]
141pub struct PathComponents<'a> {
142    storage: PathStorage<'a>,
143}
144
145#[derive(Debug, Clone)]
146enum PathStorage<'a> {
147    Cbor {
148        d: Decoder<'a>,
149        len: usize,
150    },
151    RawDerivationPath(&'a [u32]),
152    #[cfg(feature = "bitcoin")]
153    DerivationPath(&'a [bitcoin::bip32::ChildNumber]),
154}
155
156impl<'a> PathStorage<'a> {
157    fn len(&self) -> usize {
158        match self {
159            PathStorage::Cbor { len, .. } => *len,
160            PathStorage::RawDerivationPath(path) => path.len(),
161            #[cfg(feature = "bitcoin")]
162            PathStorage::DerivationPath(path) => path.len(),
163        }
164    }
165}
166
167impl<'a> PathComponents<'a> {
168    pub fn len(&self) -> usize {
169        self.storage.len()
170    }
171
172    pub fn is_empty(&self) -> bool {
173        self.len() == 0
174    }
175
176    pub fn iter(&self) -> PathComponentsIter<'a> {
177        PathComponentsIter {
178            storage: self.storage.clone(),
179            index: 0,
180        }
181    }
182}
183
184impl<'b, C> Decode<'b, C> for PathComponents<'b> {
185    fn decode(d: &mut Decoder<'b>, ctx: &mut C) -> Result<Self, Error> {
186        // Eat the array type bytes.
187        let mut array_len = d.array()?.map(|len| len / 2);
188
189        // Clone the original decoder as the "starting point" of the
190        // path components.
191        let path_decoder = d.clone();
192
193        // Iterate over the path components in order to verify the data and
194        // to consume the bytes of the passed decoder.
195        let mut len: usize = 0;
196        loop {
197            match array_len {
198                Some(0) => break,
199                Some(n) => array_len = Some(n.saturating_sub(1)),
200                None => {
201                    if d.datatype()? == Type::Break {
202                        break;
203                    }
204                }
205            }
206
207            // Consume the path component in order to advance the decoder.
208            PathComponent::decode(d, ctx)?;
209            match len.overflowing_add(1) {
210                (new_len, false) => len = new_len,
211                (_, true) => return Err(Error::message("too many elements")),
212            }
213        }
214
215        Ok(Self {
216            storage: PathStorage::Cbor {
217                d: path_decoder,
218                len,
219            },
220        })
221    }
222}
223
224impl<'a, C> Encode<C> for PathComponents<'a> {
225    fn encode<W: Write>(
226        &self,
227        e: &mut Encoder<W>,
228        ctx: &mut C,
229    ) -> Result<(), minicbor::encode::Error<W::Error>> {
230        e.array(self.len() as u64 * 2)?;
231
232        for elt in self.iter() {
233            elt.encode(e, ctx)?;
234        }
235
236        Ok(())
237    }
238}
239
240impl<'a> PartialEq for PathComponents<'a> {
241    fn eq(&self, other: &Self) -> bool {
242        if self.len() != other.len() {
243            return false;
244        }
245
246        for (lhs, rhs) in self.iter().zip(other.iter()) {
247            if lhs != rhs {
248                return false;
249            }
250        }
251
252        true
253    }
254}
255
256impl<'a> From<&'a [u32]> for PathComponents<'a> {
257    fn from(path: &'a [u32]) -> Self {
258        Self {
259            storage: PathStorage::RawDerivationPath(path),
260        }
261    }
262}
263
264impl<'a, const N: usize> From<&'a [u32; N]> for PathComponents<'a> {
265    fn from(path: &'a [u32; N]) -> Self {
266        Self {
267            storage: PathStorage::RawDerivationPath(path),
268        }
269    }
270}
271
272/// Iterator over the path components of a [`PathComponents`].
273pub struct PathComponentsIter<'a> {
274    storage: PathStorage<'a>,
275    index: usize,
276}
277
278impl<'a> Iterator for PathComponentsIter<'a> {
279    type Item = PathComponent;
280
281    fn next(&mut self) -> Option<Self::Item> {
282        if self.index >= self.storage.len() {
283            return None;
284        }
285
286        let component = match self.storage {
287            PathStorage::Cbor { ref mut d, .. } => {
288                PathComponent::decode(d, &mut ()).expect("path component should be valid")
289            }
290            PathStorage::RawDerivationPath(path) => {
291                let (number, is_hardened) = if path[self.index] & (1 << 31) != 0 {
292                    (path[self.index] ^ (1 << 31), true)
293                } else {
294                    (path[self.index], false)
295                };
296
297                PathComponent {
298                    number: ChildNumber::Number(number),
299                    is_hardened,
300                }
301            }
302            #[cfg(feature = "bitcoin")]
303            PathStorage::DerivationPath(path) => PathComponent::from(path[self.index]),
304        };
305
306        self.index += 1;
307        Some(component)
308    }
309}
310
311impl<'a> ExactSizeIterator for PathComponentsIter<'a> {
312    fn len(&self) -> usize {
313        self.storage.len()
314    }
315}
316
317/// A derivation path component.
318#[doc(alias("path-component"))]
319#[derive(Debug, Clone, Eq, PartialEq)]
320pub struct PathComponent {
321    /// The child number.
322    pub number: ChildNumber,
323    /// Hardened key?
324    pub is_hardened: bool,
325}
326
327impl<'b, C> Decode<'b, C> for PathComponent {
328    fn decode(d: &mut Decoder<'b>, _ctx: &mut C) -> Result<Self, Error> {
329        let number = match d.datatype()? {
330            Type::U8 | Type::U16 | Type::U32 => ChildNumber::Number(d.u32()?),
331            Type::Array => {
332                let mut array = d.array_iter::<u32>()?;
333                let low = array
334                    .next()
335                    .ok_or_else(|| Error::message("low child-index not present"))??;
336                let high = array
337                    .next()
338                    .ok_or_else(|| Error::message("high child-index not present"))??;
339                if array.next().is_some() {
340                    return Err(Error::message("invalid child-index-range size"));
341                }
342
343                ChildNumber::Range(low..high)
344            }
345            _ => return Err(Error::message("unknown child number")),
346        };
347
348        Ok(Self {
349            number,
350            is_hardened: d.bool()?,
351        })
352    }
353}
354
355impl<C> Encode<C> for PathComponent {
356    fn encode<W: Write>(
357        &self,
358        e: &mut Encoder<W>,
359        _ctx: &mut C,
360    ) -> Result<(), minicbor::encode::Error<W::Error>> {
361        match self.number {
362            ChildNumber::Number(n) => e.u32(n)?,
363            ChildNumber::Range(ref range) => e.array(2)?.u32(range.start)?.u32(range.end)?,
364        };
365
366        e.bool(self.is_hardened)?;
367
368        Ok(())
369    }
370}
371
372#[cfg(feature = "bitcoin")]
373impl From<bitcoin::bip32::ChildNumber> for PathComponent {
374    fn from(number: bitcoin::bip32::ChildNumber) -> Self {
375        match number {
376            bitcoin::bip32::ChildNumber::Normal { index } => PathComponent {
377                number: ChildNumber::Number(index),
378                is_hardened: false,
379            },
380            bitcoin::bip32::ChildNumber::Hardened { index } => PathComponent {
381                number: ChildNumber::Number(index),
382                is_hardened: true,
383            },
384        }
385    }
386}
387
388/// The child number of a path component.
389// TODO: add wildcard support.
390#[derive(Debug, Clone, Eq, PartialEq)]
391pub enum ChildNumber {
392    /// A single child number.
393    Number(u32),
394    /// A range of child numbers.
395    Range(Range<u32>),
396}