use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
use sha2::{Digest, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::encryption::{KEY_LENGTH, NONCE_LENGTH, TAG_LENGTH};
use crate::{CryptoError, EncryptionKey};
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct FieldKey([u8; KEY_LENGTH]);
impl FieldKey {
pub(crate) fn from_derived_bytes(bytes: [u8; KEY_LENGTH]) -> Self {
Self(bytes)
}
pub fn from_bytes(bytes: &[u8; KEY_LENGTH]) -> Self {
Self(*bytes)
}
pub fn as_bytes(&self) -> &[u8; KEY_LENGTH] {
&self.0
}
pub fn derive(parent_key: &EncryptionKey, field_name: &str) -> Self {
let mut hasher = Sha256::new();
hasher.update(parent_key.to_bytes());
hasher.update(b"field:");
hasher.update(field_name.as_bytes());
let derived: [u8; 32] = hasher.finalize().into();
Self::from_derived_bytes(derived)
}
}
impl std::fmt::Debug for FieldKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FieldKey([REDACTED])")
}
}
pub fn encrypt_field(key: &FieldKey, plaintext: &[u8]) -> Vec<u8> {
let mut nonce_bytes = [0u8; NONCE_LENGTH];
getrandom::fill(&mut nonce_bytes).expect("CSPRNG failure is catastrophic");
let cipher = Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid");
let nonce = aes_gcm::Nonce::from(nonce_bytes);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.expect("AES-GCM encryption cannot fail with valid inputs");
let mut result = Vec::with_capacity(NONCE_LENGTH + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
result
}
pub fn decrypt_field(key: &FieldKey, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
const MIN_SIZE: usize = NONCE_LENGTH + TAG_LENGTH;
if ciphertext.len() < MIN_SIZE {
return Err(CryptoError::DecryptionError);
}
let (nonce_bytes, encrypted) = ciphertext.split_at(NONCE_LENGTH);
let nonce_array: [u8; NONCE_LENGTH] = nonce_bytes
.try_into()
.expect("split_at guarantees correct length");
let nonce = aes_gcm::Nonce::from(nonce_array);
let cipher = Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid");
cipher
.decrypt(&nonce, encrypted)
.map_err(|_| CryptoError::DecryptionError)
}
pub const TOKEN_LENGTH: usize = 32;
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Token([u8; TOKEN_LENGTH]);
impl Token {
pub fn from_bytes(bytes: [u8; TOKEN_LENGTH]) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8; TOKEN_LENGTH] {
&self.0
}
pub fn to_hex(&self) -> String {
let mut s = String::with_capacity(TOKEN_LENGTH * 2);
for byte in &self.0 {
use std::fmt::Write;
write!(s, "{byte:02x}").expect("formatting cannot fail");
}
s
}
}
impl std::fmt::Debug for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Token({}...)", &self.to_hex()[..16])
}
}
impl std::fmt::Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
pub fn tokenize(key: &FieldKey, value: &[u8]) -> Token {
let mut hasher = Sha256::new();
hasher.update(key.0);
hasher.update(value);
let hash: [u8; 32] = hasher.finalize().into();
Token(hash)
}
pub fn matches_token(key: &FieldKey, value: &[u8], token: &Token) -> bool {
tokenize(key, value) == *token
}
#[derive(Clone)]
pub struct ReversibleToken {
pub token: Token,
pub encrypted: Vec<u8>,
}
impl ReversibleToken {
pub fn create(key: &FieldKey, value: &[u8]) -> Self {
let token = tokenize(key, value);
let encrypted = encrypt_field(key, value);
Self { token, encrypted }
}
pub fn reveal(&self, key: &FieldKey) -> Result<Vec<u8>, CryptoError> {
decrypt_field(key, &self.encrypted)
}
}
impl std::fmt::Debug for ReversibleToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReversibleToken")
.field("token", &self.token)
.field("encrypted_len", &self.encrypted.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn field_key_derivation_is_deterministic() {
let parent = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent, "patient_ssn");
let key2 = FieldKey::derive(&parent, "patient_ssn");
assert_eq!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn different_field_names_produce_different_keys() {
let parent = EncryptionKey::generate();
let ssn_key = FieldKey::derive(&parent, "ssn");
let dob_key = FieldKey::derive(&parent, "dob");
assert_ne!(ssn_key.as_bytes(), dob_key.as_bytes());
}
#[test]
fn encrypt_decrypt_roundtrip() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "test_field");
let plaintext = b"sensitive data";
let ciphertext = encrypt_field(&key, plaintext);
let decrypted = decrypt_field(&key, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn randomized_encryption_produces_different_ciphertext() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "test_field");
let plaintext = b"sensitive data";
let ct1 = encrypt_field(&key, plaintext);
let ct2 = encrypt_field(&key, plaintext);
assert_ne!(ct1, ct2);
assert_eq!(decrypt_field(&key, &ct1).unwrap(), plaintext);
assert_eq!(decrypt_field(&key, &ct2).unwrap(), plaintext);
}
#[test]
fn wrong_key_fails_decryption() {
let parent1 = EncryptionKey::generate();
let parent2 = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent1, "field");
let key2 = FieldKey::derive(&parent2, "field");
let ciphertext = encrypt_field(&key1, b"secret");
let result = decrypt_field(&key2, &ciphertext);
assert!(result.is_err());
}
#[test]
fn tokenization_is_deterministic() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "ssn");
let token1 = tokenize(&key, b"123-45-6789");
let token2 = tokenize(&key, b"123-45-6789");
assert_eq!(token1, token2);
}
#[test]
fn different_values_produce_different_tokens() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "ssn");
let token1 = tokenize(&key, b"123-45-6789");
let token2 = tokenize(&key, b"987-65-4321");
assert_ne!(token1, token2);
}
#[test]
fn different_keys_produce_different_tokens() {
let parent = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent, "field1");
let key2 = FieldKey::derive(&parent, "field2");
let token1 = tokenize(&key1, b"same value");
let token2 = tokenize(&key2, b"same value");
assert_ne!(token1, token2);
}
#[test]
fn matches_token_works() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let value = b"test value";
let token = tokenize(&key, value);
assert!(matches_token(&key, value, &token));
assert!(!matches_token(&key, b"other value", &token));
}
#[test]
fn reversible_token_roundtrip() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let value = b"original value";
let rt = ReversibleToken::create(&key, value);
assert_eq!(rt.token, tokenize(&key, value));
let revealed = rt.reveal(&key).unwrap();
assert_eq!(revealed, value);
}
#[test]
fn token_hex_formatting() {
let token = Token::from_bytes([0xab; 32]);
let hex = token.to_hex();
assert_eq!(hex.len(), 64);
assert!(hex.chars().all(|c| c == 'a' || c == 'b'));
}
#[test]
fn empty_plaintext_encrypts() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let encrypted = encrypt_field(&key, b"");
let decrypted = decrypt_field(&key, &encrypted).unwrap();
assert!(decrypted.is_empty());
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_field_key_derivation_deterministic(field_name in "\\PC{1,100}") {
let parent = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent, &field_name);
let key2 = FieldKey::derive(&parent, &field_name);
prop_assert_eq!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn prop_different_fields_different_keys(
fields in ("\\PC{1,100}", "\\PC{1,100}").prop_filter("fields must differ", |(f1, f2)| f1 != f2),
) {
let (field1, field2) = fields;
let parent = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent, &field1);
let key2 = FieldKey::derive(&parent, &field2);
prop_assert_ne!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn prop_field_encrypt_decrypt_roundtrip(
plaintext in prop::collection::vec(any::<u8>(), 0..10000)
) {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "test_field");
let encrypted = encrypt_field(&key, &plaintext);
let decrypted = decrypt_field(&key, &encrypted)
.expect("decryption should succeed");
prop_assert_eq!(decrypted, plaintext);
}
#[test]
fn prop_randomized_encryption_differs(
plaintext in prop::collection::vec(any::<u8>(), 1..1000)
) {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let ct1 = encrypt_field(&key, &plaintext);
let ct2 = encrypt_field(&key, &plaintext);
let decrypted1 = decrypt_field(&key, &ct1).unwrap();
let decrypted2 = decrypt_field(&key, &ct2).unwrap();
prop_assert_eq!(&decrypted1[..], &plaintext[..]);
prop_assert_eq!(&decrypted2[..], &plaintext[..]);
prop_assert_ne!(ct1, ct2);
}
#[test]
fn prop_tokenization_deterministic(
plaintext in prop::collection::vec(any::<u8>(), 1..1000)
) {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let token1 = tokenize(&key, &plaintext);
let token2 = tokenize(&key, &plaintext);
prop_assert_eq!(token1, token2);
}
#[test]
fn prop_different_values_different_tokens(
values in (prop::collection::vec(any::<u8>(), 1..1000), prop::collection::vec(any::<u8>(), 1..1000))
.prop_filter("values must differ", |(v1, v2)| v1 != v2),
) {
let (value1, value2) = values;
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let token1 = tokenize(&key, &value1);
let token2 = tokenize(&key, &value2);
prop_assert_ne!(token1, token2);
}
#[test]
fn prop_matches_token_consistent(
value in prop::collection::vec(any::<u8>(), 1..1000)
) {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let token = tokenize(&key, &value);
prop_assert!(matches_token(&key, &value, &token));
}
#[test]
fn prop_matches_token_rejects_different(
values in (prop::collection::vec(any::<u8>(), 1..1000), prop::collection::vec(any::<u8>(), 1..1000))
.prop_filter("values must differ", |(v1, v2)| v1 != v2),
) {
let (value1, value2) = values;
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let token1 = tokenize(&key, &value1);
prop_assert!(!matches_token(&key, &value2, &token1));
}
#[test]
fn prop_reversible_token_roundtrip(
value in prop::collection::vec(any::<u8>(), 1..1000)
) {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let rt = ReversibleToken::create(&key, &value);
prop_assert_eq!(rt.token, tokenize(&key, &value));
let revealed = rt.reveal(&key).expect("reveal should succeed");
prop_assert_eq!(revealed, value);
}
#[test]
fn prop_field_key_serialization(
plaintext in prop::collection::vec(any::<u8>(), 1..1000)
) {
let parent = EncryptionKey::generate();
let original = FieldKey::derive(&parent, "field");
let bytes = original.as_bytes();
let restored = FieldKey::from_bytes(bytes);
let token1 = tokenize(&original, &plaintext);
let token2 = tokenize(&restored, &plaintext);
prop_assert_eq!(token1, token2);
let encrypted = encrypt_field(&original, &plaintext);
let decrypted = decrypt_field(&restored, &encrypted)
.expect("restored key should decrypt");
prop_assert_eq!(decrypted, plaintext);
}
#[test]
fn prop_wrong_key_fails_decryption(
plaintext in prop::collection::vec(any::<u8>(), 1..1000)
) {
let parent1 = EncryptionKey::generate();
let parent2 = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent1, "field");
let key2 = FieldKey::derive(&parent2, "field");
let encrypted = encrypt_field(&key1, &plaintext);
let result = decrypt_field(&key2, &encrypted);
prop_assert!(result.is_err(), "wrong key must fail decryption");
}
}
use test_case::test_case;
#[test_case("ssn"; "social security number")]
#[test_case("patient_id"; "patient identifier")]
#[test_case("email_address"; "email")]
#[test_case(""; "empty field name")]
#[test_case("a"; "single char")]
#[test_case("very_long_field_name_with_many_underscores_and_characters_to_test_edge_cases"; "long name")]
fn field_key_derivation_various_names(field_name: &str) {
let parent = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent, field_name);
let key2 = FieldKey::derive(&parent, field_name);
assert_eq!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn different_parent_keys_produce_different_field_keys() {
let parent1 = EncryptionKey::generate();
let parent2 = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent1, "same_field");
let key2 = FieldKey::derive(&parent2, "same_field");
assert_ne!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn token_hex_length() {
let token = Token::from_bytes([0u8; 32]);
let hex = token.to_hex();
assert_eq!(hex.len(), 64); assert!(hex.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn large_plaintext_field_encryption() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "large_field");
let plaintext = vec![0xCC; 10 * 1024 * 1024];
let encrypted = encrypt_field(&key, &plaintext);
let decrypted = decrypt_field(&key, &encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn corrupted_field_ciphertext_fails() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let encrypted = encrypt_field(&key, b"sensitive data");
let mut corrupted = encrypted.clone();
if !corrupted.is_empty() {
corrupted[0] ^= 0x01;
}
let result = decrypt_field(&key, &corrupted);
assert!(result.is_err(), "corrupted field ciphertext must fail");
}
#[test]
fn token_from_empty_value() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let token = tokenize(&key, b"");
assert_eq!(token.as_bytes().len(), 32);
assert!(matches_token(&key, b"", &token));
assert!(!matches_token(&key, b"non-empty", &token));
}
#[test]
fn reversible_token_wrong_key_fails() {
let parent1 = EncryptionKey::generate();
let parent2 = EncryptionKey::generate();
let key1 = FieldKey::derive(&parent1, "field");
let key2 = FieldKey::derive(&parent2, "field");
let value = b"secret value";
let rt = ReversibleToken::create(&key1, value);
let result = rt.reveal(&key2);
assert!(result.is_err(), "wrong key must fail to reveal token");
}
#[test]
fn field_encryption_preserves_binary_data() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "binary_field");
let binary_data: Vec<u8> = (0..=255).collect();
let encrypted = encrypt_field(&key, &binary_data);
let decrypted = decrypt_field(&key, &encrypted).unwrap();
assert_eq!(decrypted, binary_data);
}
#[test]
fn tokenization_collision_resistance() {
let parent = EncryptionKey::generate();
let key = FieldKey::derive(&parent, "field");
let mut tokens = std::collections::HashSet::new();
for i in 0..1000 {
let value = format!("value_{i}");
let token = tokenize(&key, value.as_bytes());
assert!(
tokens.insert(token),
"token collision detected for different values"
);
}
}
#[test]
fn field_key_bytes_roundtrip() {
let parent = EncryptionKey::generate();
let original = FieldKey::derive(&parent, "test");
let bytes = original.as_bytes();
let restored = FieldKey::from_bytes(bytes);
let test_value = b"test value";
let token1 = tokenize(&original, test_value);
let token2 = tokenize(&restored, test_value);
assert_eq!(token1, token2);
}
}