1use std::collections::HashMap;
122use std::path::Path;
123use std::sync::Arc;
124
125use aes_gcm::aead::{Aead, KeyInit, Payload};
126use aes_gcm::{Aes256Gcm, Key, Nonce};
127use bytes::Bytes;
128use md5::{Digest as Md5Digest, Md5};
129use rand::RngCore;
130use thiserror::Error;
131
132use crate::kms::{KmsBackend, KmsError, WrappedDek};
133
134pub const SSE_MAGIC_V1: &[u8; 4] = b"S4E1";
135pub const SSE_MAGIC_V2: &[u8; 4] = b"S4E2";
136pub const SSE_MAGIC_V3: &[u8; 4] = b"S4E3";
137pub const SSE_MAGIC_V4: &[u8; 4] = b"S4E4";
138pub const SSE_MAGIC_V5: &[u8; 4] = b"S4E5";
149pub const SSE_MAGIC_V6: &[u8; 4] = b"S4E6";
155pub const SSE_MAGIC: &[u8; 4] = SSE_MAGIC_V1;
157
158pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const SSE_HEADER_BYTES_V3: usize = 4 + 1 + KEY_MD5_LEN + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
169const NONCE_LEN: usize = 12;
170const TAG_LEN: usize = 16;
171const KEY_LEN: usize = 32;
172const KEY_MD5_LEN: usize = 16;
173pub const SSE_C_ALGORITHM: &str = "AES256";
177
178#[derive(Debug, Error)]
179pub enum SseError {
180 #[error("SSE key file {path:?}: {source}")]
181 KeyFileIo {
182 path: std::path::PathBuf,
183 source: std::io::Error,
184 },
185 #[error(
186 "SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
187 )]
188 BadKeyLength { got: usize },
189 #[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
190 TooShort { got: usize },
191 #[error("SSE bad magic: expected S4E1/S4E2/S4E3/S4E4/S4E5/S4E6, got {got:?}")]
192 BadMagic { got: [u8; 4] },
193 #[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
194 UnsupportedAlgo { tag: u8 },
195 #[error(
196 "SSE key_id {id} (S4E2 frame) not present in keyring; rotation history likely incomplete"
197 )]
198 KeyNotInKeyring { id: u16 },
199 #[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
200 DecryptFailed,
201 #[error("SSE-C key MD5 fingerprint mismatch — client supplied a different key than PUT")]
209 WrongCustomerKey,
210 #[error("SSE-C customer-key headers invalid: {reason}")]
215 InvalidCustomerKey { reason: &'static str },
216 #[error("SSE-C algorithm {algo:?} unsupported (only {SSE_C_ALGORITHM:?} is allowed)")]
220 CustomerKeyAlgorithmUnsupported { algo: String },
221 #[error("S4E3 frame requires SseSource::CustomerKey; got Keyring")]
226 CustomerKeyRequired,
227 #[error("S4E1/S4E2 frame stored without SSE-C; SseSource::CustomerKey is unexpected")]
232 CustomerKeyUnexpected,
233 #[error(
240 "S4E4 (SSE-KMS) body requires async decrypt — call decrypt_with_kms() instead of decrypt()"
241 )]
242 KmsAsyncRequired,
243 #[error("S4E4 frame too short ({got} bytes; need at least {min})")]
247 KmsFrameTooShort { got: usize, min: usize },
248 #[error("S4E4 frame field length out of bounds: {what}")]
253 KmsFrameFieldOob { what: &'static str },
254 #[error("S4E4 key_id is not valid UTF-8")]
259 KmsKeyIdNotUtf8,
260 #[error(
267 "S4E4 SseSource::Kms wrapped DEK key_id {supplied:?} doesn't match frame key_id {stored:?}"
268 )]
269 KmsWrappedDekMismatch { supplied: String, stored: String },
270 #[error("S4E4 frame requires SseSource::Kms")]
277 KmsRequired,
278 #[error("KMS unwrap: {0}")]
281 KmsBackend(#[from] KmsError),
282 #[error(
291 "S4E5 chunk {chunk_index} auth tag verify failed (key mismatch or chunk tampered with)"
292 )]
293 ChunkAuthFailed { chunk_index: u32 },
294 #[error("S4E5 chunk_size must be > 0 (got 0)")]
299 ChunkSizeInvalid,
300 #[error("S4E5 frame truncated: {what}")]
306 ChunkFrameTruncated { what: &'static str },
307 #[error("S4E6 chunk_count {got} exceeds 24-bit max ({max}) — pick a larger --sse-chunk-size")]
317 ChunkCountTooLarge { got: u32, max: u32 },
318 #[error("S4E5/S4E6 chunked frame declares an over-large size: {details}")]
339 ChunkFrameTooLarge { details: &'static str },
340}
341
342pub const DEFAULT_MAX_BODY_BYTES: usize = 5 * 1024 * 1024 * 1024;
349
350pub struct SseKey {
355 pub bytes: [u8; 32],
356}
357
358impl SseKey {
359 pub fn from_path(path: &Path) -> Result<Self, SseError> {
363 let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
364 path: path.to_path_buf(),
365 source,
366 })?;
367 Self::from_bytes(&raw)
368 }
369
370 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
371 if bytes.len() == KEY_LEN {
373 let mut k = [0u8; KEY_LEN];
374 k.copy_from_slice(bytes);
375 return Ok(Self { bytes: k });
376 }
377 let s = std::str::from_utf8(bytes).unwrap_or("").trim();
379 if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
380 let mut k = [0u8; KEY_LEN];
381 for (i, k_byte) in k.iter_mut().enumerate() {
382 *k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
383 .map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
384 }
385 return Ok(Self { bytes: k });
386 }
387 if let Ok(decoded) =
388 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
389 && decoded.len() == KEY_LEN
390 {
391 let mut k = [0u8; KEY_LEN];
392 k.copy_from_slice(&decoded);
393 return Ok(Self { bytes: k });
394 }
395 Err(SseError::BadKeyLength { got: bytes.len() })
396 }
397
398 fn as_aes_key(&self) -> &Key<Aes256Gcm> {
399 Key::<Aes256Gcm>::from_slice(&self.bytes)
400 }
401}
402
403impl std::fmt::Debug for SseKey {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 f.debug_struct("SseKey")
406 .field("len", &KEY_LEN)
407 .field("key", &"<redacted>")
408 .finish()
409 }
410}
411
412#[derive(Clone)]
417pub struct SseKeyring {
418 active: u16,
419 keys: HashMap<u16, Arc<SseKey>>,
420}
421
422impl SseKeyring {
423 pub fn new(active: u16, key: Arc<SseKey>) -> Self {
427 let mut keys = HashMap::new();
428 keys.insert(active, key);
429 Self { active, keys }
430 }
431
432 pub fn add(&mut self, id: u16, key: Arc<SseKey>) {
436 self.keys.insert(id, key);
437 }
438
439 pub fn active(&self) -> (u16, &SseKey) {
442 let id = self.active;
443 let key = self
444 .keys
445 .get(&id)
446 .expect("active key id must be present in keyring (constructor invariant)");
447 (id, key.as_ref())
448 }
449
450 pub fn get(&self, id: u16) -> Option<&SseKey> {
453 self.keys.get(&id).map(Arc::as_ref)
454 }
455}
456
457impl std::fmt::Debug for SseKeyring {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 f.debug_struct("SseKeyring")
460 .field("active", &self.active)
461 .field("key_count", &self.keys.len())
462 .field("key_ids", &self.keys.keys().collect::<Vec<_>>())
463 .finish()
464 }
465}
466
467pub type SharedSseKeyring = Arc<SseKeyring>;
468
469pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
476 let cipher = Aes256Gcm::new(key.as_aes_key());
477 let mut nonce_bytes = [0u8; NONCE_LEN];
478 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
479 let nonce = Nonce::from_slice(&nonce_bytes);
480 let mut aad = [0u8; 8];
482 aad[..4].copy_from_slice(SSE_MAGIC_V1);
483 aad[4] = ALGO_AES_256_GCM;
484 let ct_with_tag = cipher
485 .encrypt(
486 nonce,
487 Payload {
488 msg: plaintext,
489 aad: &aad,
490 },
491 )
492 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
493 debug_assert!(ct_with_tag.len() >= TAG_LEN);
494 let split = ct_with_tag.len() - TAG_LEN;
495 let (ct, tag) = ct_with_tag.split_at(split);
496
497 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
498 out.extend_from_slice(SSE_MAGIC_V1);
499 out.push(ALGO_AES_256_GCM);
500 out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
502 out.extend_from_slice(tag);
503 out.extend_from_slice(ct);
504 Bytes::from(out)
505}
506
507pub fn encrypt_v2(plaintext: &[u8], keyring: &SseKeyring) -> Bytes {
512 let (key_id, key) = keyring.active();
513 let cipher = Aes256Gcm::new(key.as_aes_key());
514 let mut nonce_bytes = [0u8; NONCE_LEN];
515 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
516 let nonce = Nonce::from_slice(&nonce_bytes);
517 let aad = aad_v2(key_id);
518 let ct_with_tag = cipher
519 .encrypt(
520 nonce,
521 Payload {
522 msg: plaintext,
523 aad: &aad,
524 },
525 )
526 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
527 debug_assert!(ct_with_tag.len() >= TAG_LEN);
528 let split = ct_with_tag.len() - TAG_LEN;
529 let (ct, tag) = ct_with_tag.split_at(split);
530
531 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
532 out.extend_from_slice(SSE_MAGIC_V2);
533 out.push(ALGO_AES_256_GCM);
534 out.extend_from_slice(&key_id.to_be_bytes()); out.push(0u8); out.extend_from_slice(&nonce_bytes);
537 out.extend_from_slice(tag);
538 out.extend_from_slice(ct);
539 Bytes::from(out)
540}
541
542fn aad_v1() -> [u8; 8] {
543 let mut aad = [0u8; 8];
544 aad[..4].copy_from_slice(SSE_MAGIC_V1);
545 aad[4] = ALGO_AES_256_GCM;
546 aad
547}
548
549fn aad_v2(key_id: u16) -> [u8; 8] {
550 let mut aad = [0u8; 8];
551 aad[..4].copy_from_slice(SSE_MAGIC_V2);
552 aad[4] = ALGO_AES_256_GCM;
553 aad[5..7].copy_from_slice(&key_id.to_be_bytes());
554 aad[7] = 0u8;
555 aad
556}
557
558fn aad_v3(key_md5: &[u8; KEY_MD5_LEN]) -> [u8; 4 + 1 + KEY_MD5_LEN] {
564 let mut aad = [0u8; 4 + 1 + KEY_MD5_LEN];
565 aad[..4].copy_from_slice(SSE_MAGIC_V3);
566 aad[4] = ALGO_AES_256_GCM;
567 aad[5..5 + KEY_MD5_LEN].copy_from_slice(key_md5);
568 aad
569}
570
571#[derive(Clone)]
577pub struct CustomerKeyMaterial {
578 pub key: [u8; KEY_LEN],
579 pub key_md5: [u8; KEY_MD5_LEN],
580}
581
582impl std::fmt::Debug for CustomerKeyMaterial {
583 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584 f.debug_struct("CustomerKeyMaterial")
587 .field("key", &"<redacted>")
588 .field("key_md5_hex", &hex_lower(&self.key_md5))
589 .finish()
590 }
591}
592
593fn hex_lower(bytes: &[u8]) -> String {
594 let mut s = String::with_capacity(bytes.len() * 2);
595 for b in bytes {
596 s.push_str(&format!("{b:02x}"));
597 }
598 s
599}
600
601#[derive(Debug, Clone, Copy)]
609pub enum SseSource<'a> {
610 Keyring(&'a SseKeyring),
613 CustomerKey {
617 key: &'a [u8; KEY_LEN],
618 key_md5: &'a [u8; KEY_MD5_LEN],
619 },
620 Kms {
626 dek: &'a [u8; KEY_LEN],
628 wrapped: &'a WrappedDek,
631 },
632}
633
634impl<'a> From<&'a SseKeyring> for SseSource<'a> {
641 fn from(kr: &'a SseKeyring) -> Self {
642 SseSource::Keyring(kr)
643 }
644}
645
646impl<'a> From<&'a Arc<SseKeyring>> for SseSource<'a> {
650 fn from(kr: &'a Arc<SseKeyring>) -> Self {
651 SseSource::Keyring(kr.as_ref())
652 }
653}
654
655impl<'a> From<&'a CustomerKeyMaterial> for SseSource<'a> {
656 fn from(m: &'a CustomerKeyMaterial) -> Self {
657 SseSource::CustomerKey {
658 key: &m.key,
659 key_md5: &m.key_md5,
660 }
661 }
662}
663
664pub fn parse_customer_key_headers(
676 algorithm: &str,
677 key_base64: &str,
678 key_md5_base64: &str,
679) -> Result<CustomerKeyMaterial, SseError> {
680 use base64::Engine as _;
681 if algorithm != SSE_C_ALGORITHM {
682 return Err(SseError::CustomerKeyAlgorithmUnsupported {
683 algo: algorithm.to_string(),
684 });
685 }
686 let key_bytes = base64::engine::general_purpose::STANDARD
687 .decode(key_base64.trim().as_bytes())
688 .map_err(|_| SseError::InvalidCustomerKey {
689 reason: "base64 decode of key",
690 })?;
691 if key_bytes.len() != KEY_LEN {
692 return Err(SseError::InvalidCustomerKey {
693 reason: "key length (must be 32 bytes after base64 decode)",
694 });
695 }
696 let supplied_md5 = base64::engine::general_purpose::STANDARD
697 .decode(key_md5_base64.trim().as_bytes())
698 .map_err(|_| SseError::InvalidCustomerKey {
699 reason: "base64 decode of key MD5",
700 })?;
701 if supplied_md5.len() != KEY_MD5_LEN {
702 return Err(SseError::InvalidCustomerKey {
703 reason: "key MD5 length (must be 16 bytes after base64 decode)",
704 });
705 }
706 let actual_md5 = compute_key_md5(&key_bytes);
707 if !constant_time_eq(&actual_md5, &supplied_md5) {
710 return Err(SseError::InvalidCustomerKey {
711 reason: "supplied MD5 does not match MD5 of supplied key",
712 });
713 }
714 let mut key = [0u8; KEY_LEN];
715 key.copy_from_slice(&key_bytes);
716 let mut key_md5 = [0u8; KEY_MD5_LEN];
717 key_md5.copy_from_slice(&actual_md5);
718 Ok(CustomerKeyMaterial { key, key_md5 })
719}
720
721pub fn compute_key_md5(key: &[u8]) -> [u8; KEY_MD5_LEN] {
726 let mut h = Md5::new();
727 h.update(key);
728 let out = h.finalize();
729 let mut md5 = [0u8; KEY_MD5_LEN];
730 md5.copy_from_slice(&out);
731 md5
732}
733
734fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
737 if a.len() != b.len() {
738 return false;
739 }
740 let mut acc: u8 = 0;
741 for (x, y) in a.iter().zip(b.iter()) {
742 acc |= x ^ y;
743 }
744 acc == 0
745}
746
747pub fn encrypt_with_source(plaintext: &[u8], source: SseSource<'_>) -> Bytes {
757 match source {
758 SseSource::Keyring(kr) => encrypt_v2(plaintext, kr),
759 SseSource::CustomerKey { key, key_md5 } => encrypt_v3(plaintext, key, key_md5),
760 SseSource::Kms { dek, wrapped } => encrypt_v4(plaintext, dek, wrapped),
761 }
762}
763
764fn encrypt_v3(plaintext: &[u8], key: &[u8; KEY_LEN], key_md5: &[u8; KEY_MD5_LEN]) -> Bytes {
765 let aes_key = Key::<Aes256Gcm>::from_slice(key);
766 let cipher = Aes256Gcm::new(aes_key);
767 let mut nonce_bytes = [0u8; NONCE_LEN];
768 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
769 let nonce = Nonce::from_slice(&nonce_bytes);
770 let aad = aad_v3(key_md5);
771 let ct_with_tag = cipher
772 .encrypt(
773 nonce,
774 Payload {
775 msg: plaintext,
776 aad: &aad,
777 },
778 )
779 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
780 debug_assert!(ct_with_tag.len() >= TAG_LEN);
781 let split = ct_with_tag.len() - TAG_LEN;
782 let (ct, tag) = ct_with_tag.split_at(split);
783
784 let mut out = Vec::with_capacity(SSE_HEADER_BYTES_V3 + ct.len());
785 out.extend_from_slice(SSE_MAGIC_V3);
786 out.push(ALGO_AES_256_GCM);
787 out.extend_from_slice(key_md5);
788 out.extend_from_slice(&nonce_bytes);
789 out.extend_from_slice(tag);
790 out.extend_from_slice(ct);
791 Bytes::from(out)
792}
793
794pub fn decrypt<'a, S: Into<SseSource<'a>>>(body: &[u8], source: S) -> Result<Bytes, SseError> {
813 let source = source.into();
814 if body.len() < SSE_HEADER_BYTES {
820 return Err(SseError::TooShort { got: body.len() });
821 }
822 let mut magic = [0u8; 4];
823 magic.copy_from_slice(&body[..4]);
824 match &magic {
825 m if m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 => {
826 let keyring = match source {
827 SseSource::Keyring(kr) => kr,
828 SseSource::CustomerKey { .. } => return Err(SseError::CustomerKeyUnexpected),
829 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
835 };
836 if m == SSE_MAGIC_V1 {
837 decrypt_v1_with_keyring(body, keyring)
838 } else {
839 decrypt_v2_with_keyring(body, keyring)
840 }
841 }
842 m if m == SSE_MAGIC_V3 => {
843 if body.len() < SSE_HEADER_BYTES_V3 {
845 return Err(SseError::TooShort { got: body.len() });
846 }
847 let (key, key_md5) = match source {
848 SseSource::CustomerKey { key, key_md5 } => (key, key_md5),
849 SseSource::Keyring(_) => return Err(SseError::CustomerKeyRequired),
850 SseSource::Kms { .. } => return Err(SseError::CustomerKeyRequired),
851 };
852 decrypt_v3(body, key, key_md5)
853 }
854 m if m == SSE_MAGIC_V4 => {
855 Err(SseError::KmsAsyncRequired)
860 }
861 m if m == SSE_MAGIC_V5 || m == SSE_MAGIC_V6 => {
862 let keyring = match source {
870 SseSource::Keyring(kr) => kr,
871 SseSource::CustomerKey { .. } => {
872 return Err(SseError::CustomerKeyUnexpected);
873 }
874 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
875 };
876 decrypt_chunked_buffered_default(body, keyring)
883 }
884 _ => Err(SseError::BadMagic { got: magic }),
885 }
886}
887
888fn decrypt_v3(
889 body: &[u8],
890 key: &[u8; KEY_LEN],
891 supplied_md5: &[u8; KEY_MD5_LEN],
892) -> Result<Bytes, SseError> {
893 let algo = body[4];
894 if algo != ALGO_AES_256_GCM {
895 return Err(SseError::UnsupportedAlgo { tag: algo });
896 }
897 let mut stored_md5 = [0u8; KEY_MD5_LEN];
898 stored_md5.copy_from_slice(&body[5..5 + KEY_MD5_LEN]);
899 if !constant_time_eq(supplied_md5, &stored_md5) {
905 return Err(SseError::WrongCustomerKey);
906 }
907 let nonce_off = 5 + KEY_MD5_LEN;
908 let tag_off = nonce_off + NONCE_LEN;
909 let mut nonce_bytes = [0u8; NONCE_LEN];
910 nonce_bytes.copy_from_slice(&body[nonce_off..nonce_off + NONCE_LEN]);
911 let mut tag_bytes = [0u8; TAG_LEN];
912 tag_bytes.copy_from_slice(&body[tag_off..tag_off + TAG_LEN]);
913 let ct = &body[SSE_HEADER_BYTES_V3..];
914
915 let aad = aad_v3(&stored_md5);
916 let nonce = Nonce::from_slice(&nonce_bytes);
917 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
918 ct_with_tag.extend_from_slice(ct);
919 ct_with_tag.extend_from_slice(&tag_bytes);
920
921 let aes_key = Key::<Aes256Gcm>::from_slice(key);
922 let cipher = Aes256Gcm::new(aes_key);
923 let plain = cipher
924 .decrypt(
925 nonce,
926 Payload {
927 msg: &ct_with_tag,
928 aad: &aad,
929 },
930 )
931 .map_err(|_| SseError::DecryptFailed)?;
932 Ok(Bytes::from(plain))
933}
934
935fn aad_v4(key_id: &[u8], wrapped_dek: &[u8]) -> Vec<u8> {
946 let mut aad = Vec::with_capacity(4 + 1 + 1 + key_id.len() + 4 + wrapped_dek.len());
947 aad.extend_from_slice(SSE_MAGIC_V4);
948 aad.push(ALGO_AES_256_GCM);
949 aad.push(key_id.len() as u8);
950 aad.extend_from_slice(key_id);
951 aad.extend_from_slice(&(wrapped_dek.len() as u32).to_be_bytes());
952 aad.extend_from_slice(wrapped_dek);
953 aad
954}
955
956fn encrypt_v4(plaintext: &[u8], dek: &[u8; KEY_LEN], wrapped: &WrappedDek) -> Bytes {
957 assert!(
965 !wrapped.key_id.is_empty() && wrapped.key_id.len() <= u8::MAX as usize,
966 "S4E4 key_id must be 1..=255 bytes (got {})",
967 wrapped.key_id.len()
968 );
969 assert!(
970 wrapped.ciphertext.len() <= u32::MAX as usize,
971 "S4E4 wrapped_dek longer than u32::MAX",
972 );
973
974 let aes_key = Key::<Aes256Gcm>::from_slice(dek);
975 let cipher = Aes256Gcm::new(aes_key);
976 let mut nonce_bytes = [0u8; NONCE_LEN];
977 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
978 let nonce = Nonce::from_slice(&nonce_bytes);
979 let aad = aad_v4(wrapped.key_id.as_bytes(), &wrapped.ciphertext);
980 let ct_with_tag = cipher
981 .encrypt(
982 nonce,
983 Payload {
984 msg: plaintext,
985 aad: &aad,
986 },
987 )
988 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
989 debug_assert!(ct_with_tag.len() >= TAG_LEN);
990 let split = ct_with_tag.len() - TAG_LEN;
991 let (ct, tag) = ct_with_tag.split_at(split);
992
993 let key_id_bytes = wrapped.key_id.as_bytes();
994 let mut out = Vec::with_capacity(
995 4 + 1
996 + 1
997 + key_id_bytes.len()
998 + 4
999 + wrapped.ciphertext.len()
1000 + NONCE_LEN
1001 + TAG_LEN
1002 + ct.len(),
1003 );
1004 out.extend_from_slice(SSE_MAGIC_V4);
1005 out.push(ALGO_AES_256_GCM);
1006 out.push(key_id_bytes.len() as u8);
1007 out.extend_from_slice(key_id_bytes);
1008 out.extend_from_slice(&(wrapped.ciphertext.len() as u32).to_be_bytes());
1009 out.extend_from_slice(&wrapped.ciphertext);
1010 out.extend_from_slice(&nonce_bytes);
1011 out.extend_from_slice(tag);
1012 out.extend_from_slice(ct);
1013 Bytes::from(out)
1014}
1015
1016#[derive(Debug)]
1022pub struct S4E4Header<'a> {
1023 pub key_id: &'a str,
1024 pub wrapped_dek: &'a [u8],
1025 pub nonce: &'a [u8],
1026 pub tag: &'a [u8],
1027 pub ciphertext: &'a [u8],
1028}
1029
1030pub fn parse_s4e4_header(body: &[u8]) -> Result<S4E4Header<'_>, SseError> {
1034 const S4E4_MIN: usize = 4 + 1 + 1 + 4 + NONCE_LEN + TAG_LEN; if body.len() < S4E4_MIN {
1041 return Err(SseError::KmsFrameTooShort {
1042 got: body.len(),
1043 min: S4E4_MIN,
1044 });
1045 }
1046 let magic = &body[..4];
1047 if magic != SSE_MAGIC_V4 {
1048 let mut got = [0u8; 4];
1049 got.copy_from_slice(magic);
1050 return Err(SseError::BadMagic { got });
1051 }
1052 let algo = body[4];
1053 if algo != ALGO_AES_256_GCM {
1054 return Err(SseError::UnsupportedAlgo { tag: algo });
1055 }
1056 let key_id_len = body[5] as usize;
1057 let key_id_off: usize = 6;
1058 let key_id_end = key_id_off
1059 .checked_add(key_id_len)
1060 .ok_or(SseError::KmsFrameFieldOob { what: "key_id_len" })?;
1061 if key_id_end + 4 > body.len() {
1062 return Err(SseError::KmsFrameFieldOob { what: "key_id" });
1063 }
1064 let key_id = std::str::from_utf8(&body[key_id_off..key_id_end])
1065 .map_err(|_| SseError::KmsKeyIdNotUtf8)?;
1066 let wrapped_len_off = key_id_end;
1067 let wrapped_dek_len = u32::from_be_bytes([
1068 body[wrapped_len_off],
1069 body[wrapped_len_off + 1],
1070 body[wrapped_len_off + 2],
1071 body[wrapped_len_off + 3],
1072 ]) as usize;
1073 let wrapped_off = wrapped_len_off + 4;
1074 let wrapped_end =
1075 wrapped_off
1076 .checked_add(wrapped_dek_len)
1077 .ok_or(SseError::KmsFrameFieldOob {
1078 what: "wrapped_dek_len",
1079 })?;
1080 if wrapped_end + NONCE_LEN + TAG_LEN > body.len() {
1081 return Err(SseError::KmsFrameFieldOob {
1082 what: "wrapped_dek",
1083 });
1084 }
1085 let wrapped_dek = &body[wrapped_off..wrapped_end];
1086 let nonce_off = wrapped_end;
1087 let tag_off = nonce_off + NONCE_LEN;
1088 let ct_off = tag_off + TAG_LEN;
1089 let nonce = &body[nonce_off..nonce_off + NONCE_LEN];
1090 let tag = &body[tag_off..tag_off + TAG_LEN];
1091 let ciphertext = &body[ct_off..];
1092 Ok(S4E4Header {
1093 key_id,
1094 wrapped_dek,
1095 nonce,
1096 tag,
1097 ciphertext,
1098 })
1099}
1100
1101pub async fn decrypt_with_kms(body: &[u8], kms: &dyn KmsBackend) -> Result<Bytes, SseError> {
1117 let hdr = parse_s4e4_header(body)?;
1118 let wrapped = WrappedDek {
1119 key_id: hdr.key_id.to_string(),
1120 ciphertext: hdr.wrapped_dek.to_vec(),
1121 };
1122 let dek_vec = kms.decrypt_dek(&wrapped).await?;
1123 if dek_vec.len() != KEY_LEN {
1124 return Err(SseError::KmsBackend(KmsError::BackendUnavailable {
1129 message: format!(
1130 "KMS returned {} byte DEK; expected {KEY_LEN}",
1131 dek_vec.len()
1132 ),
1133 }));
1134 }
1135 let mut dek = [0u8; KEY_LEN];
1136 dek.copy_from_slice(&dek_vec);
1137
1138 let aad = aad_v4(hdr.key_id.as_bytes(), hdr.wrapped_dek);
1139 let aes_key = Key::<Aes256Gcm>::from_slice(&dek);
1140 let cipher = Aes256Gcm::new(aes_key);
1141 let nonce = Nonce::from_slice(hdr.nonce);
1142 let mut ct_with_tag = Vec::with_capacity(hdr.ciphertext.len() + TAG_LEN);
1143 ct_with_tag.extend_from_slice(hdr.ciphertext);
1144 ct_with_tag.extend_from_slice(hdr.tag);
1145 let plain = cipher
1146 .decrypt(
1147 nonce,
1148 Payload {
1149 msg: &ct_with_tag,
1150 aad: &aad,
1151 },
1152 )
1153 .map_err(|_| SseError::DecryptFailed)?;
1154 Ok(Bytes::from(plain))
1155}
1156
1157fn decrypt_v1_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1158 let algo = body[4];
1159 if algo != ALGO_AES_256_GCM {
1160 return Err(SseError::UnsupportedAlgo { tag: algo });
1161 }
1162 let mut nonce_bytes = [0u8; NONCE_LEN];
1165 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1166 let mut tag_bytes = [0u8; TAG_LEN];
1167 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1168 let ct = &body[SSE_HEADER_BYTES..];
1169
1170 let aad = aad_v1();
1171 let nonce = Nonce::from_slice(&nonce_bytes);
1172 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1173 ct_with_tag.extend_from_slice(ct);
1174 ct_with_tag.extend_from_slice(&tag_bytes);
1175
1176 let (active_id, _active_key) = keyring.active();
1180 let mut ids: Vec<u16> = keyring.keys.keys().copied().collect();
1181 ids.sort_by_key(|id| if *id == active_id { 0 } else { 1 });
1182 for id in ids {
1183 let key = keyring.get(id).expect("id came from keyring iteration");
1184 let cipher = Aes256Gcm::new(key.as_aes_key());
1185 if let Ok(plain) = cipher.decrypt(
1186 nonce,
1187 Payload {
1188 msg: &ct_with_tag,
1189 aad: &aad,
1190 },
1191 ) {
1192 return Ok(Bytes::from(plain));
1193 }
1194 }
1195 Err(SseError::DecryptFailed)
1196}
1197
1198fn decrypt_v2_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1199 let algo = body[4];
1200 if algo != ALGO_AES_256_GCM {
1201 return Err(SseError::UnsupportedAlgo { tag: algo });
1202 }
1203 let key_id = u16::from_be_bytes([body[5], body[6]]);
1204 let key = keyring
1206 .get(key_id)
1207 .ok_or(SseError::KeyNotInKeyring { id: key_id })?;
1208 let mut nonce_bytes = [0u8; NONCE_LEN];
1209 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1210 let mut tag_bytes = [0u8; TAG_LEN];
1211 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1212 let ct = &body[SSE_HEADER_BYTES..];
1213
1214 let aad = aad_v2(key_id);
1215 let nonce = Nonce::from_slice(&nonce_bytes);
1216 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1217 ct_with_tag.extend_from_slice(ct);
1218 ct_with_tag.extend_from_slice(&tag_bytes);
1219 let cipher = Aes256Gcm::new(key.as_aes_key());
1220 let plain = cipher
1221 .decrypt(
1222 nonce,
1223 Payload {
1224 msg: &ct_with_tag,
1225 aad: &aad,
1226 },
1227 )
1228 .map_err(|_| SseError::DecryptFailed)?;
1229 Ok(Bytes::from(plain))
1230}
1231
1232pub fn looks_encrypted(body: &[u8]) -> bool {
1243 if body.len() < SSE_HEADER_BYTES {
1244 return false;
1245 }
1246 let m = &body[..4];
1247 m == SSE_MAGIC_V1
1248 || m == SSE_MAGIC_V2
1249 || m == SSE_MAGIC_V3
1250 || m == SSE_MAGIC_V4
1251 || m == SSE_MAGIC_V5
1252 || m == SSE_MAGIC_V6
1253}
1254
1255pub fn peek_magic(body: &[u8]) -> Option<&'static str> {
1266 if body.len() < SSE_HEADER_BYTES {
1267 return None;
1268 }
1269 match &body[..4] {
1270 m if m == SSE_MAGIC_V1 => Some("S4E1"),
1271 m if m == SSE_MAGIC_V2 => Some("S4E2"),
1272 m if m == SSE_MAGIC_V3 => Some("S4E3"),
1273 m if m == SSE_MAGIC_V4 => Some("S4E4"),
1274 m if m == SSE_MAGIC_V5 => Some("S4E5"),
1279 m if m == SSE_MAGIC_V6 => Some("S4E6"),
1281 _ => None,
1282 }
1283}
1284
1285pub type SharedSseKey = Arc<SseKey>;
1286
1287pub const S4E5_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 4; pub const S4E5_PER_CHUNK_OVERHEAD: usize = TAG_LEN; pub const S4E6_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 8; pub const S4E6_PER_CHUNK_OVERHEAD: usize = TAG_LEN; pub const S4E6_MAX_CHUNK_COUNT: u32 = (1u32 << 24) - 1; const S4E5_NONCE_TAG: [u8; 4] = [b'E', b'5', 0, 0];
1409
1410const S4E6_NONCE_PREFIX: u8 = b'E';
1415
1416#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1421enum ChunkedVariant {
1422 V5,
1423 V6,
1424}
1425
1426impl ChunkedVariant {
1427 fn header_bytes(self) -> usize {
1428 match self {
1429 ChunkedVariant::V5 => S4E5_HEADER_BYTES,
1430 ChunkedVariant::V6 => S4E6_HEADER_BYTES,
1431 }
1432 }
1433}
1434
1435fn aad_v5(
1440 chunk_index: u32,
1441 total_chunks: u32,
1442 key_id: u16,
1443 salt: &[u8; 4],
1444) -> [u8; 4 + 1 + 4 + 4 + 2 + 4] {
1445 let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 4]; aad[..4].copy_from_slice(SSE_MAGIC_V5);
1447 aad[4] = ALGO_AES_256_GCM;
1448 aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
1449 aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
1450 aad[13..15].copy_from_slice(&key_id.to_be_bytes());
1451 aad[15..19].copy_from_slice(salt);
1452 aad
1453}
1454
1455fn aad_v6(
1461 chunk_index: u32,
1462 total_chunks: u32,
1463 key_id: u16,
1464 salt: &[u8; 8],
1465) -> [u8; 4 + 1 + 4 + 4 + 2 + 8] {
1466 let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 8]; aad[..4].copy_from_slice(SSE_MAGIC_V6);
1468 aad[4] = ALGO_AES_256_GCM;
1469 aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
1470 aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
1471 aad[13..15].copy_from_slice(&key_id.to_be_bytes());
1472 aad[15..23].copy_from_slice(salt);
1473 aad
1474}
1475
1476fn nonce_v5(salt: &[u8; 4], chunk_index: u32) -> [u8; NONCE_LEN] {
1482 let mut n = [0u8; NONCE_LEN];
1483 n[..4].copy_from_slice(&S4E5_NONCE_TAG);
1484 n[4..8].copy_from_slice(salt);
1485 n[8..12].copy_from_slice(&chunk_index.to_be_bytes());
1486 n
1487}
1488
1489fn nonce_v6(salt: &[u8; 8], chunk_index: u32) -> [u8; NONCE_LEN] {
1497 debug_assert!(
1498 chunk_index <= S4E6_MAX_CHUNK_COUNT,
1499 "S4E6 chunk_index {chunk_index} exceeds 24-bit cap (caller MUST validate)",
1500 );
1501 let mut n = [0u8; NONCE_LEN];
1502 n[0] = S4E6_NONCE_PREFIX;
1503 n[1..9].copy_from_slice(salt);
1504 let be = chunk_index.to_be_bytes(); n[9..12].copy_from_slice(&be[1..4]);
1508 n
1509}
1510
1511pub fn encrypt_v2_chunked(
1534 plaintext: &[u8],
1535 keyring: &SseKeyring,
1536 chunk_size: usize,
1537) -> Result<Bytes, SseError> {
1538 if chunk_size == 0 {
1539 return Err(SseError::ChunkSizeInvalid);
1540 }
1541 let (key_id, key) = keyring.active();
1542 let cipher = Aes256Gcm::new(key.as_aes_key());
1543 let mut salt = [0u8; 8];
1544 rand::rngs::OsRng.fill_bytes(&mut salt);
1545
1546 let chunk_count_usize = if plaintext.is_empty() {
1549 1
1550 } else {
1551 plaintext.len().div_ceil(chunk_size)
1552 };
1553 let chunk_count: u32 = u32::try_from(chunk_count_usize).unwrap_or(u32::MAX);
1557 if chunk_count > S4E6_MAX_CHUNK_COUNT {
1558 return Err(SseError::ChunkCountTooLarge {
1559 got: chunk_count,
1560 max: S4E6_MAX_CHUNK_COUNT,
1561 });
1562 }
1563
1564 let mut out = Vec::with_capacity(
1565 S4E6_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E6_PER_CHUNK_OVERHEAD),
1566 );
1567 out.extend_from_slice(SSE_MAGIC_V6);
1568 out.push(ALGO_AES_256_GCM);
1569 out.extend_from_slice(&key_id.to_be_bytes());
1570 out.push(0u8); out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
1572 out.extend_from_slice(&chunk_count.to_be_bytes());
1573 out.extend_from_slice(&salt);
1574
1575 for i in 0..chunk_count {
1576 let off = (i as usize).saturating_mul(chunk_size);
1577 let end = off.saturating_add(chunk_size).min(plaintext.len());
1578 let chunk_pt: &[u8] = if off >= plaintext.len() {
1579 &[]
1582 } else {
1583 &plaintext[off..end]
1584 };
1585 let nonce_bytes = nonce_v6(&salt, i);
1586 let nonce = Nonce::from_slice(&nonce_bytes);
1587 let aad = aad_v6(i, chunk_count, key_id, &salt);
1588 let ct_with_tag = cipher
1589 .encrypt(
1590 nonce,
1591 Payload {
1592 msg: chunk_pt,
1593 aad: &aad,
1594 },
1595 )
1596 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
1597 debug_assert!(ct_with_tag.len() >= TAG_LEN);
1598 let split = ct_with_tag.len() - TAG_LEN;
1599 let (ct, tag) = ct_with_tag.split_at(split);
1600 out.extend_from_slice(tag);
1601 out.extend_from_slice(ct);
1602 crate::metrics::record_sse_streaming_chunk("encrypt");
1603 }
1604 Ok(Bytes::from(out))
1605}
1606
1607#[derive(Debug, Clone, Copy)]
1611enum ChunkedSalt {
1612 V5([u8; 4]),
1613 V6([u8; 8]),
1614}
1615
1616#[derive(Debug, Clone, Copy)]
1621struct ChunkedHeader {
1622 #[allow(dead_code)]
1628 variant: ChunkedVariant,
1629 key_id: u16,
1630 chunk_size: u32,
1631 chunk_count: u32,
1632 salt: ChunkedSalt,
1633 chunks_offset: usize,
1637}
1638
1639#[derive(Debug, Clone, Copy)]
1646pub struct S4E6Header<'a> {
1647 pub key_id: u16,
1648 pub chunk_size: u32,
1649 pub chunk_count: u32,
1650 pub salt: &'a [u8; 8],
1651}
1652
1653pub fn parse_s4e6_header(blob: &[u8]) -> Result<S4E6Header<'_>, SseError> {
1657 if blob.len() < S4E6_HEADER_BYTES {
1658 return Err(SseError::ChunkFrameTruncated { what: "header" });
1659 }
1660 if &blob[..4] != SSE_MAGIC_V6 {
1661 let mut got = [0u8; 4];
1662 got.copy_from_slice(&blob[..4]);
1663 return Err(SseError::BadMagic { got });
1664 }
1665 let algo = blob[4];
1666 if algo != ALGO_AES_256_GCM {
1667 return Err(SseError::UnsupportedAlgo { tag: algo });
1668 }
1669 let key_id = u16::from_be_bytes([blob[5], blob[6]]);
1670 let chunk_size = u32::from_be_bytes([blob[8], blob[9], blob[10], blob[11]]);
1672 let chunk_count = u32::from_be_bytes([blob[12], blob[13], blob[14], blob[15]]);
1673 if chunk_size == 0 {
1674 return Err(SseError::ChunkSizeInvalid);
1675 }
1676 if chunk_count == 0 {
1677 return Err(SseError::ChunkFrameTruncated {
1678 what: "chunk_count == 0",
1679 });
1680 }
1681 if chunk_count > S4E6_MAX_CHUNK_COUNT {
1682 return Err(SseError::ChunkCountTooLarge {
1683 got: chunk_count,
1684 max: S4E6_MAX_CHUNK_COUNT,
1685 });
1686 }
1687 let salt: &[u8; 8] = (&blob[16..24]).try_into().expect("8B salt slice");
1688 Ok(S4E6Header {
1689 key_id,
1690 chunk_size,
1691 chunk_count,
1692 salt,
1693 })
1694}
1695
1696fn parse_chunked_header(body: &[u8], max_body_bytes: usize) -> Result<ChunkedHeader, SseError> {
1697 if body.len() < 4 {
1698 return Err(SseError::ChunkFrameTruncated { what: "magic" });
1699 }
1700 let magic = &body[..4];
1701 let variant = if magic == SSE_MAGIC_V5 {
1702 ChunkedVariant::V5
1703 } else if magic == SSE_MAGIC_V6 {
1704 ChunkedVariant::V6
1705 } else {
1706 let mut got = [0u8; 4];
1707 got.copy_from_slice(magic);
1708 return Err(SseError::BadMagic { got });
1709 };
1710 let header_bytes = variant.header_bytes();
1711 if body.len() < header_bytes {
1712 return Err(SseError::ChunkFrameTruncated { what: "header" });
1713 }
1714 let algo = body[4];
1715 if algo != ALGO_AES_256_GCM {
1716 return Err(SseError::UnsupportedAlgo { tag: algo });
1717 }
1718 let key_id = u16::from_be_bytes([body[5], body[6]]);
1719 let chunk_size = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
1721 let chunk_count = u32::from_be_bytes([body[12], body[13], body[14], body[15]]);
1722 if chunk_size == 0 {
1723 return Err(SseError::ChunkSizeInvalid);
1724 }
1725 if chunk_count == 0 {
1726 return Err(SseError::ChunkFrameTruncated {
1727 what: "chunk_count == 0",
1728 });
1729 }
1730 let salt = match variant {
1731 ChunkedVariant::V5 => {
1732 let mut s = [0u8; 4];
1733 s.copy_from_slice(&body[16..20]);
1734 ChunkedSalt::V5(s)
1735 }
1736 ChunkedVariant::V6 => {
1737 if chunk_count > S4E6_MAX_CHUNK_COUNT {
1742 return Err(SseError::ChunkCountTooLarge {
1743 got: chunk_count,
1744 max: S4E6_MAX_CHUNK_COUNT,
1745 });
1746 }
1747 let mut s = [0u8; 8];
1748 s.copy_from_slice(&body[16..24]);
1749 ChunkedSalt::V6(s)
1750 }
1751 };
1752
1753 let chunk_size_u64 = chunk_size as u64;
1781 let chunk_count_u64 = chunk_count as u64;
1782 let expected_plain_size =
1783 chunk_size_u64
1784 .checked_mul(chunk_count_u64)
1785 .ok_or(SseError::ChunkFrameTooLarge {
1786 details: "chunk_size * chunk_count overflows u64",
1787 })?;
1788 let per_chunk_overhead = S4E5_PER_CHUNK_OVERHEAD as u64; let total_tag_overhead =
1790 per_chunk_overhead
1791 .checked_mul(chunk_count_u64)
1792 .ok_or(SseError::ChunkFrameTooLarge {
1793 details: "tag_len * chunk_count overflows u64",
1794 })?;
1795 let max_total = expected_plain_size
1796 .checked_add(total_tag_overhead)
1797 .and_then(|t| t.checked_add(header_bytes as u64))
1798 .ok_or(SseError::ChunkFrameTooLarge {
1799 details: "header + plaintext + tag overhead overflows u64",
1800 })?;
1801 if (body.len() as u64) > max_total {
1810 return Err(SseError::ChunkFrameTruncated {
1811 what: "trailing bytes past declared chunk geometry",
1812 });
1813 }
1814 if expected_plain_size > max_body_bytes as u64 {
1819 return Err(SseError::ChunkFrameTooLarge {
1820 details: "declared plaintext exceeds gateway max_body_bytes",
1821 });
1822 }
1823
1824 Ok(ChunkedHeader {
1825 variant,
1826 key_id,
1827 chunk_size,
1828 chunk_count,
1829 salt,
1830 chunks_offset: header_bytes,
1831 })
1832}
1833
1834fn decrypt_chunked_chunk(
1838 cipher: &Aes256Gcm,
1839 chunk_index: u32,
1840 chunk_count: u32,
1841 key_id: u16,
1842 salt: &ChunkedSalt,
1843 tag: &[u8; TAG_LEN],
1844 ct: &[u8],
1845) -> Result<Bytes, SseError> {
1846 let nonce_bytes = match salt {
1847 ChunkedSalt::V5(s) => nonce_v5(s, chunk_index),
1848 ChunkedSalt::V6(s) => nonce_v6(s, chunk_index),
1849 };
1850 let nonce = Nonce::from_slice(&nonce_bytes);
1851 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1852 ct_with_tag.extend_from_slice(ct);
1853 ct_with_tag.extend_from_slice(tag);
1854 let result = match salt {
1855 ChunkedSalt::V5(s) => {
1856 let aad = aad_v5(chunk_index, chunk_count, key_id, s);
1857 cipher.decrypt(
1858 nonce,
1859 Payload {
1860 msg: &ct_with_tag,
1861 aad: &aad,
1862 },
1863 )
1864 }
1865 ChunkedSalt::V6(s) => {
1866 let aad = aad_v6(chunk_index, chunk_count, key_id, s);
1867 cipher.decrypt(
1868 nonce,
1869 Payload {
1870 msg: &ct_with_tag,
1871 aad: &aad,
1872 },
1873 )
1874 }
1875 };
1876 result
1877 .map(Bytes::from)
1878 .map_err(|_| SseError::ChunkAuthFailed { chunk_index })
1879}
1880
1881fn walk_chunked<F: FnMut(Bytes) -> Result<(), SseError>>(
1887 body: &[u8],
1888 keyring: &SseKeyring,
1889 max_body_bytes: usize,
1890 mut emit: F,
1891) -> Result<(), SseError> {
1892 let hdr = parse_chunked_header(body, max_body_bytes)?;
1893 let key = keyring
1894 .get(hdr.key_id)
1895 .ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
1896 let cipher = Aes256Gcm::new(key.as_aes_key());
1897
1898 let mut cursor = hdr.chunks_offset;
1899 let chunk_size = hdr.chunk_size as usize;
1900 for i in 0..hdr.chunk_count {
1901 if cursor + TAG_LEN > body.len() {
1902 return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
1903 }
1904 let tag_off = cursor;
1905 let ct_off = tag_off + TAG_LEN;
1906 let is_last = i + 1 == hdr.chunk_count;
1907 let ct_len = if is_last {
1908 if ct_off > body.len() {
1909 return Err(SseError::ChunkFrameTruncated {
1910 what: "final chunk ciphertext",
1911 });
1912 }
1913 let remaining = body.len() - ct_off;
1914 if remaining > chunk_size {
1915 return Err(SseError::ChunkFrameTruncated {
1916 what: "trailing bytes after final chunk",
1917 });
1918 }
1919 remaining
1920 } else {
1921 chunk_size
1922 };
1923 let ct_end = ct_off + ct_len;
1924 if ct_end > body.len() {
1925 return Err(SseError::ChunkFrameTruncated {
1926 what: "chunk ciphertext",
1927 });
1928 }
1929 let mut tag = [0u8; TAG_LEN];
1930 tag.copy_from_slice(&body[tag_off..ct_off]);
1931 let ct = &body[ct_off..ct_end];
1932 let plain =
1933 decrypt_chunked_chunk(&cipher, i, hdr.chunk_count, hdr.key_id, &hdr.salt, &tag, ct)?;
1934 crate::metrics::record_sse_streaming_chunk("decrypt");
1935 emit(plain)?;
1936 cursor = ct_end;
1937 }
1938 if cursor != body.len() {
1939 return Err(SseError::ChunkFrameTruncated {
1940 what: "trailing bytes after declared chunk_count",
1941 });
1942 }
1943 Ok(())
1944}
1945
1946pub fn decrypt_chunked_buffered(
1959 body: &[u8],
1960 keyring: &SseKeyring,
1961 max_body_bytes: usize,
1962) -> Result<Bytes, SseError> {
1963 let hdr = parse_chunked_header(body, max_body_bytes)?;
1964 let mut out = Vec::with_capacity(hdr.chunk_size as usize * hdr.chunk_count as usize);
1970 walk_chunked(body, keyring, max_body_bytes, |chunk| {
1971 out.extend_from_slice(&chunk);
1972 Ok(())
1973 })?;
1974 Ok(Bytes::from(out))
1975}
1976
1977pub fn decrypt_chunked_buffered_default(
1984 body: &[u8],
1985 keyring: &SseKeyring,
1986) -> Result<Bytes, SseError> {
1987 decrypt_chunked_buffered(body, keyring, DEFAULT_MAX_BODY_BYTES)
1988}
1989
1990pub fn decrypt_chunked_stream(
2015 body: bytes::Bytes,
2016 keyring: &SseKeyring,
2017) -> impl futures::Stream<Item = Result<Bytes, SseError>> + 'static {
2018 use futures::stream::{self, StreamExt};
2019
2020 let prelude = (|| {
2027 let hdr = parse_chunked_header(&body, usize::MAX)?;
2037 let key = keyring
2038 .get(hdr.key_id)
2039 .ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
2040 let cipher = Aes256Gcm::new(key.as_aes_key());
2041 Ok::<_, SseError>((hdr, cipher))
2042 })();
2043
2044 match prelude {
2045 Err(e) => stream::iter(std::iter::once(Err(e))).left_stream(),
2046 Ok((hdr, cipher)) => {
2047 let chunks_offset = hdr.chunks_offset;
2048 let state = ChunkedDecryptState {
2049 body,
2050 cipher,
2051 hdr,
2052 cursor: chunks_offset,
2053 next_index: 0,
2054 };
2055 stream::try_unfold(state, decrypt_next_chunk).right_stream()
2056 }
2057 }
2058}
2059
2060struct ChunkedDecryptState {
2064 body: bytes::Bytes,
2065 cipher: Aes256Gcm,
2066 hdr: ChunkedHeader,
2067 cursor: usize,
2068 next_index: u32,
2069}
2070
2071async fn decrypt_next_chunk(
2072 mut state: ChunkedDecryptState,
2073) -> Result<Option<(Bytes, ChunkedDecryptState)>, SseError> {
2074 if state.next_index >= state.hdr.chunk_count {
2075 if state.cursor != state.body.len() {
2078 return Err(SseError::ChunkFrameTruncated {
2079 what: "trailing bytes after declared chunk_count",
2080 });
2081 }
2082 return Ok(None);
2083 }
2084 let i = state.next_index;
2085 let chunk_size = state.hdr.chunk_size as usize;
2086 if state.cursor + TAG_LEN > state.body.len() {
2087 return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
2088 }
2089 let tag_off = state.cursor;
2090 let ct_off = tag_off + TAG_LEN;
2091 let is_last = i + 1 == state.hdr.chunk_count;
2092 let ct_len = if is_last {
2093 if ct_off > state.body.len() {
2094 return Err(SseError::ChunkFrameTruncated {
2095 what: "final chunk ciphertext",
2096 });
2097 }
2098 let remaining = state.body.len() - ct_off;
2099 if remaining > chunk_size {
2100 return Err(SseError::ChunkFrameTruncated {
2101 what: "trailing bytes after final chunk",
2102 });
2103 }
2104 remaining
2105 } else {
2106 chunk_size
2107 };
2108 let ct_end = ct_off + ct_len;
2109 if ct_end > state.body.len() {
2110 return Err(SseError::ChunkFrameTruncated {
2111 what: "chunk ciphertext",
2112 });
2113 }
2114 let mut tag = [0u8; TAG_LEN];
2115 tag.copy_from_slice(&state.body[tag_off..ct_off]);
2116 let ct = &state.body[ct_off..ct_end];
2117 let plain = decrypt_chunked_chunk(
2118 &state.cipher,
2119 i,
2120 state.hdr.chunk_count,
2121 state.hdr.key_id,
2122 &state.hdr.salt,
2123 &tag,
2124 ct,
2125 )?;
2126 crate::metrics::record_sse_streaming_chunk("decrypt");
2127 state.cursor = ct_end;
2128 state.next_index += 1;
2129 Ok(Some((plain, state)))
2130}
2131
2132#[cfg(test)]
2138fn encrypt_v2_chunked_s4e5_for_test(
2139 plaintext: &[u8],
2140 keyring: &SseKeyring,
2141 chunk_size: usize,
2142) -> Result<Bytes, SseError> {
2143 if chunk_size == 0 {
2144 return Err(SseError::ChunkSizeInvalid);
2145 }
2146 let (key_id, key) = keyring.active();
2147 let cipher = Aes256Gcm::new(key.as_aes_key());
2148 let mut salt = [0u8; 4];
2149 rand::rngs::OsRng.fill_bytes(&mut salt);
2150
2151 let chunk_count: u32 = if plaintext.is_empty() {
2152 1
2153 } else {
2154 plaintext
2155 .len()
2156 .div_ceil(chunk_size)
2157 .try_into()
2158 .expect("chunk_count overflows u32")
2159 };
2160
2161 let mut out = Vec::with_capacity(
2162 S4E5_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E5_PER_CHUNK_OVERHEAD),
2163 );
2164 out.extend_from_slice(SSE_MAGIC_V5);
2165 out.push(ALGO_AES_256_GCM);
2166 out.extend_from_slice(&key_id.to_be_bytes());
2167 out.push(0u8);
2168 out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
2169 out.extend_from_slice(&chunk_count.to_be_bytes());
2170 out.extend_from_slice(&salt);
2171
2172 for i in 0..chunk_count {
2173 let off = (i as usize).saturating_mul(chunk_size);
2174 let end = off.saturating_add(chunk_size).min(plaintext.len());
2175 let chunk_pt: &[u8] = if off >= plaintext.len() {
2176 &[]
2177 } else {
2178 &plaintext[off..end]
2179 };
2180 let nonce_bytes = nonce_v5(&salt, i);
2181 let nonce = Nonce::from_slice(&nonce_bytes);
2182 let aad = aad_v5(i, chunk_count, key_id, &salt);
2183 let ct_with_tag = cipher
2184 .encrypt(
2185 nonce,
2186 Payload {
2187 msg: chunk_pt,
2188 aad: &aad,
2189 },
2190 )
2191 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
2192 let split = ct_with_tag.len() - TAG_LEN;
2193 let (ct, tag) = ct_with_tag.split_at(split);
2194 out.extend_from_slice(tag);
2195 out.extend_from_slice(ct);
2196 }
2197 Ok(Bytes::from(out))
2198}
2199
2200#[cfg(test)]
2201mod tests {
2202 use super::*;
2203
2204 fn key32(seed: u8) -> Arc<SseKey> {
2205 Arc::new(SseKey::from_bytes(&[seed; 32]).unwrap())
2206 }
2207
2208 fn keyring_single(seed: u8) -> SseKeyring {
2209 SseKeyring::new(1, key32(seed))
2210 }
2211
2212 #[test]
2213 fn roundtrip_basic_v1() {
2214 let k = SseKey::from_bytes(&[7u8; 32]).unwrap();
2216 let pt = b"the quick brown fox jumps over the lazy dog";
2217 let ct = encrypt(&k, pt);
2218 assert!(looks_encrypted(&ct));
2219 assert_eq!(&ct[..4], SSE_MAGIC_V1);
2220 assert_eq!(ct[4], ALGO_AES_256_GCM);
2221 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
2222 let kr = SseKeyring::new(1, Arc::new(k));
2224 let pt2 = decrypt(&ct, &kr).unwrap();
2225 assert_eq!(pt2.as_ref(), pt);
2226 }
2227
2228 #[test]
2229 fn s4e2_roundtrip_active_key() {
2230 let kr = keyring_single(7);
2231 let pt = b"S4E2 active-key roundtrip";
2232 let ct = encrypt_v2(pt, &kr);
2233 assert_eq!(&ct[..4], SSE_MAGIC_V2);
2234 assert_eq!(ct[4], ALGO_AES_256_GCM);
2235 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE");
2236 assert_eq!(ct[7], 0, "reserved byte");
2237 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
2238 assert!(looks_encrypted(&ct));
2239 let pt2 = decrypt(&ct, &kr).unwrap();
2240 assert_eq!(pt2.as_ref(), pt);
2241 }
2242
2243 #[test]
2244 fn decrypt_s4e1_via_active_only_keyring() {
2245 let k_arc = key32(11);
2248 let legacy_ct = encrypt(&k_arc, b"v0.4 vintage object");
2249 assert_eq!(&legacy_ct[..4], SSE_MAGIC_V1);
2250 let kr = SseKeyring::new(1, Arc::clone(&k_arc));
2251 let plain = decrypt(&legacy_ct, &kr).unwrap();
2252 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
2253 }
2254
2255 #[test]
2256 fn decrypt_s4e2_under_old_key_after_rotation() {
2257 let k1 = key32(1);
2261 let k2 = key32(2);
2262 let mut kr_old = SseKeyring::new(1, Arc::clone(&k1));
2263 let ct = encrypt_v2(b"old-rotation object", &kr_old);
2264 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
2265
2266 kr_old.add(2, Arc::clone(&k2));
2268 let mut kr_new = SseKeyring::new(2, Arc::clone(&k2));
2269 kr_new.add(1, Arc::clone(&k1));
2270
2271 let plain = decrypt(&ct, &kr_new).unwrap();
2272 assert_eq!(plain.as_ref(), b"old-rotation object");
2273
2274 let new_ct = encrypt_v2(b"new-rotation object", &kr_new);
2276 assert_eq!(u16::from_be_bytes([new_ct[5], new_ct[6]]), 2);
2277 let plain_new = decrypt(&new_ct, &kr_new).unwrap();
2278 assert_eq!(plain_new.as_ref(), b"new-rotation object");
2279 }
2280
2281 #[test]
2282 fn s4e2_unknown_key_id_errors() {
2283 let kr = keyring_single(3); let kr_other = SseKeyring::new(99, key32(3));
2285 let ct = encrypt_v2(b"x", &kr_other); let err = decrypt(&ct, &kr).unwrap_err();
2287 assert!(
2288 matches!(err, SseError::KeyNotInKeyring { id: 99 }),
2289 "got {err:?}"
2290 );
2291 }
2292
2293 #[test]
2294 fn s4e2_tampered_key_id_fails_auth() {
2295 let kr = SseKeyring::new(1, key32(4));
2296 let mut kr_with_2 = kr.clone();
2297 kr_with_2.add(2, key32(5)); let mut ct = encrypt_v2(b"do not flip my key id", &kr).to_vec();
2299 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
2303 ct[5] = 0;
2304 ct[6] = 2;
2305 let err = decrypt(&ct, &kr_with_2).unwrap_err();
2306 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2307 }
2308
2309 #[test]
2310 fn s4e2_tampered_ciphertext_fails() {
2311 let kr = SseKeyring::new(7, key32(9));
2312 let mut ct = encrypt_v2(b"secret message v2", &kr).to_vec();
2313 let last = ct.len() - 1;
2314 ct[last] ^= 0x01;
2315 let err = decrypt(&ct, &kr).unwrap_err();
2316 assert!(matches!(err, SseError::DecryptFailed));
2317 }
2318
2319 #[test]
2320 fn s4e2_tampered_algo_byte_fails() {
2321 let kr = SseKeyring::new(1, key32(2));
2322 let mut ct = encrypt_v2(b"hi", &kr).to_vec();
2323 ct[4] = 99;
2324 let err = decrypt(&ct, &kr).unwrap_err();
2325 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
2326 }
2327
2328 #[test]
2329 fn wrong_key_fails_v1_via_keyring() {
2330 let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
2332 let ct = encrypt(&k1, b"secret");
2333 let kr_wrong = SseKeyring::new(1, Arc::new(SseKey::from_bytes(&[2u8; 32]).unwrap()));
2334 let err = decrypt(&ct, &kr_wrong).unwrap_err();
2335 assert!(matches!(err, SseError::DecryptFailed));
2336 }
2337
2338 #[test]
2339 fn rejects_short_body() {
2340 let kr = SseKeyring::new(1, key32(1));
2341 let err = decrypt(b"short", &kr).unwrap_err();
2342 assert!(matches!(err, SseError::TooShort { got: 5 }));
2343 }
2344
2345 #[test]
2346 fn looks_encrypted_passthrough_returns_false() {
2347 let f2 = b"S4F2\x01\x00\x00\x00........................................";
2349 assert!(!looks_encrypted(f2));
2350 assert!(!looks_encrypted(b""));
2351 }
2352
2353 #[test]
2354 fn looks_encrypted_detects_both_v1_and_v2() {
2355 let kr = SseKeyring::new(1, key32(8));
2356 let v1 = encrypt(&SseKey::from_bytes(&[8u8; 32]).unwrap(), b"x");
2357 let v2 = encrypt_v2(b"x", &kr);
2358 assert!(looks_encrypted(&v1));
2359 assert!(looks_encrypted(&v2));
2360 }
2361
2362 #[test]
2363 fn key_from_hex_string() {
2364 let bad =
2365 SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
2366 .unwrap_err();
2367 assert!(matches!(bad, SseError::BadKeyLength { .. }));
2368 let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
2369 let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
2370 }
2371
2372 #[test]
2373 fn encrypt_v2_uses_random_nonce() {
2374 let kr = SseKeyring::new(1, key32(3));
2375 let pt = b"deterministic input";
2376 let a = encrypt_v2(pt, &kr);
2377 let b = encrypt_v2(pt, &kr);
2378 assert_ne!(a, b, "nonce must be random per-call");
2379 }
2380
2381 #[test]
2382 fn keyring_active_and_get() {
2383 let k1 = key32(1);
2384 let k2 = key32(2);
2385 let mut kr = SseKeyring::new(1, Arc::clone(&k1));
2386 kr.add(2, Arc::clone(&k2));
2387 let (id, active) = kr.active();
2388 assert_eq!(id, 1);
2389 assert_eq!(active.bytes, [1u8; 32]);
2390 assert!(kr.get(2).is_some());
2391 assert!(kr.get(3).is_none());
2392 }
2393
2394 use base64::Engine as _;
2399
2400 fn cust_key(seed: u8) -> CustomerKeyMaterial {
2401 let key = [seed; KEY_LEN];
2402 let key_md5 = compute_key_md5(&key);
2403 CustomerKeyMaterial { key, key_md5 }
2404 }
2405
2406 #[test]
2407 fn s4e3_roundtrip_happy_path() {
2408 let m = cust_key(42);
2409 let pt = b"top-secret SSE-C payload";
2410 let ct = encrypt_with_source(
2411 pt,
2412 SseSource::CustomerKey {
2413 key: &m.key,
2414 key_md5: &m.key_md5,
2415 },
2416 );
2417 assert_eq!(&ct[..4], SSE_MAGIC_V3);
2419 assert_eq!(ct[4], ALGO_AES_256_GCM);
2420 assert_eq!(&ct[5..5 + KEY_MD5_LEN], &m.key_md5);
2421 assert_eq!(ct.len(), SSE_HEADER_BYTES_V3 + pt.len());
2422 assert!(looks_encrypted(&ct));
2423 let plain = decrypt(
2425 &ct,
2426 SseSource::CustomerKey {
2427 key: &m.key,
2428 key_md5: &m.key_md5,
2429 },
2430 )
2431 .unwrap();
2432 assert_eq!(plain.as_ref(), pt);
2433 let plain2 = decrypt(&ct, &m).unwrap();
2435 assert_eq!(plain2.as_ref(), pt);
2436 }
2437
2438 #[test]
2439 fn s4e3_wrong_key_yields_wrong_customer_key_error() {
2440 let m = cust_key(1);
2441 let other = cust_key(2);
2442 let ct = encrypt_with_source(b"payload", (&m).into());
2443 let err = decrypt(
2444 &ct,
2445 SseSource::CustomerKey {
2446 key: &other.key,
2447 key_md5: &other.key_md5,
2448 },
2449 )
2450 .unwrap_err();
2451 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
2452 }
2453
2454 #[test]
2455 fn s4e3_tampered_stored_md5_is_caught() {
2456 let m = cust_key(7);
2463 let mut ct = encrypt_with_source(b"victim payload", (&m).into()).to_vec();
2464 ct[5] ^= 0x55;
2466 let err = decrypt(
2468 &ct,
2469 SseSource::CustomerKey {
2470 key: &m.key,
2471 key_md5: &m.key_md5,
2472 },
2473 )
2474 .unwrap_err();
2475 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
2476 }
2477
2478 #[test]
2479 fn s4e3_tampered_md5_with_matching_supplied_md5_fails_aead() {
2480 let m = cust_key(3);
2484 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
2485 ct[5] ^= 0xFF;
2486 let mut bogus_md5 = m.key_md5;
2487 bogus_md5[0] ^= 0xFF;
2488 let err = decrypt(
2489 &ct,
2490 SseSource::CustomerKey {
2491 key: &m.key,
2492 key_md5: &bogus_md5,
2493 },
2494 )
2495 .unwrap_err();
2496 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2497 }
2498
2499 #[test]
2500 fn s4e3_tampered_ciphertext_fails_aead() {
2501 let m = cust_key(8);
2502 let mut ct = encrypt_with_source(b"sealed message", (&m).into()).to_vec();
2503 let last = ct.len() - 1;
2504 ct[last] ^= 0x01;
2505 let err = decrypt(&ct, &m).unwrap_err();
2506 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2507 }
2508
2509 #[test]
2510 fn s4e3_tampered_algo_byte_rejected() {
2511 let m = cust_key(9);
2512 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
2513 ct[4] = 99;
2514 let err = decrypt(&ct, &m).unwrap_err();
2515 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
2516 }
2517
2518 #[test]
2519 fn s4e3_uses_random_nonce() {
2520 let m = cust_key(10);
2521 let a = encrypt_with_source(b"deterministic input", (&m).into());
2522 let b = encrypt_with_source(b"deterministic input", (&m).into());
2523 assert_ne!(a, b, "nonce must be random per-call");
2524 }
2525
2526 #[test]
2527 fn parse_customer_key_headers_happy_path() {
2528 let key = [11u8; KEY_LEN];
2529 let md5 = compute_key_md5(&key);
2530 let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
2531 let md5_b64 = base64::engine::general_purpose::STANDARD.encode(md5);
2532 let m = parse_customer_key_headers("AES256", &key_b64, &md5_b64).unwrap();
2533 assert_eq!(m.key, key);
2534 assert_eq!(m.key_md5, md5);
2535 }
2536
2537 #[test]
2538 fn parse_customer_key_headers_rejects_wrong_algorithm() {
2539 let key = [1u8; KEY_LEN];
2540 let md5 = compute_key_md5(&key);
2541 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2542 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2543 let err = parse_customer_key_headers("AES128", &kb, &mb).unwrap_err();
2544 assert!(
2545 matches!(err, SseError::CustomerKeyAlgorithmUnsupported { ref algo } if algo == "AES128"),
2546 "got {err:?}"
2547 );
2548 let err2 = parse_customer_key_headers("aes256", &kb, &mb).unwrap_err();
2550 assert!(
2551 matches!(err2, SseError::CustomerKeyAlgorithmUnsupported { .. }),
2552 "got {err2:?}"
2553 );
2554 }
2555
2556 #[test]
2557 fn parse_customer_key_headers_rejects_wrong_key_length() {
2558 let short_key = vec![5u8; 16]; let md5 = compute_key_md5(&short_key);
2560 let kb = base64::engine::general_purpose::STANDARD.encode(&short_key);
2561 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2562 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2563 assert!(
2564 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("key length")),
2565 "got {err:?}"
2566 );
2567 }
2568
2569 #[test]
2570 fn parse_customer_key_headers_rejects_wrong_md5_length() {
2571 let key = [3u8; KEY_LEN];
2572 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2573 let bad_md5 = vec![0u8; 15];
2575 let mb = base64::engine::general_purpose::STANDARD.encode(bad_md5);
2576 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2577 assert!(
2578 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 length")),
2579 "got {err:?}"
2580 );
2581 }
2582
2583 #[test]
2584 fn parse_customer_key_headers_rejects_md5_mismatch() {
2585 let key = [4u8; KEY_LEN];
2586 let other = [5u8; KEY_LEN];
2587 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2588 let wrong_md5 = compute_key_md5(&other);
2589 let mb = base64::engine::general_purpose::STANDARD.encode(wrong_md5);
2590 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2591 assert!(
2592 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 does not match")),
2593 "got {err:?}"
2594 );
2595 }
2596
2597 #[test]
2598 fn parse_customer_key_headers_rejects_bad_base64() {
2599 let valid_key = [0u8; KEY_LEN];
2600 let md5 = compute_key_md5(&valid_key);
2601 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2602 let err = parse_customer_key_headers("AES256", "!!!not-base64!!!", &mb).unwrap_err();
2603 assert!(
2604 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
2605 "got {err:?}"
2606 );
2607 let kb = base64::engine::general_purpose::STANDARD.encode(valid_key);
2609 let err2 = parse_customer_key_headers("AES256", &kb, "??not-base64??").unwrap_err();
2610 assert!(
2611 matches!(err2, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
2612 "got {err2:?}"
2613 );
2614 }
2615
2616 #[test]
2617 fn parse_customer_key_headers_trims_whitespace() {
2618 let key = [12u8; KEY_LEN];
2620 let md5 = compute_key_md5(&key);
2621 let kb = format!(
2622 " {}\n",
2623 base64::engine::general_purpose::STANDARD.encode(key)
2624 );
2625 let mb = format!(
2626 "\t{} ",
2627 base64::engine::general_purpose::STANDARD.encode(md5)
2628 );
2629 let m = parse_customer_key_headers("AES256", &kb, &mb).unwrap();
2630 assert_eq!(m.key, key);
2631 }
2632
2633 #[test]
2638 fn back_compat_decrypt_s4e1_with_keyring_source() {
2639 let k = key32(33);
2640 let legacy_ct = encrypt(&k, b"v0.4 vintage object");
2641 let kr = SseKeyring::new(1, Arc::clone(&k));
2642 let plain = decrypt(&legacy_ct, &kr).unwrap();
2645 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
2646 let plain2 = decrypt(&legacy_ct, SseSource::Keyring(&kr)).unwrap();
2647 assert_eq!(plain2.as_ref(), b"v0.4 vintage object");
2648 }
2649
2650 #[test]
2651 fn back_compat_decrypt_s4e2_with_keyring_source() {
2652 let kr = keyring_single(34);
2653 let ct = encrypt_v2(b"v0.5 #29 object", &kr);
2654 let plain = decrypt(&ct, &kr).unwrap();
2655 assert_eq!(plain.as_ref(), b"v0.5 #29 object");
2656 let ct2 = encrypt_with_source(b"v0.5 #29 object", SseSource::Keyring(&kr));
2659 assert_eq!(&ct2[..4], SSE_MAGIC_V2);
2660 let plain2 = decrypt(&ct2, &kr).unwrap();
2661 assert_eq!(plain2.as_ref(), b"v0.5 #29 object");
2662 }
2663
2664 #[test]
2665 fn s4e2_blob_with_customer_key_source_is_rejected() {
2666 let kr = keyring_single(50);
2670 let ct = encrypt_v2(b"server-managed object", &kr);
2671 let m = cust_key(99);
2672 let err = decrypt(
2673 &ct,
2674 SseSource::CustomerKey {
2675 key: &m.key,
2676 key_md5: &m.key_md5,
2677 },
2678 )
2679 .unwrap_err();
2680 assert!(
2681 matches!(err, SseError::CustomerKeyUnexpected),
2682 "got {err:?}"
2683 );
2684 }
2685
2686 #[test]
2687 fn s4e3_blob_with_keyring_source_is_rejected() {
2688 let m = cust_key(60);
2691 let ct = encrypt_with_source(b"customer-key object", (&m).into());
2692 let kr = keyring_single(60);
2693 let err = decrypt(&ct, &kr).unwrap_err();
2694 assert!(matches!(err, SseError::CustomerKeyRequired), "got {err:?}");
2695 }
2696
2697 #[test]
2698 fn looks_encrypted_detects_s4e3() {
2699 let m = cust_key(13);
2700 let ct = encrypt_with_source(b"x", (&m).into());
2701 assert!(looks_encrypted(&ct));
2702 }
2703
2704 #[test]
2705 fn s4e3_rejects_short_body() {
2706 let mut short = Vec::new();
2709 short.extend_from_slice(SSE_MAGIC_V3);
2710 short.push(ALGO_AES_256_GCM);
2711 short.extend_from_slice(&[0u8; SSE_HEADER_BYTES - 5]);
2714 assert_eq!(short.len(), SSE_HEADER_BYTES);
2715 let m = cust_key(1);
2716 let err = decrypt(
2717 &short,
2718 SseSource::CustomerKey {
2719 key: &m.key,
2720 key_md5: &m.key_md5,
2721 },
2722 )
2723 .unwrap_err();
2724 assert!(matches!(err, SseError::TooShort { .. }), "got {err:?}");
2725 }
2726
2727 #[test]
2728 fn customer_key_material_debug_redacts_key() {
2729 let m = cust_key(99);
2730 let s = format!("{m:?}");
2731 assert!(s.contains("redacted"));
2732 assert!(!s.contains(&format!("{:?}", m.key.as_slice())));
2733 }
2734
2735 #[test]
2736 fn constant_time_eq_basic() {
2737 assert!(constant_time_eq(b"abc", b"abc"));
2738 assert!(!constant_time_eq(b"abc", b"abd"));
2739 assert!(!constant_time_eq(b"abc", b"abcd"));
2740 assert!(constant_time_eq(b"", b""));
2741 }
2742
2743 #[test]
2744 fn compute_key_md5_known_vector() {
2745 let got = compute_key_md5(b"");
2747 let expected_hex = "d41d8cd98f00b204e9800998ecf8427e";
2748 assert_eq!(hex_lower(&got), expected_hex);
2749 }
2750
2751 use crate::kms::{KmsBackend, LocalKms};
2756 use std::collections::HashMap;
2757 use std::path::PathBuf;
2758
2759 fn local_kms_with(key_ids: &[(&str, [u8; 32])]) -> LocalKms {
2760 let mut keks: HashMap<String, [u8; 32]> = HashMap::new();
2761 for (id, k) in key_ids {
2762 keks.insert((*id).to_string(), *k);
2763 }
2764 LocalKms::from_keks(PathBuf::from("/tmp/none"), keks)
2765 }
2766
2767 #[tokio::test]
2768 async fn s4e4_roundtrip_via_local_kms() {
2769 let kms = local_kms_with(&[("alpha", [42u8; 32])]);
2770 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2771 let mut dek = [0u8; 32];
2772 dek.copy_from_slice(&dek_vec);
2773 let pt = b"SSE-KMS envelope payload across the S4E4 frame";
2774 let ct = encrypt_with_source(
2775 pt,
2776 SseSource::Kms {
2777 dek: &dek,
2778 wrapped: &wrapped,
2779 },
2780 );
2781 assert_eq!(&ct[..4], SSE_MAGIC_V4);
2783 assert_eq!(ct[4], ALGO_AES_256_GCM);
2784 let key_id_len = ct[5] as usize;
2785 assert_eq!(key_id_len, "alpha".len());
2786 assert_eq!(&ct[6..6 + key_id_len], b"alpha");
2787 assert!(looks_encrypted(&ct));
2789 assert_eq!(peek_magic(&ct), Some("S4E4"));
2790 let plain = decrypt_with_kms(&ct, &kms).await.unwrap();
2792 assert_eq!(plain.as_ref(), pt);
2793 }
2794
2795 #[tokio::test]
2796 async fn s4e4_tampered_key_id_fails_aead() {
2797 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
2798 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2799 let mut dek = [0u8; 32];
2800 dek.copy_from_slice(&dek_vec);
2801 let mut ct = encrypt_with_source(
2802 b"do not redirect",
2803 SseSource::Kms {
2804 dek: &dek,
2805 wrapped: &wrapped,
2806 },
2807 )
2808 .to_vec();
2809 let key_id_off = 6;
2814 ct[key_id_off] = b'b';
2815 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2816 assert!(
2817 matches!(
2818 err,
2819 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
2820 | SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
2821 ),
2822 "got {err:?}"
2823 );
2824 }
2825
2826 #[tokio::test]
2827 async fn s4e4_tampered_key_id_to_real_other_id_still_fails() {
2828 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
2834 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2835 let mut dek = [0u8; 32];
2836 dek.copy_from_slice(&dek_vec);
2837 let mut ct = encrypt_with_source(
2838 b"redirect attempt",
2839 SseSource::Kms {
2840 dek: &dek,
2841 wrapped: &wrapped,
2842 },
2843 )
2844 .to_vec();
2845 let key_id_off = 6;
2848 ct[key_id_off..key_id_off + 5].copy_from_slice(b"beta_");
2849 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2856 assert!(
2857 matches!(
2858 err,
2859 SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
2860 ),
2861 "got {err:?}"
2862 );
2863 }
2864
2865 #[tokio::test]
2866 async fn s4e4_tampered_wrapped_dek_fails_unwrap() {
2867 let kms = local_kms_with(&[("k", [3u8; 32])]);
2868 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2869 let mut dek = [0u8; 32];
2870 dek.copy_from_slice(&dek_vec);
2871 let mut ct = encrypt_with_source(
2872 b"target body",
2873 SseSource::Kms {
2874 dek: &dek,
2875 wrapped: &wrapped,
2876 },
2877 )
2878 .to_vec();
2879 let key_id_len = ct[5] as usize;
2883 let wrapped_len_off = 6 + key_id_len;
2884 let wrapped_off = wrapped_len_off + 4;
2885 let mid = wrapped_off + (wrapped.ciphertext.len() / 2);
2886 ct[mid] ^= 0xFF;
2887 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2888 assert!(
2889 matches!(
2890 err,
2891 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
2892 ),
2893 "got {err:?}"
2894 );
2895 }
2896
2897 #[tokio::test]
2898 async fn s4e4_tampered_ciphertext_fails_aead() {
2899 let kms = local_kms_with(&[("k", [4u8; 32])]);
2900 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2901 let mut dek = [0u8; 32];
2902 dek.copy_from_slice(&dek_vec);
2903 let mut ct = encrypt_with_source(
2904 b"sealed body",
2905 SseSource::Kms {
2906 dek: &dek,
2907 wrapped: &wrapped,
2908 },
2909 )
2910 .to_vec();
2911 let last = ct.len() - 1;
2912 ct[last] ^= 0x01;
2913 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2914 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2915 }
2916
2917 #[tokio::test]
2918 async fn s4e4_uses_random_nonce_and_dek_per_put() {
2919 let kms = local_kms_with(&[("k", [5u8; 32])]);
2920 let (dek1_vec, wrapped1) = kms.generate_dek("k").await.unwrap();
2923 let (dek2_vec, wrapped2) = kms.generate_dek("k").await.unwrap();
2924 let mut dek1 = [0u8; 32];
2925 dek1.copy_from_slice(&dek1_vec);
2926 let mut dek2 = [0u8; 32];
2927 dek2.copy_from_slice(&dek2_vec);
2928 let pt = b"deterministic input";
2929 let a = encrypt_with_source(
2930 pt,
2931 SseSource::Kms {
2932 dek: &dek1,
2933 wrapped: &wrapped1,
2934 },
2935 );
2936 let b = encrypt_with_source(
2937 pt,
2938 SseSource::Kms {
2939 dek: &dek2,
2940 wrapped: &wrapped2,
2941 },
2942 );
2943 assert_ne!(a, b);
2944 let plain_a = decrypt_with_kms(&a, &kms).await.unwrap();
2946 let plain_b = decrypt_with_kms(&b, &kms).await.unwrap();
2947 assert_eq!(plain_a.as_ref(), pt);
2948 assert_eq!(plain_b.as_ref(), pt);
2949 }
2950
2951 #[tokio::test]
2952 async fn s4e4_sync_decrypt_returns_kms_async_required() {
2953 let kms = local_kms_with(&[("k", [6u8; 32])]);
2958 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2959 let mut dek = [0u8; 32];
2960 dek.copy_from_slice(&dek_vec);
2961 let ct = encrypt_with_source(
2962 b"async only",
2963 SseSource::Kms {
2964 dek: &dek,
2965 wrapped: &wrapped,
2966 },
2967 );
2968 let kr = SseKeyring::new(1, key32(0));
2970 let err = decrypt(&ct, &kr).unwrap_err();
2971 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
2972 }
2973
2974 #[test]
2975 fn back_compat_s4e1_e2_e3_still_decrypt_via_sync() {
2976 let k = key32(7);
2979 let v1 = encrypt(&k, b"v0.4 vintage");
2980 let kr = SseKeyring::new(1, Arc::clone(&k));
2981 assert_eq!(decrypt(&v1, &kr).unwrap().as_ref(), b"v0.4 vintage");
2982
2983 let v2 = encrypt_v2(b"v0.5 #29 vintage", &kr);
2984 assert_eq!(decrypt(&v2, &kr).unwrap().as_ref(), b"v0.5 #29 vintage");
2985
2986 let m = cust_key(7);
2987 let v3 = encrypt_with_source(b"v0.5 #27 vintage", (&m).into());
2988 assert_eq!(decrypt(&v3, &m).unwrap().as_ref(), b"v0.5 #27 vintage");
2989 }
2990
2991 #[test]
2992 fn peek_magic_distinguishes_all_variants() {
2993 let k = key32(9);
2996 let v1 = encrypt(&k, b"x");
2997 assert_eq!(peek_magic(&v1), Some("S4E1"));
2998 let kr = SseKeyring::new(1, Arc::clone(&k));
2999 let v2 = encrypt_v2(b"x", &kr);
3000 assert_eq!(peek_magic(&v2), Some("S4E2"));
3001 let m = cust_key(9);
3002 let v3 = encrypt_with_source(b"x", (&m).into());
3003 assert_eq!(peek_magic(&v3), Some("S4E3"));
3004 let mut v4 = Vec::new();
3009 v4.extend_from_slice(SSE_MAGIC_V4);
3010 v4.extend_from_slice(&[0u8; 40]);
3011 assert_eq!(peek_magic(&v4), Some("S4E4"));
3012 assert!(peek_magic(b"NOPE").is_none());
3014 assert!(peek_magic(b"short").is_none());
3015 assert!(peek_magic(&[0u8; 100]).is_none());
3016 }
3017
3018 #[tokio::test]
3019 async fn s4e4_truncated_frame_errors_cleanly() {
3020 let truncated = b"S4E4\x01\x05hi";
3023 let kms = local_kms_with(&[("k", [1u8; 32])]);
3024 let err = decrypt_with_kms(truncated, &kms).await.unwrap_err();
3025 assert!(
3026 matches!(err, SseError::KmsFrameTooShort { .. }),
3027 "got {err:?}"
3028 );
3029 }
3030
3031 #[tokio::test]
3032 async fn s4e4_oob_key_id_len_errors() {
3033 let mut body = Vec::new();
3037 body.extend_from_slice(SSE_MAGIC_V4);
3038 body.push(ALGO_AES_256_GCM);
3039 body.push(200u8); body.extend_from_slice(&[0u8; 50]);
3044 let kms = local_kms_with(&[("k", [1u8; 32])]);
3045 let err = decrypt_with_kms(&body, &kms).await.unwrap_err();
3046 assert!(
3047 matches!(err, SseError::KmsFrameFieldOob { .. }),
3048 "got {err:?}"
3049 );
3050 }
3051
3052 #[tokio::test]
3053 async fn s4e4_via_keyring_source_into_sync_decrypt_is_kms_async_required() {
3054 let kms = local_kms_with(&[("k", [9u8; 32])]);
3060 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
3061 let mut dek = [0u8; 32];
3062 dek.copy_from_slice(&dek_vec);
3063 let ct = encrypt_with_source(
3064 b"x",
3065 SseSource::Kms {
3066 dek: &dek,
3067 wrapped: &wrapped,
3068 },
3069 );
3070 let m = cust_key(1);
3071 let err = decrypt(&ct, &m).unwrap_err();
3072 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
3073 }
3074
3075 #[tokio::test]
3076 async fn s4e4_looks_encrypted_passthrough_returns_false_for_synthetic() {
3077 let mut not_s4e4 = Vec::new();
3079 not_s4e4.extend_from_slice(b"S4F4");
3080 not_s4e4.extend_from_slice(&[0u8; 60]);
3081 assert!(!looks_encrypted(¬_s4e4));
3082 assert_eq!(peek_magic(¬_s4e4), None);
3083 }
3084
3085 #[tokio::test]
3086 async fn s4e4_aad_length_prefix_prevents_byte_shifting() {
3087 let kms = local_kms_with(&[("kk", [11u8; 32])]);
3094 let (dek_vec, wrapped) = kms.generate_dek("kk").await.unwrap();
3095 let mut dek = [0u8; 32];
3096 dek.copy_from_slice(&dek_vec);
3097 let mut ct = encrypt_with_source(
3098 b"length-shift defense",
3099 SseSource::Kms {
3100 dek: &dek,
3101 wrapped: &wrapped,
3102 },
3103 )
3104 .to_vec();
3105 let key_id_len = ct[5] as usize;
3106 let wrapped_len_off = 6 + key_id_len;
3107 let original_len = u32::from_be_bytes([
3113 ct[wrapped_len_off],
3114 ct[wrapped_len_off + 1],
3115 ct[wrapped_len_off + 2],
3116 ct[wrapped_len_off + 3],
3117 ]);
3118 let new_len = (original_len - 1).to_be_bytes();
3119 ct[wrapped_len_off..wrapped_len_off + 4].copy_from_slice(&new_len);
3120 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
3121 assert!(
3124 matches!(
3125 err,
3126 SseError::KmsBackend(_)
3127 | SseError::DecryptFailed
3128 | SseError::KmsFrameFieldOob { .. }
3129 | SseError::KmsFrameTooShort { .. }
3130 ),
3131 "got {err:?}"
3132 );
3133 }
3134
3135 use futures::StreamExt;
3140
3141 async fn collect_chunks(
3144 s: impl futures::Stream<Item = Result<Bytes, SseError>>,
3145 ) -> Result<Vec<Bytes>, SseError> {
3146 let mut out = Vec::new();
3147 let mut s = std::pin::pin!(s);
3148 while let Some(item) = s.next().await {
3149 out.push(item?);
3150 }
3151 Ok(out)
3152 }
3153
3154 #[test]
3155 fn s4e6_encrypt_layout_10mb_at_1mib() {
3156 let kr = keyring_single(0x42);
3161 let chunk_size = 1024 * 1024;
3162 let pt_len = 10 * 1024 * 1024;
3163 let pt = vec![0xAB_u8; pt_len];
3164 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).expect("encrypt ok");
3165 assert_eq!(&ct[..4], SSE_MAGIC_V6, "new PUTs emit S4E6 (v0.8.1 #57)");
3166 assert_eq!(ct[4], ALGO_AES_256_GCM);
3167 assert_eq!(
3168 u16::from_be_bytes([ct[5], ct[6]]),
3169 1,
3170 "key_id BE = active id"
3171 );
3172 assert_eq!(ct[7], 0, "reserved must be 0");
3173 assert_eq!(
3174 u32::from_be_bytes([ct[8], ct[9], ct[10], ct[11]]),
3175 chunk_size as u32,
3176 "chunk_size BE",
3177 );
3178 assert_eq!(
3179 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3180 10,
3181 "chunk_count BE — 10 MiB / 1 MiB = 10 (no remainder)",
3182 );
3183 assert_eq!(&ct[16..24].len(), &8, "S4E6 salt slot is 8 bytes");
3187 assert_ne!(
3188 &ct[16..24],
3189 &[0u8; 8],
3190 "S4E6 salt must be random, not zeros"
3191 );
3192 assert_eq!(
3193 ct.len(),
3194 S4E6_HEADER_BYTES + 10 * S4E6_PER_CHUNK_OVERHEAD + pt_len,
3195 "total = header (24) + 10 tags + plaintext",
3196 );
3197 assert!(looks_encrypted(&ct), "looks_encrypted must accept S4E6");
3198 assert_eq!(peek_magic(&ct), Some("S4E6"));
3199 }
3200
3201 #[tokio::test]
3202 async fn s4e6_decrypt_chunked_stream_byte_equal() {
3203 let kr = keyring_single(0x55);
3206 let pt: Vec<u8> = (0..(10 * 1024 * 1024_u32))
3207 .map(|i| (i & 0xFF) as u8)
3208 .collect();
3209 let ct = encrypt_v2_chunked(&pt, &kr, 1024 * 1024).unwrap();
3210 assert_eq!(&ct[..4], SSE_MAGIC_V6, "new emit is S4E6");
3212 let stream = decrypt_chunked_stream(ct, &kr);
3213 let chunks = collect_chunks(stream).await.expect("stream ok");
3214 assert_eq!(chunks.len(), 10, "10 chunks expected for 10 MiB / 1 MiB");
3215 let mut joined = Vec::with_capacity(pt.len());
3216 for c in chunks {
3217 joined.extend_from_slice(&c);
3218 }
3219 assert_eq!(joined.len(), pt.len(), "byte length matches");
3220 assert_eq!(joined, pt, "byte-equal round-trip");
3221 }
3222
3223 #[tokio::test]
3224 async fn s4e6_single_chunk_for_small_object() {
3225 let kr = keyring_single(0x77);
3229 let pt = b"tiny payload, smaller than chunk_size";
3230 let ct = encrypt_v2_chunked(pt, &kr, 1024 * 1024).unwrap();
3231 assert_eq!(
3232 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3233 1,
3234 "small plaintext = single chunk",
3235 );
3236 let stream = decrypt_chunked_stream(ct, &kr);
3237 let chunks = collect_chunks(stream).await.expect("stream ok");
3238 assert_eq!(chunks.len(), 1);
3239 assert_eq!(chunks[0].as_ref(), pt);
3240 }
3241
3242 #[tokio::test]
3243 async fn s4e6_tampered_chunk_n_reports_chunk_index() {
3244 let kr = keyring_single(0x91);
3249 let chunk_size = 1024;
3250 let pt = vec![0xCD_u8; chunk_size * 8]; let mut ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap().to_vec();
3252 let target = S4E6_HEADER_BYTES + 3 * (TAG_LEN + chunk_size) + TAG_LEN;
3255 ct[target] ^= 0x42;
3256 let stream = decrypt_chunked_stream(bytes::Bytes::from(ct), &kr);
3257 let mut s = std::pin::pin!(stream);
3258 for expected_i in 0..3_u32 {
3260 let item = s.next().await.expect("yield");
3261 item.unwrap_or_else(|e| panic!("chunk {expected_i}: {e:?}"));
3262 }
3263 let err = s.next().await.expect("yield error").unwrap_err();
3265 assert!(
3266 matches!(err, SseError::ChunkAuthFailed { chunk_index: 3 }),
3267 "got {err:?}",
3268 );
3269 }
3270
3271 #[tokio::test]
3272 async fn s4e5_back_compat_s4e2_blob_rejected_with_clear_error() {
3273 let kr = keyring_single(0x12);
3277 let s4e2 = encrypt_v2(b"a v2 blob, not chunked", &kr);
3278 let stream = decrypt_chunked_stream(s4e2, &kr);
3279 let result = collect_chunks(stream).await;
3280 let err = result.unwrap_err();
3281 assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
3282 }
3283
3284 #[test]
3285 fn s4e6_salt_randomness_smoke() {
3286 let kr = keyring_single(0x33);
3293 let mut salts = std::collections::HashSet::new();
3294 let n = 1024;
3295 for _ in 0..n {
3296 let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
3297 let mut salt = [0u8; 8];
3298 salt.copy_from_slice(&ct[16..24]);
3299 salts.insert(salt);
3300 }
3301 assert!(
3302 salts.len() > n / 2,
3303 "expected most of the {n} salts to be unique (got {} unique)",
3304 salts.len(),
3305 );
3306 }
3307
3308 #[test]
3309 fn s4e6_chunk_size_zero_invalid() {
3310 let kr = keyring_single(0x66);
3311 let err = encrypt_v2_chunked(b"hi", &kr, 0).unwrap_err();
3312 assert!(matches!(err, SseError::ChunkSizeInvalid));
3313 }
3314
3315 #[tokio::test]
3316 async fn s4e6_truncated_body_reports_frame_truncated() {
3317 let kr = keyring_single(0xA1);
3320 let chunk_size = 256;
3321 let pt = vec![0u8; chunk_size * 4];
3322 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
3323 let trunc = S4E6_HEADER_BYTES + 2 * (TAG_LEN + chunk_size) + 8;
3326 let truncated = bytes::Bytes::copy_from_slice(&ct[..trunc]);
3327 let stream = decrypt_chunked_stream(truncated, &kr);
3328 let result = collect_chunks(stream).await;
3329 let err = result.unwrap_err();
3330 assert!(
3331 matches!(err, SseError::ChunkFrameTruncated { .. }),
3332 "got {err:?}",
3333 );
3334 }
3335
3336 #[test]
3337 fn s4e6_decrypt_buffered_round_trip_via_top_level_decrypt() {
3338 let kr = keyring_single(0xDE);
3342 let pt = b"buffered sync decrypt path".repeat(32);
3343 let ct = encrypt_v2_chunked(&pt, &kr, 13).unwrap();
3344 let plain = decrypt(&ct, &kr).expect("buffered S4E6 decrypt ok");
3345 assert_eq!(plain.as_ref(), pt.as_slice());
3346 }
3347
3348 #[tokio::test]
3349 async fn s4e6_unknown_key_id_in_frame_errors() {
3350 let kr_put = SseKeyring::new(7, key32(0xCC));
3352 let kr_get = keyring_single(0xCC); let ct = encrypt_v2_chunked(b"orphan key", &kr_put, 64).unwrap();
3354 let err = decrypt(&ct, &kr_get).unwrap_err();
3356 assert!(
3357 matches!(err, SseError::KeyNotInKeyring { id: 7 }),
3358 "got {err:?}"
3359 );
3360 let stream = decrypt_chunked_stream(ct, &kr_get);
3362 let result = collect_chunks(stream).await;
3363 assert!(
3364 matches!(result, Err(SseError::KeyNotInKeyring { id: 7 })),
3365 "got {result:?}",
3366 );
3367 }
3368
3369 #[tokio::test]
3370 async fn s4e6_final_chunk_smaller_than_chunk_size() {
3371 let kr = keyring_single(0xEF);
3374 let chunk_size = 100;
3375 let pt: Vec<u8> = (0..250_u32).map(|i| i as u8).collect();
3376 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
3377 assert_eq!(
3378 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3379 3,
3380 "ceil(250/100) = 3 chunks",
3381 );
3382 assert_eq!(ct.len(), S4E6_HEADER_BYTES + 48 + 250);
3384 let stream = decrypt_chunked_stream(ct, &kr);
3385 let chunks = collect_chunks(stream).await.expect("stream ok");
3386 assert_eq!(chunks.len(), 3);
3387 assert_eq!(chunks[0].len(), 100);
3388 assert_eq!(chunks[1].len(), 100);
3389 assert_eq!(chunks[2].len(), 50, "final chunk is the remainder");
3390 let joined: Vec<u8> = chunks.iter().flat_map(|c| c.iter().copied()).collect();
3391 assert_eq!(joined, pt);
3392 }
3393
3394 #[test]
3403 fn s4e6_back_compat_read_s4e5_blob() {
3404 let kr = keyring_single(0x57);
3410 let pt = b"v0.8.0 vintage chunked SSE-S4 object".repeat(64);
3411 let s4e5 = encrypt_v2_chunked_s4e5_for_test(&pt, &kr, 91).unwrap();
3412 assert_eq!(&s4e5[..4], SSE_MAGIC_V5, "fixture must be S4E5");
3414 assert_eq!(peek_magic(&s4e5), Some("S4E5"));
3415 let plain_sync = decrypt(&s4e5, &kr).expect("sync S4E5 decrypt ok");
3417 assert_eq!(plain_sync.as_ref(), pt.as_slice());
3418 let collected = futures::executor::block_on(async {
3420 let stream = decrypt_chunked_stream(s4e5.clone(), &kr);
3421 collect_chunks(stream).await
3422 })
3423 .expect("stream S4E5 decrypt ok");
3424 let mut joined = Vec::with_capacity(pt.len());
3425 for c in collected {
3426 joined.extend_from_slice(&c);
3427 }
3428 assert_eq!(joined, pt, "S4E5 streaming round-trip byte-equal");
3429 }
3430
3431 #[test]
3432 fn s4e6_layout_24_bytes_header() {
3433 assert_eq!(S4E6_HEADER_BYTES, 24);
3437 assert_eq!(S4E6_PER_CHUNK_OVERHEAD, TAG_LEN);
3438 assert_eq!(S4E6_HEADER_BYTES, S4E5_HEADER_BYTES + 4);
3439 }
3440
3441 #[test]
3442 fn s4e6_parse_header_round_trip() {
3443 let kr = keyring_single(0xAB);
3447 let chunk_size = 256;
3448 let pt = vec![1u8; 7 * chunk_size];
3449 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
3450 let hdr = parse_s4e6_header(&ct).expect("parse ok");
3451 assert_eq!(hdr.key_id, 1);
3452 assert_eq!(hdr.chunk_size, chunk_size as u32);
3453 assert_eq!(hdr.chunk_count, 7);
3454 assert_eq!(hdr.salt.len(), 8);
3455 let bogus = b"S4E2\x01\x00\x00\x00........................";
3457 let err = parse_s4e6_header(bogus).unwrap_err();
3458 assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
3459 let err2 = parse_s4e6_header(&ct[..10]).unwrap_err();
3461 assert!(
3462 matches!(err2, SseError::ChunkFrameTruncated { .. }),
3463 "got {err2:?}"
3464 );
3465 }
3466
3467 #[test]
3468 fn s4e6_salt_uniqueness_smoke_16m() {
3469 let kr = keyring_single(0xA6);
3486 let mut salts = std::collections::HashSet::with_capacity(16384);
3487 let n = 16384_usize;
3488 let mut collisions_top4 = 0usize;
3489 let mut top4_seen = std::collections::HashSet::with_capacity(16384);
3490 for _ in 0..n {
3491 let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
3492 let mut salt = [0u8; 8];
3493 salt.copy_from_slice(&ct[16..24]);
3494 salts.insert(salt);
3495 let mut top4 = [0u8; 4];
3505 top4.copy_from_slice(&salt[..4]);
3506 if !top4_seen.insert(top4) {
3507 collisions_top4 += 1;
3508 }
3509 }
3510 assert_eq!(
3511 salts.len(),
3512 n,
3513 "all 8-byte salts must be unique across {n} PUTs (got {} unique)",
3514 salts.len(),
3515 );
3516 eprintln!(
3523 "s4e6_salt_uniqueness_smoke_16m: 16k PUTs, full 8B salts \
3524 all unique ({}/{}), simulated 4B-truncated salt yielded \
3525 {} collisions (this is what S4E5 would have shipped)",
3526 salts.len(),
3527 n,
3528 collisions_top4,
3529 );
3530 }
3534
3535 #[test]
3536 fn s4e6_max_chunks_24bit() {
3537 assert_eq!(S4E6_MAX_CHUNK_COUNT, (1u32 << 24) - 1);
3546 assert_eq!(S4E6_MAX_CHUNK_COUNT, 16_777_215);
3547
3548 let kr = keyring_single(0xC4);
3552 let pt = vec![0u8; (S4E6_MAX_CHUNK_COUNT as usize) + 1]; let err = encrypt_v2_chunked(&pt, &kr, 1).unwrap_err();
3554 assert!(
3555 matches!(
3556 err,
3557 SseError::ChunkCountTooLarge {
3558 got: 16_777_216,
3559 max: 16_777_215
3560 }
3561 ),
3562 "got {err:?}",
3563 );
3564
3565 let pt_ok = vec![0u8; 1023];
3574 let ct = encrypt_v2_chunked(&pt_ok, &kr, 1).expect("under-cap PUT must succeed");
3575 let hdr = parse_s4e6_header(&ct).unwrap();
3576 assert_eq!(hdr.chunk_count, 1023);
3577
3578 let mut tampered = ct.to_vec();
3582 let bad = (S4E6_MAX_CHUNK_COUNT + 1).to_be_bytes();
3584 tampered[12..16].copy_from_slice(&bad);
3585 let err2 = parse_s4e6_header(&tampered).unwrap_err();
3586 assert!(
3587 matches!(
3588 err2,
3589 SseError::ChunkCountTooLarge {
3590 got: 16_777_216,
3591 max: 16_777_215
3592 }
3593 ),
3594 "got {err2:?}",
3595 );
3596 }
3597
3598 #[test]
3599 fn s4e6_nonce_v6_layout() {
3600 let salt = [0xAA_u8; 8];
3604 let n0 = nonce_v6(&salt, 0);
3605 assert_eq!(n0[0], b'E');
3606 assert_eq!(&n0[1..9], &salt);
3607 assert_eq!(&n0[9..12], &[0, 0, 0]);
3608 let n1 = nonce_v6(&salt, 1);
3609 assert_eq!(&n1[9..12], &[0, 0, 1]);
3610 let n_mid = nonce_v6(&salt, 0x123456);
3611 assert_eq!(&n_mid[9..12], &[0x12, 0x34, 0x56]);
3612 let n_max = nonce_v6(&salt, S4E6_MAX_CHUNK_COUNT);
3613 assert_eq!(&n_max[9..12], &[0xFF, 0xFF, 0xFF]);
3614 }
3615
3616 #[tokio::test]
3617 async fn s4e6_tampered_salt_byte_fails_aead() {
3618 let kr = keyring_single(0xB6);
3623 let pt = b"salt-in-aad coverage".repeat(64);
3624 let mut ct = encrypt_v2_chunked(&pt, &kr, 128).unwrap().to_vec();
3625 ct[20] ^= 0x01;
3627 let err = decrypt(&ct, &kr).unwrap_err();
3628 assert!(
3629 matches!(err, SseError::ChunkAuthFailed { chunk_index: 0 }),
3630 "got {err:?}",
3631 );
3632 }
3633
3634 fn synth_s4e6_header(chunk_size: u32, chunk_count: u32) -> Vec<u8> {
3649 let mut blob = Vec::with_capacity(S4E6_HEADER_BYTES);
3650 blob.extend_from_slice(SSE_MAGIC_V6);
3651 blob.push(ALGO_AES_256_GCM);
3652 blob.extend_from_slice(&1_u16.to_be_bytes()); blob.push(0); blob.extend_from_slice(&chunk_size.to_be_bytes());
3655 blob.extend_from_slice(&chunk_count.to_be_bytes());
3656 blob.extend_from_slice(&[0u8; 8]); debug_assert_eq!(blob.len(), S4E6_HEADER_BYTES);
3658 blob
3659 }
3660
3661 #[test]
3662 fn s4e6_header_claims_huge_size_rejected_pre_alloc() {
3663 let kr = keyring_single(0x01);
3669 let chunk_size: u32 = 1 << 30; let chunk_count: u32 = 100;
3671 let mut blob = synth_s4e6_header(chunk_size, chunk_count);
3672 blob.extend_from_slice(&[0u8; 100]);
3675 let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
3676 assert!(
3677 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3678 "expected ChunkFrameTooLarge (declared 100 GiB > 5 GiB cap), got {err:?}",
3679 );
3680 let err2 = decrypt_chunked_buffered(&blob, &kr, 1024 * 1024).unwrap_err();
3683 assert!(
3684 matches!(err2, SseError::ChunkFrameTooLarge { .. }),
3685 "expected ChunkFrameTooLarge under tighter cap, got {err2:?}",
3686 );
3687 }
3688
3689 #[test]
3690 fn s4e6_header_chunk_size_x_chunk_count_overflows_u64() {
3691 let kr = keyring_single(0x02);
3704 let mut blob = Vec::with_capacity(S4E5_HEADER_BYTES);
3707 blob.extend_from_slice(SSE_MAGIC_V5);
3708 blob.push(ALGO_AES_256_GCM);
3709 blob.extend_from_slice(&1_u16.to_be_bytes());
3710 blob.push(0);
3711 blob.extend_from_slice(&u32::MAX.to_be_bytes()); blob.extend_from_slice(&u32::MAX.to_be_bytes()); blob.extend_from_slice(&[0u8; 4]); debug_assert_eq!(blob.len(), S4E5_HEADER_BYTES);
3715 let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
3716 assert!(
3717 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3718 "expected ChunkFrameTooLarge for u64 overflow, got {err:?}",
3719 );
3720
3721 let direct = parse_chunked_header(&blob, usize::MAX).unwrap_err();
3724 assert!(
3725 matches!(direct, SseError::ChunkFrameTooLarge { .. }),
3726 "streaming path: expected ChunkFrameTooLarge, got {direct:?}",
3727 );
3728 }
3729
3730 #[test]
3731 fn s4e6_header_within_max_body_bytes_passes() {
3732 let kr = keyring_single(0x03);
3739 let chunk_size: u32 = 1024 * 1024; let chunk_count: u32 = 100;
3741 let mut blob = synth_s4e6_header(chunk_size, chunk_count);
3742 let chunk_array_size =
3747 (chunk_count as usize) * (S4E6_PER_CHUNK_OVERHEAD + chunk_size as usize);
3748 blob.resize(blob.len() + chunk_array_size, 0);
3749 let err = decrypt_chunked_buffered(&blob, &kr, DEFAULT_MAX_BODY_BYTES).unwrap_err();
3750 assert!(
3758 matches!(err, SseError::ChunkAuthFailed { chunk_index: 0 }),
3759 "expected ChunkAuthFailed (guard let it through), got {err:?}",
3760 );
3761 }
3762
3763 #[test]
3764 fn s4e6_header_exceeds_max_body_bytes_rejected() {
3765 let kr = keyring_single(0x04);
3772 let chunk_size: u32 = 1024 * 1024; let chunk_count: u32 = 6000;
3774 let blob = synth_s4e6_header(chunk_size, chunk_count);
3775 let err = decrypt_chunked_buffered(&blob, &kr, DEFAULT_MAX_BODY_BYTES).unwrap_err();
3779 assert!(
3780 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3781 "expected ChunkFrameTooLarge (6 GiB declared > 5 GiB cap), got {err:?}",
3782 );
3783
3784 let chunk_size_b: u32 = 1024 * 1024; let chunk_count_b: u32 = 100;
3789 let mut blob_b = synth_s4e6_header(chunk_size_b, chunk_count_b);
3790 let pad_b = (chunk_count_b as usize) * (S4E6_PER_CHUNK_OVERHEAD + chunk_size_b as usize);
3791 blob_b.resize(blob_b.len() + pad_b, 0);
3792 let err_b = decrypt_chunked_buffered(&blob_b, &kr, 1024 * 1024).unwrap_err();
3794 assert!(
3795 matches!(err_b, SseError::ChunkFrameTooLarge { .. }),
3796 "expected ChunkFrameTooLarge (cap < declared), got {err_b:?}",
3797 );
3798 }
3799
3800 #[test]
3801 fn s4e6_random_header_never_panics() {
3802 use rand::{Rng, SeedableRng, rngs::StdRng};
3813 let mut rng = StdRng::seed_from_u64(0xC0FF_EE64_6464_64DE);
3814 let mut max_body_bytes_choices = [
3815 0_usize,
3816 1024,
3817 1024 * 1024,
3818 DEFAULT_MAX_BODY_BYTES,
3819 usize::MAX,
3820 ]
3821 .iter()
3822 .copied()
3823 .cycle();
3824 for _ in 0..100_000 {
3825 let body_len = rng.gen_range(0..=256_usize);
3830 let mut body = vec![0u8; body_len];
3831 rng.fill(body.as_mut_slice());
3832 if body_len >= 4 && rng.gen_bool(0.25) {
3836 if rng.gen_bool(0.5) {
3837 body[..4].copy_from_slice(SSE_MAGIC_V5);
3838 } else {
3839 body[..4].copy_from_slice(SSE_MAGIC_V6);
3840 }
3841 }
3842 let max_cap = max_body_bytes_choices.next().unwrap();
3843 let _ = parse_chunked_header(&body, max_cap);
3848 }
3849 }
3850
3851 #[test]
3856 fn s4e5_extreme_overflow_chunk_count_u32_max() {
3857 let kr = keyring_single(0x05);
3858 let mut blob = Vec::with_capacity(S4E5_HEADER_BYTES);
3864 blob.extend_from_slice(SSE_MAGIC_V5);
3865 blob.push(ALGO_AES_256_GCM);
3866 blob.extend_from_slice(&1_u16.to_be_bytes());
3867 blob.push(0);
3868 blob.extend_from_slice(&u32::MAX.to_be_bytes());
3869 blob.extend_from_slice(&u32::MAX.to_be_bytes());
3870 blob.extend_from_slice(&[0u8; 4]);
3871 let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
3872 assert!(
3873 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3874 "expected ChunkFrameTooLarge for extreme overflow, got {err:?}",
3875 );
3876 }
3877}