1use crate::error::{AllSourceError, Result};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub enum KmsProvider {
18 AwsKms,
19 GoogleCloudKms,
20 AzureKeyVault,
21 HashicorpVault,
22 Pkcs11,
23 Local, }
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct KmsConfig {
29 pub provider: KmsProvider,
31
32 pub config: HashMap<String, String>,
34
35 pub auto_rotate: bool,
37
38 pub rotation_period_days: u32,
40}
41
42impl Default for KmsConfig {
43 fn default() -> Self {
44 Self {
45 provider: KmsProvider::Local,
46 config: HashMap::new(),
47 auto_rotate: true,
48 rotation_period_days: 90,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct KeyMetadata {
56 pub key_id: String,
58
59 pub alias: String,
61
62 pub purpose: KeyPurpose,
64
65 pub algorithm: KeyAlgorithm,
67
68 pub created_at: chrono::DateTime<chrono::Utc>,
70
71 pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
73
74 pub status: KeyStatus,
76
77 pub version: u32,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub enum KeyPurpose {
83 DataEncryption,
84 JwtSigning,
85 ApiKeySigning,
86 DatabaseEncryption,
87 Custom(String),
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
91pub enum KeyAlgorithm {
92 Aes256Gcm,
93 Aes128Gcm,
94 ChaCha20Poly1305,
95 RsaOaep,
96 EcdsaP256,
97 Ed25519,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
101pub enum KeyStatus {
102 Active,
103 Rotating,
104 Deprecated,
105 Destroyed,
106}
107
108#[async_trait::async_trait]
110pub trait KmsClient: Send + Sync {
111 async fn create_key(
113 &self,
114 alias: String,
115 purpose: KeyPurpose,
116 algorithm: KeyAlgorithm,
117 ) -> Result<KeyMetadata>;
118
119 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
121
122 async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
124
125 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
127
128 async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
130
131 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
133
134 async fn disable_key(&self, key_id: &str) -> Result<()>;
136
137 async fn enable_key(&self, key_id: &str) -> Result<()>;
139
140 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
142}
143
144pub struct LocalKms {
146 keys: Arc<DashMap<String, StoredKey>>,
147 config: KmsConfig,
148}
149
150struct StoredKey {
151 metadata: KeyMetadata,
152 key_material: Vec<u8>,
153}
154
155impl LocalKms {
156 pub fn new(config: KmsConfig) -> Self {
157 Self {
158 keys: Arc::new(DashMap::new()),
159 config,
160 }
161 }
162}
163
164#[async_trait::async_trait]
165impl KmsClient for LocalKms {
166 async fn create_key(
167 &self,
168 alias: String,
169 purpose: KeyPurpose,
170 algorithm: KeyAlgorithm,
171 ) -> Result<KeyMetadata> {
172 let key_id = uuid::Uuid::new_v4().to_string();
173
174 let key_material = match algorithm {
176 KeyAlgorithm::Aes256Gcm => {
177 let mut key = vec![0u8; 32];
178 use aes_gcm::aead::rand_core::RngCore;
179 use aes_gcm::aead::OsRng;
180 RngCore::fill_bytes(&mut OsRng, &mut key);
181 key
182 }
183 KeyAlgorithm::Aes128Gcm => {
184 let mut key = vec![0u8; 16];
185 use aes_gcm::aead::rand_core::RngCore;
186 use aes_gcm::aead::OsRng;
187 RngCore::fill_bytes(&mut OsRng, &mut key);
188 key
189 }
190 _ => {
191 return Err(AllSourceError::ValidationError(format!(
192 "Algorithm {:?} not supported in local KMS",
193 algorithm
194 )));
195 }
196 };
197
198 let metadata = KeyMetadata {
199 key_id: key_id.clone(),
200 alias,
201 purpose,
202 algorithm,
203 created_at: chrono::Utc::now(),
204 last_rotated: None,
205 status: KeyStatus::Active,
206 version: 1,
207 };
208
209 let stored_key = StoredKey {
210 metadata: metadata.clone(),
211 key_material,
212 };
213
214 self.keys.insert(key_id, stored_key);
215
216 Ok(metadata)
217 }
218
219 async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
220 self.keys
221 .get(key_id)
222 .map(|entry| entry.value().metadata.clone())
223 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))
224 }
225
226 async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
227 Ok(self
228 .keys
229 .iter()
230 .map(|entry| entry.value().metadata.clone())
231 .collect())
232 }
233
234 async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
235 use aes_gcm::aead::Aead;
236 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
237
238 let stored_key = self.keys
239 .get(key_id)
240 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
241
242 if stored_key.metadata.status != KeyStatus::Active {
243 return Err(AllSourceError::ValidationError(
244 "Key is not active".to_string(),
245 ));
246 }
247
248 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
249 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
250
251 use aes_gcm::aead::rand_core::RngCore;
253 use aes_gcm::aead::OsRng;
254 let nonce_bytes = OsRng.next_u64().to_le_bytes();
255 let mut nonce_array = [0u8; 12];
256 nonce_array[..8].copy_from_slice(&nonce_bytes);
257 let nonce = Nonce::from_slice(&nonce_array);
258
259 let ciphertext = cipher
260 .encrypt(nonce, plaintext)
261 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
262
263 let mut result = nonce.to_vec();
265 result.extend_from_slice(&ciphertext);
266
267 Ok(result)
268 }
269
270 async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
271 use aes_gcm::aead::Aead;
272 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
273
274 if ciphertext_with_nonce.len() < 12 {
275 return Err(AllSourceError::ValidationError(
276 "Invalid ciphertext".to_string(),
277 ));
278 }
279
280 let stored_key = self.keys
281 .get(key_id)
282 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
283
284 let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
285 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
286
287 let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
289 let ciphertext = &ciphertext_with_nonce[12..];
290
291 cipher
292 .decrypt(nonce, ciphertext)
293 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
294 }
295
296 async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
297 let mut stored_key = self.keys
298 .get_mut(key_id)
299 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
300
301 let new_key_material = {
303 let mut key = vec![0u8; 32];
304 use aes_gcm::aead::rand_core::RngCore;
305 use aes_gcm::aead::OsRng;
306 RngCore::fill_bytes(&mut OsRng, &mut key);
307 key
308 };
309
310 stored_key.key_material = new_key_material;
311 stored_key.metadata.version += 1;
312 stored_key.metadata.last_rotated = Some(chrono::Utc::now());
313
314 Ok(stored_key.metadata.clone())
315 }
316
317 async fn disable_key(&self, key_id: &str) -> Result<()> {
318 let mut stored_key = self.keys
319 .get_mut(key_id)
320 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
321
322 stored_key.metadata.status = KeyStatus::Deprecated;
323 Ok(())
324 }
325
326 async fn enable_key(&self, key_id: &str) -> Result<()> {
327 let mut stored_key = self.keys
328 .get_mut(key_id)
329 .ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
330
331 stored_key.metadata.status = KeyStatus::Active;
332 Ok(())
333 }
334
335 async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
336 let mut dek = vec![0u8; 32];
338 use aes_gcm::aead::rand_core::RngCore;
339 use aes_gcm::aead::OsRng;
340 RngCore::fill_bytes(&mut OsRng, &mut dek);
341
342 let encrypted_dek = self.encrypt(key_id, &dek).await?;
344
345 Ok((dek, encrypted_dek))
346 }
347}
348
349pub struct KmsManager {
351 client: Arc<dyn KmsClient>,
352 config: KmsConfig,
353}
354
355impl KmsManager {
356 pub fn new(config: KmsConfig) -> Result<Self> {
358 let client: Arc<dyn KmsClient> = match config.provider {
359 KmsProvider::Local => Arc::new(LocalKms::new(config.clone())),
360 _ => {
361 return Err(AllSourceError::ValidationError(format!(
362 "KMS provider {:?} not yet implemented",
363 config.provider
364 )));
365 }
366 };
367
368 Ok(Self { client, config })
369 }
370
371 pub fn client(&self) -> &Arc<dyn KmsClient> {
373 &self.client
374 }
375
376 pub async fn envelope_encrypt(
378 &self,
379 master_key_id: &str,
380 plaintext: &[u8],
381 ) -> Result<EnvelopeEncryptedData> {
382 let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
384
385 use aes_gcm::aead::Aead;
387 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
388
389 let cipher = Aes256Gcm::new_from_slice(&dek)
390 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
391
392 use aes_gcm::aead::rand_core::RngCore;
393 use aes_gcm::aead::OsRng;
394 let nonce_bytes = OsRng.next_u64().to_le_bytes();
395 let mut nonce_array = [0u8; 12];
396 nonce_array[..8].copy_from_slice(&nonce_bytes);
397 let nonce = Nonce::from_slice(&nonce_array);
398
399 let ciphertext = cipher
400 .encrypt(nonce, plaintext)
401 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
402
403 Ok(EnvelopeEncryptedData {
404 ciphertext,
405 nonce: nonce.to_vec(),
406 encrypted_dek,
407 master_key_id: master_key_id.to_string(),
408 })
409 }
410
411 pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
413 let dek = self
415 .client
416 .decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek)
417 .await?;
418
419 use aes_gcm::aead::Aead;
421 use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
422
423 let cipher = Aes256Gcm::new_from_slice(&dek)
424 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
425
426 let nonce = Nonce::from_slice(&encrypted.nonce);
427
428 cipher
429 .decrypt(nonce, encrypted.ciphertext.as_ref())
430 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
431 }
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct EnvelopeEncryptedData {
437 pub ciphertext: Vec<u8>,
439
440 pub nonce: Vec<u8>,
442
443 pub encrypted_dek: Vec<u8>,
445
446 pub master_key_id: String,
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[tokio::test]
455 async fn test_local_kms_create_key() {
456 let config = KmsConfig::default();
457 let kms = LocalKms::new(config);
458
459 let metadata = kms
460 .create_key(
461 "test-key".to_string(),
462 KeyPurpose::DataEncryption,
463 KeyAlgorithm::Aes256Gcm,
464 )
465 .await
466 .unwrap();
467
468 assert_eq!(metadata.alias, "test-key");
469 assert_eq!(metadata.status, KeyStatus::Active);
470 assert_eq!(metadata.version, 1);
471 }
472
473 #[tokio::test]
474 async fn test_local_kms_encrypt_decrypt() {
475 let config = KmsConfig::default();
476 let kms = LocalKms::new(config);
477
478 let key = kms
479 .create_key(
480 "test-key".to_string(),
481 KeyPurpose::DataEncryption,
482 KeyAlgorithm::Aes256Gcm,
483 )
484 .await
485 .unwrap();
486
487 let plaintext = b"sensitive data";
488 let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
489 let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
490
491 assert_eq!(decrypted, plaintext);
492 }
493
494 #[tokio::test]
495 async fn test_key_rotation() {
496 let config = KmsConfig::default();
497 let kms = LocalKms::new(config);
498
499 let key = kms
500 .create_key(
501 "test-key".to_string(),
502 KeyPurpose::DataEncryption,
503 KeyAlgorithm::Aes256Gcm,
504 )
505 .await
506 .unwrap();
507
508 let rotated = kms.rotate_key(&key.key_id).await.unwrap();
509 assert_eq!(rotated.version, 2);
510 assert!(rotated.last_rotated.is_some());
511 }
512
513 #[tokio::test]
514 async fn test_envelope_encryption() {
515 let config = KmsConfig::default();
516 let manager = KmsManager::new(config).unwrap();
517
518 let master_key = manager
520 .client()
521 .create_key(
522 "master-key".to_string(),
523 KeyPurpose::DataEncryption,
524 KeyAlgorithm::Aes256Gcm,
525 )
526 .await
527 .unwrap();
528
529 let plaintext = b"sensitive data for envelope encryption";
531 let encrypted = manager
532 .envelope_encrypt(&master_key.key_id, plaintext)
533 .await
534 .unwrap();
535
536 let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
538
539 assert_eq!(decrypted, plaintext);
540 }
541
542 #[tokio::test]
543 async fn test_disable_enable_key() {
544 let config = KmsConfig::default();
545 let kms = LocalKms::new(config);
546
547 let key = kms
548 .create_key(
549 "test-key".to_string(),
550 KeyPurpose::DataEncryption,
551 KeyAlgorithm::Aes256Gcm,
552 )
553 .await
554 .unwrap();
555
556 kms.disable_key(&key.key_id).await.unwrap();
558 let metadata = kms.get_key(&key.key_id).await.unwrap();
559 assert_eq!(metadata.status, KeyStatus::Deprecated);
560
561 let result = kms.encrypt(&key.key_id, b"test").await;
563 assert!(result.is_err());
564
565 kms.enable_key(&key.key_id).await.unwrap();
567 let metadata = kms.get_key(&key.key_id).await.unwrap();
568 assert_eq!(metadata.status, KeyStatus::Active);
569
570 let result = kms.encrypt(&key.key_id, b"test").await;
572 assert!(result.is_ok());
573 }
574}