1use hmac::{Hmac, KeyInit, Mac};
16use ripemd::Ripemd160;
17use sha2::{Digest, Sha256, Sha512};
18
19use elliptic_curve::ops::Reduce;
20use rustcrypto_ff::PrimeField;
21use rustcrypto_group::prime::PrimeCurveAffine;
22use rustcrypto_group::Curve;
23use zeroize::{Zeroize, ZeroizeOnDrop};
24
25use crate::curve::DklsCurve;
26use crate::protocols::{Party, PublicKeyPackage};
27use crate::utilities::hashes::point_to_bytes;
28
29pub type Fingerprint = [u8; 4];
33pub const CHAIN_CODE_LEN: usize = 32;
37pub type ChainCode = [u8; CHAIN_CODE_LEN];
38
39#[derive(Clone, Debug)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub struct ErrorDeriv {
43 pub description: String,
44}
45
46impl ErrorDeriv {
47 #[must_use]
49 pub fn new(description: &str) -> ErrorDeriv {
50 ErrorDeriv {
51 description: String::from(description),
52 }
53 }
54}
55
56#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
64#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65#[cfg_attr(
66 feature = "serde",
67 serde(bound(
68 serialize = "C::AffinePoint: serde::Serialize, C::Scalar: serde::Serialize",
69 deserialize = "C::AffinePoint: serde::Deserialize<'de>, C::Scalar: serde::Deserialize<'de>"
70 ))
71)]
72pub struct DerivData<C: DklsCurve> {
73 pub depth: u8,
75 pub child_number: u32,
77 pub parent_fingerprint: Fingerprint,
79 pub poly_point: C::Scalar,
81 #[zeroize(skip)]
83 pub pk: C::AffinePoint,
84 pub chain_code: ChainCode,
86}
87
88pub const MAX_DEPTH: u8 = u8::MAX;
90pub const MAX_CHILD_NUMBER: u32 = i32::MAX as u32; impl<C: DklsCurve> DerivData<C> {
96 pub fn child_tweak(
106 &self,
107 child_number: u32,
108 ) -> Result<(C::Scalar, ChainCode, Fingerprint), ErrorDeriv> {
109 let mut hmac_engine =
110 Hmac::<Sha512>::new_from_slice(&self.chain_code).expect("HMAC accepts any key length");
111
112 let pk_as_bytes = point_to_bytes::<C>(&self.pk);
113 hmac_engine.update(&pk_as_bytes);
114 hmac_engine.update(&child_number.to_be_bytes());
115
116 let hmac_result = hmac_engine.finalize().into_bytes();
117 let hmac_bytes: [u8; 64] = hmac_result.into();
118
119 let tweak_bytes: [u8; CHAIN_CODE_LEN] = hmac_bytes[..CHAIN_CODE_LEN]
122 .try_into()
123 .expect("Half of hmac is guaranteed to be 32 bytes!");
124 let field_bytes: &elliptic_curve::FieldBytes<C> = tweak_bytes
125 .as_slice()
126 .try_into()
127 .expect("tweak_bytes length matches field size");
128 let tweak = <C::Scalar as Reduce<elliptic_curve::FieldBytes<C>>>::reduce(field_bytes);
129 if tweak.to_repr() != *field_bytes {
130 return Err(ErrorDeriv::new(
131 "BIP-32: HMAC left half >= curve order, child index is invalid",
132 ));
133 }
134
135 let chain_code: ChainCode = hmac_bytes[CHAIN_CODE_LEN..]
136 .try_into()
137 .expect("Half of hmac is guaranteed to be 32 bytes!");
138
139 let hash_result = Ripemd160::digest(Sha256::digest(&pk_as_bytes));
142 let fingerprint: Fingerprint = hash_result[0..4]
143 .try_into()
144 .expect("4 is the fingerprint length!");
145
146 Ok((tweak, chain_code, fingerprint))
147 }
148
149 pub fn derive_child(&self, child_number: u32) -> Result<DerivData<C>, ErrorDeriv> {
157 if self.depth == MAX_DEPTH {
158 return Err(ErrorDeriv::new("We are already at maximum depth!"));
159 }
160
161 if child_number > MAX_CHILD_NUMBER {
162 return Err(ErrorDeriv::new(
163 "Child index should be between 0 and 2^31 - 1!",
164 ));
165 }
166
167 let (tweak, new_chain_code, parent_fingerprint) = self.child_tweak(child_number)?;
168
169 let new_poly_point = self.poly_point + tweak;
173 let generator = <C::AffinePoint as PrimeCurveAffine>::generator();
174 let new_pk = ((generator * tweak) + self.pk).to_affine();
175
176 let identity = <C::AffinePoint as PrimeCurveAffine>::identity();
177 if new_pk == identity {
178 return Err(ErrorDeriv::new(
179 "Very improbable: Child index results in value not allowed by BIP-32!",
180 ));
181 }
182
183 Ok(DerivData {
184 depth: self.depth + 1,
185 child_number,
186 parent_fingerprint,
187 poly_point: new_poly_point,
188 pk: new_pk,
189 chain_code: new_chain_code,
190 })
191 }
192
193 pub fn derive_from_path(&self, path: &str) -> Result<DerivData<C>, ErrorDeriv> {
204 let path_parsed = parse_path(path)?;
205
206 let mut final_data = self.clone();
207 for child_number in path_parsed {
208 final_data = final_data.derive_child(child_number)?;
209 }
210
211 Ok(final_data)
212 }
213}
214
215impl<C: DklsCurve> Party<C> {
219 pub fn derive_child(
227 &self,
228 child_number: u32,
229 address_fn: impl Fn(&C::AffinePoint) -> String,
230 ) -> Result<Party<C>, ErrorDeriv> {
231 let new_derivation_data = self.derivation_data.derive_child(child_number)?;
232
233 let new_address = address_fn(&new_derivation_data.pk);
236
237 Ok(Party {
238 parameters: self.parameters.clone(),
239 party_index: self.party_index,
240 session_id: self.session_id.clone(),
241
242 poly_point: new_derivation_data.poly_point,
243 pk: new_derivation_data.pk,
244
245 zero_share: self.zero_share.clone(),
246
247 mul_senders: self.mul_senders.clone(),
248 mul_receivers: self.mul_receivers.clone(),
249
250 derivation_data: new_derivation_data,
251
252 address: new_address,
253 })
254 }
255
256 pub fn derive_from_path(
267 &self,
268 path: &str,
269 address_fn: impl Fn(&C::AffinePoint) -> String,
270 ) -> Result<Party<C>, ErrorDeriv> {
271 let new_derivation_data = self.derivation_data.derive_from_path(path)?;
272
273 let new_address = address_fn(&new_derivation_data.pk);
276
277 Ok(Party {
278 parameters: self.parameters.clone(),
279 party_index: self.party_index,
280 session_id: self.session_id.clone(),
281
282 poly_point: new_derivation_data.poly_point,
283 pk: new_derivation_data.pk,
284
285 zero_share: self.zero_share.clone(),
286
287 mul_senders: self.mul_senders.clone(),
288 mul_receivers: self.mul_receivers.clone(),
289
290 derivation_data: new_derivation_data,
291
292 address: new_address,
293 })
294 }
295}
296
297impl<C: DklsCurve> PublicKeyPackage<C> {
299 pub fn derive_child(
306 &self,
307 chain_code: &ChainCode,
308 child_number: u32,
309 ) -> Result<PublicKeyPackage<C>, ErrorDeriv> {
310 if child_number > MAX_CHILD_NUMBER {
311 return Err(ErrorDeriv::new(
312 "Child index should be between 0 and 2^31 - 1!",
313 ));
314 }
315
316 let mut hmac_engine =
317 Hmac::<Sha512>::new_from_slice(chain_code).expect("HMAC accepts any key length");
318 let pk_as_bytes = point_to_bytes::<C>(self.verifying_key());
319 hmac_engine.update(&pk_as_bytes);
320 hmac_engine.update(&child_number.to_be_bytes());
321 let hmac_result = hmac_engine.finalize().into_bytes();
322 let hmac_bytes: [u8; 64] = hmac_result.into();
323
324 let tweak_bytes: [u8; CHAIN_CODE_LEN] = hmac_bytes[..CHAIN_CODE_LEN]
325 .try_into()
326 .expect("Half of hmac is guaranteed to be 32 bytes!");
327 let field_bytes: &elliptic_curve::FieldBytes<C> = tweak_bytes
328 .as_slice()
329 .try_into()
330 .expect("tweak_bytes length matches field size");
331 let tweak = <C::Scalar as Reduce<elliptic_curve::FieldBytes<C>>>::reduce(field_bytes);
332 if tweak.to_repr() != *field_bytes {
333 return Err(ErrorDeriv::new(
334 "BIP-32: HMAC left half >= curve order, child index is invalid",
335 ));
336 }
337
338 let generator = <C::AffinePoint as PrimeCurveAffine>::generator();
339 let tweak_point = generator * tweak;
340 let new_verifying_key = (tweak_point + *self.verifying_key()).to_affine();
341
342 let identity = <C::AffinePoint as PrimeCurveAffine>::identity();
343 if new_verifying_key == identity {
344 return Err(ErrorDeriv::new(
345 "Very improbable: Child index results in value not allowed by BIP-32!",
346 ));
347 }
348
349 let new_verifying_shares = self
350 .verifying_shares
351 .iter()
352 .map(|(party, share)| (*party, (tweak_point + *share).to_affine()))
353 .collect();
354
355 Ok(PublicKeyPackage::new(
356 new_verifying_key,
357 new_verifying_shares,
358 self.parameters.clone(),
359 ))
360 }
361}
362
363pub fn parse_path(path: &str) -> Result<Vec<u32>, ErrorDeriv> {
370 let mut parts = path.split('/');
371
372 if parts.next().unwrap_or_default() != "m" {
373 return Err(ErrorDeriv::new("Invalid path format!"));
374 }
375
376 let mut path_parsed = Vec::new();
377
378 for part in parts {
379 match part.parse::<u32>() {
380 Ok(num) if num <= MAX_CHILD_NUMBER => path_parsed.push(num),
381 _ => {
382 return Err(ErrorDeriv::new(
383 "Invalid path format or index out of bounds!",
384 ))
385 }
386 }
387 }
388
389 if path_parsed.len() > MAX_DEPTH as usize {
390 return Err(ErrorDeriv::new("The path is too long!"));
391 }
392
393 Ok(path_parsed)
394}
395
396#[cfg(test)]
397mod tests {
398
399 use super::*;
400
401 type TestCurve = k256::Secp256k1;
402
403 use crate::protocols::re_key::re_key;
404 use crate::protocols::signing::*;
405 use crate::protocols::{Parameters, PartyIndex};
406
407 use crate::utilities::hashes::*;
408
409 use crate::utilities::rng;
410 use elliptic_curve::CurveArithmetic;
411 use hex;
412 use k256::elliptic_curve::ops::Reduce;
413 use k256::elliptic_curve::Field;
414 use k256::{AffinePoint, Scalar, U256};
415 use rand::RngExt;
416 use std::collections::BTreeMap;
417
418 fn no_address(_pk: &<TestCurve as CurveArithmetic>::AffinePoint) -> String {
419 String::new()
420 }
421
422 #[test]
428 fn test_derivation() {
429 let sk = Scalar::reduce(&U256::from_be_hex(
432 "6728f18f7163f7a0c11cc0ad53140afb4e345d760f966176865a860041549903",
433 ));
434 let pk = (AffinePoint::GENERATOR * sk).to_affine();
435 let chain_code: ChainCode =
436 hex::decode("6f990adb9337033001af2487a8617f68586c4ea17433492bbf1659f6e4cf9564")
437 .unwrap()
438 .try_into()
439 .unwrap();
440
441 let data = DerivData::<TestCurve> {
442 depth: 0,
443 child_number: 0,
444 parent_fingerprint: [0u8; 4],
445 poly_point: sk,
446 pk,
447 chain_code,
448 };
449
450 let path = "m/0/1/2/3";
452 let try_derive = data.derive_from_path(path);
453
454 match try_derive {
455 Err(error) => {
456 panic!("Error: {:?}", error.description);
457 }
458 Ok(child) => {
459 assert_eq!(child.depth, 4);
460 assert_eq!(child.child_number, 3);
461 assert_eq!(hex::encode(child.parent_fingerprint), "9502bb8b");
462 assert_eq!(
463 hex::encode(scalar_to_bytes::<TestCurve>(&child.poly_point)),
464 "bdebf4ed48fae0b5b3ed6671496f7e1d741996dbb30d79f990933892c8ed316a"
465 );
466 assert_eq!(
467 hex::encode(point_to_bytes::<TestCurve>(&child.pk)),
468 "037c892dca96d4c940aafb3a1e65f470e43fba57b3146efeb312c2a39a208fffaa"
469 );
470 assert_eq!(
471 hex::encode(child.chain_code),
472 "c6536c2f5c232aa7613652831b7a3b21e97f4baa3114a3837de3764759f5b2aa"
473 );
474 }
475 }
476 }
477
478 #[test]
481 fn test_derivation_and_signing() {
482 let threshold = rng::get_rng().random_range(2..=5); let offset = rng::get_rng().random_range(0..=5);
484
485 let parameters = Parameters {
486 threshold,
487 share_count: threshold + offset,
488 }; let session_id = rng::get_rng().random::<[u8; crate::utilities::ID_LEN]>();
492 let secret_key = Scalar::random(&mut rng::get_rng());
493 let (parties, _) =
494 re_key::<TestCurve>(¶meters, &session_id, &secret_key, None, no_address);
495
496 let path = "m/0/1/2/3";
499
500 let mut derived_parties: Vec<Party<TestCurve>> =
501 Vec::with_capacity(parameters.share_count as usize);
502 for i in 0..parameters.share_count {
503 let result = parties[i as usize].derive_from_path(path, no_address);
504 match result {
505 Err(error) => {
506 panic!("Error for Party {}: {:?}", i, error.description);
507 }
508 Ok(party) => {
509 derived_parties.push(party);
510 }
511 }
512 }
513
514 let parties = derived_parties;
515
516 let sign_id = rng::get_rng().random::<[u8; crate::utilities::ID_LEN]>();
519 let message_to_sign = tagged_hash(b"test-sign", &[b"Message to sign!"]);
520
521 let executing_parties: Vec<PartyIndex> = (1..=parameters.threshold)
523 .map(|i| PartyIndex::new(i).unwrap())
524 .collect();
525
526 let mut all_data: BTreeMap<PartyIndex, SignData> = BTreeMap::new();
528 for party_index in executing_parties.clone() {
529 let mut counterparties = executing_parties.clone();
531 counterparties.retain(|index| *index != party_index);
532
533 all_data.insert(
534 party_index,
535 SignData {
536 sign_id: sign_id.to_vec(),
537 counterparties,
538 message_hash: message_to_sign,
539 },
540 );
541 }
542
543 let mut unique_kept_1to2: BTreeMap<PartyIndex, UniqueKeep1to2<TestCurve>> = BTreeMap::new();
545 let mut kept_1to2: BTreeMap<PartyIndex, BTreeMap<PartyIndex, KeepPhase1to2<TestCurve>>> =
546 BTreeMap::new();
547 let mut transmit_1to2: BTreeMap<PartyIndex, Vec<TransmitPhase1to2>> = BTreeMap::new();
548 for party_index in executing_parties.clone() {
549 let (unique_keep, keep, transmit) = parties[(party_index.as_u8() - 1) as usize]
550 .sign_phase1(all_data.get(&party_index).unwrap())
551 .unwrap();
552
553 unique_kept_1to2.insert(party_index, unique_keep);
554 kept_1to2.insert(party_index, keep);
555 transmit_1to2.insert(party_index, transmit);
556 }
557
558 let mut received_1to2: BTreeMap<PartyIndex, Vec<TransmitPhase1to2>> = BTreeMap::new();
560
561 for &party_index in &executing_parties {
563 let new_row: Vec<TransmitPhase1to2> = transmit_1to2
564 .iter()
565 .flat_map(|(_, messages)| {
566 messages
567 .iter()
568 .filter(|message| message.parties.receiver == party_index)
569 .cloned()
570 })
571 .collect();
572
573 received_1to2.insert(party_index, new_row);
574 }
575
576 let mut unique_kept_2to3: BTreeMap<PartyIndex, UniqueKeep2to3<TestCurve>> = BTreeMap::new();
578 let mut kept_2to3: BTreeMap<PartyIndex, BTreeMap<PartyIndex, KeepPhase2to3<TestCurve>>> =
579 BTreeMap::new();
580 let mut transmit_2to3: BTreeMap<PartyIndex, Vec<TransmitPhase2to3<TestCurve>>> =
581 BTreeMap::new();
582 for party_index in executing_parties.clone() {
583 let result = parties[(party_index.as_u8() - 1) as usize].sign_phase2(
584 all_data.get(&party_index).unwrap(),
585 unique_kept_1to2.get(&party_index).unwrap(),
586 kept_1to2.get(&party_index).unwrap(),
587 received_1to2.get(&party_index).unwrap(),
588 );
589 match result {
590 Err(abort) => {
591 panic!("Party {} aborted: {:?}", abort.index, abort.description());
592 }
593 Ok((unique_keep, keep, transmit)) => {
594 unique_kept_2to3.insert(party_index, unique_keep);
595 kept_2to3.insert(party_index, keep);
596 transmit_2to3.insert(party_index, transmit);
597 }
598 }
599 }
600
601 let mut received_2to3: BTreeMap<PartyIndex, Vec<TransmitPhase2to3<TestCurve>>> =
603 BTreeMap::new();
604
605 for &party_index in &executing_parties {
607 let filtered_messages: Vec<TransmitPhase2to3<TestCurve>> = transmit_2to3
608 .iter()
609 .flat_map(|(_, messages)| {
610 messages
611 .iter()
612 .filter(|message| message.parties.receiver == party_index)
613 })
614 .cloned()
615 .collect();
616
617 received_2to3.insert(party_index, filtered_messages);
618 }
619
620 let mut x_coords: Vec<String> = Vec::with_capacity(parameters.threshold as usize);
622 let mut broadcast_3to4: Vec<Broadcast3to4<TestCurve>> =
623 Vec::with_capacity(parameters.threshold as usize);
624 for party_index in executing_parties.clone() {
625 let result = parties[(party_index.as_u8() - 1) as usize].sign_phase3(
626 all_data.get(&party_index).unwrap(),
627 unique_kept_2to3.get(&party_index).unwrap(),
628 kept_2to3.get(&party_index).unwrap(),
629 received_2to3.get(&party_index).unwrap(),
630 );
631 match result {
632 Err(abort) => {
633 panic!("Party {} aborted: {:?}", abort.index, abort.description());
634 }
635 Ok((x_coord, broadcast)) => {
636 x_coords.push(x_coord);
637 broadcast_3to4.push(broadcast);
638 }
639 }
640 }
641
642 let x_coord = x_coords[0].clone(); for i in 1..parameters.threshold {
645 assert_eq!(x_coord, x_coords[i as usize]);
646 }
647
648 let some_index = executing_parties[0];
653 let result = parties[(some_index.as_u8() - 1) as usize].sign_phase4(
654 all_data.get(&some_index).unwrap(),
655 &x_coord,
656 &broadcast_3to4,
657 true,
658 );
659 if let Err(abort) = result {
660 panic!("Party {} aborted: {:?}", abort.index, abort.description());
661 }
662 }
663
664 #[test]
667 fn test_public_key_package_derive_child() {
668 let parameters = Parameters {
669 threshold: 2,
670 share_count: 3,
671 };
672 let session_id = rng::get_rng().random::<[u8; crate::utilities::ID_LEN]>();
673 let secret_key = Scalar::random(&mut rng::get_rng());
674 let (parties, pkg) =
675 re_key::<TestCurve>(¶meters, &session_id, &secret_key, None, no_address);
676
677 let chain_code = parties[0].derivation_data.chain_code;
678 const TEST_CHILD_NUMBER: u32 = 42;
679 let child_number = TEST_CHILD_NUMBER;
680
681 let derived_pkg = pkg.derive_child(&chain_code, child_number).unwrap();
682
683 for party in &parties {
685 let derived_party = party.derive_child(child_number, no_address).unwrap();
686 assert_eq!(*derived_pkg.verifying_key(), derived_party.pk);
687
688 let expected_share = (AffinePoint::GENERATOR * derived_party.poly_point).to_affine();
689 assert!(derived_pkg.verify_share(party.party_index, &expected_share));
690 }
691 }
692}