#![allow(unused_assignments)]
use aws_lc_rs::hkdf;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const HOSTKEY_VERSION: &str = "v1";
const HOSTKEY_SALT: &[u8] = b"antq:hostkey:v1";
const ENDPOINT_ENCRYPT_INFO: &[u8] = b"antq:endpoint-encrypt:v1";
const CACHE_KEY_INFO: &[u8] = b"antq:cache-key:v1";
const DERIVED_KEY_SIZE: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EndpointKeyPolicy {
#[default]
PerNetwork,
Shared,
}
#[derive(ZeroizeOnDrop)]
pub struct HostIdentity {
#[zeroize(skip)]
prk: hkdf::Prk,
#[zeroize(skip)]
policy: EndpointKeyPolicy,
}
impl HostIdentity {
pub fn from_secret(mut secret: [u8; 32]) -> Self {
let salt = hkdf::Salt::new(hkdf::HKDF_SHA256, HOSTKEY_SALT);
let prk = salt.extract(&secret);
secret.zeroize();
Self {
prk,
policy: EndpointKeyPolicy::default(),
}
}
pub fn from_secret_with_policy(secret: [u8; 32], policy: EndpointKeyPolicy) -> Self {
let mut identity = Self::from_secret(secret);
identity.policy = policy;
identity
}
pub fn generate() -> Self {
use rand::RngCore;
let mut secret = [0u8; 32];
rand::thread_rng().fill_bytes(&mut secret);
Self::from_secret(secret)
}
pub fn generate_with_policy(policy: EndpointKeyPolicy) -> Self {
let mut identity = Self::generate();
identity.policy = policy;
identity
}
pub fn policy(&self) -> EndpointKeyPolicy {
self.policy
}
pub fn set_policy(&mut self, policy: EndpointKeyPolicy) {
self.policy = policy;
}
#[allow(clippy::expect_used)] pub fn derive_endpoint_encryption_key(&self, network_id: &[u8]) -> [u8; DERIVED_KEY_SIZE] {
let effective_network_id = match self.policy {
EndpointKeyPolicy::PerNetwork => network_id,
EndpointKeyPolicy::Shared => b"antq:shared-identity",
};
let mut base_key = [0u8; DERIVED_KEY_SIZE];
let okm = self
.prk
.expand(&[ENDPOINT_ENCRYPT_INFO], hkdf::HKDF_SHA256)
.expect("HKDF expand should succeed with valid parameters");
okm.fill(&mut base_key)
.expect("OKM fill should succeed for 32 bytes");
let salt = hkdf::Salt::new(hkdf::HKDF_SHA256, effective_network_id);
let prk = salt.extract(&base_key);
let mut key = [0u8; DERIVED_KEY_SIZE];
let okm = prk
.expand(&[b"antq:endpoint-key:v1"], hkdf::HKDF_SHA256)
.expect("HKDF expand should succeed");
okm.fill(&mut key).expect("OKM fill should succeed");
key
}
#[allow(clippy::expect_used)] pub fn derive_cache_key(&self) -> [u8; DERIVED_KEY_SIZE] {
let mut key = [0u8; DERIVED_KEY_SIZE];
let okm = self
.prk
.expand(&[CACHE_KEY_INFO], hkdf::HKDF_SHA256)
.expect("HKDF expand should succeed");
okm.fill(&mut key).expect("OKM fill should succeed");
key
}
#[allow(clippy::expect_used)] pub fn fingerprint(&self) -> String {
let mut full_bytes = [0u8; 32];
let okm = self
.prk
.expand(&[b"antq:fingerprint:v1"], hkdf::HKDF_SHA256)
.expect("HKDF expand should succeed");
okm.fill(&mut full_bytes).expect("OKM fill should succeed");
hex::encode(&full_bytes[..8])
}
}
impl std::fmt::Debug for HostIdentity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostIdentity")
.field("fingerprint", &self.fingerprint())
.field("policy", &self.policy)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_host_identity_from_secret() {
let secret = [42u8; 32];
let host = HostIdentity::from_secret(secret);
assert_eq!(host.policy(), EndpointKeyPolicy::PerNetwork);
let fingerprint1 = host.fingerprint();
let host2 = HostIdentity::from_secret([42u8; 32]);
let fingerprint2 = host2.fingerprint();
assert_eq!(fingerprint1, fingerprint2);
}
#[test]
fn test_host_identity_generate() {
let host1 = HostIdentity::generate();
let host2 = HostIdentity::generate();
assert_ne!(host1.fingerprint(), host2.fingerprint());
}
#[test]
fn test_derive_endpoint_encryption_key_deterministic() {
let secret = [1u8; 32];
let host = HostIdentity::from_secret(secret);
let key1 = host.derive_endpoint_encryption_key(b"network-1");
let key2 = host.derive_endpoint_encryption_key(b"network-1");
assert_eq!(key1, key2);
}
#[test]
fn test_derive_endpoint_encryption_key_per_network_isolation() {
let secret = [1u8; 32];
let host = HostIdentity::from_secret(secret);
let key1 = host.derive_endpoint_encryption_key(b"network-1");
let key2 = host.derive_endpoint_encryption_key(b"network-2");
assert_ne!(key1, key2);
}
#[test]
fn test_derive_endpoint_encryption_key_shared_policy() {
let secret = [1u8; 32];
let mut host = HostIdentity::from_secret(secret);
host.set_policy(EndpointKeyPolicy::Shared);
let key1 = host.derive_endpoint_encryption_key(b"network-1");
let key2 = host.derive_endpoint_encryption_key(b"network-2");
assert_eq!(key1, key2);
}
#[test]
fn test_derive_cache_key() {
let secret = [1u8; 32];
let host = HostIdentity::from_secret(secret);
let key1 = host.derive_cache_key();
let key2 = host.derive_cache_key();
assert_eq!(key1, key2);
assert_eq!(key1.len(), 32);
}
#[test]
fn test_cache_key_differs_from_endpoint_key() {
let secret = [1u8; 32];
let host = HostIdentity::from_secret(secret);
let cache_key = host.derive_cache_key();
let endpoint_key = host.derive_endpoint_encryption_key(b"test-network");
assert_ne!(cache_key, endpoint_key);
}
#[test]
fn test_fingerprint_safe_for_display() {
let host = HostIdentity::generate();
let fingerprint = host.fingerprint();
assert_eq!(fingerprint.len(), 16);
assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_different_secrets_different_keys() {
let host1 = HostIdentity::from_secret([1u8; 32]);
let host2 = HostIdentity::from_secret([2u8; 32]);
let key1 = host1.derive_endpoint_encryption_key(b"network");
let key2 = host2.derive_endpoint_encryption_key(b"network");
assert_ne!(key1, key2);
assert_ne!(host1.derive_cache_key(), host2.derive_cache_key());
}
}