1use crate::error::{AllSourceError, Result};
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::{collections::HashMap, sync::Arc};
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub enum KmsProvider {
13 Local,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct KmsConfig {
19 pub provider: KmsProvider,
21
22 pub config: HashMap<String, String>,
24
25 pub auto_rotate: bool,
27
28 pub rotation_period_days: u32,
30}
31
32impl Default for KmsConfig {
33 fn default() -> Self {
34 Self {
35 provider: KmsProvider::Local,
36 config: HashMap::new(),
37 auto_rotate: true,
38 rotation_period_days: 90,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct KeyMetadata {
46 pub key_id: String,
48
49 pub alias: String,
51
52 pub purpose: KeyPurpose,
54
55 pub algorithm: KeyAlgorithm,
57
58 pub created_at: chrono::DateTime<chrono::Utc>,
60
61 pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
63
64 pub status: KeyStatus,
66
67 pub version: u32,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum KeyPurpose {
73 DataEncryption,
74 JwtSigning,
75 ApiKeySigning,
76 DatabaseEncryption,
77 Custom(String),
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
81pub enum KeyAlgorithm {
82 Aes256Gcm,
83 Aes128Gcm,
84 ChaCha20Poly1305,
85 RsaOaep,
86 EcdsaP256,
87 Ed25519,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
91pub enum KeyStatus {
92 Active,
93 Rotating,
94 Deprecated,
95 Destroyed,
96}
97
98#[async_trait::async_trait]
100pub trait KmsClient: Send + Sync {
101 async fn create_key(
103 &self,
104 alias: String,
105 purpose: KeyPurpose,
106 algorithm: KeyAlgorithm,
107 ) -> Result<KeyMetadata>;
108
109 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
111
112 async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
114
115 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
117
118 async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
120
121 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
123
124 async fn disable_key(&self, key_id: &str) -> Result<()>;
126
127 async fn enable_key(&self, key_id: &str) -> Result<()>;
129
130 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
132}
133
134pub struct LocalKms {
136 keys: Arc<DashMap<String, StoredKey>>,
137 config: KmsConfig,
138}
139
140struct StoredKey {
141 metadata: KeyMetadata,
142 key_material: Vec<u8>,
143}
144
145impl LocalKms {
146 pub fn new(config: KmsConfig) -> Self {
147 Self {
148 keys: Arc::new(DashMap::new()),
149 config,
150 }
151 }
152}
153
154#[async_trait::async_trait]
155impl KmsClient for LocalKms {
156 async fn create_key(
157 &self,
158 alias: String,
159 purpose: KeyPurpose,
160 algorithm: KeyAlgorithm,
161 ) -> Result<KeyMetadata> {
162 let key_id = uuid::Uuid::new_v4().to_string();
163
164 let key_material = match algorithm {
166 KeyAlgorithm::Aes256Gcm => {
167 let mut key = vec![0u8; 32];
168 use aes_gcm::aead::{OsRng, rand_core::RngCore};
169 RngCore::fill_bytes(&mut OsRng, &mut key);
170 key
171 }
172 KeyAlgorithm::Aes128Gcm => {
173 let mut key = vec![0u8; 16];
174 use aes_gcm::aead::{OsRng, rand_core::RngCore};
175 RngCore::fill_bytes(&mut OsRng, &mut key);
176 key
177 }
178 _ => {
179 return Err(AllSourceError::ValidationError(format!(
180 "Algorithm {algorithm:?} not supported in local KMS"
181 )));
182 }
183 };
184
185 let metadata = KeyMetadata {
186 key_id: key_id.clone(),
187 alias,
188 purpose,
189 algorithm,
190 created_at: chrono::Utc::now(),
191 last_rotated: None,
192 status: KeyStatus::Active,
193 version: 1,
194 };
195
196 let stored_key = StoredKey {
197 metadata: metadata.clone(),
198 key_material,
199 };
200
201 self.keys.insert(key_id, stored_key);
202
203 Ok(metadata)
204 }
205
206 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
207 self.keys
208 .get(key_id)
209 .map(|entry| entry.value().metadata.clone())
210 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))
211 }
212
213 async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
214 Ok(self
215 .keys
216 .iter()
217 .map(|entry| entry.value().metadata.clone())
218 .collect())
219 }
220
221 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
222 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
223
224 let stored_key = self
225 .keys
226 .get(key_id)
227 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
228
229 if stored_key.metadata.status != KeyStatus::Active {
230 return Err(AllSourceError::ValidationError(
231 "Key is not active".to_string(),
232 ));
233 }
234
235 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
236 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
237
238 use aes_gcm::aead::{OsRng, rand_core::RngCore};
240 let nonce_bytes = OsRng.next_u64().to_le_bytes();
241 let mut nonce_array = [0u8; 12];
242 nonce_array[..8].copy_from_slice(&nonce_bytes);
243 let nonce = Nonce::from_slice(&nonce_array);
244
245 let ciphertext = cipher
246 .encrypt(nonce, plaintext)
247 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
248
249 let mut result = nonce.to_vec();
251 result.extend_from_slice(&ciphertext);
252
253 Ok(result)
254 }
255
256 async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
257 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
258
259 if ciphertext_with_nonce.len() < 12 {
260 return Err(AllSourceError::ValidationError(
261 "Invalid ciphertext".to_string(),
262 ));
263 }
264
265 let stored_key = self
266 .keys
267 .get(key_id)
268 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
269
270 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
271 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
272
273 let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
275 let ciphertext = &ciphertext_with_nonce[12..];
276
277 cipher
278 .decrypt(nonce, ciphertext)
279 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
280 }
281
282 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
283 let mut stored_key = self
284 .keys
285 .get_mut(key_id)
286 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
287
288 let new_key_material = {
290 let mut key = vec![0u8; 32];
291 use aes_gcm::aead::{OsRng, rand_core::RngCore};
292 RngCore::fill_bytes(&mut OsRng, &mut key);
293 key
294 };
295
296 stored_key.key_material = new_key_material;
297 stored_key.metadata.version += 1;
298 stored_key.metadata.last_rotated = Some(chrono::Utc::now());
299
300 Ok(stored_key.metadata.clone())
301 }
302
303 async fn disable_key(&self, key_id: &str) -> Result<()> {
304 let mut stored_key = self
305 .keys
306 .get_mut(key_id)
307 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
308
309 stored_key.metadata.status = KeyStatus::Deprecated;
310 Ok(())
311 }
312
313 async fn enable_key(&self, key_id: &str) -> Result<()> {
314 let mut stored_key = self
315 .keys
316 .get_mut(key_id)
317 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
318
319 stored_key.metadata.status = KeyStatus::Active;
320 Ok(())
321 }
322
323 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
324 let mut dek = vec![0u8; 32];
326 use aes_gcm::aead::{OsRng, rand_core::RngCore};
327 RngCore::fill_bytes(&mut OsRng, &mut dek);
328
329 let encrypted_dek = self.encrypt(key_id, &dek).await?;
331
332 Ok((dek, encrypted_dek))
333 }
334}
335
336pub struct KmsManager {
338 client: Arc<dyn KmsClient>,
339 config: KmsConfig,
340}
341
342impl KmsManager {
343 pub fn new(config: KmsConfig) -> Result<Self> {
345 let client: Arc<dyn KmsClient> = match config.provider {
346 KmsProvider::Local => Arc::new(LocalKms::new(config.clone())),
347 };
348
349 Ok(Self { client, config })
350 }
351
352 pub fn client(&self) -> &Arc<dyn KmsClient> {
354 &self.client
355 }
356
357 pub async fn envelope_encrypt(
359 &self,
360 master_key_id: &str,
361 plaintext: &[u8],
362 ) -> Result<EnvelopeEncryptedData> {
363 let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
365
366 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
368
369 let cipher = Aes256Gcm::new_from_slice(&dek)
370 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
371
372 use aes_gcm::aead::{OsRng, rand_core::RngCore};
373 let nonce_bytes = OsRng.next_u64().to_le_bytes();
374 let mut nonce_array = [0u8; 12];
375 nonce_array[..8].copy_from_slice(&nonce_bytes);
376 let nonce = Nonce::from_slice(&nonce_array);
377
378 let ciphertext = cipher
379 .encrypt(nonce, plaintext)
380 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
381
382 Ok(EnvelopeEncryptedData {
383 ciphertext,
384 nonce: nonce.to_vec(),
385 encrypted_dek,
386 master_key_id: master_key_id.to_string(),
387 })
388 }
389
390 pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
392 let dek = self
394 .client
395 .decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek)
396 .await?;
397
398 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
400
401 let cipher = Aes256Gcm::new_from_slice(&dek)
402 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
403
404 let nonce = Nonce::from_slice(&encrypted.nonce);
405
406 cipher
407 .decrypt(nonce, encrypted.ciphertext.as_ref())
408 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
409 }
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct EnvelopeEncryptedData {
415 pub ciphertext: Vec<u8>,
417
418 pub nonce: Vec<u8>,
420
421 pub encrypted_dek: Vec<u8>,
423
424 pub master_key_id: String,
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[tokio::test]
433 async fn test_local_kms_create_key() {
434 let config = KmsConfig::default();
435 let kms = LocalKms::new(config);
436
437 let metadata = kms
438 .create_key(
439 "test-key".to_string(),
440 KeyPurpose::DataEncryption,
441 KeyAlgorithm::Aes256Gcm,
442 )
443 .await
444 .unwrap();
445
446 assert_eq!(metadata.alias, "test-key");
447 assert_eq!(metadata.status, KeyStatus::Active);
448 assert_eq!(metadata.version, 1);
449 }
450
451 #[tokio::test]
452 async fn test_local_kms_encrypt_decrypt() {
453 let config = KmsConfig::default();
454 let kms = LocalKms::new(config);
455
456 let key = kms
457 .create_key(
458 "test-key".to_string(),
459 KeyPurpose::DataEncryption,
460 KeyAlgorithm::Aes256Gcm,
461 )
462 .await
463 .unwrap();
464
465 let plaintext = b"sensitive data";
466 let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
467 let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
468
469 assert_eq!(decrypted, plaintext);
470 }
471
472 #[tokio::test]
473 async fn test_key_rotation() {
474 let config = KmsConfig::default();
475 let kms = LocalKms::new(config);
476
477 let key = kms
478 .create_key(
479 "test-key".to_string(),
480 KeyPurpose::DataEncryption,
481 KeyAlgorithm::Aes256Gcm,
482 )
483 .await
484 .unwrap();
485
486 let rotated = kms.rotate_key(&key.key_id).await.unwrap();
487 assert_eq!(rotated.version, 2);
488 assert!(rotated.last_rotated.is_some());
489 }
490
491 #[tokio::test]
492 async fn test_envelope_encryption() {
493 let config = KmsConfig::default();
494 let manager = KmsManager::new(config).unwrap();
495
496 let master_key = manager
498 .client()
499 .create_key(
500 "master-key".to_string(),
501 KeyPurpose::DataEncryption,
502 KeyAlgorithm::Aes256Gcm,
503 )
504 .await
505 .unwrap();
506
507 let plaintext = b"sensitive data for envelope encryption";
509 let encrypted = manager
510 .envelope_encrypt(&master_key.key_id, plaintext)
511 .await
512 .unwrap();
513
514 let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
516
517 assert_eq!(decrypted, plaintext);
518 }
519
520 #[tokio::test]
521 async fn test_disable_enable_key() {
522 let config = KmsConfig::default();
523 let kms = LocalKms::new(config);
524
525 let key = kms
526 .create_key(
527 "test-key".to_string(),
528 KeyPurpose::DataEncryption,
529 KeyAlgorithm::Aes256Gcm,
530 )
531 .await
532 .unwrap();
533
534 kms.disable_key(&key.key_id).await.unwrap();
536 let metadata = kms.get_key(&key.key_id).await.unwrap();
537 assert_eq!(metadata.status, KeyStatus::Deprecated);
538
539 let result = kms.encrypt(&key.key_id, b"test").await;
541 assert!(result.is_err());
542
543 kms.enable_key(&key.key_id).await.unwrap();
545 let metadata = kms.get_key(&key.key_id).await.unwrap();
546 assert_eq!(metadata.status, KeyStatus::Active);
547
548 let result = kms.encrypt(&key.key_id, b"test").await;
550 assert!(result.is_ok());
551 }
552}