1use crate::error::{AllSourceError, Result};
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::{collections::HashMap, sync::Arc};
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub enum KmsProvider {
13 Local,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct KmsConfig {
19 pub provider: KmsProvider,
21
22 pub config: HashMap<String, String>,
24
25 pub auto_rotate: bool,
27
28 pub rotation_period_days: u32,
30}
31
32impl Default for KmsConfig {
33 fn default() -> Self {
34 Self {
35 provider: KmsProvider::Local,
36 config: HashMap::new(),
37 auto_rotate: true,
38 rotation_period_days: 90,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct KeyMetadata {
46 pub key_id: String,
48
49 pub alias: String,
51
52 pub purpose: KeyPurpose,
54
55 pub algorithm: KeyAlgorithm,
57
58 pub created_at: chrono::DateTime<chrono::Utc>,
60
61 pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
63
64 pub status: KeyStatus,
66
67 pub version: u32,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum KeyPurpose {
73 DataEncryption,
74 JwtSigning,
75 ApiKeySigning,
76 DatabaseEncryption,
77 Custom(String),
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
81pub enum KeyAlgorithm {
82 Aes256Gcm,
83 Aes128Gcm,
84 ChaCha20Poly1305,
85 RsaOaep,
86 EcdsaP256,
87 Ed25519,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
91pub enum KeyStatus {
92 Active,
93 Rotating,
94 Deprecated,
95 Destroyed,
96}
97
98#[async_trait::async_trait]
100pub trait KmsClient: Send + Sync {
101 async fn create_key(
103 &self,
104 alias: String,
105 purpose: KeyPurpose,
106 algorithm: KeyAlgorithm,
107 ) -> Result<KeyMetadata>;
108
109 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
111
112 async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
114
115 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
117
118 async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
120
121 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
123
124 async fn disable_key(&self, key_id: &str) -> Result<()>;
126
127 async fn enable_key(&self, key_id: &str) -> Result<()>;
129
130 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
132}
133
134pub struct LocalKms {
136 keys: Arc<DashMap<String, StoredKey>>,
137 config: KmsConfig,
138}
139
140struct StoredKey {
141 metadata: KeyMetadata,
142 key_material: Vec<u8>,
143}
144
145impl LocalKms {
146 pub fn new(config: KmsConfig) -> Self {
147 Self {
148 keys: Arc::new(DashMap::new()),
149 config,
150 }
151 }
152}
153
154#[async_trait::async_trait]
155impl KmsClient for LocalKms {
156 async fn create_key(
157 &self,
158 alias: String,
159 purpose: KeyPurpose,
160 algorithm: KeyAlgorithm,
161 ) -> Result<KeyMetadata> {
162 let key_id = uuid::Uuid::new_v4().to_string();
163
164 let key_material = match algorithm {
166 KeyAlgorithm::Aes256Gcm => {
167 let mut key = vec![0u8; 32];
168 use aes_gcm::aead::{OsRng, rand_core::RngCore};
169 RngCore::fill_bytes(&mut OsRng, &mut key);
170 key
171 }
172 KeyAlgorithm::Aes128Gcm => {
173 let mut key = vec![0u8; 16];
174 use aes_gcm::aead::{OsRng, rand_core::RngCore};
175 RngCore::fill_bytes(&mut OsRng, &mut key);
176 key
177 }
178 _ => {
179 return Err(AllSourceError::ValidationError(format!(
180 "Algorithm {:?} not supported in local KMS",
181 algorithm
182 )));
183 }
184 };
185
186 let metadata = KeyMetadata {
187 key_id: key_id.clone(),
188 alias,
189 purpose,
190 algorithm,
191 created_at: chrono::Utc::now(),
192 last_rotated: None,
193 status: KeyStatus::Active,
194 version: 1,
195 };
196
197 let stored_key = StoredKey {
198 metadata: metadata.clone(),
199 key_material,
200 };
201
202 self.keys.insert(key_id, stored_key);
203
204 Ok(metadata)
205 }
206
207 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
208 self.keys
209 .get(key_id)
210 .map(|entry| entry.value().metadata.clone())
211 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))
212 }
213
214 async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
215 Ok(self
216 .keys
217 .iter()
218 .map(|entry| entry.value().metadata.clone())
219 .collect())
220 }
221
222 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
223 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
224
225 let stored_key = self
226 .keys
227 .get(key_id)
228 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
229
230 if stored_key.metadata.status != KeyStatus::Active {
231 return Err(AllSourceError::ValidationError(
232 "Key is not active".to_string(),
233 ));
234 }
235
236 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
237 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
238
239 use aes_gcm::aead::{OsRng, rand_core::RngCore};
241 let nonce_bytes = OsRng.next_u64().to_le_bytes();
242 let mut nonce_array = [0u8; 12];
243 nonce_array[..8].copy_from_slice(&nonce_bytes);
244 let nonce = Nonce::from_slice(&nonce_array);
245
246 let ciphertext = cipher
247 .encrypt(nonce, plaintext)
248 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
249
250 let mut result = nonce.to_vec();
252 result.extend_from_slice(&ciphertext);
253
254 Ok(result)
255 }
256
257 async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
258 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
259
260 if ciphertext_with_nonce.len() < 12 {
261 return Err(AllSourceError::ValidationError(
262 "Invalid ciphertext".to_string(),
263 ));
264 }
265
266 let stored_key = self
267 .keys
268 .get(key_id)
269 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
270
271 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
272 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
273
274 let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
276 let ciphertext = &ciphertext_with_nonce[12..];
277
278 cipher
279 .decrypt(nonce, ciphertext)
280 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
281 }
282
283 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
284 let mut stored_key = self
285 .keys
286 .get_mut(key_id)
287 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
288
289 let new_key_material = {
291 let mut key = vec![0u8; 32];
292 use aes_gcm::aead::{OsRng, rand_core::RngCore};
293 RngCore::fill_bytes(&mut OsRng, &mut key);
294 key
295 };
296
297 stored_key.key_material = new_key_material;
298 stored_key.metadata.version += 1;
299 stored_key.metadata.last_rotated = Some(chrono::Utc::now());
300
301 Ok(stored_key.metadata.clone())
302 }
303
304 async fn disable_key(&self, key_id: &str) -> Result<()> {
305 let mut stored_key = self
306 .keys
307 .get_mut(key_id)
308 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
309
310 stored_key.metadata.status = KeyStatus::Deprecated;
311 Ok(())
312 }
313
314 async fn enable_key(&self, key_id: &str) -> Result<()> {
315 let mut stored_key = self
316 .keys
317 .get_mut(key_id)
318 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
319
320 stored_key.metadata.status = KeyStatus::Active;
321 Ok(())
322 }
323
324 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
325 let mut dek = vec![0u8; 32];
327 use aes_gcm::aead::{OsRng, rand_core::RngCore};
328 RngCore::fill_bytes(&mut OsRng, &mut dek);
329
330 let encrypted_dek = self.encrypt(key_id, &dek).await?;
332
333 Ok((dek, encrypted_dek))
334 }
335}
336
337pub struct KmsManager {
339 client: Arc<dyn KmsClient>,
340 config: KmsConfig,
341}
342
343impl KmsManager {
344 pub fn new(config: KmsConfig) -> Result<Self> {
346 let client: Arc<dyn KmsClient> = match config.provider {
347 KmsProvider::Local => Arc::new(LocalKms::new(config.clone())),
348 };
349
350 Ok(Self { client, config })
351 }
352
353 pub fn client(&self) -> &Arc<dyn KmsClient> {
355 &self.client
356 }
357
358 pub async fn envelope_encrypt(
360 &self,
361 master_key_id: &str,
362 plaintext: &[u8],
363 ) -> Result<EnvelopeEncryptedData> {
364 let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
366
367 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
369
370 let cipher = Aes256Gcm::new_from_slice(&dek)
371 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
372
373 use aes_gcm::aead::{OsRng, rand_core::RngCore};
374 let nonce_bytes = OsRng.next_u64().to_le_bytes();
375 let mut nonce_array = [0u8; 12];
376 nonce_array[..8].copy_from_slice(&nonce_bytes);
377 let nonce = Nonce::from_slice(&nonce_array);
378
379 let ciphertext = cipher
380 .encrypt(nonce, plaintext)
381 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
382
383 Ok(EnvelopeEncryptedData {
384 ciphertext,
385 nonce: nonce.to_vec(),
386 encrypted_dek,
387 master_key_id: master_key_id.to_string(),
388 })
389 }
390
391 pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
393 let dek = self
395 .client
396 .decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek)
397 .await?;
398
399 use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
401
402 let cipher = Aes256Gcm::new_from_slice(&dek)
403 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
404
405 let nonce = Nonce::from_slice(&encrypted.nonce);
406
407 cipher
408 .decrypt(nonce, encrypted.ciphertext.as_ref())
409 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
410 }
411}
412
413#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct EnvelopeEncryptedData {
416 pub ciphertext: Vec<u8>,
418
419 pub nonce: Vec<u8>,
421
422 pub encrypted_dek: Vec<u8>,
424
425 pub master_key_id: String,
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[tokio::test]
434 async fn test_local_kms_create_key() {
435 let config = KmsConfig::default();
436 let kms = LocalKms::new(config);
437
438 let metadata = kms
439 .create_key(
440 "test-key".to_string(),
441 KeyPurpose::DataEncryption,
442 KeyAlgorithm::Aes256Gcm,
443 )
444 .await
445 .unwrap();
446
447 assert_eq!(metadata.alias, "test-key");
448 assert_eq!(metadata.status, KeyStatus::Active);
449 assert_eq!(metadata.version, 1);
450 }
451
452 #[tokio::test]
453 async fn test_local_kms_encrypt_decrypt() {
454 let config = KmsConfig::default();
455 let kms = LocalKms::new(config);
456
457 let key = kms
458 .create_key(
459 "test-key".to_string(),
460 KeyPurpose::DataEncryption,
461 KeyAlgorithm::Aes256Gcm,
462 )
463 .await
464 .unwrap();
465
466 let plaintext = b"sensitive data";
467 let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
468 let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
469
470 assert_eq!(decrypted, plaintext);
471 }
472
473 #[tokio::test]
474 async fn test_key_rotation() {
475 let config = KmsConfig::default();
476 let kms = LocalKms::new(config);
477
478 let key = kms
479 .create_key(
480 "test-key".to_string(),
481 KeyPurpose::DataEncryption,
482 KeyAlgorithm::Aes256Gcm,
483 )
484 .await
485 .unwrap();
486
487 let rotated = kms.rotate_key(&key.key_id).await.unwrap();
488 assert_eq!(rotated.version, 2);
489 assert!(rotated.last_rotated.is_some());
490 }
491
492 #[tokio::test]
493 async fn test_envelope_encryption() {
494 let config = KmsConfig::default();
495 let manager = KmsManager::new(config).unwrap();
496
497 let master_key = manager
499 .client()
500 .create_key(
501 "master-key".to_string(),
502 KeyPurpose::DataEncryption,
503 KeyAlgorithm::Aes256Gcm,
504 )
505 .await
506 .unwrap();
507
508 let plaintext = b"sensitive data for envelope encryption";
510 let encrypted = manager
511 .envelope_encrypt(&master_key.key_id, plaintext)
512 .await
513 .unwrap();
514
515 let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
517
518 assert_eq!(decrypted, plaintext);
519 }
520
521 #[tokio::test]
522 async fn test_disable_enable_key() {
523 let config = KmsConfig::default();
524 let kms = LocalKms::new(config);
525
526 let key = kms
527 .create_key(
528 "test-key".to_string(),
529 KeyPurpose::DataEncryption,
530 KeyAlgorithm::Aes256Gcm,
531 )
532 .await
533 .unwrap();
534
535 kms.disable_key(&key.key_id).await.unwrap();
537 let metadata = kms.get_key(&key.key_id).await.unwrap();
538 assert_eq!(metadata.status, KeyStatus::Deprecated);
539
540 let result = kms.encrypt(&key.key_id, b"test").await;
542 assert!(result.is_err());
543
544 kms.enable_key(&key.key_id).await.unwrap();
546 let metadata = kms.get_key(&key.key_id).await.unwrap();
547 assert_eq!(metadata.status, KeyStatus::Active);
548
549 let result = kms.encrypt(&key.key_id, b"test").await;
551 assert!(result.is_ok());
552 }
553}