1use crate::error::{AllSourceError, Result};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub enum KmsProvider {
18 AwsKms,
19 GoogleCloudKms,
20 AzureKeyVault,
21 HashicorpVault,
22 Pkcs11,
23 Local, }
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct KmsConfig {
29 pub provider: KmsProvider,
31
32 pub config: HashMap<String, String>,
34
35 pub auto_rotate: bool,
37
38 pub rotation_period_days: u32,
40}
41
42impl Default for KmsConfig {
43 fn default() -> Self {
44 Self {
45 provider: KmsProvider::Local,
46 config: HashMap::new(),
47 auto_rotate: true,
48 rotation_period_days: 90,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct KeyMetadata {
56 pub key_id: String,
58
59 pub alias: String,
61
62 pub purpose: KeyPurpose,
64
65 pub algorithm: KeyAlgorithm,
67
68 pub created_at: chrono::DateTime<chrono::Utc>,
70
71 pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
73
74 pub status: KeyStatus,
76
77 pub version: u32,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub enum KeyPurpose {
83 DataEncryption,
84 JwtSigning,
85 ApiKeySigning,
86 DatabaseEncryption,
87 Custom(String),
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
91pub enum KeyAlgorithm {
92 Aes256Gcm,
93 Aes128Gcm,
94 ChaCha20Poly1305,
95 RsaOaep,
96 EcdsaP256,
97 Ed25519,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
101pub enum KeyStatus {
102 Active,
103 Rotating,
104 Deprecated,
105 Destroyed,
106}
107
108#[async_trait::async_trait]
110pub trait KmsClient: Send + Sync {
111 async fn create_key(
113 &self,
114 alias: String,
115 purpose: KeyPurpose,
116 algorithm: KeyAlgorithm,
117 ) -> Result<KeyMetadata>;
118
119 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
121
122 async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
124
125 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
127
128 async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
130
131 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
133
134 async fn disable_key(&self, key_id: &str) -> Result<()>;
136
137 async fn enable_key(&self, key_id: &str) -> Result<()>;
139
140 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
142}
143
144pub struct LocalKms {
146 keys: Arc<RwLock<HashMap<String, StoredKey>>>,
147 config: KmsConfig,
148}
149
150struct StoredKey {
151 metadata: KeyMetadata,
152 key_material: Vec<u8>,
153}
154
155impl LocalKms {
156 pub fn new(config: KmsConfig) -> Self {
157 Self {
158 keys: Arc::new(RwLock::new(HashMap::new())),
159 config,
160 }
161 }
162}
163
164#[async_trait::async_trait]
165impl KmsClient for LocalKms {
166 async fn create_key(
167 &self,
168 alias: String,
169 purpose: KeyPurpose,
170 algorithm: KeyAlgorithm,
171 ) -> Result<KeyMetadata> {
172 let key_id = uuid::Uuid::new_v4().to_string();
173
174 let key_material = match algorithm {
176 KeyAlgorithm::Aes256Gcm => {
177 let mut key = vec![0u8; 32];
178 use aes_gcm::aead::rand_core::RngCore;
179 use aes_gcm::aead::OsRng;
180 RngCore::fill_bytes(&mut OsRng, &mut key);
181 key
182 }
183 KeyAlgorithm::Aes128Gcm => {
184 let mut key = vec![0u8; 16];
185 use aes_gcm::aead::rand_core::RngCore;
186 use aes_gcm::aead::OsRng;
187 RngCore::fill_bytes(&mut OsRng, &mut key);
188 key
189 }
190 _ => {
191 return Err(AllSourceError::ValidationError(format!(
192 "Algorithm {:?} not supported in local KMS",
193 algorithm
194 )));
195 }
196 };
197
198 let metadata = KeyMetadata {
199 key_id: key_id.clone(),
200 alias,
201 purpose,
202 algorithm,
203 created_at: chrono::Utc::now(),
204 last_rotated: None,
205 status: KeyStatus::Active,
206 version: 1,
207 };
208
209 let stored_key = StoredKey {
210 metadata: metadata.clone(),
211 key_material,
212 };
213
214 self.keys.write().insert(key_id, stored_key);
215
216 Ok(metadata)
217 }
218
219 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
220 self.keys
221 .read()
222 .get(key_id)
223 .map(|k| k.metadata.clone())
224 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))
225 }
226
227 async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
228 Ok(self
229 .keys
230 .read()
231 .values()
232 .map(|k| k.metadata.clone())
233 .collect())
234 }
235
236 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
237 use aes_gcm::aead::Aead;
238 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
239
240 let keys = self.keys.read();
241 let stored_key = keys
242 .get(key_id)
243 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
244
245 if stored_key.metadata.status != KeyStatus::Active {
246 return Err(AllSourceError::ValidationError(
247 "Key is not active".to_string(),
248 ));
249 }
250
251 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
252 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
253
254 use aes_gcm::aead::rand_core::RngCore;
256 use aes_gcm::aead::OsRng;
257 let nonce_bytes = OsRng.next_u64().to_le_bytes();
258 let mut nonce_array = [0u8; 12];
259 nonce_array[..8].copy_from_slice(&nonce_bytes);
260 let nonce = Nonce::from_slice(&nonce_array);
261
262 let ciphertext = cipher
263 .encrypt(nonce, plaintext)
264 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {}", e)))?;
265
266 let mut result = nonce.to_vec();
268 result.extend_from_slice(&ciphertext);
269
270 Ok(result)
271 }
272
273 async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
274 use aes_gcm::aead::Aead;
275 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
276
277 if ciphertext_with_nonce.len() < 12 {
278 return Err(AllSourceError::ValidationError(
279 "Invalid ciphertext".to_string(),
280 ));
281 }
282
283 let keys = self.keys.read();
284 let stored_key = keys
285 .get(key_id)
286 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
287
288 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
289 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
290
291 let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
293 let ciphertext = &ciphertext_with_nonce[12..];
294
295 cipher
296 .decrypt(nonce, ciphertext)
297 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {}", e)))
298 }
299
300 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
301 let mut keys = self.keys.write();
302 let stored_key = keys
303 .get_mut(key_id)
304 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
305
306 let new_key_material = {
308 let mut key = vec![0u8; 32];
309 use aes_gcm::aead::rand_core::RngCore;
310 use aes_gcm::aead::OsRng;
311 RngCore::fill_bytes(&mut OsRng, &mut key);
312 key
313 };
314
315 stored_key.key_material = new_key_material;
316 stored_key.metadata.version += 1;
317 stored_key.metadata.last_rotated = Some(chrono::Utc::now());
318
319 Ok(stored_key.metadata.clone())
320 }
321
322 async fn disable_key(&self, key_id: &str) -> Result<()> {
323 let mut keys = self.keys.write();
324 let stored_key = keys
325 .get_mut(key_id)
326 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
327
328 stored_key.metadata.status = KeyStatus::Deprecated;
329 Ok(())
330 }
331
332 async fn enable_key(&self, key_id: &str) -> Result<()> {
333 let mut keys = self.keys.write();
334 let stored_key = keys
335 .get_mut(key_id)
336 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
337
338 stored_key.metadata.status = KeyStatus::Active;
339 Ok(())
340 }
341
342 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
343 let mut dek = vec![0u8; 32];
345 use aes_gcm::aead::rand_core::RngCore;
346 use aes_gcm::aead::OsRng;
347 RngCore::fill_bytes(&mut OsRng, &mut dek);
348
349 let encrypted_dek = self.encrypt(key_id, &dek).await?;
351
352 Ok((dek, encrypted_dek))
353 }
354}
355
356pub struct KmsManager {
358 client: Arc<dyn KmsClient>,
359 config: KmsConfig,
360}
361
362impl KmsManager {
363 pub fn new(config: KmsConfig) -> Result<Self> {
365 let client: Arc<dyn KmsClient> = match config.provider {
366 KmsProvider::Local => Arc::new(LocalKms::new(config.clone())),
367 _ => {
368 return Err(AllSourceError::ValidationError(format!(
369 "KMS provider {:?} not yet implemented",
370 config.provider
371 )));
372 }
373 };
374
375 Ok(Self { client, config })
376 }
377
378 pub fn client(&self) -> &Arc<dyn KmsClient> {
380 &self.client
381 }
382
383 pub async fn envelope_encrypt(
385 &self,
386 master_key_id: &str,
387 plaintext: &[u8],
388 ) -> Result<EnvelopeEncryptedData> {
389 let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
391
392 use aes_gcm::aead::Aead;
394 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
395
396 let cipher = Aes256Gcm::new_from_slice(&dek)
397 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
398
399 use aes_gcm::aead::rand_core::RngCore;
400 use aes_gcm::aead::OsRng;
401 let nonce_bytes = OsRng.next_u64().to_le_bytes();
402 let mut nonce_array = [0u8; 12];
403 nonce_array[..8].copy_from_slice(&nonce_bytes);
404 let nonce = Nonce::from_slice(&nonce_array);
405
406 let ciphertext = cipher
407 .encrypt(nonce, plaintext)
408 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {}", e)))?;
409
410 Ok(EnvelopeEncryptedData {
411 ciphertext,
412 nonce: nonce.to_vec(),
413 encrypted_dek,
414 master_key_id: master_key_id.to_string(),
415 })
416 }
417
418 pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
420 let dek = self
422 .client
423 .decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek)
424 .await?;
425
426 use aes_gcm::aead::Aead;
428 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
429
430 let cipher = Aes256Gcm::new_from_slice(&dek)
431 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
432
433 let nonce = Nonce::from_slice(&encrypted.nonce);
434
435 cipher
436 .decrypt(nonce, encrypted.ciphertext.as_ref())
437 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {}", e)))
438 }
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct EnvelopeEncryptedData {
444 pub ciphertext: Vec<u8>,
446
447 pub nonce: Vec<u8>,
449
450 pub encrypted_dek: Vec<u8>,
452
453 pub master_key_id: String,
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[tokio::test]
462 async fn test_local_kms_create_key() {
463 let config = KmsConfig::default();
464 let kms = LocalKms::new(config);
465
466 let metadata = kms
467 .create_key(
468 "test-key".to_string(),
469 KeyPurpose::DataEncryption,
470 KeyAlgorithm::Aes256Gcm,
471 )
472 .await
473 .unwrap();
474
475 assert_eq!(metadata.alias, "test-key");
476 assert_eq!(metadata.status, KeyStatus::Active);
477 assert_eq!(metadata.version, 1);
478 }
479
480 #[tokio::test]
481 async fn test_local_kms_encrypt_decrypt() {
482 let config = KmsConfig::default();
483 let kms = LocalKms::new(config);
484
485 let key = kms
486 .create_key(
487 "test-key".to_string(),
488 KeyPurpose::DataEncryption,
489 KeyAlgorithm::Aes256Gcm,
490 )
491 .await
492 .unwrap();
493
494 let plaintext = b"sensitive data";
495 let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
496 let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
497
498 assert_eq!(decrypted, plaintext);
499 }
500
501 #[tokio::test]
502 async fn test_key_rotation() {
503 let config = KmsConfig::default();
504 let kms = LocalKms::new(config);
505
506 let key = kms
507 .create_key(
508 "test-key".to_string(),
509 KeyPurpose::DataEncryption,
510 KeyAlgorithm::Aes256Gcm,
511 )
512 .await
513 .unwrap();
514
515 let rotated = kms.rotate_key(&key.key_id).await.unwrap();
516 assert_eq!(rotated.version, 2);
517 assert!(rotated.last_rotated.is_some());
518 }
519
520 #[tokio::test]
521 async fn test_envelope_encryption() {
522 let config = KmsConfig::default();
523 let manager = KmsManager::new(config).unwrap();
524
525 let master_key = manager
527 .client()
528 .create_key(
529 "master-key".to_string(),
530 KeyPurpose::DataEncryption,
531 KeyAlgorithm::Aes256Gcm,
532 )
533 .await
534 .unwrap();
535
536 let plaintext = b"sensitive data for envelope encryption";
538 let encrypted = manager
539 .envelope_encrypt(&master_key.key_id, plaintext)
540 .await
541 .unwrap();
542
543 let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
545
546 assert_eq!(decrypted, plaintext);
547 }
548
549 #[tokio::test]
550 async fn test_disable_enable_key() {
551 let config = KmsConfig::default();
552 let kms = LocalKms::new(config);
553
554 let key = kms
555 .create_key(
556 "test-key".to_string(),
557 KeyPurpose::DataEncryption,
558 KeyAlgorithm::Aes256Gcm,
559 )
560 .await
561 .unwrap();
562
563 kms.disable_key(&key.key_id).await.unwrap();
565 let metadata = kms.get_key(&key.key_id).await.unwrap();
566 assert_eq!(metadata.status, KeyStatus::Deprecated);
567
568 let result = kms.encrypt(&key.key_id, b"test").await;
570 assert!(result.is_err());
571
572 kms.enable_key(&key.key_id).await.unwrap();
574 let metadata = kms.get_key(&key.key_id).await.unwrap();
575 assert_eq!(metadata.status, KeyStatus::Active);
576
577 let result = kms.encrypt(&key.key_id, b"test").await;
579 assert!(result.is_ok());
580 }
581}