use crate::error::{Error, Result};
use crate::primitives::hash::sha256_hmac;
use crate::primitives::{PrivateKey, PublicKey, SymmetricKey};
use super::types::{validate_key_id, validate_protocol_name, Counterparty, Protocol};
#[cfg(test)]
use super::types::SecurityLevel;
#[derive(Clone)]
pub struct KeyDeriver {
root_key: PrivateKey,
}
impl KeyDeriver {
pub fn new(root_key: Option<PrivateKey>) -> Self {
let root_key = root_key.unwrap_or_else(|| Self::anyone_key().0);
Self { root_key }
}
pub fn anyone_key() -> (PrivateKey, PublicKey) {
let mut key_bytes = [0u8; 32];
key_bytes[31] = 1; let private_key = PrivateKey::from_bytes(&key_bytes).expect("valid key");
let public_key = private_key.public_key();
(private_key, public_key)
}
pub fn root_key(&self) -> &PrivateKey {
&self.root_key
}
pub fn identity_key(&self) -> PublicKey {
self.root_key.public_key()
}
pub fn identity_key_hex(&self) -> String {
self.identity_key().to_hex()
}
pub fn derive_public_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
for_self: bool,
) -> Result<PublicKey> {
let counterparty_key = self.normalize_counterparty(counterparty)?;
let invoice_number = self.compute_invoice_number(protocol, key_id)?;
if for_self {
let derived = self
.root_key
.derive_child(&counterparty_key, &invoice_number)?;
Ok(derived.public_key())
} else {
counterparty_key.derive_child(&self.root_key, &invoice_number)
}
}
pub fn derive_private_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey> {
let counterparty_key = self.normalize_counterparty(counterparty)?;
let invoice_number = self.compute_invoice_number(protocol, key_id)?;
self.root_key
.derive_child(&counterparty_key, &invoice_number)
}
pub fn derive_private_key_raw(
&self,
invoice_number: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey> {
let counterparty_key = self.normalize_counterparty(counterparty)?;
self.root_key
.derive_child(&counterparty_key, invoice_number)
}
pub fn derive_symmetric_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<SymmetricKey> {
let actual_counterparty = match counterparty {
Counterparty::Anyone => {
let (_, anyone_pub) = Self::anyone_key();
Counterparty::Other(anyone_pub)
}
other => other.clone(),
};
let derived_public =
self.derive_public_key(protocol, key_id, &actual_counterparty, false)?;
let derived_private = self.derive_private_key(protocol, key_id, &actual_counterparty)?;
let shared_secret = derived_private.derive_shared_secret(&derived_public)?;
let x_bytes = shared_secret.x();
SymmetricKey::from_bytes(&x_bytes)
}
pub fn reveal_specific_secret(
&self,
counterparty: &Counterparty,
protocol: &Protocol,
key_id: &str,
) -> Result<Vec<u8>> {
let counterparty_key = self.normalize_counterparty(counterparty)?;
let shared_secret = self.root_key.derive_shared_secret(&counterparty_key)?;
let invoice_number = self.compute_invoice_number(protocol, key_id)?;
Ok(sha256_hmac(&shared_secret.to_compressed(), invoice_number.as_bytes()).to_vec())
}
pub fn reveal_counterparty_secret(&self, counterparty: &Counterparty) -> Result<PublicKey> {
if matches!(counterparty, Counterparty::Self_) {
return Err(Error::InvalidCounterparty(
"counterparty secrets cannot be revealed for 'self'".to_string(),
));
}
let counterparty_key = self.normalize_counterparty(counterparty)?;
let self_pub = self.root_key.public_key();
if counterparty_key == self_pub {
return Err(Error::InvalidCounterparty(
"counterparty secrets cannot be revealed if counterparty key is self".to_string(),
));
}
self.root_key.derive_shared_secret(&counterparty_key)
}
fn normalize_counterparty(&self, counterparty: &Counterparty) -> Result<PublicKey> {
match counterparty {
Counterparty::Self_ => Ok(self.root_key.public_key()),
Counterparty::Anyone => Ok(Self::anyone_key().1),
Counterparty::Other(pubkey) => Ok(pubkey.clone()),
}
}
fn compute_invoice_number(&self, protocol: &Protocol, key_id: &str) -> Result<String> {
let level = protocol.security_level.as_u8();
if level > 2 {
return Err(Error::ProtocolValidationError(
"security level must be 0, 1, or 2".to_string(),
));
}
validate_key_id(key_id)?;
let protocol_name = validate_protocol_name(&protocol.protocol_name)?;
Ok(format!("{}-{}-{}", level, protocol_name, key_id))
}
}
impl std::fmt::Debug for KeyDeriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KeyDeriver")
.field("identity_key", &self.identity_key_hex())
.finish_non_exhaustive()
}
}
pub trait KeyDeriverApi {
fn identity_key(&self) -> PublicKey;
fn identity_key_hex(&self) -> String {
self.identity_key().to_hex()
}
fn derive_public_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
for_self: bool,
) -> Result<PublicKey>;
fn derive_private_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey>;
fn derive_private_key_raw(
&self,
invoice_number: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey>;
fn derive_symmetric_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<SymmetricKey>;
fn reveal_specific_secret(
&self,
counterparty: &Counterparty,
protocol: &Protocol,
key_id: &str,
) -> Result<Vec<u8>>;
fn reveal_counterparty_secret(&self, counterparty: &Counterparty) -> Result<PublicKey>;
}
impl KeyDeriverApi for KeyDeriver {
fn identity_key(&self) -> PublicKey {
self.identity_key()
}
fn derive_public_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
for_self: bool,
) -> Result<PublicKey> {
self.derive_public_key(protocol, key_id, counterparty, for_self)
}
fn derive_private_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey> {
self.derive_private_key(protocol, key_id, counterparty)
}
fn derive_private_key_raw(
&self,
invoice_number: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey> {
self.derive_private_key_raw(invoice_number, counterparty)
}
fn derive_symmetric_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<SymmetricKey> {
self.derive_symmetric_key(protocol, key_id, counterparty)
}
fn reveal_specific_secret(
&self,
counterparty: &Counterparty,
protocol: &Protocol,
key_id: &str,
) -> Result<Vec<u8>> {
self.reveal_specific_secret(counterparty, protocol, key_id)
}
fn reveal_counterparty_secret(&self, counterparty: &Counterparty) -> Result<PublicKey> {
self.reveal_counterparty_secret(counterparty)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_deriver_creation() {
let key = PrivateKey::random();
let deriver = KeyDeriver::new(Some(key.clone()));
assert_eq!(deriver.identity_key(), key.public_key());
}
#[test]
fn test_key_deriver_anyone() {
let deriver = KeyDeriver::new(None);
let (_anyone_priv, anyone_pub) = KeyDeriver::anyone_key();
assert_eq!(deriver.identity_key(), anyone_pub);
}
#[test]
fn test_anyone_key_is_deterministic() {
let (priv1, pub1) = KeyDeriver::anyone_key();
let (priv2, pub2) = KeyDeriver::anyone_key();
assert_eq!(priv1.to_bytes(), priv2.to_bytes());
assert_eq!(pub1.to_compressed(), pub2.to_compressed());
}
#[test]
fn test_derive_public_key_for_self() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "test application");
let key_id = "invoice-123";
let pub_key = deriver
.derive_public_key(&protocol, key_id, &Counterparty::Self_, true)
.unwrap();
let pub_key2 = deriver
.derive_public_key(&protocol, key_id, &Counterparty::Self_, true)
.unwrap();
assert_eq!(pub_key.to_compressed(), pub_key2.to_compressed());
}
#[test]
fn test_derive_private_key_matches_public() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "test application");
let key_id = "invoice-123";
let priv_key = deriver
.derive_private_key(&protocol, key_id, &Counterparty::Self_)
.unwrap();
let pub_key = deriver
.derive_public_key(&protocol, key_id, &Counterparty::Self_, true)
.unwrap();
assert_eq!(
priv_key.public_key().to_compressed(),
pub_key.to_compressed()
);
}
#[test]
fn test_two_party_derivation() {
let alice = KeyDeriver::new(Some(PrivateKey::random()));
let bob = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "test application");
let key_id = "payment-456";
let alice_counterparty = Counterparty::Other(alice.identity_key());
let bob_priv = bob
.derive_private_key(&protocol, key_id, &alice_counterparty)
.unwrap();
let bob_pub_from_alice = bob
.derive_public_key(&protocol, key_id, &alice_counterparty, true)
.unwrap();
assert_eq!(
bob_priv.public_key().to_compressed(),
bob_pub_from_alice.to_compressed()
);
}
#[test]
fn test_derive_symmetric_key() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "encryption test");
let key_id = "message-789";
let sym_key = deriver
.derive_symmetric_key(&protocol, key_id, &Counterparty::Self_)
.unwrap();
let sym_key2 = deriver
.derive_symmetric_key(&protocol, key_id, &Counterparty::Self_)
.unwrap();
assert_eq!(sym_key.as_bytes(), sym_key2.as_bytes());
}
#[test]
fn test_reveal_specific_secret() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "test application");
let key_id = "secret-123";
let secret = deriver
.reveal_specific_secret(&Counterparty::Self_, &protocol, key_id)
.unwrap();
assert_eq!(secret.len(), 32);
let secret2 = deriver
.reveal_specific_secret(&Counterparty::Self_, &protocol, key_id)
.unwrap();
assert_eq!(secret, secret2);
}
#[test]
fn test_reveal_counterparty_secret_fails_for_self() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let result = deriver.reveal_counterparty_secret(&Counterparty::Self_);
assert!(result.is_err());
}
#[test]
fn test_reveal_counterparty_secret_succeeds_for_other() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let other = PrivateKey::random().public_key();
let result = deriver.reveal_counterparty_secret(&Counterparty::Other(other));
assert!(result.is_ok());
}
#[test]
fn test_invalid_protocol_name() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "bad"); let result = deriver.derive_private_key(&protocol, "key-1", &Counterparty::Self_);
assert!(result.is_err());
}
#[test]
fn test_invalid_key_id() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let protocol = Protocol::new(SecurityLevel::App, "test application");
let result = deriver.derive_private_key(&protocol, "", &Counterparty::Self_);
assert!(result.is_err());
let long_key = "a".repeat(801);
let result = deriver.derive_private_key(&protocol, &long_key, &Counterparty::Self_);
assert!(result.is_err());
}
#[test]
fn test_different_security_levels_produce_different_keys() {
let deriver = KeyDeriver::new(Some(PrivateKey::random()));
let key_id = "test-key";
let proto0 = Protocol::new(SecurityLevel::Silent, "test application");
let proto1 = Protocol::new(SecurityLevel::App, "test application");
let proto2 = Protocol::new(SecurityLevel::Counterparty, "test application");
let key0 = deriver
.derive_public_key(&proto0, key_id, &Counterparty::Self_, true)
.unwrap();
let key1 = deriver
.derive_public_key(&proto1, key_id, &Counterparty::Self_, true)
.unwrap();
let key2 = deriver
.derive_public_key(&proto2, key_id, &Counterparty::Self_, true)
.unwrap();
assert_ne!(key0.to_compressed(), key1.to_compressed());
assert_ne!(key1.to_compressed(), key2.to_compressed());
assert_ne!(key0.to_compressed(), key2.to_compressed());
}
}