1use std::collections::HashMap;
36use std::path::PathBuf;
37
38use aes_gcm::aead::{Aead, KeyInit, Payload};
39use aes_gcm::{Aes256Gcm, Key, Nonce};
40use async_trait::async_trait;
41use rand::RngCore;
42
43const KEK_LEN: usize = 32;
44const DEK_LEN: usize = 32;
45const WRAP_NONCE_LEN: usize = 12;
46const WRAP_TAG_LEN: usize = 16;
47const LOCAL_WRAP_MIN_LEN: usize = WRAP_NONCE_LEN + WRAP_TAG_LEN;
53
54#[derive(Debug, thiserror::Error)]
55pub enum KmsError {
56 #[error("KMS key id {key_id:?} not found in backend")]
57 KeyNotFound { key_id: String },
58 #[error("KMS KEK file {path:?}: {source}")]
59 KekFileIo {
60 path: PathBuf,
61 source: std::io::Error,
62 },
63 #[error("KMS KEK file {path:?} must be exactly {expected} raw bytes; got {got}")]
64 KekBadLength {
65 path: PathBuf,
66 expected: usize,
67 got: usize,
68 },
69 #[error("KMS KEK directory {path:?}: {source}")]
70 KekDirIo {
71 path: PathBuf,
72 source: std::io::Error,
73 },
74 #[error("KMS wrapped DEK too short ({got} bytes; need at least {min})")]
79 WrappedDekTooShort { got: usize, min: usize },
80 #[error("KMS unwrap failed (wrapped DEK auth tag mismatch for key_id {key_id:?})")]
84 UnwrapFailed { key_id: String },
85 #[error("KMS backend unavailable: {message}")]
88 BackendUnavailable { message: String },
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct WrappedDek {
99 pub key_id: String,
103 pub ciphertext: Vec<u8>,
107}
108
109#[async_trait]
110pub trait KmsBackend: Send + Sync + std::fmt::Debug {
111 async fn generate_dek(&self, key_id: &str) -> Result<(Vec<u8>, WrappedDek), KmsError>;
121
122 async fn decrypt_dek(&self, wrapped: &WrappedDek) -> Result<Vec<u8>, KmsError>;
127}
128
129pub struct LocalKms {
154 dir: PathBuf,
155 keks: HashMap<String, [u8; KEK_LEN]>,
156}
157
158impl std::fmt::Debug for LocalKms {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("LocalKms")
161 .field("dir", &self.dir)
162 .field("key_count", &self.keks.len())
163 .field("key_ids", &self.keks.keys().collect::<Vec<_>>())
164 .finish()
165 }
166}
167
168impl LocalKms {
169 pub fn open(dir: PathBuf) -> Result<Self, KmsError> {
177 let read_dir = std::fs::read_dir(&dir).map_err(|source| KmsError::KekDirIo {
178 path: dir.clone(),
179 source,
180 })?;
181 let mut keks = HashMap::new();
182 for entry in read_dir {
183 let entry = entry.map_err(|source| KmsError::KekDirIo {
184 path: dir.clone(),
185 source,
186 })?;
187 let path = entry.path();
188 if path.extension().and_then(|s| s.to_str()) != Some("kek") {
189 continue;
190 }
191 let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
192 continue;
193 };
194 let key_id = stem.to_string();
195 let bytes = std::fs::read(&path).map_err(|source| KmsError::KekFileIo {
196 path: path.clone(),
197 source,
198 })?;
199 if bytes.len() != KEK_LEN {
200 return Err(KmsError::KekBadLength {
201 path: path.clone(),
202 expected: KEK_LEN,
203 got: bytes.len(),
204 });
205 }
206 let mut k = [0u8; KEK_LEN];
207 k.copy_from_slice(&bytes);
208 keks.insert(key_id, k);
209 }
210 Ok(Self { dir, keks })
211 }
212
213 pub fn from_keks(dir: PathBuf, keks: HashMap<String, [u8; KEK_LEN]>) -> Self {
218 Self { dir, keks }
219 }
220
221 pub fn key_ids(&self) -> Vec<String> {
225 let mut ids: Vec<String> = self.keks.keys().cloned().collect();
226 ids.sort();
227 ids
228 }
229
230 fn kek(&self, key_id: &str) -> Result<&[u8; KEK_LEN], KmsError> {
231 self.keks.get(key_id).ok_or_else(|| KmsError::KeyNotFound {
232 key_id: key_id.to_string(),
233 })
234 }
235}
236
237#[async_trait]
238impl KmsBackend for LocalKms {
239 async fn generate_dek(&self, key_id: &str) -> Result<(Vec<u8>, WrappedDek), KmsError> {
240 let kek = self.kek(key_id)?;
241 let mut dek = vec![0u8; DEK_LEN];
242 rand::rngs::OsRng.fill_bytes(&mut dek);
243
244 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(kek));
245 let mut nonce_bytes = [0u8; WRAP_NONCE_LEN];
246 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
247 let nonce = Nonce::from_slice(&nonce_bytes);
248 let aad = key_id.as_bytes();
249 let ct_with_tag = cipher
250 .encrypt(
251 nonce,
252 Payload {
253 msg: &dek,
254 aad,
255 },
256 )
257 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
258
259 let mut wrapped = Vec::with_capacity(WRAP_NONCE_LEN + ct_with_tag.len());
262 wrapped.extend_from_slice(&nonce_bytes);
263 wrapped.extend_from_slice(&ct_with_tag);
264
265 Ok((
266 dek,
267 WrappedDek {
268 key_id: key_id.to_string(),
269 ciphertext: wrapped,
270 },
271 ))
272 }
273
274 async fn decrypt_dek(&self, wrapped: &WrappedDek) -> Result<Vec<u8>, KmsError> {
275 let kek = self.kek(&wrapped.key_id)?;
276 if wrapped.ciphertext.len() < LOCAL_WRAP_MIN_LEN {
277 return Err(KmsError::WrappedDekTooShort {
278 got: wrapped.ciphertext.len(),
279 min: LOCAL_WRAP_MIN_LEN,
280 });
281 }
282 let (nonce_bytes, ct_with_tag) = wrapped.ciphertext.split_at(WRAP_NONCE_LEN);
283 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(kek));
284 let nonce = Nonce::from_slice(nonce_bytes);
285 let aad = wrapped.key_id.as_bytes();
286 let dek = cipher
287 .decrypt(
288 nonce,
289 Payload {
290 msg: ct_with_tag,
291 aad,
292 },
293 )
294 .map_err(|_| KmsError::UnwrapFailed {
295 key_id: wrapped.key_id.clone(),
296 })?;
297 Ok(dek)
298 }
299}
300
301#[cfg(feature = "aws-kms")]
306pub mod aws {
307 use super::{KmsBackend, KmsError, WrappedDek};
313 use async_trait::async_trait;
314
315 pub struct AwsKms {
322 client: aws_sdk_kms::Client,
323 }
324
325 impl std::fmt::Debug for AwsKms {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 f.debug_struct("AwsKms").finish()
328 }
329 }
330
331 impl AwsKms {
332 pub fn new(client: aws_sdk_kms::Client) -> Self {
336 Self { client }
337 }
338
339 pub async fn from_default_env() -> Self {
342 let cfg = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
343 let client = aws_sdk_kms::Client::new(&cfg);
344 Self { client }
345 }
346 }
347
348 #[async_trait]
349 impl KmsBackend for AwsKms {
350 async fn generate_dek(&self, key_id: &str) -> Result<(Vec<u8>, WrappedDek), KmsError> {
351 let resp = self
352 .client
353 .generate_data_key()
354 .key_id(key_id)
355 .key_spec(aws_sdk_kms::types::DataKeySpec::Aes256)
356 .send()
357 .await
358 .map_err(|e| KmsError::BackendUnavailable {
359 message: format!("GenerateDataKey({key_id}): {e}"),
360 })?;
361 let dek = resp
362 .plaintext
363 .ok_or_else(|| KmsError::BackendUnavailable {
364 message: format!("GenerateDataKey({key_id}): missing Plaintext in response"),
365 })?
366 .into_inner();
367 let ciphertext = resp
368 .ciphertext_blob
369 .ok_or_else(|| KmsError::BackendUnavailable {
370 message: format!("GenerateDataKey({key_id}): missing CiphertextBlob in response"),
371 })?
372 .into_inner();
373 let stored_id = resp.key_id.unwrap_or_else(|| key_id.to_string());
378 Ok((
379 dek,
380 WrappedDek {
381 key_id: stored_id,
382 ciphertext,
383 },
384 ))
385 }
386
387 async fn decrypt_dek(&self, wrapped: &WrappedDek) -> Result<Vec<u8>, KmsError> {
388 let resp = self
389 .client
390 .decrypt()
391 .ciphertext_blob(aws_sdk_kms::primitives::Blob::new(
392 wrapped.ciphertext.clone(),
393 ))
394 .key_id(&wrapped.key_id)
395 .send()
396 .await
397 .map_err(|e| KmsError::BackendUnavailable {
398 message: format!("Decrypt({}): {e}", wrapped.key_id),
399 })?;
400 let dek = resp
401 .plaintext
402 .ok_or_else(|| KmsError::BackendUnavailable {
403 message: format!("Decrypt({}): missing Plaintext in response", wrapped.key_id),
404 })?
405 .into_inner();
406 Ok(dek)
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use std::collections::HashMap;
415 use std::path::Path;
416 use tempfile::TempDir;
417
418 fn write_kek(dir: &Path, name: &str, bytes: &[u8]) {
419 std::fs::write(dir.join(format!("{name}.kek")), bytes).unwrap();
420 }
421
422 #[tokio::test]
423 async fn open_empty_dir_is_ok() {
424 let tmp = TempDir::new().unwrap();
425 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
426 assert!(kms.key_ids().is_empty());
427 let err = kms.generate_dek("missing").await.unwrap_err();
429 assert!(
430 matches!(err, KmsError::KeyNotFound { ref key_id } if key_id == "missing"),
431 "got {err:?}"
432 );
433 }
434
435 #[tokio::test]
436 async fn open_loads_kek_files_and_skips_others() {
437 let tmp = TempDir::new().unwrap();
438 write_kek(tmp.path(), "alpha", &[1u8; KEK_LEN]);
439 write_kek(tmp.path(), "beta", &[2u8; KEK_LEN]);
440 std::fs::write(tmp.path().join("README"), b"hello").unwrap();
443 std::fs::write(tmp.path().join("alpha.kek.bak"), [9u8; 99]).unwrap();
444 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
445 let ids = kms.key_ids();
446 assert_eq!(ids, vec!["alpha".to_string(), "beta".to_string()]);
447 }
448
449 #[tokio::test]
450 async fn open_rejects_truncated_kek_file() {
451 let tmp = TempDir::new().unwrap();
452 write_kek(tmp.path(), "short", &[7u8; KEK_LEN - 1]);
454 let err = LocalKms::open(tmp.path().to_path_buf()).unwrap_err();
455 assert!(
456 matches!(
457 err,
458 KmsError::KekBadLength { expected, got, .. } if expected == KEK_LEN && got == KEK_LEN - 1
459 ),
460 "got {err:?}"
461 );
462 }
463
464 #[tokio::test]
465 async fn generate_then_decrypt_roundtrip() {
466 let tmp = TempDir::new().unwrap();
467 write_kek(tmp.path(), "main", &[42u8; KEK_LEN]);
468 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
469 let (dek, wrapped) = kms.generate_dek("main").await.unwrap();
470 assert_eq!(dek.len(), DEK_LEN);
471 assert_eq!(wrapped.key_id, "main");
472 assert_eq!(wrapped.ciphertext.len(), WRAP_NONCE_LEN + DEK_LEN + WRAP_TAG_LEN);
475
476 let unwrapped = kms.decrypt_dek(&wrapped).await.unwrap();
477 assert_eq!(unwrapped, dek);
478 }
479
480 #[tokio::test]
481 async fn generate_uses_random_dek_and_nonce() {
482 let tmp = TempDir::new().unwrap();
483 write_kek(tmp.path(), "k", &[5u8; KEK_LEN]);
484 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
485 let (dek1, w1) = kms.generate_dek("k").await.unwrap();
486 let (dek2, w2) = kms.generate_dek("k").await.unwrap();
487 assert_ne!(dek1, dek2, "DEK must be random per call");
488 assert_ne!(w1.ciphertext, w2.ciphertext, "wrap nonce must be random per call");
489 }
490
491 #[tokio::test]
492 async fn decrypt_unknown_key_id_errors() {
493 let tmp = TempDir::new().unwrap();
494 write_kek(tmp.path(), "real", &[1u8; KEK_LEN]);
495 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
496 let bogus = WrappedDek {
497 key_id: "phantom".to_string(),
498 ciphertext: vec![0u8; LOCAL_WRAP_MIN_LEN + DEK_LEN],
499 };
500 let err = kms.decrypt_dek(&bogus).await.unwrap_err();
501 assert!(
502 matches!(err, KmsError::KeyNotFound { ref key_id } if key_id == "phantom"),
503 "got {err:?}"
504 );
505 }
506
507 #[tokio::test]
508 async fn decrypt_tampered_ciphertext_fails_unwrap() {
509 let tmp = TempDir::new().unwrap();
510 write_kek(tmp.path(), "k", &[3u8; KEK_LEN]);
511 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
512 let (_dek, mut wrapped) = kms.generate_dek("k").await.unwrap();
513 let mid = wrapped.ciphertext.len() / 2;
517 wrapped.ciphertext[mid] ^= 0xFF;
518 let err = kms.decrypt_dek(&wrapped).await.unwrap_err();
519 assert!(
520 matches!(err, KmsError::UnwrapFailed { ref key_id } if key_id == "k"),
521 "got {err:?}"
522 );
523 }
524
525 #[tokio::test]
526 async fn decrypt_short_ciphertext_errors() {
527 let tmp = TempDir::new().unwrap();
528 write_kek(tmp.path(), "k", &[8u8; KEK_LEN]);
529 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
530 let bogus = WrappedDek {
531 key_id: "k".to_string(),
532 ciphertext: vec![0u8; 5], };
534 let err = kms.decrypt_dek(&bogus).await.unwrap_err();
535 assert!(
536 matches!(err, KmsError::WrappedDekTooShort { got: 5, .. }),
537 "got {err:?}"
538 );
539 }
540
541 #[tokio::test]
542 async fn decrypt_wrong_key_id_aad_fails_unwrap() {
543 let tmp = TempDir::new().unwrap();
548 write_kek(tmp.path(), "alpha", &[1u8; KEK_LEN]);
549 write_kek(tmp.path(), "beta", &[2u8; KEK_LEN]);
550 let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
551 let (_dek, wrapped) = kms.generate_dek("alpha").await.unwrap();
552 let forged = WrappedDek {
553 key_id: "beta".to_string(),
554 ciphertext: wrapped.ciphertext.clone(),
555 };
556 let err = kms.decrypt_dek(&forged).await.unwrap_err();
557 assert!(
558 matches!(err, KmsError::UnwrapFailed { ref key_id } if key_id == "beta"),
559 "got {err:?}"
560 );
561 }
562
563 #[tokio::test]
564 async fn from_keks_constructor_works() {
565 let mut keks = HashMap::new();
566 keks.insert("inline".to_string(), [9u8; KEK_LEN]);
567 let kms = LocalKms::from_keks(PathBuf::from("/tmp/none"), keks);
568 let (_dek, wrapped) = kms.generate_dek("inline").await.unwrap();
569 assert_eq!(wrapped.key_id, "inline");
570 let _back = kms.decrypt_dek(&wrapped).await.unwrap();
571 }
572
573 #[cfg(feature = "aws-kms")]
582 #[tokio::test]
583 #[ignore = "requires AWS credentials and a real KMS key (set S4_KMS_TEST_KEY_ID)"]
584 async fn aws_kms_roundtrip() {
585 let key_id = std::env::var("S4_KMS_TEST_KEY_ID")
586 .expect("set S4_KMS_TEST_KEY_ID to a real AWS KMS key ARN/alias");
587 let kms = super::aws::AwsKms::from_default_env().await;
588 let (dek, wrapped) = kms.generate_dek(&key_id).await.unwrap();
589 assert_eq!(dek.len(), DEK_LEN);
590 let unwrapped = kms.decrypt_dek(&wrapped).await.unwrap();
591 assert_eq!(unwrapped, dek);
592 }
593}