use crate::{EncryptionKey, EncryptionNonce, decrypt, encrypt, generate_key, generate_nonce, hash};
use crate::{KeyPair, PublicKey, SecretKey, SigningError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RotationError {
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("Key expired: version {0}")]
KeyExpired(u32),
#[error("Key revoked: version {0}")]
KeyRevoked(u32),
#[error("Encryption error")]
EncryptionError,
#[error("Decryption error")]
DecryptionError,
#[error("Invalid key format")]
InvalidKeyFormat,
#[error("Signing error: {0}")]
SigningError(#[from] SigningError),
#[error("Serialization error: {0}")]
SerializationError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyVersion {
pub version: u32,
pub created_at: u64,
pub expires_at: Option<u64>,
pub revoked: bool,
pub revoked_at: Option<u64>,
pub revocation_reason: Option<String>,
pub fingerprint: String,
}
impl KeyVersion {
pub fn new(version: u32, fingerprint: String, ttl: Option<Duration>) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
version,
created_at: now,
expires_at: ttl.map(|d| now + d.as_secs()),
revoked: false,
revoked_at: None,
revocation_reason: None,
fingerprint,
}
}
pub fn is_valid(&self) -> bool {
if self.revoked {
return false;
}
if let Some(expires_at) = self.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if now > expires_at {
return false;
}
}
true
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
return now > expires_at;
}
false
}
pub fn revoke(&mut self, reason: Option<String>) {
self.revoked = true;
self.revoked_at = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
);
self.revocation_reason = reason;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedKey {
pub ciphertext: Vec<u8>,
pub nonce: [u8; 12],
pub version: u32,
pub salt: Option<Vec<u8>>,
}
impl EncryptedKey {
pub fn encrypt_secret_key(
secret_key: &SecretKey,
master_key: &EncryptionKey,
) -> Result<Self, RotationError> {
let nonce = generate_nonce();
let ciphertext =
encrypt(secret_key, master_key, &nonce).map_err(|_| RotationError::EncryptionError)?;
Ok(Self {
ciphertext,
nonce,
version: 0,
salt: None,
})
}
pub fn decrypt_secret_key(
&self,
master_key: &EncryptionKey,
) -> Result<SecretKey, RotationError> {
let decrypted = decrypt(&self.ciphertext, master_key, &self.nonce)
.map_err(|_| RotationError::DecryptionError)?;
if decrypted.len() != 32 {
return Err(RotationError::InvalidKeyFormat);
}
let mut key = [0u8; 32];
key.copy_from_slice(&decrypted);
Ok(key)
}
pub fn encrypt_encryption_key(
key: &EncryptionKey,
master_key: &EncryptionKey,
) -> Result<Self, RotationError> {
let nonce = generate_nonce();
let ciphertext =
encrypt(key, master_key, &nonce).map_err(|_| RotationError::EncryptionError)?;
Ok(Self {
ciphertext,
nonce,
version: 0,
salt: None,
})
}
pub fn decrypt_encryption_key(
&self,
master_key: &EncryptionKey,
) -> Result<EncryptionKey, RotationError> {
let decrypted = decrypt(&self.ciphertext, master_key, &self.nonce)
.map_err(|_| RotationError::DecryptionError)?;
if decrypted.len() != 32 {
return Err(RotationError::InvalidKeyFormat);
}
let mut key = [0u8; 32];
key.copy_from_slice(&decrypted);
Ok(key)
}
}
#[derive(Debug, Clone)]
pub struct RotationPolicy {
pub max_age: Duration,
pub retention_count: usize,
pub auto_rotate: bool,
}
impl Default for RotationPolicy {
fn default() -> Self {
Self {
max_age: Duration::from_secs(30 * 24 * 3600), retention_count: 3,
auto_rotate: true,
}
}
}
pub struct SigningKeyRing {
current_version: u32,
versions: HashMap<u32, KeyVersion>,
encrypted_keys: HashMap<u32, EncryptedKey>,
public_keys: HashMap<u32, PublicKey>,
master_key: EncryptionKey,
policy: RotationPolicy,
}
impl SigningKeyRing {
pub fn new(master_key: EncryptionKey, policy: RotationPolicy) -> Self {
Self {
current_version: 0,
versions: HashMap::new(),
encrypted_keys: HashMap::new(),
public_keys: HashMap::new(),
master_key,
policy,
}
}
pub fn add_key(
&mut self,
key_pair: &KeyPair,
ttl: Option<Duration>,
) -> Result<u32, RotationError> {
let version = self.current_version + 1;
let public_key = key_pair.public_key();
let secret_key = key_pair.secret_key();
let fingerprint = hex::encode(&hash(&public_key)[..16]);
let key_version = KeyVersion::new(version, fingerprint, ttl);
let encrypted = EncryptedKey::encrypt_secret_key(&secret_key, &self.master_key)?;
self.versions.insert(version, key_version);
self.encrypted_keys.insert(version, encrypted);
self.public_keys.insert(version, public_key);
self.current_version = version;
self.cleanup_old_keys();
Ok(version)
}
pub fn generate_key(
&mut self,
ttl: Option<Duration>,
) -> Result<(u32, PublicKey), RotationError> {
let key_pair = KeyPair::generate();
let public_key = key_pair.public_key();
let version = self.add_key(&key_pair, ttl)?;
Ok((version, public_key))
}
pub fn current_version(&self) -> u32 {
self.current_version
}
pub fn get_version(&self, version: u32) -> Option<&KeyVersion> {
self.versions.get(&version)
}
pub fn get_public_key(&self, version: u32) -> Option<&PublicKey> {
self.public_keys.get(&version)
}
pub fn current_public_key(&self) -> Option<&PublicKey> {
self.public_keys.get(&self.current_version)
}
pub fn get_key_pair(&self, version: u32) -> Result<KeyPair, RotationError> {
let version_meta = self
.versions
.get(&version)
.ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
if version_meta.revoked {
return Err(RotationError::KeyRevoked(version));
}
if version_meta.is_expired() {
return Err(RotationError::KeyExpired(version));
}
let encrypted = self
.encrypted_keys
.get(&version)
.ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
let secret_key = encrypted.decrypt_secret_key(&self.master_key)?;
KeyPair::from_secret_key(&secret_key).map_err(RotationError::from)
}
pub fn current_key_pair(&self) -> Result<KeyPair, RotationError> {
self.get_key_pair(self.current_version)
}
pub fn revoke_key(
&mut self,
version: u32,
reason: Option<String>,
) -> Result<(), RotationError> {
let version_meta = self
.versions
.get_mut(&version)
.ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
version_meta.revoke(reason);
Ok(())
}
pub fn needs_rotation(&self) -> bool {
if let Some(version) = self.versions.get(&self.current_version) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let age = now.saturating_sub(version.created_at);
age > self.policy.max_age.as_secs() || version.revoked || version.is_expired()
} else {
true
}
}
pub fn rotate_if_needed(&mut self) -> Result<Option<u32>, RotationError> {
if self.needs_rotation() && self.policy.auto_rotate {
let (version, _) = self.generate_key(Some(self.policy.max_age))?;
Ok(Some(version))
} else {
Ok(None)
}
}
pub fn list_versions(&self) -> Vec<&KeyVersion> {
let mut versions: Vec<_> = self.versions.values().collect();
versions.sort_by_key(|v| v.version);
versions
}
pub fn valid_versions(&self) -> Vec<u32> {
self.versions
.iter()
.filter(|(_, v)| v.is_valid())
.map(|(k, _)| *k)
.collect()
}
fn cleanup_old_keys(&mut self) {
let mut versions: Vec<_> = self.versions.keys().copied().collect();
versions.sort();
let to_remove = versions
.len()
.saturating_sub(self.policy.retention_count + 1);
for version in versions.into_iter().take(to_remove) {
if version != self.current_version {
self.versions.remove(&version);
self.encrypted_keys.remove(&version);
self.public_keys.remove(&version);
}
}
}
}
pub struct EncryptionKeyRing {
current_version: u32,
versions: HashMap<u32, KeyVersion>,
encrypted_keys: HashMap<u32, EncryptedKey>,
master_key: EncryptionKey,
policy: RotationPolicy,
}
impl EncryptionKeyRing {
pub fn new(master_key: EncryptionKey, policy: RotationPolicy) -> Self {
Self {
current_version: 0,
versions: HashMap::new(),
encrypted_keys: HashMap::new(),
master_key,
policy,
}
}
pub fn add_key(
&mut self,
key: &EncryptionKey,
ttl: Option<Duration>,
) -> Result<u32, RotationError> {
let version = self.current_version + 1;
let fingerprint = hex::encode(&hash(key)[..16]);
let key_version = KeyVersion::new(version, fingerprint, ttl);
let encrypted = EncryptedKey::encrypt_encryption_key(key, &self.master_key)?;
self.versions.insert(version, key_version);
self.encrypted_keys.insert(version, encrypted);
self.current_version = version;
self.cleanup_old_keys();
Ok(version)
}
pub fn generate_key(&mut self, ttl: Option<Duration>) -> Result<u32, RotationError> {
let key = generate_key();
self.add_key(&key, ttl)
}
pub fn current_version(&self) -> u32 {
self.current_version
}
pub fn get_key(&self, version: u32) -> Result<EncryptionKey, RotationError> {
let version_meta = self
.versions
.get(&version)
.ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
if version_meta.revoked {
return Err(RotationError::KeyRevoked(version));
}
let encrypted = self
.encrypted_keys
.get(&version)
.ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
encrypted.decrypt_encryption_key(&self.master_key)
}
pub fn current_key(&self) -> Result<EncryptionKey, RotationError> {
let version_meta = self.versions.get(&self.current_version).ok_or_else(|| {
RotationError::KeyNotFound(format!("version {}", self.current_version))
})?;
if !version_meta.is_valid() {
if version_meta.is_expired() {
return Err(RotationError::KeyExpired(self.current_version));
}
if version_meta.revoked {
return Err(RotationError::KeyRevoked(self.current_version));
}
}
self.get_key(self.current_version)
}
pub fn revoke_key(
&mut self,
version: u32,
reason: Option<String>,
) -> Result<(), RotationError> {
let version_meta = self
.versions
.get_mut(&version)
.ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
version_meta.revoke(reason);
Ok(())
}
pub fn needs_rotation(&self) -> bool {
if let Some(version) = self.versions.get(&self.current_version) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let age = now.saturating_sub(version.created_at);
age > self.policy.max_age.as_secs() || version.revoked || version.is_expired()
} else {
true
}
}
pub fn rotate_if_needed(&mut self) -> Result<Option<u32>, RotationError> {
if self.needs_rotation() && self.policy.auto_rotate {
let version = self.generate_key(Some(self.policy.max_age))?;
Ok(Some(version))
} else {
Ok(None)
}
}
pub fn list_versions(&self) -> Vec<&KeyVersion> {
let mut versions: Vec<_> = self.versions.values().collect();
versions.sort_by_key(|v| v.version);
versions
}
fn cleanup_old_keys(&mut self) {
let mut versions: Vec<_> = self.versions.keys().copied().collect();
versions.sort();
let to_remove = versions
.len()
.saturating_sub(self.policy.retention_count + 1);
for version in versions.into_iter().take(to_remove) {
if version != self.current_version {
self.versions.remove(&version);
self.encrypted_keys.remove(&version);
}
}
}
}
pub struct ReEncryptor<'a> {
old_key: EncryptionKey,
new_key: EncryptionKey,
old_nonce: &'a EncryptionNonce,
}
impl<'a> ReEncryptor<'a> {
pub fn new(
old_key: EncryptionKey,
new_key: EncryptionKey,
old_nonce: &'a EncryptionNonce,
) -> Self {
Self {
old_key,
new_key,
old_nonce,
}
}
pub fn re_encrypt(
&self,
ciphertext: &[u8],
) -> Result<(Vec<u8>, EncryptionNonce), RotationError> {
let plaintext = decrypt(ciphertext, &self.old_key, self.old_nonce)
.map_err(|_| RotationError::DecryptionError)?;
let new_nonce = generate_nonce();
let new_ciphertext = encrypt(&plaintext, &self.new_key, &new_nonce)
.map_err(|_| RotationError::EncryptionError)?;
Ok((new_ciphertext, new_nonce))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_version_validity() {
let version = KeyVersion::new(1, "abc123".to_string(), Some(Duration::from_secs(3600)));
assert!(version.is_valid());
assert!(!version.is_expired());
assert!(!version.revoked);
}
#[test]
fn test_key_revocation() {
let mut version = KeyVersion::new(1, "abc123".to_string(), None);
assert!(version.is_valid());
version.revoke(Some("Compromised".to_string()));
assert!(!version.is_valid());
assert!(version.revoked);
assert!(version.revoked_at.is_some());
}
#[test]
fn test_encrypted_key() {
let master_key = generate_key();
let secret_key: SecretKey = [1u8; 32];
let encrypted = EncryptedKey::encrypt_secret_key(&secret_key, &master_key).unwrap();
let decrypted = encrypted.decrypt_secret_key(&master_key).unwrap();
assert_eq!(secret_key, decrypted);
}
#[test]
fn test_signing_key_ring() {
let master_key = generate_key();
let policy = RotationPolicy::default();
let mut ring = SigningKeyRing::new(master_key, policy);
let (v1, pk1) = ring.generate_key(None).unwrap();
assert_eq!(v1, 1);
assert_eq!(ring.current_version(), 1);
let (v2, pk2) = ring.generate_key(None).unwrap();
assert_eq!(v2, 2);
assert_ne!(pk1, pk2);
let kp1 = ring.get_key_pair(1).unwrap();
assert_eq!(kp1.public_key(), pk1);
let kp2 = ring.current_key_pair().unwrap();
assert_eq!(kp2.public_key(), pk2);
}
#[test]
fn test_encryption_key_ring() {
let master_key = generate_key();
let policy = RotationPolicy::default();
let mut ring = EncryptionKeyRing::new(master_key, policy);
let v1 = ring.generate_key(None).unwrap();
assert_eq!(v1, 1);
let key1 = ring.get_key(1).unwrap();
let current = ring.current_key().unwrap();
assert_eq!(key1, current);
let v2 = ring.generate_key(None).unwrap();
assert_eq!(v2, 2);
let key2 = ring.current_key().unwrap();
assert_ne!(key1, key2);
}
#[test]
fn test_re_encryption() {
let old_key = generate_key();
let new_key = generate_key();
let old_nonce = generate_nonce();
let plaintext = b"Secret data for re-encryption";
let ciphertext = encrypt(plaintext, &old_key, &old_nonce).unwrap();
let re_encryptor = ReEncryptor::new(old_key, new_key, &old_nonce);
let (new_ciphertext, new_nonce) = re_encryptor.re_encrypt(&ciphertext).unwrap();
let decrypted = decrypt(&new_ciphertext, &new_key, &new_nonce).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
}