Skip to main content

bsv/compat/
bip32.rs

1//! BIP32 Hierarchical Deterministic (HD) key derivation.
2//!
3//! Implements the BIP32 specification for generating a tree of keypairs
4//! from a single seed. Supports extended private and public keys (xprv/xpub),
5//! hardened and normal child derivation, and Base58Check serialization.
6
7use crate::compat::error::CompatError;
8use crate::primitives::base_point::BasePoint;
9use crate::primitives::big_number::{BigNumber, Endian};
10use crate::primitives::curve::Curve;
11use crate::primitives::hash::hash256;
12use crate::primitives::hash::{hash160, sha512_hmac};
13use crate::primitives::private_key::PrivateKey;
14use crate::primitives::public_key::PublicKey;
15use crate::primitives::utils::{base58_decode, base58_encode};
16
17/// Version bytes for mainnet extended private key (xprv).
18const XPRV_VERSION: [u8; 4] = [0x04, 0x88, 0xAD, 0xE4];
19
20/// Version bytes for mainnet extended public key (xpub).
21const XPUB_VERSION: [u8; 4] = [0x04, 0x88, 0xB2, 0x1E];
22
23/// Offset added to child index for hardened derivation.
24const HARDENED_OFFSET: u32 = 0x80000000;
25
26/// A BIP32 extended key (private or public) with chain code and derivation metadata.
27///
28/// Supports key derivation, serialization to/from Base58Check xprv/xpub strings,
29/// and conversion between private and public extended keys.
30#[derive(Clone, Debug)]
31pub struct ExtendedKey {
32    /// Private key bytes (32) or compressed public key bytes (33).
33    key: Vec<u8>,
34    /// 32-byte chain code for child derivation.
35    chain_code: Vec<u8>,
36    /// Derivation depth (0 = master).
37    depth: u8,
38    /// First 4 bytes of parent key's hash160 (0x00000000 for master).
39    parent_fingerprint: [u8; 4],
40    /// Child index that produced this key (0 for master).
41    child_index: u32,
42    /// Version bytes (XPRV_VERSION or XPUB_VERSION).
43    version: [u8; 4],
44    /// Whether this is a private extended key.
45    is_private: bool,
46}
47
48impl ExtendedKey {
49    /// Derive a master extended key from a seed.
50    ///
51    /// Uses HMAC-SHA512 with key "Bitcoin seed" per BIP32 specification.
52    /// Seed must be between 16 and 64 bytes (128-512 bits).
53    pub fn from_seed(seed: &[u8]) -> Result<Self, CompatError> {
54        if seed.len() < 16 {
55            return Err(CompatError::InvalidEntropy(
56                "seed must be at least 128 bits".to_string(),
57            ));
58        }
59        if seed.len() > 64 {
60            return Err(CompatError::InvalidEntropy(
61                "seed must be at most 512 bits".to_string(),
62            ));
63        }
64
65        let hmac = sha512_hmac(b"Bitcoin seed", seed);
66        let secret_key = &hmac[0..32];
67        let chain_code = &hmac[32..64];
68
69        // Validate secret key is in valid range [1, n-1]
70        let key_num = BigNumber::from_bytes(secret_key, Endian::Big);
71        let curve = Curve::secp256k1();
72        if key_num.cmp(&curve.n) >= 0 || key_num.is_zero() {
73            return Err(CompatError::UnusableSeed);
74        }
75
76        Ok(ExtendedKey {
77            key: secret_key.to_vec(),
78            chain_code: chain_code.to_vec(),
79            depth: 0,
80            parent_fingerprint: [0, 0, 0, 0],
81            child_index: 0,
82            version: XPRV_VERSION,
83            is_private: true,
84        })
85    }
86
87    /// Derive a child extended key by path string.
88    ///
89    /// Path format: "m/0'/1/2'" where apostrophe or "h" suffix means hardened.
90    /// Leading "m/" or "m" prefix is stripped.
91    pub fn derive(&self, path: &str) -> Result<Self, CompatError> {
92        let path = path.trim();
93
94        // Strip leading "m" or "m/"
95        let components = if path == "m" || path == "M" {
96            return Ok(self.clone());
97        } else if let Some(rest) = path.strip_prefix("m/").or_else(|| path.strip_prefix("M/")) {
98            rest
99        } else {
100            path
101        };
102
103        let mut current = self.clone();
104        for component in components.split('/') {
105            let component = component.trim();
106            if component.is_empty() {
107                continue;
108            }
109
110            let (index_str, hardened) = if let Some(s) = component.strip_suffix('\'') {
111                (s, true)
112            } else if let Some(s) = component.strip_suffix('h') {
113                (s, true)
114            } else {
115                (component, false)
116            };
117
118            let index: u32 = index_str
119                .parse()
120                .map_err(|_| CompatError::InvalidPath(format!("invalid index: {}", index_str)))?;
121
122            let child_index = if hardened {
123                index
124                    .checked_add(HARDENED_OFFSET)
125                    .ok_or_else(|| CompatError::InvalidPath("index overflow".to_string()))?
126            } else {
127                index
128            };
129
130            current = current.derive_child(child_index)?;
131        }
132
133        Ok(current)
134    }
135
136    /// Derive a single child key by index.
137    ///
138    /// Index >= 0x80000000 is hardened derivation (requires private key).
139    /// Index < 0x80000000 is normal derivation (works with public key).
140    pub fn derive_child(&self, index: u32) -> Result<Self, CompatError> {
141        if self.depth == 255 {
142            return Err(CompatError::DepthExceeded);
143        }
144
145        let is_hardened = index >= HARDENED_OFFSET;
146        if is_hardened && !self.is_private {
147            return Err(CompatError::HardenedFromPublic);
148        }
149
150        // Build HMAC data
151        let mut data = Vec::with_capacity(37);
152        if is_hardened {
153            // Hardened: 0x00 || private_key(32) || index(4)
154            data.push(0x00);
155            let padded_key = self.padded_key_bytes(32);
156            data.extend_from_slice(&padded_key);
157        } else {
158            // Normal: compressed_pubkey(33) || index(4)
159            let pubkey_bytes = self.compressed_pubkey_bytes()?;
160            data.extend_from_slice(&pubkey_bytes);
161        }
162        data.extend_from_slice(&index.to_be_bytes());
163
164        let hmac = sha512_hmac(&self.chain_code, &data);
165        let il = &hmac[0..32];
166        let ir = &hmac[32..64];
167
168        let curve = Curve::secp256k1();
169        let il_num = BigNumber::from_bytes(il, Endian::Big);
170
171        // Validate IL < n
172        if il_num.cmp(&curve.n) >= 0 {
173            return Err(CompatError::InvalidChild);
174        }
175
176        // Compute parent fingerprint
177        let parent_pubkey = self.compressed_pubkey_bytes()?;
178        let parent_hash = hash160(&parent_pubkey);
179        let mut fingerprint = [0u8; 4];
180        fingerprint.copy_from_slice(&parent_hash[..4]);
181
182        if self.is_private {
183            // Private child: key = (IL + parent_key) mod n
184            let parent_num = BigNumber::from_bytes(&self.key, Endian::Big);
185            let child_num = il_num.add(&parent_num).umod(&curve.n).map_err(|e| {
186                CompatError::Primitives(crate::primitives::error::PrimitivesError::ArithmeticError(
187                    e.to_string(),
188                ))
189            })?;
190
191            if child_num.is_zero() {
192                return Err(CompatError::InvalidChild);
193            }
194
195            let child_key = child_num.to_array(Endian::Big, Some(32));
196
197            Ok(ExtendedKey {
198                key: child_key,
199                chain_code: ir.to_vec(),
200                depth: self.depth + 1,
201                parent_fingerprint: fingerprint,
202                child_index: index,
203                version: XPRV_VERSION,
204                is_private: true,
205            })
206        } else {
207            // Public child: key = point(IL) + parent_pubkey
208            let il_point = BasePoint::instance().mul(&il_num);
209            let parent_point = PublicKey::from_der_bytes(&parent_pubkey)?;
210            let child_point = il_point.add(parent_point.point());
211
212            if child_point.is_infinity() {
213                return Err(CompatError::InvalidChild);
214            }
215
216            let child_pubkey = child_point.to_der(true);
217
218            Ok(ExtendedKey {
219                key: child_pubkey,
220                chain_code: ir.to_vec(),
221                depth: self.depth + 1,
222                parent_fingerprint: fingerprint,
223                child_index: index,
224                version: XPUB_VERSION,
225                is_private: false,
226            })
227        }
228    }
229
230    /// Convert a private extended key to its public counterpart.
231    ///
232    /// Returns a new ExtendedKey with the public key and xpub version bytes.
233    pub fn to_public(&self) -> Result<Self, CompatError> {
234        if !self.is_private {
235            return Ok(self.clone());
236        }
237
238        let pubkey_bytes = self.compressed_pubkey_bytes()?;
239
240        Ok(ExtendedKey {
241            key: pubkey_bytes,
242            chain_code: self.chain_code.clone(),
243            depth: self.depth,
244            parent_fingerprint: self.parent_fingerprint,
245            child_index: self.child_index,
246            version: XPUB_VERSION,
247            is_private: false,
248        })
249    }
250
251    /// Serialize to a Base58Check string (xprv or xpub).
252    ///
253    /// 78-byte payload: version(4) || depth(1) || fingerprint(4) ||
254    /// child_index(4) || chain_code(32) || key(33, with 0x00 prefix for private).
255    pub fn to_base58(&self) -> String {
256        let mut payload = Vec::with_capacity(78);
257        payload.extend_from_slice(&self.version);
258        payload.push(self.depth);
259        payload.extend_from_slice(&self.parent_fingerprint);
260        payload.extend_from_slice(&self.child_index.to_be_bytes());
261        payload.extend_from_slice(&self.chain_code);
262
263        if self.is_private {
264            payload.push(0x00);
265            let padded = self.padded_key_bytes(32);
266            payload.extend_from_slice(&padded);
267        } else {
268            payload.extend_from_slice(&self.key);
269        }
270
271        assert_eq!(payload.len(), 78, "BIP32 payload must be exactly 78 bytes");
272
273        // Manual checksum using hash256 (matches Go SDK pattern and works
274        // correctly with the 4-byte version prefix).
275        let checksum = hash256(&payload);
276        payload.extend_from_slice(&checksum[..4]);
277
278        base58_encode(&payload)
279    }
280
281    /// Parse an extended key from a Base58Check string (xprv or xpub).
282    pub fn from_string(s: &str) -> Result<Self, CompatError> {
283        let decoded = base58_decode(s)
284            .map_err(|e| CompatError::InvalidExtendedKey(format!("base58 decode: {}", e)))?;
285
286        if decoded.len() != 82 {
287            return Err(CompatError::InvalidExtendedKey(format!(
288                "expected 82 bytes, got {}",
289                decoded.len()
290            )));
291        }
292
293        // Verify checksum
294        let payload = &decoded[..78];
295        let checksum = &decoded[78..82];
296        let expected_checksum = hash256(payload);
297        if checksum != &expected_checksum[..4] {
298            return Err(CompatError::ChecksumMismatch);
299        }
300
301        let mut version = [0u8; 4];
302        version.copy_from_slice(&payload[0..4]);
303
304        let is_private = if version == XPRV_VERSION {
305            true
306        } else if version == XPUB_VERSION {
307            false
308        } else {
309            return Err(CompatError::InvalidMagic);
310        };
311
312        let depth = payload[4];
313        let mut parent_fingerprint = [0u8; 4];
314        parent_fingerprint.copy_from_slice(&payload[5..9]);
315        let child_index = u32::from_be_bytes([payload[9], payload[10], payload[11], payload[12]]);
316        let chain_code = payload[13..45].to_vec();
317
318        let key = if is_private {
319            // Private key: 0x00 prefix + 32 bytes
320            if payload[45] != 0x00 {
321                return Err(CompatError::InvalidExtendedKey(
322                    "private key must start with 0x00".to_string(),
323                ));
324            }
325            payload[46..78].to_vec()
326        } else {
327            // Public key: 33 bytes (compressed)
328            payload[45..78].to_vec()
329        };
330
331        Ok(ExtendedKey {
332            key,
333            chain_code,
334            depth,
335            parent_fingerprint,
336            child_index,
337            version,
338            is_private,
339        })
340    }
341
342    /// Get the public key for this extended key.
343    ///
344    /// If private, derives the public key. If public, parses the stored key.
345    pub fn public_key(&self) -> Result<PublicKey, CompatError> {
346        if self.is_private {
347            let priv_key = PrivateKey::from_bytes(&self.key)?;
348            Ok(priv_key.to_public_key())
349        } else {
350            Ok(PublicKey::from_der_bytes(&self.key)?)
351        }
352    }
353
354    /// Whether this is a private extended key.
355    pub fn is_private(&self) -> bool {
356        self.is_private
357    }
358
359    /// Get the derivation depth.
360    pub fn depth(&self) -> u8 {
361        self.depth
362    }
363
364    // -----------------------------------------------------------------------
365    // Private helpers
366    // -----------------------------------------------------------------------
367
368    /// Get the compressed public key bytes (33 bytes) for this key.
369    fn compressed_pubkey_bytes(&self) -> Result<Vec<u8>, CompatError> {
370        if self.is_private {
371            let priv_key = PrivateKey::from_bytes(&self.key)?;
372            Ok(priv_key.to_public_key().to_der())
373        } else {
374            Ok(self.key.clone())
375        }
376    }
377
378    /// Get the key bytes padded to the specified length.
379    fn padded_key_bytes(&self, len: usize) -> Vec<u8> {
380        if self.key.len() >= len {
381            self.key[self.key.len() - len..].to_vec()
382        } else {
383            let mut padded = vec![0u8; len - self.key.len()];
384            padded.extend_from_slice(&self.key);
385            padded
386        }
387    }
388}
389
390impl std::fmt::Display for ExtendedKey {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        write!(f, "{}", self.to_base58())
393    }
394}
395
396// ============================================================================
397// Tests
398// ============================================================================
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use serde::Deserialize;
404
405    fn hex_to_bytes(hex: &str) -> Vec<u8> {
406        (0..hex.len())
407            .step_by(2)
408            .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
409            .collect()
410    }
411
412    #[derive(Deserialize)]
413    struct ChainVector {
414        path: String,
415        xprv: String,
416        xpub: String,
417    }
418
419    #[derive(Deserialize)]
420    struct SeedVector {
421        seed: String,
422        chains: Vec<ChainVector>,
423    }
424
425    #[derive(Deserialize)]
426    struct Bip32Vectors {
427        vectors: Vec<SeedVector>,
428    }
429
430    fn load_vectors() -> Bip32Vectors {
431        let data = include_str!("../../test-vectors/bip32_vectors.json");
432        serde_json::from_str(data).unwrap()
433    }
434
435    // Test 1: from_seed with vector 1 produces correct master xprv and xpub
436    #[test]
437    fn test_vector1_master_key() {
438        let vectors = load_vectors();
439        let v = &vectors.vectors[0];
440        let seed = hex_to_bytes(&v.seed);
441        let master = ExtendedKey::from_seed(&seed).unwrap();
442
443        assert_eq!(master.to_string(), v.chains[0].xprv);
444        assert_eq!(master.to_public().unwrap().to_string(), v.chains[0].xpub);
445    }
446
447    // Test 2: from_seed with vector 2 produces correct master key
448    #[test]
449    fn test_vector2_master_key() {
450        let vectors = load_vectors();
451        let v = &vectors.vectors[1];
452        let seed = hex_to_bytes(&v.seed);
453        let master = ExtendedKey::from_seed(&seed).unwrap();
454
455        assert_eq!(master.to_string(), v.chains[0].xprv);
456        assert_eq!(master.to_public().unwrap().to_string(), v.chains[0].xpub);
457    }
458
459    // Test 3: derive("m/0'") from master produces correct hardened child
460    #[test]
461    fn test_vector1_hardened_child() {
462        let vectors = load_vectors();
463        let v = &vectors.vectors[0];
464        let seed = hex_to_bytes(&v.seed);
465        let master = ExtendedKey::from_seed(&seed).unwrap();
466
467        let child = master.derive("m/0'").unwrap();
468        assert_eq!(child.to_string(), v.chains[1].xprv);
469        assert_eq!(child.to_public().unwrap().to_string(), v.chains[1].xpub);
470    }
471
472    // Test 4: Full derivation path produces correct keys at each level
473    #[test]
474    fn test_vector1_full_derivation() {
475        let vectors = load_vectors();
476        let v = &vectors.vectors[0];
477        let seed = hex_to_bytes(&v.seed);
478        let master = ExtendedKey::from_seed(&seed).unwrap();
479
480        for chain in &v.chains {
481            let derived = master.derive(&chain.path).unwrap();
482            assert_eq!(
483                derived.to_string(),
484                chain.xprv,
485                "xprv mismatch for path {}",
486                chain.path
487            );
488            assert_eq!(
489                derived.to_public().unwrap().to_string(),
490                chain.xpub,
491                "xpub mismatch for path {}",
492                chain.path
493            );
494        }
495    }
496
497    // Test 4b: Full derivation of vector 2
498    #[test]
499    fn test_vector2_full_derivation() {
500        let vectors = load_vectors();
501        let v = &vectors.vectors[1];
502        let seed = hex_to_bytes(&v.seed);
503        let master = ExtendedKey::from_seed(&seed).unwrap();
504
505        for chain in &v.chains {
506            let derived = master.derive(&chain.path).unwrap();
507            assert_eq!(
508                derived.to_string(),
509                chain.xprv,
510                "xprv mismatch for path {}",
511                chain.path
512            );
513            assert_eq!(
514                derived.to_public().unwrap().to_string(),
515                chain.xpub,
516                "xpub mismatch for path {}",
517                chain.path
518            );
519        }
520    }
521
522    // Test 5: to_public() produces correct xpub serialization
523    #[test]
524    fn test_to_public() {
525        let vectors = load_vectors();
526        let v = &vectors.vectors[0];
527        let seed = hex_to_bytes(&v.seed);
528        let master = ExtendedKey::from_seed(&seed).unwrap();
529
530        let public = master.to_public().unwrap();
531        assert!(!public.is_private());
532        assert_eq!(public.to_string(), v.chains[0].xpub);
533    }
534
535    // Test 6: from_string() round-trips
536    #[test]
537    fn test_from_string_round_trip() {
538        let vectors = load_vectors();
539        let v = &vectors.vectors[0];
540
541        // Test xprv round-trip
542        let parsed_priv = ExtendedKey::from_string(&v.chains[0].xprv).unwrap();
543        assert_eq!(parsed_priv.to_string(), v.chains[0].xprv);
544
545        // Test xpub round-trip
546        let parsed_pub = ExtendedKey::from_string(&v.chains[0].xpub).unwrap();
547        assert_eq!(parsed_pub.to_string(), v.chains[0].xpub);
548    }
549
550    // Test 7: derive from public xpub for normal children
551    #[test]
552    fn test_public_derivation() {
553        let vectors = load_vectors();
554        let v = &vectors.vectors[0];
555        let seed = hex_to_bytes(&v.seed);
556        let master = ExtendedKey::from_seed(&seed).unwrap();
557
558        // Derive m/0' privately, then get the public key
559        let child_priv = master.derive("m/0'").unwrap();
560        let child_pub = child_priv.to_public().unwrap();
561
562        // From public key, derive normal child m/0'/1
563        let grandchild_pub = child_pub.derive("m/1").unwrap();
564
565        // Should match the public key derived from m/0'/1 privately
566        assert_eq!(
567            grandchild_pub.to_string(),
568            v.chains[2].xpub,
569            "public derivation of normal child should match"
570        );
571    }
572
573    // Test 8: derive hardened child from public key returns error
574    #[test]
575    fn test_hardened_from_public_error() {
576        let vectors = load_vectors();
577        let v = &vectors.vectors[0];
578        let seed = hex_to_bytes(&v.seed);
579        let master = ExtendedKey::from_seed(&seed).unwrap();
580        let public = master.to_public().unwrap();
581
582        let result = public.derive("m/0'");
583        assert!(result.is_err(), "hardened from public should fail");
584        match result.unwrap_err() {
585            CompatError::HardenedFromPublic => {}
586            e => panic!("expected HardenedFromPublic, got {:?}", e),
587        }
588    }
589
590    // Test 9: depth exceeding 255 returns error
591    #[test]
592    fn test_depth_exceeded() {
593        let seed = hex_to_bytes("000102030405060708090a0b0c0d0e0f");
594        let master = ExtendedKey::from_seed(&seed).unwrap();
595
596        // Create a key at depth 255 by manipulating internals
597        let deep_key = ExtendedKey {
598            key: master.key.clone(),
599            chain_code: master.chain_code.clone(),
600            depth: 255,
601            parent_fingerprint: [0; 4],
602            child_index: 0,
603            version: XPRV_VERSION,
604            is_private: true,
605        };
606
607        let result = deep_key.derive_child(0);
608        assert!(result.is_err(), "depth 255 derivation should fail");
609        match result.unwrap_err() {
610            CompatError::DepthExceeded => {}
611            e => panic!("expected DepthExceeded, got {:?}", e),
612        }
613    }
614
615    // Test: from_string/to_string round-trip for all vector keys
616    #[test]
617    fn test_all_vectors_from_string_round_trip() {
618        let vectors = load_vectors();
619        for v in &vectors.vectors {
620            for chain in &v.chains {
621                let priv_key = ExtendedKey::from_string(&chain.xprv).unwrap();
622                assert_eq!(
623                    priv_key.to_string(),
624                    chain.xprv,
625                    "xprv round-trip failed for {}",
626                    chain.path
627                );
628
629                let pub_key = ExtendedKey::from_string(&chain.xpub).unwrap();
630                assert_eq!(
631                    pub_key.to_string(),
632                    chain.xpub,
633                    "xpub round-trip failed for {}",
634                    chain.path
635                );
636            }
637        }
638    }
639
640    // Test: public derivation for vector 1 m/0'/1/2'/2 -> m/0'/1/2'/2/1000000000
641    #[test]
642    fn test_public_derivation_deep() {
643        let vectors = load_vectors();
644        let v = &vectors.vectors[0];
645        let seed = hex_to_bytes(&v.seed);
646        let master = ExtendedKey::from_seed(&seed).unwrap();
647
648        // Derive m/0'/1/2'/2 privately, then get public
649        let child_priv = master.derive("m/0'/1/2'/2").unwrap();
650        let child_pub = child_priv.to_public().unwrap();
651
652        // From public, derive normal child 1000000000
653        let grandchild_pub = child_pub.derive("m/1000000000").unwrap();
654        assert_eq!(
655            grandchild_pub.to_string(),
656            v.chains[5].xpub,
657            "public derivation of deep normal child should match"
658        );
659    }
660}