use blake3::Hasher;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct DekId(pub [u8; 16]);
#[non_exhaustive]
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AeadKind {
XChaCha20Poly1305 = 0,
Aes256Gcm = 1,
Aes256GcmSiv = 2,
}
pub mod pii_code {
pub const ACTOR_HANDLE: u16 = 0x0001;
pub const ENTRY_BODY: u16 = 0x0002;
pub const ACTIVITY_EXTRA_BYTES: u16 = 0x0003;
pub const AUTH_CREDENTIAL_SECRET: u16 = 0x0004;
}
#[non_exhaustive]
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Sensitivity {
Low = 0,
Medium = 1,
High = 2,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShellPiiType {
pii_code: u16,
sensitivity: Sensitivity,
aead_kind: AeadKind,
}
pub const SHELL_PII_CODE_RANGE: core::ops::RangeInclusive<u16> = 0x0100..=0xFFFF;
impl ShellPiiType {
pub fn new(
pii_code: u16,
sensitivity: Sensitivity,
aead_kind: AeadKind,
) -> Result<Self, PiiError> {
if !SHELL_PII_CODE_RANGE.contains(&pii_code) {
return Err(PiiError::ShellPiiCodeOutOfRange);
}
Ok(Self {
pii_code,
sensitivity,
aead_kind,
})
}
#[inline]
#[must_use]
pub fn pii_code(&self) -> u16 {
self.pii_code
}
#[inline]
#[must_use]
pub fn sensitivity(&self) -> Sensitivity {
self.sensitivity
}
#[inline]
#[must_use]
pub fn aead_kind(&self) -> AeadKind {
self.aead_kind
}
}
#[derive(Debug, Default)]
pub struct ShellPiiRegistry {
entries: BTreeMap<u16, ShellPiiType>,
}
impl ShellPiiRegistry {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, entry: ShellPiiType) -> Result<(), PiiError> {
if !SHELL_PII_CODE_RANGE.contains(&entry.pii_code) {
return Err(PiiError::ShellPiiCodeOutOfRange);
}
if self.entries.contains_key(&entry.pii_code) {
return Err(PiiError::ShellPiiAlreadyRegistered);
}
self.entries.insert(entry.pii_code, entry);
Ok(())
}
#[inline]
#[must_use]
pub fn get(&self, pii_code: u16) -> Option<&ShellPiiType> {
self.entries.get(&pii_code)
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[must_use]
pub fn compute_aad(dek_id: &DekId, pii_code: u16, aead_kind: AeadKind) -> [u8; 19] {
let mut aad = [0u8; 19];
aad[..16].copy_from_slice(&dek_id.0);
aad[16..18].copy_from_slice(&pii_code.to_be_bytes());
aad[18] = aead_kind as u8;
aad
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct UserSalt([u8; 16]);
impl UserSalt {
#[inline]
#[must_use]
pub fn from_bytes(bytes: [u8; 16]) -> Self {
Self(bytes)
}
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8; 16] {
&self.0
}
}
impl core::fmt::Debug for UserSalt {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("UserSalt").finish_non_exhaustive()
}
}
#[must_use]
pub fn compute_body_hash(body: &[u8], user_salt: &UserSalt, entry_nonce: &[u8; 16]) -> [u8; 32] {
let mut h = Hasher::new();
h.update(body);
h.update(user_salt.as_bytes());
h.update(entry_nonce);
*h.finalize().as_bytes()
}
#[derive(Debug, Clone)]
pub struct DekMessageCounter {
dek_id: DekId,
count: u64,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RotationTrigger {
Healthy,
WarnApproachingLimit,
MustRotate,
}
impl DekMessageCounter {
pub fn new(dek_id: DekId) -> Self {
Self { dek_id, count: 0 }
}
#[must_use]
pub fn dek_id(&self) -> DekId {
self.dek_id
}
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
pub fn record_message(&mut self) {
self.count = self.count.saturating_add(1);
}
#[must_use]
pub fn rotation_trigger(&self) -> RotationTrigger {
const WARN_THRESHOLD: u64 = 1u64 << 30;
const FORCE_THRESHOLD: u64 = 1u64 << 31;
if self.count >= FORCE_THRESHOLD {
RotationTrigger::MustRotate
} else if self.count >= WARN_THRESHOLD {
RotationTrigger::WarnApproachingLimit
} else {
RotationTrigger::Healthy
}
}
}
mod pii_seal {
pub trait Sealed {}
}
pub trait PiiType: pii_seal::Sealed + Serialize + for<'de> Deserialize<'de> {
const PII_CODE: u16;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ActorHandle(pub Vec<u8>);
impl pii_seal::Sealed for ActorHandle {}
impl PiiType for ActorHandle {
const PII_CODE: u16 = pii_code::ACTOR_HANDLE;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EntryBody(pub Vec<u8>);
impl pii_seal::Sealed for EntryBody {}
impl PiiType for EntryBody {
const PII_CODE: u16 = pii_code::ENTRY_BODY;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ActivityExtraBytes(pub Vec<u8>);
impl pii_seal::Sealed for ActivityExtraBytes {}
impl PiiType for ActivityExtraBytes {
const PII_CODE: u16 = pii_code::ACTIVITY_EXTRA_BYTES;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthCredentialSecret(pub Vec<u8>);
impl pii_seal::Sealed for AuthCredentialSecret {}
impl PiiType for AuthCredentialSecret {
const PII_CODE: u16 = pii_code::AUTH_CREDENTIAL_SECRET;
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum PiiError {
#[error("compliance tier too low for encryption")]
TierTooLow,
#[error("PII type marker mismatch")]
TypeMismatch,
#[error("AEAD cipher downgrade rejected")]
CipherDowngrade,
#[error("AEAD tag verification failed")]
AadMismatch,
#[error("payload decode failed")]
DecodeFailed,
#[error("DEK must be exactly 32 bytes")]
InvalidKeyLength,
#[error("AEAD encrypt failed")]
EncryptFailed,
#[error("AEAD kind unsupported for current feature set")]
UnsupportedAead,
#[error("DEK nonce counter exhausted; rotation required")]
DekExhausted,
#[error("shell PII code outside the 0x0100..=0xFFFF range")]
ShellPiiCodeOutOfRange,
#[error("shell PII code already registered")]
ShellPiiAlreadyRegistered,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn aad_is_exactly_19_bytes_and_deterministic() {
let dek_id = DekId([0xAB; 16]);
let aad1 = compute_aad(&dek_id, pii_code::ACTOR_HANDLE, AeadKind::XChaCha20Poly1305);
let aad2 = compute_aad(&dek_id, pii_code::ACTOR_HANDLE, AeadKind::XChaCha20Poly1305);
assert_eq!(aad1.len(), 19);
assert_eq!(aad1, aad2);
}
#[test]
fn aad_differs_on_pii_code_change() {
let dek_id = DekId([0u8; 16]);
let a = compute_aad(&dek_id, pii_code::ACTOR_HANDLE, AeadKind::XChaCha20Poly1305);
let b = compute_aad(&dek_id, pii_code::ENTRY_BODY, AeadKind::XChaCha20Poly1305);
assert_ne!(a, b);
}
#[test]
fn aad_differs_on_aead_kind_change() {
let dek_id = DekId([0u8; 16]);
let a = compute_aad(&dek_id, pii_code::ACTOR_HANDLE, AeadKind::XChaCha20Poly1305);
let b = compute_aad(&dek_id, pii_code::ACTOR_HANDLE, AeadKind::Aes256Gcm);
assert_ne!(a, b);
}
#[test]
fn aad_layout_dek_first_then_code_then_kind() {
let dek_id = DekId([0x11; 16]);
let aad = compute_aad(&dek_id, 0xABCD, AeadKind::Aes256GcmSiv);
assert_eq!(&aad[..16], &[0x11; 16]);
assert_eq!(&aad[16..18], &[0xAB, 0xCD]); assert_eq!(aad[18], AeadKind::Aes256GcmSiv as u8);
}
#[test]
fn body_hash_deterministic_and_32_bytes() {
let body = b"hello world";
let salt = UserSalt::from_bytes([0x01; 16]);
let nonce = [0x02; 16];
let h1 = compute_body_hash(body, &salt, &nonce);
let h2 = compute_body_hash(body, &salt, &nonce);
assert_eq!(h1, h2);
assert_eq!(h1.len(), 32);
}
#[test]
fn body_hash_differs_on_salt_change() {
let body = b"x";
let s1 = UserSalt::from_bytes([0x01; 16]);
let s2 = UserSalt::from_bytes([0x02; 16]);
let nonce = [0u8; 16];
assert_ne!(
compute_body_hash(body, &s1, &nonce),
compute_body_hash(body, &s2, &nonce)
);
}
#[test]
fn user_salt_debug_does_not_leak_material() {
let salt = UserSalt::from_bytes([0xBB; 16]);
let s = format!("{:?}", salt);
assert!(!s.contains("BB"));
assert!(!s.contains("bb"));
}
#[test]
fn shell_pii_type_rejects_canonical_code() {
let err = ShellPiiType::new(0x0050, Sensitivity::Medium, AeadKind::XChaCha20Poly1305)
.unwrap_err();
assert!(matches!(err, PiiError::ShellPiiCodeOutOfRange));
}
#[test]
fn shell_pii_type_accepts_shell_code() {
let t = ShellPiiType::new(0x0200, Sensitivity::High, AeadKind::Aes256Gcm).unwrap();
assert_eq!(t.pii_code(), 0x0200);
assert_eq!(t.sensitivity(), Sensitivity::High);
assert_eq!(t.aead_kind(), AeadKind::Aes256Gcm);
}
#[test]
fn shell_pii_registry_rejects_duplicate() {
let mut reg = ShellPiiRegistry::new();
let a = ShellPiiType::new(0x0300, Sensitivity::Low, AeadKind::XChaCha20Poly1305).unwrap();
reg.register(a).unwrap();
let dup = ShellPiiType::new(0x0300, Sensitivity::Medium, AeadKind::Aes256Gcm).unwrap();
let err = reg.register(dup).unwrap_err();
assert!(matches!(err, PiiError::ShellPiiAlreadyRegistered));
assert_eq!(reg.len(), 1);
}
#[test]
fn shell_pii_registry_lookup_returns_entry() {
let mut reg = ShellPiiRegistry::new();
assert!(reg.is_empty());
let entry = ShellPiiType::new(0x0400, Sensitivity::Medium, AeadKind::Aes256GcmSiv).unwrap();
reg.register(entry).unwrap();
let found = reg.get(0x0400).expect("registered marker");
assert_eq!(found.aead_kind(), AeadKind::Aes256GcmSiv);
assert!(reg.get(0x0401).is_none());
}
#[test]
fn shell_pii_registry_rejects_out_of_range() {
let mut reg = ShellPiiRegistry::new();
let leaked = ShellPiiType {
pii_code: 0x00FF,
sensitivity: Sensitivity::Medium,
aead_kind: AeadKind::XChaCha20Poly1305,
};
let err = reg.register(leaked).unwrap_err();
assert!(matches!(err, PiiError::ShellPiiCodeOutOfRange));
}
#[test]
fn rotation_trigger_transitions() {
let dek_id = DekId([0u8; 16]);
let mut c = DekMessageCounter::new(dek_id);
assert_eq!(c.rotation_trigger(), RotationTrigger::Healthy);
c.count = 1u64 << 30;
assert_eq!(c.rotation_trigger(), RotationTrigger::WarnApproachingLimit);
c.count = 1u64 << 31;
assert_eq!(c.rotation_trigger(), RotationTrigger::MustRotate);
}
#[test]
fn counter_saturates_at_u64_max() {
let dek_id = DekId([0u8; 16]);
let mut c = DekMessageCounter::new(dek_id);
c.count = u64::MAX;
c.record_message();
assert_eq!(c.count(), u64::MAX); }
}