#![allow(unused_assignments)]
use hkdf::Hkdf;
use sha2::Sha256;
use thiserror::Error;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const MAX_DOMAIN_LENGTH: usize = 255;
pub const MAX_TENANT_SALT_LENGTH: usize = 1024;
#[derive(Error, Debug)]
pub enum KeyDerivationError {
#[error("Invalid master key length: expected at least 16 bytes, got {0}")]
InvalidMasterKeyLength(usize),
#[error("Invalid domain string: {0}")]
InvalidDomain(String),
#[error("Invalid salt length: expected at least 1 byte, got {0}")]
InvalidSaltLength(usize),
#[error("Key derivation failed: {0}")]
DerivationFailed(String),
#[error("Domain exceeds maximum length")]
DomainTooLong,
#[error("Tenant salt exceeds maximum length")]
TenantSaltTooLong,
}
pub fn derive_domain_key(
master_key: &[u8],
domain: &str,
tenant_salt: &[u8],
) -> Result<[u8; 32], KeyDerivationError> {
if master_key.len() < 16 {
return Err(KeyDerivationError::InvalidMasterKeyLength(master_key.len()));
}
if domain.is_empty() {
return Err(KeyDerivationError::InvalidDomain(
"Domain cannot be empty".into(),
));
}
let domain_bytes = domain.as_bytes();
if domain_bytes.len() > MAX_DOMAIN_LENGTH {
return Err(KeyDerivationError::DomainTooLong);
}
if tenant_salt.is_empty() {
return Err(KeyDerivationError::InvalidSaltLength(tenant_salt.len()));
}
if tenant_salt.len() > MAX_TENANT_SALT_LENGTH {
return Err(KeyDerivationError::TenantSaltTooLong);
}
let mut salt_data = Vec::with_capacity(12 + 1 + domain_bytes.len() + 2 + tenant_salt.len());
salt_data.extend_from_slice(b"cachekit_v1_"); salt_data.push(domain_bytes.len() as u8); salt_data.extend_from_slice(domain_bytes);
salt_data.extend_from_slice(&(tenant_salt.len() as u16).to_be_bytes()); salt_data.extend_from_slice(tenant_salt);
let hkdf = Hkdf::<Sha256>::new(Some(&salt_data), master_key);
let mut key = [0u8; 32];
hkdf.expand(domain.as_bytes(), &mut key)
.map_err(|_| KeyDerivationError::DerivationFailed("HKDF expand failed".into()))?;
Ok(key)
}
pub fn key_fingerprint(key: &[u8]) -> [u8; 16] {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
hasher.update(b"key_fingerprint_v1");
hasher.update(key);
let hash = hasher.finalize();
let mut result = [0u8; 16];
result.copy_from_slice(&hash[..16]);
result
}
pub fn derive_tenant_keys(
master_key: &[u8],
tenant_id: &str,
) -> Result<TenantKeys, KeyDerivationError> {
let tenant_salt = tenant_id.as_bytes();
let encryption_key = derive_domain_key(master_key, "encryption", tenant_salt)?;
let authentication_key = derive_domain_key(master_key, "authentication", tenant_salt)?;
let cache_key_salt = derive_domain_key(master_key, "cache_keys", tenant_salt)?;
Ok(TenantKeys {
encryption_key,
authentication_key,
cache_key_salt,
tenant_id: tenant_id.to_string(),
})
}
#[allow(unused_assignments)]
#[derive(Debug, Zeroize, ZeroizeOnDrop)]
pub struct TenantKeys {
pub encryption_key: [u8; 32],
pub authentication_key: [u8; 32],
pub cache_key_salt: [u8; 32],
#[zeroize(skip)]
#[allow(unused_assignments)] pub tenant_id: String,
}
impl TenantKeys {
pub fn encryption_fingerprint(&self) -> [u8; 16] {
key_fingerprint(&self.encryption_key)
}
pub fn authentication_fingerprint(&self) -> [u8; 16] {
key_fingerprint(&self.authentication_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_domain_key_deterministic() {
let master_key = b"test_master_key_32_bytes_long!!!";
let domain = "encryption";
let tenant_salt = b"tenant123";
let key1 = derive_domain_key(master_key, domain, tenant_salt).unwrap();
let key2 = derive_domain_key(master_key, domain, tenant_salt).unwrap();
assert_eq!(key1, key2, "Same inputs should produce same key");
}
#[test]
fn test_domain_separation() {
let master_key = b"test_master_key_32_bytes_long!!!";
let tenant_salt = b"tenant123";
let enc_key = derive_domain_key(master_key, "encryption", tenant_salt).unwrap();
let auth_key = derive_domain_key(master_key, "authentication", tenant_salt).unwrap();
let cache_key = derive_domain_key(master_key, "cache_keys", tenant_salt).unwrap();
assert_ne!(enc_key, auth_key);
assert_ne!(enc_key, cache_key);
assert_ne!(auth_key, cache_key);
}
#[test]
fn test_tenant_separation() {
let master_key = b"test_master_key_32_bytes_long!!!";
let domain = "encryption";
let key1 = derive_domain_key(master_key, domain, b"tenant1").unwrap();
let key2 = derive_domain_key(master_key, domain, b"tenant2").unwrap();
assert_ne!(key1, key2, "Different tenants should get different keys");
}
#[test]
fn test_master_key_sensitivity() {
let master_key1 = b"test_master_key_32_bytes_long!!!";
let master_key2 = b"different_master_key_32_bytes!!!";
let domain = "encryption";
let tenant_salt = b"tenant123";
let key1 = derive_domain_key(master_key1, domain, tenant_salt).unwrap();
let key2 = derive_domain_key(master_key2, domain, tenant_salt).unwrap();
assert_ne!(
key1, key2,
"Different master keys should produce different derived keys"
);
}
#[test]
fn test_invalid_inputs() {
let short_key = b"short";
let master_key = b"test_master_key_32_bytes_long!!!";
let result = derive_domain_key(short_key, "encryption", b"tenant");
assert!(matches!(
result,
Err(KeyDerivationError::InvalidMasterKeyLength(5))
));
let result = derive_domain_key(master_key, "", b"tenant");
assert!(matches!(result, Err(KeyDerivationError::InvalidDomain(_))));
let result = derive_domain_key(master_key, "encryption", b"");
assert!(matches!(
result,
Err(KeyDerivationError::InvalidSaltLength(0))
));
}
#[test]
fn test_tenant_keys_derivation() {
let master_key = b"test_master_key_32_bytes_long!!!";
let tenant_id = "test_tenant_123";
let keys = derive_tenant_keys(master_key, tenant_id).unwrap();
assert_ne!(keys.encryption_key, keys.authentication_key);
assert_ne!(keys.encryption_key, keys.cache_key_salt);
assert_ne!(keys.authentication_key, keys.cache_key_salt);
assert_eq!(keys.tenant_id, tenant_id);
let fp1 = keys.encryption_fingerprint();
let fp2 = keys.encryption_fingerprint();
assert_eq!(fp1, fp2, "Fingerprints should be deterministic");
}
#[test]
fn test_key_fingerprint_uniqueness() {
let key1 = b"test_key_1_with_32_bytes_exactly!";
let key2 = b"test_key_2_with_32_bytes_exactly!";
let fp1 = key_fingerprint(key1);
let fp2 = key_fingerprint(key2);
assert_ne!(
fp1, fp2,
"Different keys should have different fingerprints"
);
let fp1_again = key_fingerprint(key1);
assert_eq!(fp1, fp1_again, "Fingerprints should be deterministic");
}
#[test]
fn test_hkdf_salt_byte_vector() {
let mut expected_salt = Vec::new();
expected_salt.extend_from_slice(b"cachekit_v1_"); expected_salt.push(5); expected_salt.extend_from_slice(b"cache"); expected_salt.extend_from_slice(&10u16.to_be_bytes()); expected_salt.extend_from_slice(b"tenant-123");
assert_eq!(expected_salt.len(), 30);
assert_eq!(&expected_salt[0..12], b"cachekit_v1_");
assert_eq!(expected_salt[12], 5);
assert_eq!(&expected_salt[13..18], b"cache");
assert_eq!(&expected_salt[18..20], &[0x00, 0x0a]);
assert_eq!(&expected_salt[20..30], b"tenant-123");
let master_key = b"test_master_key_32_bytes_long!!!";
let key1 = derive_domain_key(master_key, "cache", b"tenant-123").unwrap();
let key2 = derive_domain_key(master_key, "cache", b"tenant-123").unwrap();
assert_eq!(key1, key2, "Same inputs should produce same derived key");
}
#[test]
fn test_hkdf_salt_collision_resistance() {
let master_key = b"test_master_key_32_bytes_long!!!";
let key1 = derive_domain_key(master_key, "foo", b"bar").unwrap();
let key2 = derive_domain_key(master_key, "foob", b"ar").unwrap();
assert_ne!(
key1, key2,
"Different (domain, salt) pairs must produce different keys"
);
let key3 = derive_domain_key(master_key, "ab", b"cd").unwrap();
let key4 = derive_domain_key(master_key, "a", b"bcd").unwrap();
assert_ne!(key3, key4);
}
#[test]
fn test_domain_and_salt_length_limits() {
let master_key = b"test_master_key_32_bytes_long!!!";
let max_domain = "a".repeat(MAX_DOMAIN_LENGTH);
let result = derive_domain_key(master_key, &max_domain, b"salt");
assert!(result.is_ok());
let oversized_domain = "a".repeat(MAX_DOMAIN_LENGTH + 1);
let result = derive_domain_key(master_key, &oversized_domain, b"salt");
assert!(matches!(result, Err(KeyDerivationError::DomainTooLong)));
let max_salt = vec![0u8; MAX_TENANT_SALT_LENGTH];
let result = derive_domain_key(master_key, "domain", &max_salt);
assert!(result.is_ok());
let oversized_salt = vec![0u8; MAX_TENANT_SALT_LENGTH + 1];
let result = derive_domain_key(master_key, "domain", &oversized_salt);
assert!(matches!(result, Err(KeyDerivationError::TenantSaltTooLong)));
}
#[test]
fn test_hkdf_output_sensitivity() {
let master_key = b"test_master_key_32_bytes_long!!!";
let key1 = derive_domain_key(master_key, "encryption", b"tenant_abc").unwrap();
assert_ne!(key1, [0u8; 32]);
assert_ne!(key1, [0xffu8; 32]);
assert_eq!(key1.len(), 32);
}
}