use arkhe_forge_core::pii::{
compute_aad, AeadKind, DekId, DekMessageCounter, PiiError, PiiType, RotationTrigger,
};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::cell::Cell;
use std::marker::PhantomData;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct DekConfig {
pub replica_id: u32,
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct Dek {
material: [u8; 32],
#[zeroize(skip)]
counter: Cell<u64>,
#[zeroize(skip)]
replica_id: u32,
}
impl Dek {
#[inline]
#[must_use]
pub fn from_bytes(material: [u8; 32]) -> Self {
Self::with_config(material, DekConfig::default())
}
#[inline]
#[must_use]
pub fn with_config(material: [u8; 32], config: DekConfig) -> Self {
Self {
material,
counter: Cell::new(0),
replica_id: config.replica_id,
}
}
pub fn try_from_slice(bytes: &[u8]) -> Result<Self, PiiError> {
if bytes.len() != 32 {
return Err(PiiError::InvalidKeyLength);
}
let mut material = [0u8; 32];
material.copy_from_slice(bytes);
Ok(Self {
material,
counter: Cell::new(0),
replica_id: 0,
})
}
#[inline]
#[must_use]
#[cfg_attr(
not(any(feature = "tier-1-kms", feature = "tier-2-multi-kms")),
allow(dead_code)
)]
pub(crate) fn as_bytes(&self) -> &[u8; 32] {
&self.material
}
#[cfg_attr(not(feature = "tier-2-multi-kms"), allow(dead_code))]
fn advance_counter(&self) -> Result<u64, PiiError> {
let n = self.counter.get();
if n == u64::MAX {
return Err(PiiError::DekExhausted);
}
self.counter.set(n.wrapping_add(1));
Ok(n)
}
#[cfg(all(test, feature = "tier-2-multi-kms"))]
pub(crate) fn set_counter_for_test(&self, n: u64) {
self.counter.set(n);
}
#[cfg(all(test, feature = "tier-2-multi-kms"))]
pub(crate) fn get_counter_for_test(&self) -> u64 {
self.counter.get()
}
}
#[cfg_attr(not(feature = "tier-2-multi-kms"), allow(dead_code))]
#[inline]
fn aes_gcm_nonce_from_counter(replica_id: u32, counter: u64) -> [u8; 12] {
let mut n = [0u8; 12];
n[0..4].copy_from_slice(&replica_id.to_be_bytes());
n[4..12].copy_from_slice(&counter.to_be_bytes());
n
}
impl core::fmt::Debug for Dek {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Dek").finish_non_exhaustive()
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NonceBytes {
X24([u8; 24]),
Short12([u8; 12]),
}
impl NonceBytes {
#[inline]
#[must_use]
pub fn expected_len(kind: AeadKind) -> usize {
match kind {
AeadKind::XChaCha20Poly1305 => 24,
AeadKind::Aes256Gcm | AeadKind::Aes256GcmSiv => 12,
_ => 0,
}
}
#[inline]
#[must_use]
pub fn as_slice(&self) -> &[u8] {
match self {
Self::X24(b) => b,
Self::Short12(b) => b,
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct EncryptedPii<T: PiiType> {
pub dek_id: DekId,
pub pii_code: u16,
pub aead_kind: AeadKind,
pub nonce: NonceBytes,
pub ciphertext: Bytes,
pub(crate) _marker: PhantomData<fn() -> T>,
}
impl<T: PiiType> Clone for EncryptedPii<T> {
fn clone(&self) -> Self {
Self {
dek_id: self.dek_id,
pii_code: self.pii_code,
aead_kind: self.aead_kind,
nonce: self.nonce.clone(),
ciphertext: self.ciphertext.clone(),
_marker: PhantomData,
}
}
}
#[derive(Serialize, Deserialize)]
struct EncryptedPiiWire {
dek_id: DekId,
pii_code: u16,
aead_kind: AeadKind,
nonce: NonceBytes,
ciphertext: Bytes,
}
impl<T: PiiType> Serialize for EncryptedPii<T> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
EncryptedPiiWire {
dek_id: self.dek_id,
pii_code: self.pii_code,
aead_kind: self.aead_kind,
nonce: self.nonce.clone(),
ciphertext: self.ciphertext.clone(),
}
.serialize(serializer)
}
}
impl<'de, T: PiiType> Deserialize<'de> for EncryptedPii<T> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let wire = EncryptedPiiWire::deserialize(deserializer)?;
Ok(Self {
dek_id: wire.dek_id,
pii_code: wire.pii_code,
aead_kind: wire.aead_kind,
nonce: wire.nonce,
ciphertext: wire.ciphertext,
_marker: PhantomData,
})
}
}
impl<T: PiiType> EncryptedPii<T> {
#[inline]
#[must_use]
pub fn new(dek_id: DekId, aead_kind: AeadKind, nonce: NonceBytes, ciphertext: Bytes) -> Self {
Self {
dek_id,
pii_code: T::PII_CODE,
aead_kind,
nonce,
ciphertext,
_marker: PhantomData,
}
}
fn into_raw(self) -> RawEncryptedPii {
RawEncryptedPii {
dek_id: self.dek_id,
pii_code: self.pii_code,
aead_kind: self.aead_kind,
nonce: self.nonce,
ciphertext: self.ciphertext,
}
}
}
#[derive(Debug, Clone)]
struct RawEncryptedPii {
dek_id: DekId,
pii_code: u16,
aead_kind: AeadKind,
nonce: NonceBytes,
ciphertext: Bytes,
}
#[derive(Debug)]
pub struct CryptoCoordinator<N: NonceSource = OsNonceSource> {
manifest_cipher: AeadKind,
#[cfg_attr(not(feature = "tier-1-kms"), allow(dead_code))]
nonce_source: N,
}
pub trait NonceSource {
fn fill(&self, out: &mut [u8]);
}
#[derive(Debug, Default, Clone, Copy)]
pub struct OsNonceSource;
impl NonceSource for OsNonceSource {
#[cfg(feature = "tier-1-kms")]
fn fill(&self, out: &mut [u8]) {
use chacha20poly1305::aead::{rand_core::RngCore, OsRng};
OsRng.fill_bytes(out);
}
#[cfg(not(feature = "tier-1-kms"))]
fn fill(&self, out: &mut [u8]) {
for byte in out.iter_mut() {
*byte = 0;
}
}
}
impl<N: NonceSource> CryptoCoordinator<N> {
#[inline]
#[must_use]
pub fn new(manifest_cipher: AeadKind, nonce_source: N) -> Self {
Self {
manifest_cipher,
nonce_source,
}
}
#[inline]
#[must_use]
pub fn manifest_cipher(&self) -> AeadKind {
self.manifest_cipher
}
pub fn encrypt<T: PiiType>(
&self,
plaintext: &T,
dek: &Dek,
dek_id: DekId,
) -> Result<EncryptedPii<T>, PiiError> {
let aad = compute_aad(&dek_id, T::PII_CODE, self.manifest_cipher);
let pt_bytes = postcard::to_stdvec(plaintext).map_err(|_| PiiError::EncryptFailed)?;
let (nonce, ciphertext) = self.encrypt_raw(dek, &aad, &pt_bytes)?;
Ok(EncryptedPii::new(
dek_id,
self.manifest_cipher,
nonce,
Bytes::from(ciphertext),
))
}
pub fn decrypt<T: PiiType>(
&self,
envelope: &EncryptedPii<T>,
dek: &Dek,
) -> Result<T, PiiError> {
if envelope.pii_code != T::PII_CODE {
return Err(PiiError::TypeMismatch);
}
if envelope.aead_kind != self.manifest_cipher {
return Err(PiiError::CipherDowngrade);
}
let aad = compute_aad(&envelope.dek_id, envelope.pii_code, envelope.aead_kind);
let pt = self.decrypt_raw(
dek,
envelope.aead_kind,
&envelope.nonce,
&aad,
&envelope.ciphertext,
)?;
postcard::from_bytes::<T>(&pt).map_err(|_| PiiError::DecodeFailed)
}
fn decrypt_raw_under(&self, dek: &Dek, raw: &RawEncryptedPii) -> Result<Vec<u8>, PiiError> {
let aad = compute_aad(&raw.dek_id, raw.pii_code, raw.aead_kind);
self.decrypt_raw(dek, raw.aead_kind, &raw.nonce, &aad, &raw.ciphertext)
}
fn encrypt_raw(
&self,
dek: &Dek,
aad: &[u8; 19],
plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
match self.manifest_cipher {
AeadKind::XChaCha20Poly1305 => self.encrypt_xchacha(dek, aad, plaintext),
AeadKind::Aes256Gcm => self.encrypt_aes_gcm(dek, aad, plaintext),
AeadKind::Aes256GcmSiv => self.encrypt_aes_gcm_siv(dek, aad, plaintext),
_ => Err(PiiError::UnsupportedAead),
}
}
fn decrypt_raw(
&self,
dek: &Dek,
kind: AeadKind,
nonce: &NonceBytes,
aad: &[u8; 19],
ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
match kind {
AeadKind::XChaCha20Poly1305 => self.decrypt_xchacha(dek, nonce, aad, ciphertext),
AeadKind::Aes256Gcm => self.decrypt_aes_gcm(dek, nonce, aad, ciphertext),
AeadKind::Aes256GcmSiv => self.decrypt_aes_gcm_siv(dek, nonce, aad, ciphertext),
_ => Err(PiiError::UnsupportedAead),
}
}
#[cfg(feature = "tier-1-kms")]
fn encrypt_xchacha(
&self,
dek: &Dek,
aad: &[u8; 19],
plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
let key = Key::from_slice(dek.as_bytes());
let cipher = XChaCha20Poly1305::new(key);
let mut nonce_buf = [0u8; 24];
self.nonce_source.fill(&mut nonce_buf);
let nonce = XNonce::from_slice(&nonce_buf);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| PiiError::EncryptFailed)?;
Ok((NonceBytes::X24(nonce_buf), ciphertext))
}
#[cfg(not(feature = "tier-1-kms"))]
fn encrypt_xchacha(
&self,
_dek: &Dek,
_aad: &[u8; 19],
_plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
Err(PiiError::TierTooLow)
}
#[cfg(feature = "tier-1-kms")]
fn decrypt_xchacha(
&self,
dek: &Dek,
nonce: &NonceBytes,
aad: &[u8; 19],
ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
let bytes_24 = match nonce {
NonceBytes::X24(b) => b,
NonceBytes::Short12(_) => return Err(PiiError::AadMismatch),
};
let key = Key::from_slice(dek.as_bytes());
let cipher = XChaCha20Poly1305::new(key);
let nonce = XNonce::from_slice(bytes_24);
cipher
.decrypt(
nonce,
Payload {
msg: ciphertext,
aad,
},
)
.map_err(|_| PiiError::AadMismatch)
}
#[cfg(not(feature = "tier-1-kms"))]
fn decrypt_xchacha(
&self,
_dek: &Dek,
_nonce: &NonceBytes,
_aad: &[u8; 19],
_ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
Err(PiiError::TierTooLow)
}
#[cfg(feature = "tier-2-multi-kms")]
fn encrypt_aes_gcm(
&self,
dek: &Dek,
aad: &[u8; 19],
plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
let counter = dek.advance_counter()?;
let nonce_buf = aes_gcm_nonce_from_counter(dek.replica_id, counter);
let key = Key::<Aes256Gcm>::from_slice(dek.as_bytes());
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&nonce_buf);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| PiiError::EncryptFailed)?;
Ok((NonceBytes::Short12(nonce_buf), ciphertext))
}
#[cfg(not(feature = "tier-2-multi-kms"))]
fn encrypt_aes_gcm(
&self,
_dek: &Dek,
_aad: &[u8; 19],
_plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
Err(PiiError::UnsupportedAead)
}
#[cfg(feature = "tier-2-multi-kms")]
fn decrypt_aes_gcm(
&self,
dek: &Dek,
nonce: &NonceBytes,
aad: &[u8; 19],
ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
let bytes_12 = match nonce {
NonceBytes::Short12(b) => b,
NonceBytes::X24(_) => return Err(PiiError::AadMismatch),
};
let key = Key::<Aes256Gcm>::from_slice(dek.as_bytes());
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(bytes_12);
cipher
.decrypt(
nonce,
Payload {
msg: ciphertext,
aad,
},
)
.map_err(|_| PiiError::AadMismatch)
}
#[cfg(not(feature = "tier-2-multi-kms"))]
fn decrypt_aes_gcm(
&self,
_dek: &Dek,
_nonce: &NonceBytes,
_aad: &[u8; 19],
_ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
Err(PiiError::UnsupportedAead)
}
#[cfg(feature = "tier-2-multi-kms")]
fn encrypt_aes_gcm_siv(
&self,
dek: &Dek,
aad: &[u8; 19],
plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
use aes_gcm_siv::aead::{Aead, KeyInit, Payload};
use aes_gcm_siv::{Aes256GcmSiv, Key, Nonce};
let counter = dek.advance_counter()?;
let nonce_buf = aes_gcm_nonce_from_counter(dek.replica_id, counter);
let key = Key::<Aes256GcmSiv>::from_slice(dek.as_bytes());
let cipher = Aes256GcmSiv::new(key);
let nonce = Nonce::from_slice(&nonce_buf);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| PiiError::EncryptFailed)?;
Ok((NonceBytes::Short12(nonce_buf), ciphertext))
}
#[cfg(not(feature = "tier-2-multi-kms"))]
fn encrypt_aes_gcm_siv(
&self,
_dek: &Dek,
_aad: &[u8; 19],
_plaintext: &[u8],
) -> Result<(NonceBytes, Vec<u8>), PiiError> {
Err(PiiError::UnsupportedAead)
}
#[cfg(feature = "tier-2-multi-kms")]
fn decrypt_aes_gcm_siv(
&self,
dek: &Dek,
nonce: &NonceBytes,
aad: &[u8; 19],
ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
use aes_gcm_siv::aead::{Aead, KeyInit, Payload};
use aes_gcm_siv::{Aes256GcmSiv, Key, Nonce};
let bytes_12 = match nonce {
NonceBytes::Short12(b) => b,
NonceBytes::X24(_) => return Err(PiiError::AadMismatch),
};
let key = Key::<Aes256GcmSiv>::from_slice(dek.as_bytes());
let cipher = Aes256GcmSiv::new(key);
let nonce = Nonce::from_slice(bytes_12);
cipher
.decrypt(
nonce,
Payload {
msg: ciphertext,
aad,
},
)
.map_err(|_| PiiError::AadMismatch)
}
#[cfg(not(feature = "tier-2-multi-kms"))]
fn decrypt_aes_gcm_siv(
&self,
_dek: &Dek,
_nonce: &NonceBytes,
_aad: &[u8; 19],
_ciphertext: &[u8],
) -> Result<Vec<u8>, PiiError> {
Err(PiiError::UnsupportedAead)
}
}
pub fn rotate_dek<T: PiiType>(
coordinator: &CryptoCoordinator<impl NonceSource>,
old_dek: &Dek,
new_dek: &Dek,
new_dek_id: DekId,
ciphertexts: &mut [EncryptedPii<T>],
counter: &mut DekMessageCounter,
) -> Result<(), PiiError> {
let originals: Vec<EncryptedPii<T>> = ciphertexts.to_vec();
for slot in ciphertexts.iter_mut() {
let raw = slot.clone().into_raw();
let plaintext_bytes = match coordinator.decrypt_raw_under(old_dek, &raw) {
Ok(v) => v,
Err(err) => {
for (target, backup) in ciphertexts.iter_mut().zip(originals.iter()) {
*target = backup.clone();
}
return Err(err);
}
};
let aad = compute_aad(&new_dek_id, T::PII_CODE, coordinator.manifest_cipher);
let (nonce, new_ct) = match coordinator.encrypt_raw(new_dek, &aad, &plaintext_bytes) {
Ok(v) => v,
Err(err) => {
for (target, backup) in ciphertexts.iter_mut().zip(originals.iter()) {
*target = backup.clone();
}
return Err(err);
}
};
*slot = EncryptedPii::new(
new_dek_id,
coordinator.manifest_cipher,
nonce,
Bytes::from(new_ct),
);
counter.record_message();
}
Ok(())
}
#[inline]
#[must_use]
pub fn rotation_advice(counter: &DekMessageCounter) -> RotationTrigger {
counter.rotation_trigger()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use arkhe_forge_core::pii::ActorHandle;
#[derive(Clone, Copy, Default)]
struct FixedNonce;
impl NonceSource for FixedNonce {
fn fill(&self, out: &mut [u8]) {
for (i, byte) in out.iter_mut().enumerate() {
*byte = (i & 0xFF) as u8;
}
}
}
fn make_dek(byte: u8) -> Dek {
Dek::from_bytes([byte; 32])
}
fn make_dek_id(byte: u8) -> DekId {
DekId([byte; 16])
}
#[test]
fn dek_from_bytes_exposes_material_via_crate_accessor() {
let d = make_dek(0x42);
assert_eq!(d.as_bytes(), &[0x42u8; 32]);
}
#[test]
fn dek_try_from_slice_rejects_short_key() {
let err = Dek::try_from_slice(&[0u8; 16]).unwrap_err();
assert!(matches!(err, PiiError::InvalidKeyLength));
}
#[test]
fn dek_try_from_slice_accepts_32_bytes() {
let key = [0x77u8; 32];
let dek = Dek::try_from_slice(&key).unwrap();
assert_eq!(dek.as_bytes(), &key);
}
#[test]
fn dek_debug_does_not_expose_material() {
let d = make_dek(0xAB);
let s = format!("{:?}", d);
assert!(!s.contains("AB"), "Debug output must not leak key bytes");
assert!(!s.contains("ab"));
}
#[test]
fn nonce_bytes_expected_len_matches_kind() {
assert_eq!(NonceBytes::expected_len(AeadKind::XChaCha20Poly1305), 24);
assert_eq!(NonceBytes::expected_len(AeadKind::Aes256Gcm), 12);
assert_eq!(NonceBytes::expected_len(AeadKind::Aes256GcmSiv), 12);
}
#[test]
fn encrypted_pii_wire_layout_roundtrips_through_postcard() {
let envelope = EncryptedPii::<ActorHandle>::new(
make_dek_id(0x11),
AeadKind::XChaCha20Poly1305,
NonceBytes::X24([0x22; 24]),
Bytes::from_static(&[0x33; 48]),
);
let bytes = postcard::to_stdvec(&envelope).unwrap();
let back: EncryptedPii<ActorHandle> = postcard::from_bytes(&bytes).unwrap();
assert_eq!(envelope, back);
assert_eq!(back.pii_code, ActorHandle::PII_CODE);
}
#[cfg(not(feature = "tier-1-kms"))]
#[test]
fn tier0_default_rejects_encryption() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let err = coord
.encrypt::<ActorHandle>(
&ActorHandle(b"alice".to_vec()),
&make_dek(0x00),
make_dek_id(0x11),
)
.unwrap_err();
assert!(matches!(err, PiiError::TierTooLow));
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn tier1_xchacha_encrypt_decrypt_roundtrip() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let handle = ActorHandle(b"alice".to_vec());
let dek = make_dek(0xA5);
let dek_id = make_dek_id(0x11);
let env = coord.encrypt(&handle, &dek, dek_id).unwrap();
assert_eq!(env.aead_kind, AeadKind::XChaCha20Poly1305);
assert_eq!(env.pii_code, ActorHandle::PII_CODE);
let back: ActorHandle = coord.decrypt(&env, &dek).unwrap();
assert_eq!(back, handle);
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn tier1_xchacha_aad_tamper_fails_tag() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let handle = ActorHandle(b"alice".to_vec());
let dek = make_dek(0x01);
let mut env = coord.encrypt(&handle, &dek, make_dek_id(0x11)).unwrap();
env.dek_id = make_dek_id(0x12);
let err = coord.decrypt::<ActorHandle>(&env, &dek).unwrap_err();
assert!(matches!(err, PiiError::AadMismatch));
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn tier1_ciphertext_tamper_fails_tag() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let handle = ActorHandle(b"alice".to_vec());
let dek = make_dek(0x03);
let env = coord.encrypt(&handle, &dek, make_dek_id(0x11)).unwrap();
let mut ct = env.ciphertext.to_vec();
if let Some(first) = ct.first_mut() {
*first ^= 0x01;
}
let tampered =
EncryptedPii::<ActorHandle>::new(env.dek_id, env.aead_kind, env.nonce, Bytes::from(ct));
let err = coord.decrypt::<ActorHandle>(&tampered, &dek).unwrap_err();
assert!(matches!(err, PiiError::AadMismatch));
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn tier1_wrong_pii_code_rejected_as_type_mismatch() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let handle = ActorHandle(b"alice".to_vec());
let dek = make_dek(0x07);
let env = coord.encrypt(&handle, &dek, make_dek_id(0x11)).unwrap();
let wrong = EncryptedPii::<ActorHandle> {
dek_id: env.dek_id,
pii_code: arkhe_forge_core::pii::EntryBody::PII_CODE,
aead_kind: env.aead_kind,
nonce: env.nonce,
ciphertext: env.ciphertext,
_marker: PhantomData,
};
let err = coord.decrypt::<ActorHandle>(&wrong, &dek).unwrap_err();
assert!(matches!(err, PiiError::TypeMismatch));
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn tier1_aead_downgrade_rejected_by_coordinator_manifest() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let env = EncryptedPii::<ActorHandle>::new(
make_dek_id(0x11),
AeadKind::Aes256Gcm,
NonceBytes::Short12([0u8; 12]),
Bytes::from_static(&[0u8; 48]),
);
let err = coord
.decrypt::<ActorHandle>(&env, &make_dek(0x00))
.unwrap_err();
assert!(matches!(err, PiiError::CipherDowngrade));
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn tier1_aes_gcm_without_tier2_is_unsupported() {
let coord = CryptoCoordinator::new(AeadKind::Aes256Gcm, FixedNonce);
let handle = ActorHandle(b"alice".to_vec());
let out = coord.encrypt(&handle, &make_dek(0x00), make_dek_id(0x11));
#[cfg(feature = "tier-2-multi-kms")]
assert!(out.is_ok());
#[cfg(not(feature = "tier-2-multi-kms"))]
assert!(matches!(out, Err(PiiError::UnsupportedAead)));
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn tier2_aes_gcm_roundtrip() {
let coord = CryptoCoordinator::new(AeadKind::Aes256Gcm, FixedNonce);
let handle = ActorHandle(b"aes-user".to_vec());
let dek = make_dek(0x5A);
let env = coord.encrypt(&handle, &dek, make_dek_id(0x21)).unwrap();
assert_eq!(env.aead_kind, AeadKind::Aes256Gcm);
assert!(matches!(env.nonce, NonceBytes::Short12(_)));
let back: ActorHandle = coord.decrypt(&env, &dek).unwrap();
assert_eq!(back, handle);
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn tier2_aes_gcm_siv_roundtrip() {
let coord = CryptoCoordinator::new(AeadKind::Aes256GcmSiv, FixedNonce);
let handle = ActorHandle(b"aes-siv-user".to_vec());
let dek = make_dek(0x7B);
let env = coord.encrypt(&handle, &dek, make_dek_id(0x22)).unwrap();
assert_eq!(env.aead_kind, AeadKind::Aes256GcmSiv);
let back: ActorHandle = coord.decrypt(&env, &dek).unwrap();
assert_eq!(back, handle);
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn aes_gcm_nonce_is_deterministic_counter() {
let coord = CryptoCoordinator::new(AeadKind::Aes256Gcm, FixedNonce);
let dek = make_dek(0x5A);
let handle = ActorHandle(b"alice".to_vec());
let env1 = coord.encrypt(&handle, &dek, make_dek_id(0x11)).unwrap();
let env2 = coord.encrypt(&handle, &dek, make_dek_id(0x11)).unwrap();
let NonceBytes::Short12(n1) = &env1.nonce else {
panic!("AES-GCM always returns Short12");
};
let NonceBytes::Short12(n2) = &env2.nonce else {
panic!("AES-GCM always returns Short12");
};
assert_eq!(&n1[0..4], &[0u8; 4], "invocation field zeros");
assert_eq!(&n1[4..12], &0u64.to_be_bytes());
assert_eq!(&n2[4..12], &1u64.to_be_bytes());
assert_ne!(n1, n2);
assert_eq!(coord.decrypt::<ActorHandle>(&env1, &dek).unwrap(), handle);
assert_eq!(coord.decrypt::<ActorHandle>(&env2, &dek).unwrap(), handle);
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn aes_gcm_nonce_honours_dek_replica_id() {
let coord = CryptoCoordinator::new(AeadKind::Aes256Gcm, FixedNonce);
let dek_a = Dek::with_config([0xC3; 32], DekConfig { replica_id: 0 });
let dek_b = Dek::with_config(
[0xC3; 32],
DekConfig {
replica_id: 0xDEAD_BEEF,
},
);
let handle = ActorHandle(b"alice".to_vec());
let env_a = coord.encrypt(&handle, &dek_a, make_dek_id(0x11)).unwrap();
let env_b = coord.encrypt(&handle, &dek_b, make_dek_id(0x11)).unwrap();
let NonceBytes::Short12(na) = &env_a.nonce else {
panic!("AES-GCM returns Short12");
};
let NonceBytes::Short12(nb) = &env_b.nonce else {
panic!("AES-GCM returns Short12");
};
assert_eq!(&na[0..4], &0u32.to_be_bytes());
assert_eq!(&nb[0..4], &0xDEAD_BEEFu32.to_be_bytes());
assert_eq!(&na[4..12], &0u64.to_be_bytes());
assert_eq!(&nb[4..12], &0u64.to_be_bytes());
assert_ne!(na, nb);
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn aes_gcm_siv_nonce_is_deterministic_counter() {
let coord = CryptoCoordinator::new(AeadKind::Aes256GcmSiv, FixedNonce);
let dek = make_dek(0x7B);
let handle = ActorHandle(b"siv".to_vec());
let env1 = coord.encrypt(&handle, &dek, make_dek_id(0x22)).unwrap();
let NonceBytes::Short12(n1) = &env1.nonce else {
panic!("AES-GCM-SIV returns Short12");
};
assert_eq!(&n1[4..12], &0u64.to_be_bytes());
assert_eq!(dek.get_counter_for_test(), 1);
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn dek_counter_exhaustion_errors() {
let coord = CryptoCoordinator::new(AeadKind::Aes256Gcm, FixedNonce);
let dek = make_dek(0xA5);
dek.set_counter_for_test(u64::MAX);
let handle = ActorHandle(b"alice".to_vec());
let err = coord.encrypt(&handle, &dek, make_dek_id(0x11)).unwrap_err();
assert!(matches!(err, PiiError::DekExhausted));
assert_eq!(dek.get_counter_for_test(), u64::MAX);
}
#[cfg(feature = "tier-2-multi-kms")]
#[test]
fn rotate_dek_starts_new_counter_from_zero() {
let coord = CryptoCoordinator::new(AeadKind::Aes256Gcm, FixedNonce);
let old = make_dek(0x10);
let new = make_dek(0x20);
let new_id = make_dek_id(0x02);
let plaintexts: Vec<ActorHandle> = (0..3u8).map(|i| ActorHandle(vec![i; 8])).collect();
let mut envs: Vec<EncryptedPii<ActorHandle>> = plaintexts
.iter()
.map(|pt| coord.encrypt(pt, &old, make_dek_id(0x01)).unwrap())
.collect();
assert_eq!(old.get_counter_for_test(), 3);
assert_eq!(new.get_counter_for_test(), 0);
let mut rotation_metric = DekMessageCounter::new(new_id);
rotate_dek(&coord, &old, &new, new_id, &mut envs, &mut rotation_metric).unwrap();
assert_eq!(new.get_counter_for_test(), 3);
assert_eq!(rotation_metric.count(), 3);
for (i, env) in envs.iter().enumerate() {
let NonceBytes::Short12(n) = &env.nonce else {
panic!("AES-GCM returns Short12");
};
assert_eq!(
&n[4..12],
&(i as u64).to_be_bytes(),
"counter values run 0,1,2 under new DEK"
);
}
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn dek_rotate_preserves_plaintext() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let old = make_dek(0x10);
let new = make_dek(0x20);
let new_id = make_dek_id(0x02);
let plaintexts: Vec<ActorHandle> = (0..4u8).map(|i| ActorHandle(vec![i; 8])).collect();
let mut envelopes: Vec<EncryptedPii<ActorHandle>> = plaintexts
.iter()
.map(|pt| coord.encrypt(pt, &old, make_dek_id(0x01)).unwrap())
.collect();
let mut counter = DekMessageCounter::new(make_dek_id(0x02));
rotate_dek(&coord, &old, &new, new_id, &mut envelopes, &mut counter).unwrap();
assert_eq!(counter.count(), 4);
for (env, pt) in envelopes.iter().zip(plaintexts.iter()) {
assert_eq!(env.dek_id, new_id);
assert_eq!(&coord.decrypt::<ActorHandle>(env, &new).unwrap(), pt);
}
}
#[cfg(feature = "tier-1-kms")]
#[test]
fn dek_rotate_with_wrong_old_key_rolls_back() {
let coord = CryptoCoordinator::new(AeadKind::XChaCha20Poly1305, FixedNonce);
let real_old = make_dek(0x10);
let wrong_old = make_dek(0xFF);
let new = make_dek(0x20);
let original_envelope = coord
.encrypt(
&ActorHandle(b"alice".to_vec()),
&real_old,
make_dek_id(0x01),
)
.unwrap();
let mut envelopes = vec![original_envelope.clone()];
let mut counter = DekMessageCounter::new(make_dek_id(0x02));
let err = rotate_dek(
&coord,
&wrong_old,
&new,
make_dek_id(0x02),
&mut envelopes,
&mut counter,
)
.unwrap_err();
assert!(matches!(err, PiiError::AadMismatch));
assert_eq!(envelopes[0], original_envelope);
}
}