1use crate::error::{AllSourceError, Result};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub enum KmsProvider {
17 AwsKms,
18 GoogleCloudKms,
19 AzureKeyVault,
20 HashicorpVault,
21 Pkcs11,
22 Local, }
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct KmsConfig {
28 pub provider: KmsProvider,
30
31 pub config: HashMap<String, String>,
33
34 pub auto_rotate: bool,
36
37 pub rotation_period_days: u32,
39}
40
41impl Default for KmsConfig {
42 fn default() -> Self {
43 Self {
44 provider: KmsProvider::Local,
45 config: HashMap::new(),
46 auto_rotate: true,
47 rotation_period_days: 90,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct KeyMetadata {
55 pub key_id: String,
57
58 pub alias: String,
60
61 pub purpose: KeyPurpose,
63
64 pub algorithm: KeyAlgorithm,
66
67 pub created_at: chrono::DateTime<chrono::Utc>,
69
70 pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
72
73 pub status: KeyStatus,
75
76 pub version: u32,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
81pub enum KeyPurpose {
82 DataEncryption,
83 JwtSigning,
84 ApiKeySigning,
85 DatabaseEncryption,
86 Custom(String),
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90pub enum KeyAlgorithm {
91 Aes256Gcm,
92 Aes128Gcm,
93 ChaCha20Poly1305,
94 RsaOaep,
95 EcdsaP256,
96 Ed25519,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
100pub enum KeyStatus {
101 Active,
102 Rotating,
103 Deprecated,
104 Destroyed,
105}
106
107#[async_trait::async_trait]
109pub trait KmsClient: Send + Sync {
110 async fn create_key(
112 &self,
113 alias: String,
114 purpose: KeyPurpose,
115 algorithm: KeyAlgorithm,
116 ) -> Result<KeyMetadata>;
117
118 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
120
121 async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
123
124 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
126
127 async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
129
130 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
132
133 async fn disable_key(&self, key_id: &str) -> Result<()>;
135
136 async fn enable_key(&self, key_id: &str) -> Result<()>;
138
139 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
141}
142
143pub struct LocalKms {
145 keys: Arc<DashMap<String, StoredKey>>,
146 config: KmsConfig,
147}
148
149struct StoredKey {
150 metadata: KeyMetadata,
151 key_material: Vec<u8>,
152}
153
154impl LocalKms {
155 pub fn new(config: KmsConfig) -> Self {
156 Self {
157 keys: Arc::new(DashMap::new()),
158 config,
159 }
160 }
161}
162
163#[async_trait::async_trait]
164impl KmsClient for LocalKms {
165 async fn create_key(
166 &self,
167 alias: String,
168 purpose: KeyPurpose,
169 algorithm: KeyAlgorithm,
170 ) -> Result<KeyMetadata> {
171 let key_id = uuid::Uuid::new_v4().to_string();
172
173 let key_material = match algorithm {
175 KeyAlgorithm::Aes256Gcm => {
176 let mut key = vec![0u8; 32];
177 use aes_gcm::aead::{OsRng, rand_core::RngCore};
178 RngCore::fill_bytes(&mut OsRng, &mut key);
179 key
180 }
181 KeyAlgorithm::Aes128Gcm => {
182 let mut key = vec![0u8; 16];
183 use aes_gcm::aead::{OsRng, rand_core::RngCore};
184 RngCore::fill_bytes(&mut OsRng, &mut key);
185 key
186 }
187 _ => {
188 return Err(AllSourceError::ValidationError(format!(
189 "Algorithm {:?} not supported in local KMS",
190 algorithm
191 )));
192 }
193 };
194
195 let metadata = KeyMetadata {
196 key_id: key_id.clone(),
197 alias,
198 purpose,
199 algorithm,
200 created_at: chrono::Utc::now(),
201 last_rotated: None,
202 status: KeyStatus::Active,
203 version: 1,
204 };
205
206 let stored_key = StoredKey {
207 metadata: metadata.clone(),
208 key_material,
209 };
210
211 self.keys.insert(key_id, stored_key);
212
213 Ok(metadata)
214 }
215
216 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
217 self.keys
218 .get(key_id)
219 .map(|entry| entry.value().metadata.clone())
220 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))
221 }
222
223 async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
224 Ok(self
225 .keys
226 .iter()
227 .map(|entry| entry.value().metadata.clone())
228 .collect())
229 }
230
231 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
232 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
233
234 let stored_key = self
235 .keys
236 .get(key_id)
237 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
238
239 if stored_key.metadata.status != KeyStatus::Active {
240 return Err(AllSourceError::ValidationError(
241 "Key is not active".to_string(),
242 ));
243 }
244
245 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
246 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
247
248 use aes_gcm::aead::{OsRng, rand_core::RngCore};
250 let nonce_bytes = OsRng.next_u64().to_le_bytes();
251 let mut nonce_array = [0u8; 12];
252 nonce_array[..8].copy_from_slice(&nonce_bytes);
253 let nonce = Nonce::from_slice(&nonce_array);
254
255 let ciphertext = cipher
256 .encrypt(nonce, plaintext)
257 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
258
259 let mut result = nonce.to_vec();
261 result.extend_from_slice(&ciphertext);
262
263 Ok(result)
264 }
265
266 async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
267 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
268
269 if ciphertext_with_nonce.len() < 12 {
270 return Err(AllSourceError::ValidationError(
271 "Invalid ciphertext".to_string(),
272 ));
273 }
274
275 let stored_key = self
276 .keys
277 .get(key_id)
278 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
279
280 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
281 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
282
283 let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
285 let ciphertext = &ciphertext_with_nonce[12..];
286
287 cipher
288 .decrypt(nonce, ciphertext)
289 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
290 }
291
292 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
293 let mut stored_key = self
294 .keys
295 .get_mut(key_id)
296 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
297
298 let new_key_material = {
300 let mut key = vec![0u8; 32];
301 use aes_gcm::aead::{OsRng, rand_core::RngCore};
302 RngCore::fill_bytes(&mut OsRng, &mut key);
303 key
304 };
305
306 stored_key.key_material = new_key_material;
307 stored_key.metadata.version += 1;
308 stored_key.metadata.last_rotated = Some(chrono::Utc::now());
309
310 Ok(stored_key.metadata.clone())
311 }
312
313 async fn disable_key(&self, key_id: &str) -> Result<()> {
314 let mut stored_key = self
315 .keys
316 .get_mut(key_id)
317 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
318
319 stored_key.metadata.status = KeyStatus::Deprecated;
320 Ok(())
321 }
322
323 async fn enable_key(&self, key_id: &str) -> Result<()> {
324 let mut stored_key = self
325 .keys
326 .get_mut(key_id)
327 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
328
329 stored_key.metadata.status = KeyStatus::Active;
330 Ok(())
331 }
332
333 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
334 let mut dek = vec![0u8; 32];
336 use aes_gcm::aead::{OsRng, rand_core::RngCore};
337 RngCore::fill_bytes(&mut OsRng, &mut dek);
338
339 let encrypted_dek = self.encrypt(key_id, &dek).await?;
341
342 Ok((dek, encrypted_dek))
343 }
344}
345
346pub struct KmsManager {
348 client: Arc<dyn KmsClient>,
349 config: KmsConfig,
350}
351
352impl KmsManager {
353 pub fn new(config: KmsConfig) -> Result<Self> {
355 let client: Arc<dyn KmsClient> = match config.provider {
356 KmsProvider::Local => Arc::new(LocalKms::new(config.clone())),
357 _ => {
358 return Err(AllSourceError::ValidationError(format!(
359 "KMS provider {:?} not yet implemented",
360 config.provider
361 )));
362 }
363 };
364
365 Ok(Self { client, config })
366 }
367
368 pub fn client(&self) -> &Arc<dyn KmsClient> {
370 &self.client
371 }
372
373 pub async fn envelope_encrypt(
375 &self,
376 master_key_id: &str,
377 plaintext: &[u8],
378 ) -> Result<EnvelopeEncryptedData> {
379 let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
381
382 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
384
385 let cipher = Aes256Gcm::new_from_slice(&dek)
386 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
387
388 use aes_gcm::aead::{OsRng, rand_core::RngCore};
389 let nonce_bytes = OsRng.next_u64().to_le_bytes();
390 let mut nonce_array = [0u8; 12];
391 nonce_array[..8].copy_from_slice(&nonce_bytes);
392 let nonce = Nonce::from_slice(&nonce_array);
393
394 let ciphertext = cipher
395 .encrypt(nonce, plaintext)
396 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
397
398 Ok(EnvelopeEncryptedData {
399 ciphertext,
400 nonce: nonce.to_vec(),
401 encrypted_dek,
402 master_key_id: master_key_id.to_string(),
403 })
404 }
405
406 pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
408 let dek = self
410 .client
411 .decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek)
412 .await?;
413
414 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
416
417 let cipher = Aes256Gcm::new_from_slice(&dek)
418 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
419
420 let nonce = Nonce::from_slice(&encrypted.nonce);
421
422 cipher
423 .decrypt(nonce, encrypted.ciphertext.as_ref())
424 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
425 }
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct EnvelopeEncryptedData {
431 pub ciphertext: Vec<u8>,
433
434 pub nonce: Vec<u8>,
436
437 pub encrypted_dek: Vec<u8>,
439
440 pub master_key_id: String,
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[tokio::test]
449 async fn test_local_kms_create_key() {
450 let config = KmsConfig::default();
451 let kms = LocalKms::new(config);
452
453 let metadata = kms
454 .create_key(
455 "test-key".to_string(),
456 KeyPurpose::DataEncryption,
457 KeyAlgorithm::Aes256Gcm,
458 )
459 .await
460 .unwrap();
461
462 assert_eq!(metadata.alias, "test-key");
463 assert_eq!(metadata.status, KeyStatus::Active);
464 assert_eq!(metadata.version, 1);
465 }
466
467 #[tokio::test]
468 async fn test_local_kms_encrypt_decrypt() {
469 let config = KmsConfig::default();
470 let kms = LocalKms::new(config);
471
472 let key = kms
473 .create_key(
474 "test-key".to_string(),
475 KeyPurpose::DataEncryption,
476 KeyAlgorithm::Aes256Gcm,
477 )
478 .await
479 .unwrap();
480
481 let plaintext = b"sensitive data";
482 let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
483 let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
484
485 assert_eq!(decrypted, plaintext);
486 }
487
488 #[tokio::test]
489 async fn test_key_rotation() {
490 let config = KmsConfig::default();
491 let kms = LocalKms::new(config);
492
493 let key = kms
494 .create_key(
495 "test-key".to_string(),
496 KeyPurpose::DataEncryption,
497 KeyAlgorithm::Aes256Gcm,
498 )
499 .await
500 .unwrap();
501
502 let rotated = kms.rotate_key(&key.key_id).await.unwrap();
503 assert_eq!(rotated.version, 2);
504 assert!(rotated.last_rotated.is_some());
505 }
506
507 #[tokio::test]
508 async fn test_envelope_encryption() {
509 let config = KmsConfig::default();
510 let manager = KmsManager::new(config).unwrap();
511
512 let master_key = manager
514 .client()
515 .create_key(
516 "master-key".to_string(),
517 KeyPurpose::DataEncryption,
518 KeyAlgorithm::Aes256Gcm,
519 )
520 .await
521 .unwrap();
522
523 let plaintext = b"sensitive data for envelope encryption";
525 let encrypted = manager
526 .envelope_encrypt(&master_key.key_id, plaintext)
527 .await
528 .unwrap();
529
530 let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
532
533 assert_eq!(decrypted, plaintext);
534 }
535
536 #[tokio::test]
537 async fn test_disable_enable_key() {
538 let config = KmsConfig::default();
539 let kms = LocalKms::new(config);
540
541 let key = kms
542 .create_key(
543 "test-key".to_string(),
544 KeyPurpose::DataEncryption,
545 KeyAlgorithm::Aes256Gcm,
546 )
547 .await
548 .unwrap();
549
550 kms.disable_key(&key.key_id).await.unwrap();
552 let metadata = kms.get_key(&key.key_id).await.unwrap();
553 assert_eq!(metadata.status, KeyStatus::Deprecated);
554
555 let result = kms.encrypt(&key.key_id, b"test").await;
557 assert!(result.is_err());
558
559 kms.enable_key(&key.key_id).await.unwrap();
561 let metadata = kms.get_key(&key.key_id).await.unwrap();
562 assert_eq!(metadata.status, KeyStatus::Active);
563
564 let result = kms.encrypt(&key.key_id, b"test").await;
566 assert!(result.is_ok());
567 }
568}