1use crate::error::{AllSourceError, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use parking_lot::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum KmsProvider {
19 AwsKms,
20 GoogleCloudKms,
21 AzureKeyVault,
22 HashicorpVault,
23 Pkcs11,
24 Local, }
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct KmsConfig {
30 pub provider: KmsProvider,
32
33 pub config: HashMap<String, String>,
35
36 pub auto_rotate: bool,
38
39 pub rotation_period_days: u32,
41}
42
43impl Default for KmsConfig {
44 fn default() -> Self {
45 Self {
46 provider: KmsProvider::Local,
47 config: HashMap::new(),
48 auto_rotate: true,
49 rotation_period_days: 90,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct KeyMetadata {
57 pub key_id: String,
59
60 pub alias: String,
62
63 pub purpose: KeyPurpose,
65
66 pub algorithm: KeyAlgorithm,
68
69 pub created_at: chrono::DateTime<chrono::Utc>,
71
72 pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
74
75 pub status: KeyStatus,
77
78 pub version: u32,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub enum KeyPurpose {
84 DataEncryption,
85 JwtSigning,
86 ApiKeySigning,
87 DatabaseEncryption,
88 Custom(String),
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
92pub enum KeyAlgorithm {
93 Aes256Gcm,
94 Aes128Gcm,
95 ChaCha20Poly1305,
96 RsaOaep,
97 EcdsaP256,
98 Ed25519,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
102pub enum KeyStatus {
103 Active,
104 Rotating,
105 Deprecated,
106 Destroyed,
107}
108
109#[async_trait::async_trait]
111pub trait KmsClient: Send + Sync {
112 async fn create_key(&self, alias: String, purpose: KeyPurpose, algorithm: KeyAlgorithm) -> Result<KeyMetadata>;
114
115 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
117
118 async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
120
121 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
123
124 async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
126
127 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
129
130 async fn disable_key(&self, key_id: &str) -> Result<()>;
132
133 async fn enable_key(&self, key_id: &str) -> Result<()>;
135
136 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
138}
139
140pub struct LocalKms {
142 keys: Arc<RwLock<HashMap<String, StoredKey>>>,
143 config: KmsConfig,
144}
145
146struct StoredKey {
147 metadata: KeyMetadata,
148 key_material: Vec<u8>,
149}
150
151impl LocalKms {
152 pub fn new(config: KmsConfig) -> Self {
153 Self {
154 keys: Arc::new(RwLock::new(HashMap::new())),
155 config,
156 }
157 }
158}
159
160#[async_trait::async_trait]
161impl KmsClient for LocalKms {
162 async fn create_key(&self, alias: String, purpose: KeyPurpose, algorithm: KeyAlgorithm) -> Result<KeyMetadata> {
163 let key_id = uuid::Uuid::new_v4().to_string();
164
165 let key_material = match algorithm {
167 KeyAlgorithm::Aes256Gcm => {
168 let mut key = vec![0u8; 32];
169 use aes_gcm::aead::OsRng;
170 use aes_gcm::aead::rand_core::RngCore;
171 RngCore::fill_bytes(&mut OsRng, &mut key);
172 key
173 }
174 KeyAlgorithm::Aes128Gcm => {
175 let mut key = vec![0u8; 16];
176 use aes_gcm::aead::OsRng;
177 use aes_gcm::aead::rand_core::RngCore;
178 RngCore::fill_bytes(&mut OsRng, &mut key);
179 key
180 }
181 _ => {
182 return Err(AllSourceError::ValidationError(
183 format!("Algorithm {:?} not supported in local KMS", algorithm)
184 ));
185 }
186 };
187
188 let metadata = KeyMetadata {
189 key_id: key_id.clone(),
190 alias,
191 purpose,
192 algorithm,
193 created_at: chrono::Utc::now(),
194 last_rotated: None,
195 status: KeyStatus::Active,
196 version: 1,
197 };
198
199 let stored_key = StoredKey {
200 metadata: metadata.clone(),
201 key_material,
202 };
203
204 self.keys.write().insert(key_id, stored_key);
205
206 Ok(metadata)
207 }
208
209 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
210 self.keys.read()
211 .get(key_id)
212 .map(|k| k.metadata.clone())
213 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))
214 }
215
216 async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
217 Ok(self.keys.read()
218 .values()
219 .map(|k| k.metadata.clone())
220 .collect())
221 }
222
223 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
224 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
225 use aes_gcm::aead::Aead;
226
227 let keys = self.keys.read();
228 let stored_key = keys.get(key_id)
229 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
230
231 if stored_key.metadata.status != KeyStatus::Active {
232 return Err(AllSourceError::ValidationError("Key is not active".to_string()));
233 }
234
235 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
236 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
237
238 use aes_gcm::aead::OsRng;
240 use aes_gcm::aead::rand_core::RngCore;
241 let nonce_bytes = OsRng.next_u64().to_le_bytes();
242 let mut nonce_array = [0u8; 12];
243 nonce_array[..8].copy_from_slice(&nonce_bytes);
244 let nonce = Nonce::from_slice(&nonce_array);
245
246 let ciphertext = cipher.encrypt(nonce, plaintext)
247 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {}", e)))?;
248
249 let mut result = nonce.to_vec();
251 result.extend_from_slice(&ciphertext);
252
253 Ok(result)
254 }
255
256 async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
257 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
258 use aes_gcm::aead::Aead;
259
260 if ciphertext_with_nonce.len() < 12 {
261 return Err(AllSourceError::ValidationError("Invalid ciphertext".to_string()));
262 }
263
264 let keys = self.keys.read();
265 let stored_key = keys.get(key_id)
266 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
267
268 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
269 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
270
271 let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
273 let ciphertext = &ciphertext_with_nonce[12..];
274
275 cipher.decrypt(nonce, ciphertext)
276 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {}", e)))
277 }
278
279 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
280 let mut keys = self.keys.write();
281 let stored_key = keys.get_mut(key_id)
282 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
283
284 let new_key_material = {
286 let mut key = vec![0u8; 32];
287 use aes_gcm::aead::OsRng;
288 use aes_gcm::aead::rand_core::RngCore;
289 RngCore::fill_bytes(&mut OsRng, &mut key);
290 key
291 };
292
293 stored_key.key_material = new_key_material;
294 stored_key.metadata.version += 1;
295 stored_key.metadata.last_rotated = Some(chrono::Utc::now());
296
297 Ok(stored_key.metadata.clone())
298 }
299
300 async fn disable_key(&self, key_id: &str) -> Result<()> {
301 let mut keys = self.keys.write();
302 let stored_key = keys.get_mut(key_id)
303 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
304
305 stored_key.metadata.status = KeyStatus::Deprecated;
306 Ok(())
307 }
308
309 async fn enable_key(&self, key_id: &str) -> Result<()> {
310 let mut keys = self.keys.write();
311 let stored_key = keys.get_mut(key_id)
312 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {} not found", key_id)))?;
313
314 stored_key.metadata.status = KeyStatus::Active;
315 Ok(())
316 }
317
318 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
319 let mut dek = vec![0u8; 32];
321 use aes_gcm::aead::OsRng;
322 use aes_gcm::aead::rand_core::RngCore;
323 RngCore::fill_bytes(&mut OsRng, &mut dek);
324
325 let encrypted_dek = self.encrypt(key_id, &dek).await?;
327
328 Ok((dek, encrypted_dek))
329 }
330}
331
332pub struct KmsManager {
334 client: Arc<dyn KmsClient>,
335 config: KmsConfig,
336}
337
338impl KmsManager {
339 pub fn new(config: KmsConfig) -> Result<Self> {
341 let client: Arc<dyn KmsClient> = match config.provider {
342 KmsProvider::Local => {
343 Arc::new(LocalKms::new(config.clone()))
344 }
345 _ => {
346 return Err(AllSourceError::ValidationError(
347 format!("KMS provider {:?} not yet implemented", config.provider)
348 ));
349 }
350 };
351
352 Ok(Self { client, config })
353 }
354
355 pub fn client(&self) -> &Arc<dyn KmsClient> {
357 &self.client
358 }
359
360 pub async fn envelope_encrypt(&self, master_key_id: &str, plaintext: &[u8]) -> Result<EnvelopeEncryptedData> {
362 let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
364
365 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
367 use aes_gcm::aead::Aead;
368
369 let cipher = Aes256Gcm::new_from_slice(&dek)
370 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
371
372 use aes_gcm::aead::OsRng;
373 use aes_gcm::aead::rand_core::RngCore;
374 let nonce_bytes = OsRng.next_u64().to_le_bytes();
375 let mut nonce_array = [0u8; 12];
376 nonce_array[..8].copy_from_slice(&nonce_bytes);
377 let nonce = Nonce::from_slice(&nonce_array);
378
379 let ciphertext = cipher.encrypt(nonce, plaintext)
380 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {}", e)))?;
381
382 Ok(EnvelopeEncryptedData {
383 ciphertext,
384 nonce: nonce.to_vec(),
385 encrypted_dek,
386 master_key_id: master_key_id.to_string(),
387 })
388 }
389
390 pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
392 let dek = self.client.decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek).await?;
394
395 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
397 use aes_gcm::aead::Aead;
398
399 let cipher = Aes256Gcm::new_from_slice(&dek)
400 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
401
402 let nonce = Nonce::from_slice(&encrypted.nonce);
403
404 cipher.decrypt(nonce, encrypted.ciphertext.as_ref())
405 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {}", e)))
406 }
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
411pub struct EnvelopeEncryptedData {
412 pub ciphertext: Vec<u8>,
414
415 pub nonce: Vec<u8>,
417
418 pub encrypted_dek: Vec<u8>,
420
421 pub master_key_id: String,
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[tokio::test]
430 async fn test_local_kms_create_key() {
431 let config = KmsConfig::default();
432 let kms = LocalKms::new(config);
433
434 let metadata = kms.create_key(
435 "test-key".to_string(),
436 KeyPurpose::DataEncryption,
437 KeyAlgorithm::Aes256Gcm,
438 ).await.unwrap();
439
440 assert_eq!(metadata.alias, "test-key");
441 assert_eq!(metadata.status, KeyStatus::Active);
442 assert_eq!(metadata.version, 1);
443 }
444
445 #[tokio::test]
446 async fn test_local_kms_encrypt_decrypt() {
447 let config = KmsConfig::default();
448 let kms = LocalKms::new(config);
449
450 let key = kms.create_key(
451 "test-key".to_string(),
452 KeyPurpose::DataEncryption,
453 KeyAlgorithm::Aes256Gcm,
454 ).await.unwrap();
455
456 let plaintext = b"sensitive data";
457 let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
458 let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
459
460 assert_eq!(decrypted, plaintext);
461 }
462
463 #[tokio::test]
464 async fn test_key_rotation() {
465 let config = KmsConfig::default();
466 let kms = LocalKms::new(config);
467
468 let key = kms.create_key(
469 "test-key".to_string(),
470 KeyPurpose::DataEncryption,
471 KeyAlgorithm::Aes256Gcm,
472 ).await.unwrap();
473
474 let rotated = kms.rotate_key(&key.key_id).await.unwrap();
475 assert_eq!(rotated.version, 2);
476 assert!(rotated.last_rotated.is_some());
477 }
478
479 #[tokio::test]
480 async fn test_envelope_encryption() {
481 let config = KmsConfig::default();
482 let manager = KmsManager::new(config).unwrap();
483
484 let master_key = manager.client().create_key(
486 "master-key".to_string(),
487 KeyPurpose::DataEncryption,
488 KeyAlgorithm::Aes256Gcm,
489 ).await.unwrap();
490
491 let plaintext = b"sensitive data for envelope encryption";
493 let encrypted = manager.envelope_encrypt(&master_key.key_id, plaintext).await.unwrap();
494
495 let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
497
498 assert_eq!(decrypted, plaintext);
499 }
500
501 #[tokio::test]
502 async fn test_disable_enable_key() {
503 let config = KmsConfig::default();
504 let kms = LocalKms::new(config);
505
506 let key = kms.create_key(
507 "test-key".to_string(),
508 KeyPurpose::DataEncryption,
509 KeyAlgorithm::Aes256Gcm,
510 ).await.unwrap();
511
512 kms.disable_key(&key.key_id).await.unwrap();
514 let metadata = kms.get_key(&key.key_id).await.unwrap();
515 assert_eq!(metadata.status, KeyStatus::Deprecated);
516
517 let result = kms.encrypt(&key.key_id, b"test").await;
519 assert!(result.is_err());
520
521 kms.enable_key(&key.key_id).await.unwrap();
523 let metadata = kms.get_key(&key.key_id).await.unwrap();
524 assert_eq!(metadata.status, KeyStatus::Active);
525
526 let result = kms.encrypt(&key.key_id, b"test").await;
528 assert!(result.is_ok());
529 }
530}