use curve25519_dalek::constants;
use curve25519_dalek::scalar::Scalar;
use super::*;
use crate::context::{SigningTranscript};
pub const CHAIN_CODE_LENGTH: usize = 32;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct ChainCode(pub [u8; CHAIN_CODE_LENGTH]);
pub trait Derivation: Sized {
fn derived_key<T>(&self, t: T, cc: ChainCode) -> (Self, ChainCode)
where
T: SigningTranscript;
fn derived_key_simple<B: AsRef<[u8]>>(&self, cc: ChainCode, i: B) -> (Self, ChainCode) {
let mut t = merlin::Transcript::new(b"SchnorrRistrettoHDKD");
t.append_message(b"sign-bytes", i.as_ref());
self.derived_key(t, cc)
}
fn derived_key_simple_rng<B, R>(&self, cc: ChainCode, i: B, rng: R) -> (Self, ChainCode)
where
B: AsRef<[u8]>,
R: RngCore + CryptoRng,
{
let mut t = merlin::Transcript::new(b"SchnorrRistrettoHDKD");
t.append_message(b"sign-bytes", i.as_ref());
self.derived_key(super::context::attach_rng(t, rng), cc)
}
}
impl PublicKey {
fn derive_scalar_and_chaincode<T>(&self, t: &mut T, cc: ChainCode) -> (Scalar, ChainCode)
where
T: SigningTranscript,
{
t.commit_bytes(b"chain-code", &cc.0);
t.commit_point(b"public-key", self.as_compressed());
let scalar = t.challenge_scalar(b"HDKD-scalar");
let mut chaincode = [0u8; 32];
t.challenge_bytes(b"HDKD-chaincode", &mut chaincode);
(scalar, ChainCode(chaincode))
}
}
impl SecretKey {
pub fn hard_derive_mini_secret_key<B: AsRef<[u8]>>(
&self,
cc: Option<ChainCode>,
i: B,
) -> (MiniSecretKey, ChainCode) {
let mut t = merlin::Transcript::new(b"SchnorrRistrettoHDKD");
t.append_message(b"sign-bytes", i.as_ref());
if let Some(c) = cc {
t.append_message(b"chain-code", &c.0);
}
t.append_message(b"secret-key", &self.key.to_bytes() as &[u8]);
let mut msk = [0u8; MINI_SECRET_KEY_LENGTH];
t.challenge_bytes(b"HDKD-hard", &mut msk);
let mut chaincode = [0u8; 32];
t.challenge_bytes(b"HDKD-chaincode", &mut chaincode);
(MiniSecretKey(msk), ChainCode(chaincode))
}
}
impl MiniSecretKey {
pub fn hard_derive_mini_secret_key<B: AsRef<[u8]>>(
&self,
cc: Option<ChainCode>,
i: B,
mode: ExpansionMode,
) -> (MiniSecretKey, ChainCode) {
self.expand(mode).hard_derive_mini_secret_key(cc, i)
}
}
impl Keypair {
pub fn hard_derive_mini_secret_key<B: AsRef<[u8]>>(
&self,
cc: Option<ChainCode>,
i: B,
) -> (MiniSecretKey, ChainCode) {
self.secret.hard_derive_mini_secret_key(cc, i)
}
pub fn derive_secret_key<T>(&self, mut t: T, cc: ChainCode) -> (SecretKey, ChainCode)
where
T: SigningTranscript,
{
let (scalar, chaincode) = self.public.derive_scalar_and_chaincode(&mut t, cc);
let mut nonce = [0u8; 32];
t.witness_bytes(
b"HDKD-nonce",
&mut nonce,
&[&self.secret.nonce, &self.secret.to_bytes() as &[u8]],
);
(SecretKey { key: self.secret.key + scalar, nonce }, chaincode)
}
}
impl Derivation for Keypair {
fn derived_key<T>(&self, t: T, cc: ChainCode) -> (Keypair, ChainCode)
where
T: SigningTranscript,
{
let (secret, chaincode) = self.derive_secret_key(t, cc);
let public = secret.to_public();
(Keypair { secret, public }, chaincode)
}
}
impl Derivation for SecretKey {
fn derived_key<T>(&self, t: T, cc: ChainCode) -> (SecretKey, ChainCode)
where
T: SigningTranscript,
{
self.clone().to_keypair().derive_secret_key(t, cc)
}
}
impl Derivation for PublicKey {
fn derived_key<T>(&self, mut t: T, cc: ChainCode) -> (PublicKey, ChainCode)
where
T: SigningTranscript,
{
let (scalar, chaincode) = self.derive_scalar_and_chaincode(&mut t, cc);
let point = self.as_point() + (&scalar * constants::RISTRETTO_BASEPOINT_TABLE);
(PublicKey::from_point(point), chaincode)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct ExtendedKey<K> {
pub key: K,
pub chaincode: ChainCode,
}
impl<K: Derivation> ExtendedKey<K> {
pub fn derived_key<T>(&self, t: T) -> ExtendedKey<K>
where
T: SigningTranscript,
{
let (key, chaincode) = self.key.derived_key(t, self.chaincode);
ExtendedKey { key, chaincode }
}
pub fn derived_key_simple<B: AsRef<[u8]>>(&self, i: B) -> ExtendedKey<K> {
let (key, chaincode) = self.key.derived_key_simple(self.chaincode, i);
ExtendedKey { key, chaincode }
}
}
impl ExtendedKey<SecretKey> {
pub fn hard_derive_mini_secret_key<B: AsRef<[u8]>>(
&self,
i: B,
mode: ExpansionMode,
) -> ExtendedKey<SecretKey> {
let (key, chaincode) = self.key.hard_derive_mini_secret_key(Some(self.chaincode), i);
let key = key.expand(mode);
ExtendedKey { key, chaincode }
}
}
#[cfg(test)]
mod tests {
use sha3::digest::{Update}; use sha3::{Shake128};
use super::*;
#[cfg(feature = "getrandom")]
#[test]
fn derive_key_public_vs_private_paths() {
let chaincode = ChainCode([0u8; CHAIN_CODE_LENGTH]);
let msg: &'static [u8] = b"Just some test message!";
let mut h = Shake128::default().chain(msg);
let mut csprng = rand_core::OsRng;
let key = Keypair::generate_with(&mut csprng);
let mut extended_public_key = ExtendedKey { key: key.public.clone(), chaincode };
let mut extended_keypair = ExtendedKey { key, chaincode };
let ctx = signing_context(b"testing testing 1 2 3");
for i in 0..30 {
let extended_keypair1 = extended_keypair.derived_key_simple(msg);
let extended_public_key1 = extended_public_key.derived_key_simple(msg);
assert_eq!(
extended_keypair1.chaincode, extended_public_key1.chaincode,
"Chain code derivation failed!"
);
assert_eq!(
extended_keypair1.key.public, extended_public_key1.key,
"Public and secret key derivation missmatch!"
);
extended_keypair = extended_keypair1;
extended_public_key = extended_public_key1;
h.update(b"Another");
if i % 5 == 0 {
let good_sig = extended_keypair.key.sign(ctx.xof(h.clone()));
let h_bad = h.clone().chain(b"oops");
let bad_sig = extended_keypair.key.sign(ctx.xof(h_bad.clone()));
assert!(
extended_public_key.key.verify(ctx.xof(h.clone()), &good_sig).is_ok(),
"Verification of a valid signature failed!"
);
assert!(
!extended_public_key.key.verify(ctx.xof(h.clone()), &bad_sig).is_ok(),
"Verification of a signature on a different message passed!"
);
assert!(
!extended_public_key.key.verify(ctx.xof(h_bad), &good_sig).is_ok(),
"Verification of a signature on a different message passed!"
);
}
}
}
}