1use crate::{EncryptionKey, EncryptionNonce, decrypt, encrypt, generate_key, generate_nonce, hash};
10use crate::{KeyPair, PublicKey, SecretKey, SigningError};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use thiserror::Error;
15
16#[derive(Debug, Error)]
18pub enum RotationError {
19 #[error("Key not found: {0}")]
20 KeyNotFound(String),
21
22 #[error("Key expired: version {0}")]
23 KeyExpired(u32),
24
25 #[error("Key revoked: version {0}")]
26 KeyRevoked(u32),
27
28 #[error("Encryption error")]
29 EncryptionError,
30
31 #[error("Decryption error")]
32 DecryptionError,
33
34 #[error("Invalid key format")]
35 InvalidKeyFormat,
36
37 #[error("Signing error: {0}")]
38 SigningError(#[from] SigningError),
39
40 #[error("Serialization error: {0}")]
41 SerializationError(String),
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct KeyVersion {
47 pub version: u32,
49 pub created_at: u64,
51 pub expires_at: Option<u64>,
53 pub revoked: bool,
55 pub revoked_at: Option<u64>,
57 pub revocation_reason: Option<String>,
59 pub fingerprint: String,
61}
62
63impl KeyVersion {
64 pub fn new(version: u32, fingerprint: String, ttl: Option<Duration>) -> Self {
66 let now = SystemTime::now()
67 .duration_since(UNIX_EPOCH)
68 .unwrap()
69 .as_secs();
70
71 Self {
72 version,
73 created_at: now,
74 expires_at: ttl.map(|d| now + d.as_secs()),
75 revoked: false,
76 revoked_at: None,
77 revocation_reason: None,
78 fingerprint,
79 }
80 }
81
82 pub fn is_valid(&self) -> bool {
84 if self.revoked {
85 return false;
86 }
87
88 if let Some(expires_at) = self.expires_at {
89 let now = SystemTime::now()
90 .duration_since(UNIX_EPOCH)
91 .unwrap()
92 .as_secs();
93 if now > expires_at {
94 return false;
95 }
96 }
97
98 true
99 }
100
101 pub fn is_expired(&self) -> bool {
103 if let Some(expires_at) = self.expires_at {
104 let now = SystemTime::now()
105 .duration_since(UNIX_EPOCH)
106 .unwrap()
107 .as_secs();
108 return now > expires_at;
109 }
110 false
111 }
112
113 pub fn revoke(&mut self, reason: Option<String>) {
115 self.revoked = true;
116 self.revoked_at = Some(
117 SystemTime::now()
118 .duration_since(UNIX_EPOCH)
119 .unwrap()
120 .as_secs(),
121 );
122 self.revocation_reason = reason;
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct EncryptedKey {
129 pub ciphertext: Vec<u8>,
131 pub nonce: [u8; 12],
133 pub version: u32,
135 pub salt: Option<Vec<u8>>,
137}
138
139impl EncryptedKey {
140 pub fn encrypt_secret_key(
142 secret_key: &SecretKey,
143 master_key: &EncryptionKey,
144 ) -> Result<Self, RotationError> {
145 let nonce = generate_nonce();
146 let ciphertext =
147 encrypt(secret_key, master_key, &nonce).map_err(|_| RotationError::EncryptionError)?;
148
149 Ok(Self {
150 ciphertext,
151 nonce,
152 version: 0,
153 salt: None,
154 })
155 }
156
157 pub fn decrypt_secret_key(
159 &self,
160 master_key: &EncryptionKey,
161 ) -> Result<SecretKey, RotationError> {
162 let decrypted = decrypt(&self.ciphertext, master_key, &self.nonce)
163 .map_err(|_| RotationError::DecryptionError)?;
164
165 if decrypted.len() != 32 {
166 return Err(RotationError::InvalidKeyFormat);
167 }
168
169 let mut key = [0u8; 32];
170 key.copy_from_slice(&decrypted);
171 Ok(key)
172 }
173
174 pub fn encrypt_encryption_key(
176 key: &EncryptionKey,
177 master_key: &EncryptionKey,
178 ) -> Result<Self, RotationError> {
179 let nonce = generate_nonce();
180 let ciphertext =
181 encrypt(key, master_key, &nonce).map_err(|_| RotationError::EncryptionError)?;
182
183 Ok(Self {
184 ciphertext,
185 nonce,
186 version: 0,
187 salt: None,
188 })
189 }
190
191 pub fn decrypt_encryption_key(
193 &self,
194 master_key: &EncryptionKey,
195 ) -> Result<EncryptionKey, RotationError> {
196 let decrypted = decrypt(&self.ciphertext, master_key, &self.nonce)
197 .map_err(|_| RotationError::DecryptionError)?;
198
199 if decrypted.len() != 32 {
200 return Err(RotationError::InvalidKeyFormat);
201 }
202
203 let mut key = [0u8; 32];
204 key.copy_from_slice(&decrypted);
205 Ok(key)
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct RotationPolicy {
212 pub max_age: Duration,
214 pub retention_count: usize,
216 pub auto_rotate: bool,
218}
219
220impl Default for RotationPolicy {
221 fn default() -> Self {
222 Self {
223 max_age: Duration::from_secs(30 * 24 * 3600), retention_count: 3,
225 auto_rotate: true,
226 }
227 }
228}
229
230pub struct SigningKeyRing {
232 current_version: u32,
234 versions: HashMap<u32, KeyVersion>,
236 encrypted_keys: HashMap<u32, EncryptedKey>,
238 public_keys: HashMap<u32, PublicKey>,
240 master_key: EncryptionKey,
242 policy: RotationPolicy,
244}
245
246impl SigningKeyRing {
247 pub fn new(master_key: EncryptionKey, policy: RotationPolicy) -> Self {
249 Self {
250 current_version: 0,
251 versions: HashMap::new(),
252 encrypted_keys: HashMap::new(),
253 public_keys: HashMap::new(),
254 master_key,
255 policy,
256 }
257 }
258
259 pub fn add_key(
261 &mut self,
262 key_pair: &KeyPair,
263 ttl: Option<Duration>,
264 ) -> Result<u32, RotationError> {
265 let version = self.current_version + 1;
266 let public_key = key_pair.public_key();
267 let secret_key = key_pair.secret_key();
268
269 let fingerprint = hex::encode(&hash(&public_key)[..16]);
271
272 let key_version = KeyVersion::new(version, fingerprint, ttl);
274
275 let encrypted = EncryptedKey::encrypt_secret_key(&secret_key, &self.master_key)?;
277
278 self.versions.insert(version, key_version);
279 self.encrypted_keys.insert(version, encrypted);
280 self.public_keys.insert(version, public_key);
281 self.current_version = version;
282
283 self.cleanup_old_keys();
285
286 Ok(version)
287 }
288
289 pub fn generate_key(
291 &mut self,
292 ttl: Option<Duration>,
293 ) -> Result<(u32, PublicKey), RotationError> {
294 let key_pair = KeyPair::generate();
295 let public_key = key_pair.public_key();
296 let version = self.add_key(&key_pair, ttl)?;
297 Ok((version, public_key))
298 }
299
300 pub fn current_version(&self) -> u32 {
302 self.current_version
303 }
304
305 pub fn get_version(&self, version: u32) -> Option<&KeyVersion> {
307 self.versions.get(&version)
308 }
309
310 pub fn get_public_key(&self, version: u32) -> Option<&PublicKey> {
312 self.public_keys.get(&version)
313 }
314
315 pub fn current_public_key(&self) -> Option<&PublicKey> {
317 self.public_keys.get(&self.current_version)
318 }
319
320 pub fn get_key_pair(&self, version: u32) -> Result<KeyPair, RotationError> {
322 let version_meta = self
323 .versions
324 .get(&version)
325 .ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
326
327 if version_meta.revoked {
328 return Err(RotationError::KeyRevoked(version));
329 }
330
331 if version_meta.is_expired() {
332 return Err(RotationError::KeyExpired(version));
333 }
334
335 let encrypted = self
336 .encrypted_keys
337 .get(&version)
338 .ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
339
340 let secret_key = encrypted.decrypt_secret_key(&self.master_key)?;
341 KeyPair::from_secret_key(&secret_key).map_err(RotationError::from)
342 }
343
344 pub fn current_key_pair(&self) -> Result<KeyPair, RotationError> {
346 self.get_key_pair(self.current_version)
347 }
348
349 pub fn revoke_key(
351 &mut self,
352 version: u32,
353 reason: Option<String>,
354 ) -> Result<(), RotationError> {
355 let version_meta = self
356 .versions
357 .get_mut(&version)
358 .ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
359
360 version_meta.revoke(reason);
361 Ok(())
362 }
363
364 pub fn needs_rotation(&self) -> bool {
366 if let Some(version) = self.versions.get(&self.current_version) {
367 let now = SystemTime::now()
368 .duration_since(UNIX_EPOCH)
369 .unwrap()
370 .as_secs();
371 let age = now.saturating_sub(version.created_at);
372 age > self.policy.max_age.as_secs() || version.revoked || version.is_expired()
373 } else {
374 true
375 }
376 }
377
378 pub fn rotate_if_needed(&mut self) -> Result<Option<u32>, RotationError> {
380 if self.needs_rotation() && self.policy.auto_rotate {
381 let (version, _) = self.generate_key(Some(self.policy.max_age))?;
382 Ok(Some(version))
383 } else {
384 Ok(None)
385 }
386 }
387
388 pub fn list_versions(&self) -> Vec<&KeyVersion> {
390 let mut versions: Vec<_> = self.versions.values().collect();
391 versions.sort_by_key(|v| v.version);
392 versions
393 }
394
395 pub fn valid_versions(&self) -> Vec<u32> {
397 self.versions
398 .iter()
399 .filter(|(_, v)| v.is_valid())
400 .map(|(k, _)| *k)
401 .collect()
402 }
403
404 fn cleanup_old_keys(&mut self) {
406 let mut versions: Vec<_> = self.versions.keys().copied().collect();
407 versions.sort();
408
409 let to_remove = versions
411 .len()
412 .saturating_sub(self.policy.retention_count + 1);
413 for version in versions.into_iter().take(to_remove) {
414 if version != self.current_version {
416 self.versions.remove(&version);
417 self.encrypted_keys.remove(&version);
418 self.public_keys.remove(&version);
419 }
420 }
421 }
422}
423
424pub struct EncryptionKeyRing {
426 current_version: u32,
428 versions: HashMap<u32, KeyVersion>,
430 encrypted_keys: HashMap<u32, EncryptedKey>,
432 master_key: EncryptionKey,
434 policy: RotationPolicy,
436}
437
438impl EncryptionKeyRing {
439 pub fn new(master_key: EncryptionKey, policy: RotationPolicy) -> Self {
441 Self {
442 current_version: 0,
443 versions: HashMap::new(),
444 encrypted_keys: HashMap::new(),
445 master_key,
446 policy,
447 }
448 }
449
450 pub fn add_key(
452 &mut self,
453 key: &EncryptionKey,
454 ttl: Option<Duration>,
455 ) -> Result<u32, RotationError> {
456 let version = self.current_version + 1;
457
458 let fingerprint = hex::encode(&hash(key)[..16]);
460
461 let key_version = KeyVersion::new(version, fingerprint, ttl);
463
464 let encrypted = EncryptedKey::encrypt_encryption_key(key, &self.master_key)?;
466
467 self.versions.insert(version, key_version);
468 self.encrypted_keys.insert(version, encrypted);
469 self.current_version = version;
470
471 self.cleanup_old_keys();
473
474 Ok(version)
475 }
476
477 pub fn generate_key(&mut self, ttl: Option<Duration>) -> Result<u32, RotationError> {
479 let key = generate_key();
480 self.add_key(&key, ttl)
481 }
482
483 pub fn current_version(&self) -> u32 {
485 self.current_version
486 }
487
488 pub fn get_key(&self, version: u32) -> Result<EncryptionKey, RotationError> {
490 let version_meta = self
491 .versions
492 .get(&version)
493 .ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
494
495 if version_meta.revoked {
496 return Err(RotationError::KeyRevoked(version));
497 }
498
499 let encrypted = self
502 .encrypted_keys
503 .get(&version)
504 .ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
505
506 encrypted.decrypt_encryption_key(&self.master_key)
507 }
508
509 pub fn current_key(&self) -> Result<EncryptionKey, RotationError> {
511 let version_meta = self.versions.get(&self.current_version).ok_or_else(|| {
512 RotationError::KeyNotFound(format!("version {}", self.current_version))
513 })?;
514
515 if !version_meta.is_valid() {
517 if version_meta.is_expired() {
518 return Err(RotationError::KeyExpired(self.current_version));
519 }
520 if version_meta.revoked {
521 return Err(RotationError::KeyRevoked(self.current_version));
522 }
523 }
524
525 self.get_key(self.current_version)
526 }
527
528 pub fn revoke_key(
530 &mut self,
531 version: u32,
532 reason: Option<String>,
533 ) -> Result<(), RotationError> {
534 let version_meta = self
535 .versions
536 .get_mut(&version)
537 .ok_or_else(|| RotationError::KeyNotFound(format!("version {}", version)))?;
538
539 version_meta.revoke(reason);
540 Ok(())
541 }
542
543 pub fn needs_rotation(&self) -> bool {
545 if let Some(version) = self.versions.get(&self.current_version) {
546 let now = SystemTime::now()
547 .duration_since(UNIX_EPOCH)
548 .unwrap()
549 .as_secs();
550 let age = now.saturating_sub(version.created_at);
551 age > self.policy.max_age.as_secs() || version.revoked || version.is_expired()
552 } else {
553 true
554 }
555 }
556
557 pub fn rotate_if_needed(&mut self) -> Result<Option<u32>, RotationError> {
559 if self.needs_rotation() && self.policy.auto_rotate {
560 let version = self.generate_key(Some(self.policy.max_age))?;
561 Ok(Some(version))
562 } else {
563 Ok(None)
564 }
565 }
566
567 pub fn list_versions(&self) -> Vec<&KeyVersion> {
569 let mut versions: Vec<_> = self.versions.values().collect();
570 versions.sort_by_key(|v| v.version);
571 versions
572 }
573
574 fn cleanup_old_keys(&mut self) {
576 let mut versions: Vec<_> = self.versions.keys().copied().collect();
577 versions.sort();
578
579 let to_remove = versions
580 .len()
581 .saturating_sub(self.policy.retention_count + 1);
582 for version in versions.into_iter().take(to_remove) {
583 if version != self.current_version {
584 self.versions.remove(&version);
585 self.encrypted_keys.remove(&version);
586 }
587 }
588 }
589}
590
591pub struct ReEncryptor<'a> {
593 old_key: EncryptionKey,
595 new_key: EncryptionKey,
597 old_nonce: &'a EncryptionNonce,
599}
600
601impl<'a> ReEncryptor<'a> {
602 pub fn new(
604 old_key: EncryptionKey,
605 new_key: EncryptionKey,
606 old_nonce: &'a EncryptionNonce,
607 ) -> Self {
608 Self {
609 old_key,
610 new_key,
611 old_nonce,
612 }
613 }
614
615 pub fn re_encrypt(
617 &self,
618 ciphertext: &[u8],
619 ) -> Result<(Vec<u8>, EncryptionNonce), RotationError> {
620 let plaintext = decrypt(ciphertext, &self.old_key, self.old_nonce)
622 .map_err(|_| RotationError::DecryptionError)?;
623
624 let new_nonce = generate_nonce();
626 let new_ciphertext = encrypt(&plaintext, &self.new_key, &new_nonce)
627 .map_err(|_| RotationError::EncryptionError)?;
628
629 Ok((new_ciphertext, new_nonce))
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn test_key_version_validity() {
639 let version = KeyVersion::new(1, "abc123".to_string(), Some(Duration::from_secs(3600)));
640 assert!(version.is_valid());
641 assert!(!version.is_expired());
642 assert!(!version.revoked);
643 }
644
645 #[test]
646 fn test_key_revocation() {
647 let mut version = KeyVersion::new(1, "abc123".to_string(), None);
648 assert!(version.is_valid());
649
650 version.revoke(Some("Compromised".to_string()));
651 assert!(!version.is_valid());
652 assert!(version.revoked);
653 assert!(version.revoked_at.is_some());
654 }
655
656 #[test]
657 fn test_encrypted_key() {
658 let master_key = generate_key();
659 let secret_key: SecretKey = [1u8; 32];
660
661 let encrypted = EncryptedKey::encrypt_secret_key(&secret_key, &master_key).unwrap();
662 let decrypted = encrypted.decrypt_secret_key(&master_key).unwrap();
663
664 assert_eq!(secret_key, decrypted);
665 }
666
667 #[test]
668 fn test_signing_key_ring() {
669 let master_key = generate_key();
670 let policy = RotationPolicy::default();
671 let mut ring = SigningKeyRing::new(master_key, policy);
672
673 let (v1, pk1) = ring.generate_key(None).unwrap();
675 assert_eq!(v1, 1);
676 assert_eq!(ring.current_version(), 1);
677
678 let (v2, pk2) = ring.generate_key(None).unwrap();
680 assert_eq!(v2, 2);
681 assert_ne!(pk1, pk2);
682
683 let kp1 = ring.get_key_pair(1).unwrap();
685 assert_eq!(kp1.public_key(), pk1);
686
687 let kp2 = ring.current_key_pair().unwrap();
688 assert_eq!(kp2.public_key(), pk2);
689 }
690
691 #[test]
692 fn test_encryption_key_ring() {
693 let master_key = generate_key();
694 let policy = RotationPolicy::default();
695 let mut ring = EncryptionKeyRing::new(master_key, policy);
696
697 let v1 = ring.generate_key(None).unwrap();
699 assert_eq!(v1, 1);
700
701 let key1 = ring.get_key(1).unwrap();
702 let current = ring.current_key().unwrap();
703 assert_eq!(key1, current);
704
705 let v2 = ring.generate_key(None).unwrap();
707 assert_eq!(v2, 2);
708
709 let key2 = ring.current_key().unwrap();
710 assert_ne!(key1, key2);
711 }
712
713 #[test]
714 fn test_re_encryption() {
715 let old_key = generate_key();
716 let new_key = generate_key();
717 let old_nonce = generate_nonce();
718
719 let plaintext = b"Secret data for re-encryption";
720 let ciphertext = encrypt(plaintext, &old_key, &old_nonce).unwrap();
721
722 let re_encryptor = ReEncryptor::new(old_key, new_key, &old_nonce);
723 let (new_ciphertext, new_nonce) = re_encryptor.re_encrypt(&ciphertext).unwrap();
724
725 let decrypted = decrypt(&new_ciphertext, &new_key, &new_nonce).unwrap();
727 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
728 }
729}