1use crate::compat::error::CompatError;
8use crate::primitives::base_point::BasePoint;
9use crate::primitives::big_number::{BigNumber, Endian};
10use crate::primitives::curve::Curve;
11use crate::primitives::hash::hash256;
12use crate::primitives::hash::{hash160, sha512_hmac};
13use crate::primitives::private_key::PrivateKey;
14use crate::primitives::public_key::PublicKey;
15use crate::primitives::utils::{base58_decode, base58_encode};
16
17const XPRV_VERSION: [u8; 4] = [0x04, 0x88, 0xAD, 0xE4];
19
20const XPUB_VERSION: [u8; 4] = [0x04, 0x88, 0xB2, 0x1E];
22
23const HARDENED_OFFSET: u32 = 0x80000000;
25
26#[derive(Clone, Debug)]
31pub struct ExtendedKey {
32 key: Vec<u8>,
34 chain_code: Vec<u8>,
36 depth: u8,
38 parent_fingerprint: [u8; 4],
40 child_index: u32,
42 version: [u8; 4],
44 is_private: bool,
46}
47
48impl ExtendedKey {
49 pub fn from_seed(seed: &[u8]) -> Result<Self, CompatError> {
54 if seed.len() < 16 {
55 return Err(CompatError::InvalidEntropy(
56 "seed must be at least 128 bits".to_string(),
57 ));
58 }
59 if seed.len() > 64 {
60 return Err(CompatError::InvalidEntropy(
61 "seed must be at most 512 bits".to_string(),
62 ));
63 }
64
65 let hmac = sha512_hmac(b"Bitcoin seed", seed);
66 let secret_key = &hmac[0..32];
67 let chain_code = &hmac[32..64];
68
69 let key_num = BigNumber::from_bytes(secret_key, Endian::Big);
71 let curve = Curve::secp256k1();
72 if key_num.cmp(&curve.n) >= 0 || key_num.is_zero() {
73 return Err(CompatError::UnusableSeed);
74 }
75
76 Ok(ExtendedKey {
77 key: secret_key.to_vec(),
78 chain_code: chain_code.to_vec(),
79 depth: 0,
80 parent_fingerprint: [0, 0, 0, 0],
81 child_index: 0,
82 version: XPRV_VERSION,
83 is_private: true,
84 })
85 }
86
87 pub fn derive(&self, path: &str) -> Result<Self, CompatError> {
92 let path = path.trim();
93
94 let components = if path == "m" || path == "M" {
96 return Ok(self.clone());
97 } else if let Some(rest) = path.strip_prefix("m/").or_else(|| path.strip_prefix("M/")) {
98 rest
99 } else {
100 path
101 };
102
103 let mut current = self.clone();
104 for component in components.split('/') {
105 let component = component.trim();
106 if component.is_empty() {
107 continue;
108 }
109
110 let (index_str, hardened) = if let Some(s) = component.strip_suffix('\'') {
111 (s, true)
112 } else if let Some(s) = component.strip_suffix('h') {
113 (s, true)
114 } else {
115 (component, false)
116 };
117
118 let index: u32 = index_str
119 .parse()
120 .map_err(|_| CompatError::InvalidPath(format!("invalid index: {}", index_str)))?;
121
122 let child_index = if hardened {
123 index
124 .checked_add(HARDENED_OFFSET)
125 .ok_or_else(|| CompatError::InvalidPath("index overflow".to_string()))?
126 } else {
127 index
128 };
129
130 current = current.derive_child(child_index)?;
131 }
132
133 Ok(current)
134 }
135
136 pub fn derive_child(&self, index: u32) -> Result<Self, CompatError> {
141 if self.depth == 255 {
142 return Err(CompatError::DepthExceeded);
143 }
144
145 let is_hardened = index >= HARDENED_OFFSET;
146 if is_hardened && !self.is_private {
147 return Err(CompatError::HardenedFromPublic);
148 }
149
150 let mut data = Vec::with_capacity(37);
152 if is_hardened {
153 data.push(0x00);
155 let padded_key = self.padded_key_bytes(32);
156 data.extend_from_slice(&padded_key);
157 } else {
158 let pubkey_bytes = self.compressed_pubkey_bytes()?;
160 data.extend_from_slice(&pubkey_bytes);
161 }
162 data.extend_from_slice(&index.to_be_bytes());
163
164 let hmac = sha512_hmac(&self.chain_code, &data);
165 let il = &hmac[0..32];
166 let ir = &hmac[32..64];
167
168 let curve = Curve::secp256k1();
169 let il_num = BigNumber::from_bytes(il, Endian::Big);
170
171 if il_num.cmp(&curve.n) >= 0 {
173 return Err(CompatError::InvalidChild);
174 }
175
176 let parent_pubkey = self.compressed_pubkey_bytes()?;
178 let parent_hash = hash160(&parent_pubkey);
179 let mut fingerprint = [0u8; 4];
180 fingerprint.copy_from_slice(&parent_hash[..4]);
181
182 if self.is_private {
183 let parent_num = BigNumber::from_bytes(&self.key, Endian::Big);
185 let child_num = il_num.add(&parent_num).umod(&curve.n).map_err(|e| {
186 CompatError::Primitives(crate::primitives::error::PrimitivesError::ArithmeticError(
187 e.to_string(),
188 ))
189 })?;
190
191 if child_num.is_zero() {
192 return Err(CompatError::InvalidChild);
193 }
194
195 let child_key = child_num.to_array(Endian::Big, Some(32));
196
197 Ok(ExtendedKey {
198 key: child_key,
199 chain_code: ir.to_vec(),
200 depth: self.depth + 1,
201 parent_fingerprint: fingerprint,
202 child_index: index,
203 version: XPRV_VERSION,
204 is_private: true,
205 })
206 } else {
207 let il_point = BasePoint::instance().mul(&il_num);
209 let parent_point = PublicKey::from_der_bytes(&parent_pubkey)?;
210 let child_point = il_point.add(parent_point.point());
211
212 if child_point.is_infinity() {
213 return Err(CompatError::InvalidChild);
214 }
215
216 let child_pubkey = child_point.to_der(true);
217
218 Ok(ExtendedKey {
219 key: child_pubkey,
220 chain_code: ir.to_vec(),
221 depth: self.depth + 1,
222 parent_fingerprint: fingerprint,
223 child_index: index,
224 version: XPUB_VERSION,
225 is_private: false,
226 })
227 }
228 }
229
230 pub fn to_public(&self) -> Result<Self, CompatError> {
234 if !self.is_private {
235 return Ok(self.clone());
236 }
237
238 let pubkey_bytes = self.compressed_pubkey_bytes()?;
239
240 Ok(ExtendedKey {
241 key: pubkey_bytes,
242 chain_code: self.chain_code.clone(),
243 depth: self.depth,
244 parent_fingerprint: self.parent_fingerprint,
245 child_index: self.child_index,
246 version: XPUB_VERSION,
247 is_private: false,
248 })
249 }
250
251 pub fn to_base58(&self) -> String {
256 let mut payload = Vec::with_capacity(78);
257 payload.extend_from_slice(&self.version);
258 payload.push(self.depth);
259 payload.extend_from_slice(&self.parent_fingerprint);
260 payload.extend_from_slice(&self.child_index.to_be_bytes());
261 payload.extend_from_slice(&self.chain_code);
262
263 if self.is_private {
264 payload.push(0x00);
265 let padded = self.padded_key_bytes(32);
266 payload.extend_from_slice(&padded);
267 } else {
268 payload.extend_from_slice(&self.key);
269 }
270
271 assert_eq!(payload.len(), 78, "BIP32 payload must be exactly 78 bytes");
272
273 let checksum = hash256(&payload);
276 payload.extend_from_slice(&checksum[..4]);
277
278 base58_encode(&payload)
279 }
280
281 pub fn from_string(s: &str) -> Result<Self, CompatError> {
283 let decoded = base58_decode(s)
284 .map_err(|e| CompatError::InvalidExtendedKey(format!("base58 decode: {}", e)))?;
285
286 if decoded.len() != 82 {
287 return Err(CompatError::InvalidExtendedKey(format!(
288 "expected 82 bytes, got {}",
289 decoded.len()
290 )));
291 }
292
293 let payload = &decoded[..78];
295 let checksum = &decoded[78..82];
296 let expected_checksum = hash256(payload);
297 if checksum != &expected_checksum[..4] {
298 return Err(CompatError::ChecksumMismatch);
299 }
300
301 let mut version = [0u8; 4];
302 version.copy_from_slice(&payload[0..4]);
303
304 let is_private = if version == XPRV_VERSION {
305 true
306 } else if version == XPUB_VERSION {
307 false
308 } else {
309 return Err(CompatError::InvalidMagic);
310 };
311
312 let depth = payload[4];
313 let mut parent_fingerprint = [0u8; 4];
314 parent_fingerprint.copy_from_slice(&payload[5..9]);
315 let child_index = u32::from_be_bytes([payload[9], payload[10], payload[11], payload[12]]);
316 let chain_code = payload[13..45].to_vec();
317
318 let key = if is_private {
319 if payload[45] != 0x00 {
321 return Err(CompatError::InvalidExtendedKey(
322 "private key must start with 0x00".to_string(),
323 ));
324 }
325 payload[46..78].to_vec()
326 } else {
327 payload[45..78].to_vec()
329 };
330
331 Ok(ExtendedKey {
332 key,
333 chain_code,
334 depth,
335 parent_fingerprint,
336 child_index,
337 version,
338 is_private,
339 })
340 }
341
342 pub fn public_key(&self) -> Result<PublicKey, CompatError> {
346 if self.is_private {
347 let priv_key = PrivateKey::from_bytes(&self.key)?;
348 Ok(priv_key.to_public_key())
349 } else {
350 Ok(PublicKey::from_der_bytes(&self.key)?)
351 }
352 }
353
354 pub fn is_private(&self) -> bool {
356 self.is_private
357 }
358
359 pub fn depth(&self) -> u8 {
361 self.depth
362 }
363
364 fn compressed_pubkey_bytes(&self) -> Result<Vec<u8>, CompatError> {
370 if self.is_private {
371 let priv_key = PrivateKey::from_bytes(&self.key)?;
372 Ok(priv_key.to_public_key().to_der())
373 } else {
374 Ok(self.key.clone())
375 }
376 }
377
378 fn padded_key_bytes(&self, len: usize) -> Vec<u8> {
380 if self.key.len() >= len {
381 self.key[self.key.len() - len..].to_vec()
382 } else {
383 let mut padded = vec![0u8; len - self.key.len()];
384 padded.extend_from_slice(&self.key);
385 padded
386 }
387 }
388}
389
390impl std::fmt::Display for ExtendedKey {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 write!(f, "{}", self.to_base58())
393 }
394}
395
396#[cfg(test)]
401mod tests {
402 use super::*;
403 use serde::Deserialize;
404
405 fn hex_to_bytes(hex: &str) -> Vec<u8> {
406 (0..hex.len())
407 .step_by(2)
408 .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
409 .collect()
410 }
411
412 #[derive(Deserialize)]
413 struct ChainVector {
414 path: String,
415 xprv: String,
416 xpub: String,
417 }
418
419 #[derive(Deserialize)]
420 struct SeedVector {
421 seed: String,
422 chains: Vec<ChainVector>,
423 }
424
425 #[derive(Deserialize)]
426 struct Bip32Vectors {
427 vectors: Vec<SeedVector>,
428 }
429
430 fn load_vectors() -> Bip32Vectors {
431 let data = include_str!("../../test-vectors/bip32_vectors.json");
432 serde_json::from_str(data).unwrap()
433 }
434
435 #[test]
437 fn test_vector1_master_key() {
438 let vectors = load_vectors();
439 let v = &vectors.vectors[0];
440 let seed = hex_to_bytes(&v.seed);
441 let master = ExtendedKey::from_seed(&seed).unwrap();
442
443 assert_eq!(master.to_string(), v.chains[0].xprv);
444 assert_eq!(master.to_public().unwrap().to_string(), v.chains[0].xpub);
445 }
446
447 #[test]
449 fn test_vector2_master_key() {
450 let vectors = load_vectors();
451 let v = &vectors.vectors[1];
452 let seed = hex_to_bytes(&v.seed);
453 let master = ExtendedKey::from_seed(&seed).unwrap();
454
455 assert_eq!(master.to_string(), v.chains[0].xprv);
456 assert_eq!(master.to_public().unwrap().to_string(), v.chains[0].xpub);
457 }
458
459 #[test]
461 fn test_vector1_hardened_child() {
462 let vectors = load_vectors();
463 let v = &vectors.vectors[0];
464 let seed = hex_to_bytes(&v.seed);
465 let master = ExtendedKey::from_seed(&seed).unwrap();
466
467 let child = master.derive("m/0'").unwrap();
468 assert_eq!(child.to_string(), v.chains[1].xprv);
469 assert_eq!(child.to_public().unwrap().to_string(), v.chains[1].xpub);
470 }
471
472 #[test]
474 fn test_vector1_full_derivation() {
475 let vectors = load_vectors();
476 let v = &vectors.vectors[0];
477 let seed = hex_to_bytes(&v.seed);
478 let master = ExtendedKey::from_seed(&seed).unwrap();
479
480 for chain in &v.chains {
481 let derived = master.derive(&chain.path).unwrap();
482 assert_eq!(
483 derived.to_string(),
484 chain.xprv,
485 "xprv mismatch for path {}",
486 chain.path
487 );
488 assert_eq!(
489 derived.to_public().unwrap().to_string(),
490 chain.xpub,
491 "xpub mismatch for path {}",
492 chain.path
493 );
494 }
495 }
496
497 #[test]
499 fn test_vector2_full_derivation() {
500 let vectors = load_vectors();
501 let v = &vectors.vectors[1];
502 let seed = hex_to_bytes(&v.seed);
503 let master = ExtendedKey::from_seed(&seed).unwrap();
504
505 for chain in &v.chains {
506 let derived = master.derive(&chain.path).unwrap();
507 assert_eq!(
508 derived.to_string(),
509 chain.xprv,
510 "xprv mismatch for path {}",
511 chain.path
512 );
513 assert_eq!(
514 derived.to_public().unwrap().to_string(),
515 chain.xpub,
516 "xpub mismatch for path {}",
517 chain.path
518 );
519 }
520 }
521
522 #[test]
524 fn test_to_public() {
525 let vectors = load_vectors();
526 let v = &vectors.vectors[0];
527 let seed = hex_to_bytes(&v.seed);
528 let master = ExtendedKey::from_seed(&seed).unwrap();
529
530 let public = master.to_public().unwrap();
531 assert!(!public.is_private());
532 assert_eq!(public.to_string(), v.chains[0].xpub);
533 }
534
535 #[test]
537 fn test_from_string_round_trip() {
538 let vectors = load_vectors();
539 let v = &vectors.vectors[0];
540
541 let parsed_priv = ExtendedKey::from_string(&v.chains[0].xprv).unwrap();
543 assert_eq!(parsed_priv.to_string(), v.chains[0].xprv);
544
545 let parsed_pub = ExtendedKey::from_string(&v.chains[0].xpub).unwrap();
547 assert_eq!(parsed_pub.to_string(), v.chains[0].xpub);
548 }
549
550 #[test]
552 fn test_public_derivation() {
553 let vectors = load_vectors();
554 let v = &vectors.vectors[0];
555 let seed = hex_to_bytes(&v.seed);
556 let master = ExtendedKey::from_seed(&seed).unwrap();
557
558 let child_priv = master.derive("m/0'").unwrap();
560 let child_pub = child_priv.to_public().unwrap();
561
562 let grandchild_pub = child_pub.derive("m/1").unwrap();
564
565 assert_eq!(
567 grandchild_pub.to_string(),
568 v.chains[2].xpub,
569 "public derivation of normal child should match"
570 );
571 }
572
573 #[test]
575 fn test_hardened_from_public_error() {
576 let vectors = load_vectors();
577 let v = &vectors.vectors[0];
578 let seed = hex_to_bytes(&v.seed);
579 let master = ExtendedKey::from_seed(&seed).unwrap();
580 let public = master.to_public().unwrap();
581
582 let result = public.derive("m/0'");
583 assert!(result.is_err(), "hardened from public should fail");
584 match result.unwrap_err() {
585 CompatError::HardenedFromPublic => {}
586 e => panic!("expected HardenedFromPublic, got {:?}", e),
587 }
588 }
589
590 #[test]
592 fn test_depth_exceeded() {
593 let seed = hex_to_bytes("000102030405060708090a0b0c0d0e0f");
594 let master = ExtendedKey::from_seed(&seed).unwrap();
595
596 let deep_key = ExtendedKey {
598 key: master.key.clone(),
599 chain_code: master.chain_code.clone(),
600 depth: 255,
601 parent_fingerprint: [0; 4],
602 child_index: 0,
603 version: XPRV_VERSION,
604 is_private: true,
605 };
606
607 let result = deep_key.derive_child(0);
608 assert!(result.is_err(), "depth 255 derivation should fail");
609 match result.unwrap_err() {
610 CompatError::DepthExceeded => {}
611 e => panic!("expected DepthExceeded, got {:?}", e),
612 }
613 }
614
615 #[test]
617 fn test_all_vectors_from_string_round_trip() {
618 let vectors = load_vectors();
619 for v in &vectors.vectors {
620 for chain in &v.chains {
621 let priv_key = ExtendedKey::from_string(&chain.xprv).unwrap();
622 assert_eq!(
623 priv_key.to_string(),
624 chain.xprv,
625 "xprv round-trip failed for {}",
626 chain.path
627 );
628
629 let pub_key = ExtendedKey::from_string(&chain.xpub).unwrap();
630 assert_eq!(
631 pub_key.to_string(),
632 chain.xpub,
633 "xpub round-trip failed for {}",
634 chain.path
635 );
636 }
637 }
638 }
639
640 #[test]
642 fn test_public_derivation_deep() {
643 let vectors = load_vectors();
644 let v = &vectors.vectors[0];
645 let seed = hex_to_bytes(&v.seed);
646 let master = ExtendedKey::from_seed(&seed).unwrap();
647
648 let child_priv = master.derive("m/0'/1/2'/2").unwrap();
650 let child_pub = child_priv.to_public().unwrap();
651
652 let grandchild_pub = child_pub.derive("m/1000000000").unwrap();
654 assert_eq!(
655 grandchild_pub.to_string(),
656 v.chains[5].xpub,
657 "public derivation of deep normal child should match"
658 );
659 }
660}