1use std::collections::HashMap;
122use std::path::Path;
123use std::sync::Arc;
124
125use aes_gcm::aead::{Aead, KeyInit, Payload};
126use aes_gcm::{Aes256Gcm, Key, Nonce};
127use bytes::Bytes;
128use md5::{Digest as Md5Digest, Md5};
129use rand::RngCore;
130use thiserror::Error;
131
132use crate::kms::{KmsBackend, KmsError, WrappedDek};
133
134pub const SSE_MAGIC_V1: &[u8; 4] = b"S4E1";
135pub const SSE_MAGIC_V2: &[u8; 4] = b"S4E2";
136pub const SSE_MAGIC_V3: &[u8; 4] = b"S4E3";
137pub const SSE_MAGIC_V4: &[u8; 4] = b"S4E4";
138pub const SSE_MAGIC: &[u8; 4] = SSE_MAGIC_V1;
140
141pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const SSE_HEADER_BYTES_V3: usize = 4 + 1 + KEY_MD5_LEN + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
152const NONCE_LEN: usize = 12;
153const TAG_LEN: usize = 16;
154const KEY_LEN: usize = 32;
155const KEY_MD5_LEN: usize = 16;
156pub const SSE_C_ALGORITHM: &str = "AES256";
160
161#[derive(Debug, Error)]
162pub enum SseError {
163 #[error("SSE key file {path:?}: {source}")]
164 KeyFileIo {
165 path: std::path::PathBuf,
166 source: std::io::Error,
167 },
168 #[error(
169 "SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
170 )]
171 BadKeyLength { got: usize },
172 #[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
173 TooShort { got: usize },
174 #[error("SSE bad magic: expected S4E1/S4E2/S4E3/S4E4, got {got:?}")]
175 BadMagic { got: [u8; 4] },
176 #[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
177 UnsupportedAlgo { tag: u8 },
178 #[error(
179 "SSE key_id {id} (S4E2 frame) not present in keyring; rotation history likely incomplete"
180 )]
181 KeyNotInKeyring { id: u16 },
182 #[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
183 DecryptFailed,
184 #[error("SSE-C key MD5 fingerprint mismatch — client supplied a different key than PUT")]
192 WrongCustomerKey,
193 #[error("SSE-C customer-key headers invalid: {reason}")]
198 InvalidCustomerKey { reason: &'static str },
199 #[error("SSE-C algorithm {algo:?} unsupported (only {SSE_C_ALGORITHM:?} is allowed)")]
203 CustomerKeyAlgorithmUnsupported { algo: String },
204 #[error("S4E3 frame requires SseSource::CustomerKey; got Keyring")]
209 CustomerKeyRequired,
210 #[error("S4E1/S4E2 frame stored without SSE-C; SseSource::CustomerKey is unexpected")]
215 CustomerKeyUnexpected,
216 #[error(
223 "S4E4 (SSE-KMS) body requires async decrypt — call decrypt_with_kms() instead of decrypt()"
224 )]
225 KmsAsyncRequired,
226 #[error("S4E4 frame too short ({got} bytes; need at least {min})")]
230 KmsFrameTooShort { got: usize, min: usize },
231 #[error("S4E4 frame field length out of bounds: {what}")]
236 KmsFrameFieldOob { what: &'static str },
237 #[error("S4E4 key_id is not valid UTF-8")]
242 KmsKeyIdNotUtf8,
243 #[error(
250 "S4E4 SseSource::Kms wrapped DEK key_id {supplied:?} doesn't match frame key_id {stored:?}"
251 )]
252 KmsWrappedDekMismatch {
253 supplied: String,
254 stored: String,
255 },
256 #[error("S4E4 frame requires SseSource::Kms")]
263 KmsRequired,
264 #[error("KMS unwrap: {0}")]
267 KmsBackend(#[from] KmsError),
268}
269
270pub struct SseKey {
275 pub bytes: [u8; 32],
276}
277
278impl SseKey {
279 pub fn from_path(path: &Path) -> Result<Self, SseError> {
283 let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
284 path: path.to_path_buf(),
285 source,
286 })?;
287 Self::from_bytes(&raw)
288 }
289
290 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
291 if bytes.len() == KEY_LEN {
293 let mut k = [0u8; KEY_LEN];
294 k.copy_from_slice(bytes);
295 return Ok(Self { bytes: k });
296 }
297 let s = std::str::from_utf8(bytes).unwrap_or("").trim();
299 if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
300 let mut k = [0u8; KEY_LEN];
301 for (i, k_byte) in k.iter_mut().enumerate() {
302 *k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
303 .map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
304 }
305 return Ok(Self { bytes: k });
306 }
307 if let Ok(decoded) =
308 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
309 && decoded.len() == KEY_LEN
310 {
311 let mut k = [0u8; KEY_LEN];
312 k.copy_from_slice(&decoded);
313 return Ok(Self { bytes: k });
314 }
315 Err(SseError::BadKeyLength { got: bytes.len() })
316 }
317
318 fn as_aes_key(&self) -> &Key<Aes256Gcm> {
319 Key::<Aes256Gcm>::from_slice(&self.bytes)
320 }
321}
322
323impl std::fmt::Debug for SseKey {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("SseKey")
326 .field("len", &KEY_LEN)
327 .field("key", &"<redacted>")
328 .finish()
329 }
330}
331
332#[derive(Clone)]
337pub struct SseKeyring {
338 active: u16,
339 keys: HashMap<u16, Arc<SseKey>>,
340}
341
342impl SseKeyring {
343 pub fn new(active: u16, key: Arc<SseKey>) -> Self {
347 let mut keys = HashMap::new();
348 keys.insert(active, key);
349 Self { active, keys }
350 }
351
352 pub fn add(&mut self, id: u16, key: Arc<SseKey>) {
356 self.keys.insert(id, key);
357 }
358
359 pub fn active(&self) -> (u16, &SseKey) {
362 let id = self.active;
363 let key = self
364 .keys
365 .get(&id)
366 .expect("active key id must be present in keyring (constructor invariant)");
367 (id, key.as_ref())
368 }
369
370 pub fn get(&self, id: u16) -> Option<&SseKey> {
373 self.keys.get(&id).map(Arc::as_ref)
374 }
375}
376
377impl std::fmt::Debug for SseKeyring {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.debug_struct("SseKeyring")
380 .field("active", &self.active)
381 .field("key_count", &self.keys.len())
382 .field("key_ids", &self.keys.keys().collect::<Vec<_>>())
383 .finish()
384 }
385}
386
387pub type SharedSseKeyring = Arc<SseKeyring>;
388
389pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
396 let cipher = Aes256Gcm::new(key.as_aes_key());
397 let mut nonce_bytes = [0u8; NONCE_LEN];
398 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
399 let nonce = Nonce::from_slice(&nonce_bytes);
400 let mut aad = [0u8; 8];
402 aad[..4].copy_from_slice(SSE_MAGIC_V1);
403 aad[4] = ALGO_AES_256_GCM;
404 let ct_with_tag = cipher
405 .encrypt(
406 nonce,
407 Payload {
408 msg: plaintext,
409 aad: &aad,
410 },
411 )
412 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
413 debug_assert!(ct_with_tag.len() >= TAG_LEN);
414 let split = ct_with_tag.len() - TAG_LEN;
415 let (ct, tag) = ct_with_tag.split_at(split);
416
417 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
418 out.extend_from_slice(SSE_MAGIC_V1);
419 out.push(ALGO_AES_256_GCM);
420 out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
422 out.extend_from_slice(tag);
423 out.extend_from_slice(ct);
424 Bytes::from(out)
425}
426
427pub fn encrypt_v2(plaintext: &[u8], keyring: &SseKeyring) -> Bytes {
432 let (key_id, key) = keyring.active();
433 let cipher = Aes256Gcm::new(key.as_aes_key());
434 let mut nonce_bytes = [0u8; NONCE_LEN];
435 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
436 let nonce = Nonce::from_slice(&nonce_bytes);
437 let aad = aad_v2(key_id);
438 let ct_with_tag = cipher
439 .encrypt(
440 nonce,
441 Payload {
442 msg: plaintext,
443 aad: &aad,
444 },
445 )
446 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
447 debug_assert!(ct_with_tag.len() >= TAG_LEN);
448 let split = ct_with_tag.len() - TAG_LEN;
449 let (ct, tag) = ct_with_tag.split_at(split);
450
451 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
452 out.extend_from_slice(SSE_MAGIC_V2);
453 out.push(ALGO_AES_256_GCM);
454 out.extend_from_slice(&key_id.to_be_bytes()); out.push(0u8); out.extend_from_slice(&nonce_bytes);
457 out.extend_from_slice(tag);
458 out.extend_from_slice(ct);
459 Bytes::from(out)
460}
461
462fn aad_v1() -> [u8; 8] {
463 let mut aad = [0u8; 8];
464 aad[..4].copy_from_slice(SSE_MAGIC_V1);
465 aad[4] = ALGO_AES_256_GCM;
466 aad
467}
468
469fn aad_v2(key_id: u16) -> [u8; 8] {
470 let mut aad = [0u8; 8];
471 aad[..4].copy_from_slice(SSE_MAGIC_V2);
472 aad[4] = ALGO_AES_256_GCM;
473 aad[5..7].copy_from_slice(&key_id.to_be_bytes());
474 aad[7] = 0u8;
475 aad
476}
477
478fn aad_v3(key_md5: &[u8; KEY_MD5_LEN]) -> [u8; 4 + 1 + KEY_MD5_LEN] {
484 let mut aad = [0u8; 4 + 1 + KEY_MD5_LEN];
485 aad[..4].copy_from_slice(SSE_MAGIC_V3);
486 aad[4] = ALGO_AES_256_GCM;
487 aad[5..5 + KEY_MD5_LEN].copy_from_slice(key_md5);
488 aad
489}
490
491#[derive(Clone)]
497pub struct CustomerKeyMaterial {
498 pub key: [u8; KEY_LEN],
499 pub key_md5: [u8; KEY_MD5_LEN],
500}
501
502impl std::fmt::Debug for CustomerKeyMaterial {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 f.debug_struct("CustomerKeyMaterial")
507 .field("key", &"<redacted>")
508 .field("key_md5_hex", &hex_lower(&self.key_md5))
509 .finish()
510 }
511}
512
513fn hex_lower(bytes: &[u8]) -> String {
514 let mut s = String::with_capacity(bytes.len() * 2);
515 for b in bytes {
516 s.push_str(&format!("{b:02x}"));
517 }
518 s
519}
520
521#[derive(Debug, Clone, Copy)]
529pub enum SseSource<'a> {
530 Keyring(&'a SseKeyring),
533 CustomerKey {
537 key: &'a [u8; KEY_LEN],
538 key_md5: &'a [u8; KEY_MD5_LEN],
539 },
540 Kms {
546 dek: &'a [u8; KEY_LEN],
548 wrapped: &'a WrappedDek,
551 },
552}
553
554impl<'a> From<&'a SseKeyring> for SseSource<'a> {
561 fn from(kr: &'a SseKeyring) -> Self {
562 SseSource::Keyring(kr)
563 }
564}
565
566impl<'a> From<&'a Arc<SseKeyring>> for SseSource<'a> {
570 fn from(kr: &'a Arc<SseKeyring>) -> Self {
571 SseSource::Keyring(kr.as_ref())
572 }
573}
574
575impl<'a> From<&'a CustomerKeyMaterial> for SseSource<'a> {
576 fn from(m: &'a CustomerKeyMaterial) -> Self {
577 SseSource::CustomerKey {
578 key: &m.key,
579 key_md5: &m.key_md5,
580 }
581 }
582}
583
584pub fn parse_customer_key_headers(
596 algorithm: &str,
597 key_base64: &str,
598 key_md5_base64: &str,
599) -> Result<CustomerKeyMaterial, SseError> {
600 use base64::Engine as _;
601 if algorithm != SSE_C_ALGORITHM {
602 return Err(SseError::CustomerKeyAlgorithmUnsupported {
603 algo: algorithm.to_string(),
604 });
605 }
606 let key_bytes = base64::engine::general_purpose::STANDARD
607 .decode(key_base64.trim().as_bytes())
608 .map_err(|_| SseError::InvalidCustomerKey {
609 reason: "base64 decode of key",
610 })?;
611 if key_bytes.len() != KEY_LEN {
612 return Err(SseError::InvalidCustomerKey {
613 reason: "key length (must be 32 bytes after base64 decode)",
614 });
615 }
616 let supplied_md5 = base64::engine::general_purpose::STANDARD
617 .decode(key_md5_base64.trim().as_bytes())
618 .map_err(|_| SseError::InvalidCustomerKey {
619 reason: "base64 decode of key MD5",
620 })?;
621 if supplied_md5.len() != KEY_MD5_LEN {
622 return Err(SseError::InvalidCustomerKey {
623 reason: "key MD5 length (must be 16 bytes after base64 decode)",
624 });
625 }
626 let actual_md5 = compute_key_md5(&key_bytes);
627 if !constant_time_eq(&actual_md5, &supplied_md5) {
630 return Err(SseError::InvalidCustomerKey {
631 reason: "supplied MD5 does not match MD5 of supplied key",
632 });
633 }
634 let mut key = [0u8; KEY_LEN];
635 key.copy_from_slice(&key_bytes);
636 let mut key_md5 = [0u8; KEY_MD5_LEN];
637 key_md5.copy_from_slice(&actual_md5);
638 Ok(CustomerKeyMaterial { key, key_md5 })
639}
640
641pub fn compute_key_md5(key: &[u8]) -> [u8; KEY_MD5_LEN] {
646 let mut h = Md5::new();
647 h.update(key);
648 let out = h.finalize();
649 let mut md5 = [0u8; KEY_MD5_LEN];
650 md5.copy_from_slice(&out);
651 md5
652}
653
654fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
657 if a.len() != b.len() {
658 return false;
659 }
660 let mut acc: u8 = 0;
661 for (x, y) in a.iter().zip(b.iter()) {
662 acc |= x ^ y;
663 }
664 acc == 0
665}
666
667pub fn encrypt_with_source(plaintext: &[u8], source: SseSource<'_>) -> Bytes {
677 match source {
678 SseSource::Keyring(kr) => encrypt_v2(plaintext, kr),
679 SseSource::CustomerKey { key, key_md5 } => encrypt_v3(plaintext, key, key_md5),
680 SseSource::Kms { dek, wrapped } => encrypt_v4(plaintext, dek, wrapped),
681 }
682}
683
684fn encrypt_v3(
685 plaintext: &[u8],
686 key: &[u8; KEY_LEN],
687 key_md5: &[u8; KEY_MD5_LEN],
688) -> Bytes {
689 let aes_key = Key::<Aes256Gcm>::from_slice(key);
690 let cipher = Aes256Gcm::new(aes_key);
691 let mut nonce_bytes = [0u8; NONCE_LEN];
692 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
693 let nonce = Nonce::from_slice(&nonce_bytes);
694 let aad = aad_v3(key_md5);
695 let ct_with_tag = cipher
696 .encrypt(
697 nonce,
698 Payload {
699 msg: plaintext,
700 aad: &aad,
701 },
702 )
703 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
704 debug_assert!(ct_with_tag.len() >= TAG_LEN);
705 let split = ct_with_tag.len() - TAG_LEN;
706 let (ct, tag) = ct_with_tag.split_at(split);
707
708 let mut out = Vec::with_capacity(SSE_HEADER_BYTES_V3 + ct.len());
709 out.extend_from_slice(SSE_MAGIC_V3);
710 out.push(ALGO_AES_256_GCM);
711 out.extend_from_slice(key_md5);
712 out.extend_from_slice(&nonce_bytes);
713 out.extend_from_slice(tag);
714 out.extend_from_slice(ct);
715 Bytes::from(out)
716}
717
718pub fn decrypt<'a, S: Into<SseSource<'a>>>(body: &[u8], source: S) -> Result<Bytes, SseError> {
737 let source = source.into();
738 if body.len() < SSE_HEADER_BYTES {
744 return Err(SseError::TooShort { got: body.len() });
745 }
746 let mut magic = [0u8; 4];
747 magic.copy_from_slice(&body[..4]);
748 match &magic {
749 m if m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 => {
750 let keyring = match source {
751 SseSource::Keyring(kr) => kr,
752 SseSource::CustomerKey { .. } => return Err(SseError::CustomerKeyUnexpected),
753 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
759 };
760 if m == SSE_MAGIC_V1 {
761 decrypt_v1_with_keyring(body, keyring)
762 } else {
763 decrypt_v2_with_keyring(body, keyring)
764 }
765 }
766 m if m == SSE_MAGIC_V3 => {
767 if body.len() < SSE_HEADER_BYTES_V3 {
769 return Err(SseError::TooShort { got: body.len() });
770 }
771 let (key, key_md5) = match source {
772 SseSource::CustomerKey { key, key_md5 } => (key, key_md5),
773 SseSource::Keyring(_) => return Err(SseError::CustomerKeyRequired),
774 SseSource::Kms { .. } => return Err(SseError::CustomerKeyRequired),
775 };
776 decrypt_v3(body, key, key_md5)
777 }
778 m if m == SSE_MAGIC_V4 => {
779 Err(SseError::KmsAsyncRequired)
784 }
785 _ => Err(SseError::BadMagic { got: magic }),
786 }
787}
788
789fn decrypt_v3(
790 body: &[u8],
791 key: &[u8; KEY_LEN],
792 supplied_md5: &[u8; KEY_MD5_LEN],
793) -> Result<Bytes, SseError> {
794 let algo = body[4];
795 if algo != ALGO_AES_256_GCM {
796 return Err(SseError::UnsupportedAlgo { tag: algo });
797 }
798 let mut stored_md5 = [0u8; KEY_MD5_LEN];
799 stored_md5.copy_from_slice(&body[5..5 + KEY_MD5_LEN]);
800 if !constant_time_eq(supplied_md5, &stored_md5) {
806 return Err(SseError::WrongCustomerKey);
807 }
808 let nonce_off = 5 + KEY_MD5_LEN;
809 let tag_off = nonce_off + NONCE_LEN;
810 let mut nonce_bytes = [0u8; NONCE_LEN];
811 nonce_bytes.copy_from_slice(&body[nonce_off..nonce_off + NONCE_LEN]);
812 let mut tag_bytes = [0u8; TAG_LEN];
813 tag_bytes.copy_from_slice(&body[tag_off..tag_off + TAG_LEN]);
814 let ct = &body[SSE_HEADER_BYTES_V3..];
815
816 let aad = aad_v3(&stored_md5);
817 let nonce = Nonce::from_slice(&nonce_bytes);
818 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
819 ct_with_tag.extend_from_slice(ct);
820 ct_with_tag.extend_from_slice(&tag_bytes);
821
822 let aes_key = Key::<Aes256Gcm>::from_slice(key);
823 let cipher = Aes256Gcm::new(aes_key);
824 let plain = cipher
825 .decrypt(
826 nonce,
827 Payload {
828 msg: &ct_with_tag,
829 aad: &aad,
830 },
831 )
832 .map_err(|_| SseError::DecryptFailed)?;
833 Ok(Bytes::from(plain))
834}
835
836fn aad_v4(key_id: &[u8], wrapped_dek: &[u8]) -> Vec<u8> {
847 let mut aad = Vec::with_capacity(4 + 1 + 1 + key_id.len() + 4 + wrapped_dek.len());
848 aad.extend_from_slice(SSE_MAGIC_V4);
849 aad.push(ALGO_AES_256_GCM);
850 aad.push(key_id.len() as u8);
851 aad.extend_from_slice(key_id);
852 aad.extend_from_slice(&(wrapped_dek.len() as u32).to_be_bytes());
853 aad.extend_from_slice(wrapped_dek);
854 aad
855}
856
857fn encrypt_v4(plaintext: &[u8], dek: &[u8; KEY_LEN], wrapped: &WrappedDek) -> Bytes {
858 assert!(
866 !wrapped.key_id.is_empty() && wrapped.key_id.len() <= u8::MAX as usize,
867 "S4E4 key_id must be 1..=255 bytes (got {})",
868 wrapped.key_id.len()
869 );
870 assert!(
871 wrapped.ciphertext.len() <= u32::MAX as usize,
872 "S4E4 wrapped_dek longer than u32::MAX",
873 );
874
875 let aes_key = Key::<Aes256Gcm>::from_slice(dek);
876 let cipher = Aes256Gcm::new(aes_key);
877 let mut nonce_bytes = [0u8; NONCE_LEN];
878 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
879 let nonce = Nonce::from_slice(&nonce_bytes);
880 let aad = aad_v4(wrapped.key_id.as_bytes(), &wrapped.ciphertext);
881 let ct_with_tag = cipher
882 .encrypt(
883 nonce,
884 Payload {
885 msg: plaintext,
886 aad: &aad,
887 },
888 )
889 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
890 debug_assert!(ct_with_tag.len() >= TAG_LEN);
891 let split = ct_with_tag.len() - TAG_LEN;
892 let (ct, tag) = ct_with_tag.split_at(split);
893
894 let key_id_bytes = wrapped.key_id.as_bytes();
895 let mut out = Vec::with_capacity(
896 4 + 1 + 1 + key_id_bytes.len() + 4 + wrapped.ciphertext.len() + NONCE_LEN + TAG_LEN + ct.len(),
897 );
898 out.extend_from_slice(SSE_MAGIC_V4);
899 out.push(ALGO_AES_256_GCM);
900 out.push(key_id_bytes.len() as u8);
901 out.extend_from_slice(key_id_bytes);
902 out.extend_from_slice(&(wrapped.ciphertext.len() as u32).to_be_bytes());
903 out.extend_from_slice(&wrapped.ciphertext);
904 out.extend_from_slice(&nonce_bytes);
905 out.extend_from_slice(tag);
906 out.extend_from_slice(ct);
907 Bytes::from(out)
908}
909
910#[derive(Debug)]
916pub struct S4E4Header<'a> {
917 pub key_id: &'a str,
918 pub wrapped_dek: &'a [u8],
919 pub nonce: &'a [u8],
920 pub tag: &'a [u8],
921 pub ciphertext: &'a [u8],
922}
923
924pub fn parse_s4e4_header(body: &[u8]) -> Result<S4E4Header<'_>, SseError> {
928 const S4E4_MIN: usize = 4 + 1 + 1 + 4 + NONCE_LEN + TAG_LEN; if body.len() < S4E4_MIN {
935 return Err(SseError::KmsFrameTooShort {
936 got: body.len(),
937 min: S4E4_MIN,
938 });
939 }
940 let magic = &body[..4];
941 if magic != SSE_MAGIC_V4 {
942 let mut got = [0u8; 4];
943 got.copy_from_slice(magic);
944 return Err(SseError::BadMagic { got });
945 }
946 let algo = body[4];
947 if algo != ALGO_AES_256_GCM {
948 return Err(SseError::UnsupportedAlgo { tag: algo });
949 }
950 let key_id_len = body[5] as usize;
951 let key_id_off: usize = 6;
952 let key_id_end = key_id_off
953 .checked_add(key_id_len)
954 .ok_or(SseError::KmsFrameFieldOob { what: "key_id_len" })?;
955 if key_id_end + 4 > body.len() {
956 return Err(SseError::KmsFrameFieldOob { what: "key_id" });
957 }
958 let key_id = std::str::from_utf8(&body[key_id_off..key_id_end])
959 .map_err(|_| SseError::KmsKeyIdNotUtf8)?;
960 let wrapped_len_off = key_id_end;
961 let wrapped_dek_len = u32::from_be_bytes([
962 body[wrapped_len_off],
963 body[wrapped_len_off + 1],
964 body[wrapped_len_off + 2],
965 body[wrapped_len_off + 3],
966 ]) as usize;
967 let wrapped_off = wrapped_len_off + 4;
968 let wrapped_end = wrapped_off
969 .checked_add(wrapped_dek_len)
970 .ok_or(SseError::KmsFrameFieldOob { what: "wrapped_dek_len" })?;
971 if wrapped_end + NONCE_LEN + TAG_LEN > body.len() {
972 return Err(SseError::KmsFrameFieldOob { what: "wrapped_dek" });
973 }
974 let wrapped_dek = &body[wrapped_off..wrapped_end];
975 let nonce_off = wrapped_end;
976 let tag_off = nonce_off + NONCE_LEN;
977 let ct_off = tag_off + TAG_LEN;
978 let nonce = &body[nonce_off..nonce_off + NONCE_LEN];
979 let tag = &body[tag_off..tag_off + TAG_LEN];
980 let ciphertext = &body[ct_off..];
981 Ok(S4E4Header {
982 key_id,
983 wrapped_dek,
984 nonce,
985 tag,
986 ciphertext,
987 })
988}
989
990pub async fn decrypt_with_kms(
1006 body: &[u8],
1007 kms: &dyn KmsBackend,
1008) -> Result<Bytes, SseError> {
1009 let hdr = parse_s4e4_header(body)?;
1010 let wrapped = WrappedDek {
1011 key_id: hdr.key_id.to_string(),
1012 ciphertext: hdr.wrapped_dek.to_vec(),
1013 };
1014 let dek_vec = kms.decrypt_dek(&wrapped).await?;
1015 if dek_vec.len() != KEY_LEN {
1016 return Err(SseError::KmsBackend(KmsError::BackendUnavailable {
1021 message: format!(
1022 "KMS returned {} byte DEK; expected {KEY_LEN}",
1023 dek_vec.len()
1024 ),
1025 }));
1026 }
1027 let mut dek = [0u8; KEY_LEN];
1028 dek.copy_from_slice(&dek_vec);
1029
1030 let aad = aad_v4(hdr.key_id.as_bytes(), hdr.wrapped_dek);
1031 let aes_key = Key::<Aes256Gcm>::from_slice(&dek);
1032 let cipher = Aes256Gcm::new(aes_key);
1033 let nonce = Nonce::from_slice(hdr.nonce);
1034 let mut ct_with_tag = Vec::with_capacity(hdr.ciphertext.len() + TAG_LEN);
1035 ct_with_tag.extend_from_slice(hdr.ciphertext);
1036 ct_with_tag.extend_from_slice(hdr.tag);
1037 let plain = cipher
1038 .decrypt(
1039 nonce,
1040 Payload {
1041 msg: &ct_with_tag,
1042 aad: &aad,
1043 },
1044 )
1045 .map_err(|_| SseError::DecryptFailed)?;
1046 Ok(Bytes::from(plain))
1047}
1048
1049fn decrypt_v1_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1050 let algo = body[4];
1051 if algo != ALGO_AES_256_GCM {
1052 return Err(SseError::UnsupportedAlgo { tag: algo });
1053 }
1054 let mut nonce_bytes = [0u8; NONCE_LEN];
1057 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1058 let mut tag_bytes = [0u8; TAG_LEN];
1059 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1060 let ct = &body[SSE_HEADER_BYTES..];
1061
1062 let aad = aad_v1();
1063 let nonce = Nonce::from_slice(&nonce_bytes);
1064 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1065 ct_with_tag.extend_from_slice(ct);
1066 ct_with_tag.extend_from_slice(&tag_bytes);
1067
1068 let (active_id, _active_key) = keyring.active();
1072 let mut ids: Vec<u16> = keyring.keys.keys().copied().collect();
1073 ids.sort_by_key(|id| if *id == active_id { 0 } else { 1 });
1074 for id in ids {
1075 let key = keyring.get(id).expect("id came from keyring iteration");
1076 let cipher = Aes256Gcm::new(key.as_aes_key());
1077 if let Ok(plain) = cipher.decrypt(
1078 nonce,
1079 Payload {
1080 msg: &ct_with_tag,
1081 aad: &aad,
1082 },
1083 ) {
1084 return Ok(Bytes::from(plain));
1085 }
1086 }
1087 Err(SseError::DecryptFailed)
1088}
1089
1090fn decrypt_v2_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1091 let algo = body[4];
1092 if algo != ALGO_AES_256_GCM {
1093 return Err(SseError::UnsupportedAlgo { tag: algo });
1094 }
1095 let key_id = u16::from_be_bytes([body[5], body[6]]);
1096 let key = keyring
1098 .get(key_id)
1099 .ok_or(SseError::KeyNotInKeyring { id: key_id })?;
1100 let mut nonce_bytes = [0u8; NONCE_LEN];
1101 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1102 let mut tag_bytes = [0u8; TAG_LEN];
1103 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1104 let ct = &body[SSE_HEADER_BYTES..];
1105
1106 let aad = aad_v2(key_id);
1107 let nonce = Nonce::from_slice(&nonce_bytes);
1108 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1109 ct_with_tag.extend_from_slice(ct);
1110 ct_with_tag.extend_from_slice(&tag_bytes);
1111 let cipher = Aes256Gcm::new(key.as_aes_key());
1112 let plain = cipher
1113 .decrypt(
1114 nonce,
1115 Payload {
1116 msg: &ct_with_tag,
1117 aad: &aad,
1118 },
1119 )
1120 .map_err(|_| SseError::DecryptFailed)?;
1121 Ok(Bytes::from(plain))
1122}
1123
1124pub fn looks_encrypted(body: &[u8]) -> bool {
1135 if body.len() < SSE_HEADER_BYTES {
1136 return false;
1137 }
1138 let m = &body[..4];
1139 m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 || m == SSE_MAGIC_V3 || m == SSE_MAGIC_V4
1140}
1141
1142pub fn peek_magic(body: &[u8]) -> Option<&'static str> {
1153 if body.len() < SSE_HEADER_BYTES {
1154 return None;
1155 }
1156 match &body[..4] {
1157 m if m == SSE_MAGIC_V1 => Some("S4E1"),
1158 m if m == SSE_MAGIC_V2 => Some("S4E2"),
1159 m if m == SSE_MAGIC_V3 => Some("S4E3"),
1160 m if m == SSE_MAGIC_V4 => Some("S4E4"),
1161 _ => None,
1162 }
1163}
1164
1165pub type SharedSseKey = Arc<SseKey>;
1166
1167#[cfg(test)]
1168mod tests {
1169 use super::*;
1170
1171 fn key32(seed: u8) -> Arc<SseKey> {
1172 Arc::new(SseKey::from_bytes(&[seed; 32]).unwrap())
1173 }
1174
1175 fn keyring_single(seed: u8) -> SseKeyring {
1176 SseKeyring::new(1, key32(seed))
1177 }
1178
1179 #[test]
1180 fn roundtrip_basic_v1() {
1181 let k = SseKey::from_bytes(&[7u8; 32]).unwrap();
1183 let pt = b"the quick brown fox jumps over the lazy dog";
1184 let ct = encrypt(&k, pt);
1185 assert!(looks_encrypted(&ct));
1186 assert_eq!(&ct[..4], SSE_MAGIC_V1);
1187 assert_eq!(ct[4], ALGO_AES_256_GCM);
1188 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
1189 let kr = SseKeyring::new(1, Arc::new(k));
1191 let pt2 = decrypt(&ct, &kr).unwrap();
1192 assert_eq!(pt2.as_ref(), pt);
1193 }
1194
1195 #[test]
1196 fn s4e2_roundtrip_active_key() {
1197 let kr = keyring_single(7);
1198 let pt = b"S4E2 active-key roundtrip";
1199 let ct = encrypt_v2(pt, &kr);
1200 assert_eq!(&ct[..4], SSE_MAGIC_V2);
1201 assert_eq!(ct[4], ALGO_AES_256_GCM);
1202 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE");
1203 assert_eq!(ct[7], 0, "reserved byte");
1204 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
1205 assert!(looks_encrypted(&ct));
1206 let pt2 = decrypt(&ct, &kr).unwrap();
1207 assert_eq!(pt2.as_ref(), pt);
1208 }
1209
1210 #[test]
1211 fn decrypt_s4e1_via_active_only_keyring() {
1212 let k_arc = key32(11);
1215 let legacy_ct = encrypt(&k_arc, b"v0.4 vintage object");
1216 assert_eq!(&legacy_ct[..4], SSE_MAGIC_V1);
1217 let kr = SseKeyring::new(1, Arc::clone(&k_arc));
1218 let plain = decrypt(&legacy_ct, &kr).unwrap();
1219 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
1220 }
1221
1222 #[test]
1223 fn decrypt_s4e2_under_old_key_after_rotation() {
1224 let k1 = key32(1);
1228 let k2 = key32(2);
1229 let mut kr_old = SseKeyring::new(1, Arc::clone(&k1));
1230 let ct = encrypt_v2(b"old-rotation object", &kr_old);
1231 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
1232
1233 kr_old.add(2, Arc::clone(&k2));
1235 let mut kr_new = SseKeyring::new(2, Arc::clone(&k2));
1236 kr_new.add(1, Arc::clone(&k1));
1237
1238 let plain = decrypt(&ct, &kr_new).unwrap();
1239 assert_eq!(plain.as_ref(), b"old-rotation object");
1240
1241 let new_ct = encrypt_v2(b"new-rotation object", &kr_new);
1243 assert_eq!(u16::from_be_bytes([new_ct[5], new_ct[6]]), 2);
1244 let plain_new = decrypt(&new_ct, &kr_new).unwrap();
1245 assert_eq!(plain_new.as_ref(), b"new-rotation object");
1246 }
1247
1248 #[test]
1249 fn s4e2_unknown_key_id_errors() {
1250 let kr = keyring_single(3); let kr_other = SseKeyring::new(99, key32(3));
1252 let ct = encrypt_v2(b"x", &kr_other); let err = decrypt(&ct, &kr).unwrap_err();
1254 assert!(
1255 matches!(err, SseError::KeyNotInKeyring { id: 99 }),
1256 "got {err:?}"
1257 );
1258 }
1259
1260 #[test]
1261 fn s4e2_tampered_key_id_fails_auth() {
1262 let kr = SseKeyring::new(1, key32(4));
1263 let mut kr_with_2 = kr.clone();
1264 kr_with_2.add(2, key32(5)); let mut ct = encrypt_v2(b"do not flip my key id", &kr).to_vec();
1266 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
1270 ct[5] = 0;
1271 ct[6] = 2;
1272 let err = decrypt(&ct, &kr_with_2).unwrap_err();
1273 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
1274 }
1275
1276 #[test]
1277 fn s4e2_tampered_ciphertext_fails() {
1278 let kr = SseKeyring::new(7, key32(9));
1279 let mut ct = encrypt_v2(b"secret message v2", &kr).to_vec();
1280 let last = ct.len() - 1;
1281 ct[last] ^= 0x01;
1282 let err = decrypt(&ct, &kr).unwrap_err();
1283 assert!(matches!(err, SseError::DecryptFailed));
1284 }
1285
1286 #[test]
1287 fn s4e2_tampered_algo_byte_fails() {
1288 let kr = SseKeyring::new(1, key32(2));
1289 let mut ct = encrypt_v2(b"hi", &kr).to_vec();
1290 ct[4] = 99;
1291 let err = decrypt(&ct, &kr).unwrap_err();
1292 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
1293 }
1294
1295 #[test]
1296 fn wrong_key_fails_v1_via_keyring() {
1297 let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
1299 let ct = encrypt(&k1, b"secret");
1300 let kr_wrong = SseKeyring::new(1, Arc::new(SseKey::from_bytes(&[2u8; 32]).unwrap()));
1301 let err = decrypt(&ct, &kr_wrong).unwrap_err();
1302 assert!(matches!(err, SseError::DecryptFailed));
1303 }
1304
1305 #[test]
1306 fn rejects_short_body() {
1307 let kr = SseKeyring::new(1, key32(1));
1308 let err = decrypt(b"short", &kr).unwrap_err();
1309 assert!(matches!(err, SseError::TooShort { got: 5 }));
1310 }
1311
1312 #[test]
1313 fn looks_encrypted_passthrough_returns_false() {
1314 let f2 = b"S4F2\x01\x00\x00\x00........................................";
1316 assert!(!looks_encrypted(f2));
1317 assert!(!looks_encrypted(b""));
1318 }
1319
1320 #[test]
1321 fn looks_encrypted_detects_both_v1_and_v2() {
1322 let kr = SseKeyring::new(1, key32(8));
1323 let v1 = encrypt(&SseKey::from_bytes(&[8u8; 32]).unwrap(), b"x");
1324 let v2 = encrypt_v2(b"x", &kr);
1325 assert!(looks_encrypted(&v1));
1326 assert!(looks_encrypted(&v2));
1327 }
1328
1329 #[test]
1330 fn key_from_hex_string() {
1331 let bad =
1332 SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
1333 .unwrap_err();
1334 assert!(matches!(bad, SseError::BadKeyLength { .. }));
1335 let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
1336 let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
1337 }
1338
1339 #[test]
1340 fn encrypt_v2_uses_random_nonce() {
1341 let kr = SseKeyring::new(1, key32(3));
1342 let pt = b"deterministic input";
1343 let a = encrypt_v2(pt, &kr);
1344 let b = encrypt_v2(pt, &kr);
1345 assert_ne!(a, b, "nonce must be random per-call");
1346 }
1347
1348 #[test]
1349 fn keyring_active_and_get() {
1350 let k1 = key32(1);
1351 let k2 = key32(2);
1352 let mut kr = SseKeyring::new(1, Arc::clone(&k1));
1353 kr.add(2, Arc::clone(&k2));
1354 let (id, active) = kr.active();
1355 assert_eq!(id, 1);
1356 assert_eq!(active.bytes, [1u8; 32]);
1357 assert!(kr.get(2).is_some());
1358 assert!(kr.get(3).is_none());
1359 }
1360
1361 use base64::Engine as _;
1366
1367 fn cust_key(seed: u8) -> CustomerKeyMaterial {
1368 let key = [seed; KEY_LEN];
1369 let key_md5 = compute_key_md5(&key);
1370 CustomerKeyMaterial { key, key_md5 }
1371 }
1372
1373 #[test]
1374 fn s4e3_roundtrip_happy_path() {
1375 let m = cust_key(42);
1376 let pt = b"top-secret SSE-C payload";
1377 let ct = encrypt_with_source(
1378 pt,
1379 SseSource::CustomerKey {
1380 key: &m.key,
1381 key_md5: &m.key_md5,
1382 },
1383 );
1384 assert_eq!(&ct[..4], SSE_MAGIC_V3);
1386 assert_eq!(ct[4], ALGO_AES_256_GCM);
1387 assert_eq!(&ct[5..5 + KEY_MD5_LEN], &m.key_md5);
1388 assert_eq!(ct.len(), SSE_HEADER_BYTES_V3 + pt.len());
1389 assert!(looks_encrypted(&ct));
1390 let plain = decrypt(
1392 &ct,
1393 SseSource::CustomerKey {
1394 key: &m.key,
1395 key_md5: &m.key_md5,
1396 },
1397 )
1398 .unwrap();
1399 assert_eq!(plain.as_ref(), pt);
1400 let plain2 = decrypt(&ct, &m).unwrap();
1402 assert_eq!(plain2.as_ref(), pt);
1403 }
1404
1405 #[test]
1406 fn s4e3_wrong_key_yields_wrong_customer_key_error() {
1407 let m = cust_key(1);
1408 let other = cust_key(2);
1409 let ct = encrypt_with_source(b"payload", (&m).into());
1410 let err = decrypt(
1411 &ct,
1412 SseSource::CustomerKey {
1413 key: &other.key,
1414 key_md5: &other.key_md5,
1415 },
1416 )
1417 .unwrap_err();
1418 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
1419 }
1420
1421 #[test]
1422 fn s4e3_tampered_stored_md5_is_caught() {
1423 let m = cust_key(7);
1430 let mut ct = encrypt_with_source(b"victim payload", (&m).into()).to_vec();
1431 ct[5] ^= 0x55;
1433 let err = decrypt(
1435 &ct,
1436 SseSource::CustomerKey {
1437 key: &m.key,
1438 key_md5: &m.key_md5,
1439 },
1440 )
1441 .unwrap_err();
1442 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
1443 }
1444
1445 #[test]
1446 fn s4e3_tampered_md5_with_matching_supplied_md5_fails_aead() {
1447 let m = cust_key(3);
1451 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
1452 ct[5] ^= 0xFF;
1453 let mut bogus_md5 = m.key_md5;
1454 bogus_md5[0] ^= 0xFF;
1455 let err = decrypt(
1456 &ct,
1457 SseSource::CustomerKey {
1458 key: &m.key,
1459 key_md5: &bogus_md5,
1460 },
1461 )
1462 .unwrap_err();
1463 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
1464 }
1465
1466 #[test]
1467 fn s4e3_tampered_ciphertext_fails_aead() {
1468 let m = cust_key(8);
1469 let mut ct = encrypt_with_source(b"sealed message", (&m).into()).to_vec();
1470 let last = ct.len() - 1;
1471 ct[last] ^= 0x01;
1472 let err = decrypt(&ct, &m).unwrap_err();
1473 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
1474 }
1475
1476 #[test]
1477 fn s4e3_tampered_algo_byte_rejected() {
1478 let m = cust_key(9);
1479 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
1480 ct[4] = 99;
1481 let err = decrypt(&ct, &m).unwrap_err();
1482 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
1483 }
1484
1485 #[test]
1486 fn s4e3_uses_random_nonce() {
1487 let m = cust_key(10);
1488 let a = encrypt_with_source(b"deterministic input", (&m).into());
1489 let b = encrypt_with_source(b"deterministic input", (&m).into());
1490 assert_ne!(a, b, "nonce must be random per-call");
1491 }
1492
1493 #[test]
1494 fn parse_customer_key_headers_happy_path() {
1495 let key = [11u8; KEY_LEN];
1496 let md5 = compute_key_md5(&key);
1497 let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
1498 let md5_b64 = base64::engine::general_purpose::STANDARD.encode(md5);
1499 let m = parse_customer_key_headers("AES256", &key_b64, &md5_b64).unwrap();
1500 assert_eq!(m.key, key);
1501 assert_eq!(m.key_md5, md5);
1502 }
1503
1504 #[test]
1505 fn parse_customer_key_headers_rejects_wrong_algorithm() {
1506 let key = [1u8; KEY_LEN];
1507 let md5 = compute_key_md5(&key);
1508 let kb = base64::engine::general_purpose::STANDARD.encode(key);
1509 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
1510 let err = parse_customer_key_headers("AES128", &kb, &mb).unwrap_err();
1511 assert!(
1512 matches!(err, SseError::CustomerKeyAlgorithmUnsupported { ref algo } if algo == "AES128"),
1513 "got {err:?}"
1514 );
1515 let err2 = parse_customer_key_headers("aes256", &kb, &mb).unwrap_err();
1517 assert!(
1518 matches!(err2, SseError::CustomerKeyAlgorithmUnsupported { .. }),
1519 "got {err2:?}"
1520 );
1521 }
1522
1523 #[test]
1524 fn parse_customer_key_headers_rejects_wrong_key_length() {
1525 let short_key = vec![5u8; 16]; let md5 = compute_key_md5(&short_key);
1527 let kb = base64::engine::general_purpose::STANDARD.encode(&short_key);
1528 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
1529 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
1530 assert!(
1531 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("key length")),
1532 "got {err:?}"
1533 );
1534 }
1535
1536 #[test]
1537 fn parse_customer_key_headers_rejects_wrong_md5_length() {
1538 let key = [3u8; KEY_LEN];
1539 let kb = base64::engine::general_purpose::STANDARD.encode(key);
1540 let bad_md5 = vec![0u8; 15];
1542 let mb = base64::engine::general_purpose::STANDARD.encode(bad_md5);
1543 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
1544 assert!(
1545 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 length")),
1546 "got {err:?}"
1547 );
1548 }
1549
1550 #[test]
1551 fn parse_customer_key_headers_rejects_md5_mismatch() {
1552 let key = [4u8; KEY_LEN];
1553 let other = [5u8; KEY_LEN];
1554 let kb = base64::engine::general_purpose::STANDARD.encode(key);
1555 let wrong_md5 = compute_key_md5(&other);
1556 let mb = base64::engine::general_purpose::STANDARD.encode(wrong_md5);
1557 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
1558 assert!(
1559 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 does not match")),
1560 "got {err:?}"
1561 );
1562 }
1563
1564 #[test]
1565 fn parse_customer_key_headers_rejects_bad_base64() {
1566 let valid_key = [0u8; KEY_LEN];
1567 let md5 = compute_key_md5(&valid_key);
1568 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
1569 let err = parse_customer_key_headers("AES256", "!!!not-base64!!!", &mb).unwrap_err();
1570 assert!(
1571 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
1572 "got {err:?}"
1573 );
1574 let kb = base64::engine::general_purpose::STANDARD.encode(valid_key);
1576 let err2 = parse_customer_key_headers("AES256", &kb, "??not-base64??").unwrap_err();
1577 assert!(
1578 matches!(err2, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
1579 "got {err2:?}"
1580 );
1581 }
1582
1583 #[test]
1584 fn parse_customer_key_headers_trims_whitespace() {
1585 let key = [12u8; KEY_LEN];
1587 let md5 = compute_key_md5(&key);
1588 let kb = format!(
1589 " {}\n",
1590 base64::engine::general_purpose::STANDARD.encode(key)
1591 );
1592 let mb = format!(
1593 "\t{} ",
1594 base64::engine::general_purpose::STANDARD.encode(md5)
1595 );
1596 let m = parse_customer_key_headers("AES256", &kb, &mb).unwrap();
1597 assert_eq!(m.key, key);
1598 }
1599
1600 #[test]
1605 fn back_compat_decrypt_s4e1_with_keyring_source() {
1606 let k = key32(33);
1607 let legacy_ct = encrypt(&k, b"v0.4 vintage object");
1608 let kr = SseKeyring::new(1, Arc::clone(&k));
1609 let plain = decrypt(&legacy_ct, &kr).unwrap();
1612 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
1613 let plain2 = decrypt(&legacy_ct, SseSource::Keyring(&kr)).unwrap();
1614 assert_eq!(plain2.as_ref(), b"v0.4 vintage object");
1615 }
1616
1617 #[test]
1618 fn back_compat_decrypt_s4e2_with_keyring_source() {
1619 let kr = keyring_single(34);
1620 let ct = encrypt_v2(b"v0.5 #29 object", &kr);
1621 let plain = decrypt(&ct, &kr).unwrap();
1622 assert_eq!(plain.as_ref(), b"v0.5 #29 object");
1623 let ct2 = encrypt_with_source(b"v0.5 #29 object", SseSource::Keyring(&kr));
1626 assert_eq!(&ct2[..4], SSE_MAGIC_V2);
1627 let plain2 = decrypt(&ct2, &kr).unwrap();
1628 assert_eq!(plain2.as_ref(), b"v0.5 #29 object");
1629 }
1630
1631 #[test]
1632 fn s4e2_blob_with_customer_key_source_is_rejected() {
1633 let kr = keyring_single(50);
1637 let ct = encrypt_v2(b"server-managed object", &kr);
1638 let m = cust_key(99);
1639 let err = decrypt(
1640 &ct,
1641 SseSource::CustomerKey {
1642 key: &m.key,
1643 key_md5: &m.key_md5,
1644 },
1645 )
1646 .unwrap_err();
1647 assert!(matches!(err, SseError::CustomerKeyUnexpected), "got {err:?}");
1648 }
1649
1650 #[test]
1651 fn s4e3_blob_with_keyring_source_is_rejected() {
1652 let m = cust_key(60);
1655 let ct = encrypt_with_source(b"customer-key object", (&m).into());
1656 let kr = keyring_single(60);
1657 let err = decrypt(&ct, &kr).unwrap_err();
1658 assert!(matches!(err, SseError::CustomerKeyRequired), "got {err:?}");
1659 }
1660
1661 #[test]
1662 fn looks_encrypted_detects_s4e3() {
1663 let m = cust_key(13);
1664 let ct = encrypt_with_source(b"x", (&m).into());
1665 assert!(looks_encrypted(&ct));
1666 }
1667
1668 #[test]
1669 fn s4e3_rejects_short_body() {
1670 let mut short = Vec::new();
1673 short.extend_from_slice(SSE_MAGIC_V3);
1674 short.push(ALGO_AES_256_GCM);
1675 short.extend_from_slice(&[0u8; SSE_HEADER_BYTES - 5]);
1678 assert_eq!(short.len(), SSE_HEADER_BYTES);
1679 let m = cust_key(1);
1680 let err = decrypt(
1681 &short,
1682 SseSource::CustomerKey {
1683 key: &m.key,
1684 key_md5: &m.key_md5,
1685 },
1686 )
1687 .unwrap_err();
1688 assert!(matches!(err, SseError::TooShort { .. }), "got {err:?}");
1689 }
1690
1691 #[test]
1692 fn customer_key_material_debug_redacts_key() {
1693 let m = cust_key(99);
1694 let s = format!("{m:?}");
1695 assert!(s.contains("redacted"));
1696 assert!(!s.contains(&format!("{:?}", m.key.as_slice())));
1697 }
1698
1699 #[test]
1700 fn constant_time_eq_basic() {
1701 assert!(constant_time_eq(b"abc", b"abc"));
1702 assert!(!constant_time_eq(b"abc", b"abd"));
1703 assert!(!constant_time_eq(b"abc", b"abcd"));
1704 assert!(constant_time_eq(b"", b""));
1705 }
1706
1707 #[test]
1708 fn compute_key_md5_known_vector() {
1709 let got = compute_key_md5(b"");
1711 let expected_hex = "d41d8cd98f00b204e9800998ecf8427e";
1712 assert_eq!(hex_lower(&got), expected_hex);
1713 }
1714
1715 use crate::kms::{KmsBackend, LocalKms};
1720 use std::collections::HashMap;
1721 use std::path::PathBuf;
1722
1723 fn local_kms_with(key_ids: &[(&str, [u8; 32])]) -> LocalKms {
1724 let mut keks: HashMap<String, [u8; 32]> = HashMap::new();
1725 for (id, k) in key_ids {
1726 keks.insert((*id).to_string(), *k);
1727 }
1728 LocalKms::from_keks(PathBuf::from("/tmp/none"), keks)
1729 }
1730
1731 #[tokio::test]
1732 async fn s4e4_roundtrip_via_local_kms() {
1733 let kms = local_kms_with(&[("alpha", [42u8; 32])]);
1734 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
1735 let mut dek = [0u8; 32];
1736 dek.copy_from_slice(&dek_vec);
1737 let pt = b"SSE-KMS envelope payload across the S4E4 frame";
1738 let ct = encrypt_with_source(
1739 pt,
1740 SseSource::Kms {
1741 dek: &dek,
1742 wrapped: &wrapped,
1743 },
1744 );
1745 assert_eq!(&ct[..4], SSE_MAGIC_V4);
1747 assert_eq!(ct[4], ALGO_AES_256_GCM);
1748 let key_id_len = ct[5] as usize;
1749 assert_eq!(key_id_len, "alpha".len());
1750 assert_eq!(&ct[6..6 + key_id_len], b"alpha");
1751 assert!(looks_encrypted(&ct));
1753 assert_eq!(peek_magic(&ct), Some("S4E4"));
1754 let plain = decrypt_with_kms(&ct, &kms).await.unwrap();
1756 assert_eq!(plain.as_ref(), pt);
1757 }
1758
1759 #[tokio::test]
1760 async fn s4e4_tampered_key_id_fails_aead() {
1761 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
1762 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
1763 let mut dek = [0u8; 32];
1764 dek.copy_from_slice(&dek_vec);
1765 let mut ct = encrypt_with_source(
1766 b"do not redirect",
1767 SseSource::Kms {
1768 dek: &dek,
1769 wrapped: &wrapped,
1770 },
1771 )
1772 .to_vec();
1773 let key_id_off = 6;
1778 ct[key_id_off] = b'b';
1779 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
1780 assert!(
1781 matches!(
1782 err,
1783 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
1784 | SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
1785 ),
1786 "got {err:?}"
1787 );
1788 }
1789
1790 #[tokio::test]
1791 async fn s4e4_tampered_key_id_to_real_other_id_still_fails() {
1792 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
1798 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
1799 let mut dek = [0u8; 32];
1800 dek.copy_from_slice(&dek_vec);
1801 let mut ct = encrypt_with_source(
1802 b"redirect attempt",
1803 SseSource::Kms {
1804 dek: &dek,
1805 wrapped: &wrapped,
1806 },
1807 )
1808 .to_vec();
1809 let key_id_off = 6;
1812 ct[key_id_off..key_id_off + 5].copy_from_slice(b"beta_");
1813 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
1820 assert!(
1821 matches!(
1822 err,
1823 SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
1824 ),
1825 "got {err:?}"
1826 );
1827 }
1828
1829 #[tokio::test]
1830 async fn s4e4_tampered_wrapped_dek_fails_unwrap() {
1831 let kms = local_kms_with(&[("k", [3u8; 32])]);
1832 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
1833 let mut dek = [0u8; 32];
1834 dek.copy_from_slice(&dek_vec);
1835 let mut ct = encrypt_with_source(
1836 b"target body",
1837 SseSource::Kms {
1838 dek: &dek,
1839 wrapped: &wrapped,
1840 },
1841 )
1842 .to_vec();
1843 let key_id_len = ct[5] as usize;
1847 let wrapped_len_off = 6 + key_id_len;
1848 let wrapped_off = wrapped_len_off + 4;
1849 let mid = wrapped_off + (wrapped.ciphertext.len() / 2);
1850 ct[mid] ^= 0xFF;
1851 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
1852 assert!(
1853 matches!(
1854 err,
1855 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
1856 ),
1857 "got {err:?}"
1858 );
1859 }
1860
1861 #[tokio::test]
1862 async fn s4e4_tampered_ciphertext_fails_aead() {
1863 let kms = local_kms_with(&[("k", [4u8; 32])]);
1864 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
1865 let mut dek = [0u8; 32];
1866 dek.copy_from_slice(&dek_vec);
1867 let mut ct = encrypt_with_source(
1868 b"sealed body",
1869 SseSource::Kms {
1870 dek: &dek,
1871 wrapped: &wrapped,
1872 },
1873 )
1874 .to_vec();
1875 let last = ct.len() - 1;
1876 ct[last] ^= 0x01;
1877 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
1878 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
1879 }
1880
1881 #[tokio::test]
1882 async fn s4e4_uses_random_nonce_and_dek_per_put() {
1883 let kms = local_kms_with(&[("k", [5u8; 32])]);
1884 let (dek1_vec, wrapped1) = kms.generate_dek("k").await.unwrap();
1887 let (dek2_vec, wrapped2) = kms.generate_dek("k").await.unwrap();
1888 let mut dek1 = [0u8; 32];
1889 dek1.copy_from_slice(&dek1_vec);
1890 let mut dek2 = [0u8; 32];
1891 dek2.copy_from_slice(&dek2_vec);
1892 let pt = b"deterministic input";
1893 let a = encrypt_with_source(
1894 pt,
1895 SseSource::Kms {
1896 dek: &dek1,
1897 wrapped: &wrapped1,
1898 },
1899 );
1900 let b = encrypt_with_source(
1901 pt,
1902 SseSource::Kms {
1903 dek: &dek2,
1904 wrapped: &wrapped2,
1905 },
1906 );
1907 assert_ne!(a, b);
1908 let plain_a = decrypt_with_kms(&a, &kms).await.unwrap();
1910 let plain_b = decrypt_with_kms(&b, &kms).await.unwrap();
1911 assert_eq!(plain_a.as_ref(), pt);
1912 assert_eq!(plain_b.as_ref(), pt);
1913 }
1914
1915 #[tokio::test]
1916 async fn s4e4_sync_decrypt_returns_kms_async_required() {
1917 let kms = local_kms_with(&[("k", [6u8; 32])]);
1922 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
1923 let mut dek = [0u8; 32];
1924 dek.copy_from_slice(&dek_vec);
1925 let ct = encrypt_with_source(
1926 b"async only",
1927 SseSource::Kms {
1928 dek: &dek,
1929 wrapped: &wrapped,
1930 },
1931 );
1932 let kr = SseKeyring::new(1, key32(0));
1934 let err = decrypt(&ct, &kr).unwrap_err();
1935 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
1936 }
1937
1938 #[test]
1939 fn back_compat_s4e1_e2_e3_still_decrypt_via_sync() {
1940 let k = key32(7);
1943 let v1 = encrypt(&k, b"v0.4 vintage");
1944 let kr = SseKeyring::new(1, Arc::clone(&k));
1945 assert_eq!(decrypt(&v1, &kr).unwrap().as_ref(), b"v0.4 vintage");
1946
1947 let v2 = encrypt_v2(b"v0.5 #29 vintage", &kr);
1948 assert_eq!(
1949 decrypt(&v2, &kr).unwrap().as_ref(),
1950 b"v0.5 #29 vintage"
1951 );
1952
1953 let m = cust_key(7);
1954 let v3 = encrypt_with_source(b"v0.5 #27 vintage", (&m).into());
1955 assert_eq!(
1956 decrypt(&v3, &m).unwrap().as_ref(),
1957 b"v0.5 #27 vintage"
1958 );
1959 }
1960
1961 #[test]
1962 fn peek_magic_distinguishes_all_variants() {
1963 let k = key32(9);
1966 let v1 = encrypt(&k, b"x");
1967 assert_eq!(peek_magic(&v1), Some("S4E1"));
1968 let kr = SseKeyring::new(1, Arc::clone(&k));
1969 let v2 = encrypt_v2(b"x", &kr);
1970 assert_eq!(peek_magic(&v2), Some("S4E2"));
1971 let m = cust_key(9);
1972 let v3 = encrypt_with_source(b"x", (&m).into());
1973 assert_eq!(peek_magic(&v3), Some("S4E3"));
1974 let mut v4 = Vec::new();
1979 v4.extend_from_slice(SSE_MAGIC_V4);
1980 v4.extend_from_slice(&[0u8; 40]);
1981 assert_eq!(peek_magic(&v4), Some("S4E4"));
1982 assert!(peek_magic(b"NOPE").is_none());
1984 assert!(peek_magic(b"short").is_none());
1985 assert!(peek_magic(&[0u8; 100]).is_none());
1986 }
1987
1988 #[tokio::test]
1989 async fn s4e4_truncated_frame_errors_cleanly() {
1990 let truncated = b"S4E4\x01\x05hi";
1993 let kms = local_kms_with(&[("k", [1u8; 32])]);
1994 let err = decrypt_with_kms(truncated, &kms).await.unwrap_err();
1995 assert!(
1996 matches!(err, SseError::KmsFrameTooShort { .. }),
1997 "got {err:?}"
1998 );
1999 }
2000
2001 #[tokio::test]
2002 async fn s4e4_oob_key_id_len_errors() {
2003 let mut body = Vec::new();
2007 body.extend_from_slice(SSE_MAGIC_V4);
2008 body.push(ALGO_AES_256_GCM);
2009 body.push(200u8); body.extend_from_slice(&[0u8; 50]);
2014 let kms = local_kms_with(&[("k", [1u8; 32])]);
2015 let err = decrypt_with_kms(&body, &kms).await.unwrap_err();
2016 assert!(
2017 matches!(err, SseError::KmsFrameFieldOob { .. }),
2018 "got {err:?}"
2019 );
2020 }
2021
2022 #[tokio::test]
2023 async fn s4e4_via_keyring_source_into_sync_decrypt_is_kms_async_required() {
2024 let kms = local_kms_with(&[("k", [9u8; 32])]);
2030 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2031 let mut dek = [0u8; 32];
2032 dek.copy_from_slice(&dek_vec);
2033 let ct = encrypt_with_source(
2034 b"x",
2035 SseSource::Kms {
2036 dek: &dek,
2037 wrapped: &wrapped,
2038 },
2039 );
2040 let m = cust_key(1);
2041 let err = decrypt(&ct, &m).unwrap_err();
2042 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
2043 }
2044
2045 #[tokio::test]
2046 async fn s4e4_looks_encrypted_passthrough_returns_false_for_synthetic() {
2047 let mut not_s4e4 = Vec::new();
2049 not_s4e4.extend_from_slice(b"S4F4");
2050 not_s4e4.extend_from_slice(&[0u8; 60]);
2051 assert!(!looks_encrypted(¬_s4e4));
2052 assert_eq!(peek_magic(¬_s4e4), None);
2053 }
2054
2055 #[tokio::test]
2056 async fn s4e4_aad_length_prefix_prevents_byte_shifting() {
2057 let kms = local_kms_with(&[("kk", [11u8; 32])]);
2064 let (dek_vec, wrapped) = kms.generate_dek("kk").await.unwrap();
2065 let mut dek = [0u8; 32];
2066 dek.copy_from_slice(&dek_vec);
2067 let mut ct = encrypt_with_source(
2068 b"length-shift defense",
2069 SseSource::Kms {
2070 dek: &dek,
2071 wrapped: &wrapped,
2072 },
2073 )
2074 .to_vec();
2075 let key_id_len = ct[5] as usize;
2076 let wrapped_len_off = 6 + key_id_len;
2077 let original_len = u32::from_be_bytes([
2083 ct[wrapped_len_off],
2084 ct[wrapped_len_off + 1],
2085 ct[wrapped_len_off + 2],
2086 ct[wrapped_len_off + 3],
2087 ]);
2088 let new_len = (original_len - 1).to_be_bytes();
2089 ct[wrapped_len_off..wrapped_len_off + 4].copy_from_slice(&new_len);
2090 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2091 assert!(
2094 matches!(
2095 err,
2096 SseError::KmsBackend(_)
2097 | SseError::DecryptFailed
2098 | SseError::KmsFrameFieldOob { .. }
2099 | SseError::KmsFrameTooShort { .. }
2100 ),
2101 "got {err:?}"
2102 );
2103 }
2104}