1use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
47use sha2::{Digest, Sha256};
48use zeroize::{Zeroize, ZeroizeOnDrop};
49
50use crate::encryption::{KEY_LENGTH, NONCE_LENGTH, TAG_LENGTH};
51use crate::{CryptoError, EncryptionKey};
52
53#[derive(Zeroize, ZeroizeOnDrop)]
68pub struct FieldKey([u8; KEY_LENGTH]);
69
70impl FieldKey {
71 pub(crate) fn from_derived_bytes(bytes: [u8; KEY_LENGTH]) -> Self {
79 Self(bytes)
80 }
81
82 pub fn from_bytes(bytes: &[u8; KEY_LENGTH]) -> Self {
86 Self(*bytes)
87 }
88
89 pub fn as_bytes(&self) -> &[u8; KEY_LENGTH] {
91 &self.0
92 }
93
94 pub fn derive(parent_key: &EncryptionKey, field_name: &str) -> Self {
114 let mut hasher = Sha256::new();
116 hasher.update(parent_key.to_bytes());
117 hasher.update(b"field:");
118 hasher.update(field_name.as_bytes());
119 let derived: [u8; 32] = hasher.finalize().into();
120
121 Self::from_derived_bytes(derived)
122 }
123}
124
125impl std::fmt::Debug for FieldKey {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 write!(f, "FieldKey([REDACTED])")
129 }
130}
131
132pub fn encrypt_field(key: &FieldKey, plaintext: &[u8]) -> Vec<u8> {
149 let mut nonce_bytes = [0u8; NONCE_LENGTH];
151 getrandom::fill(&mut nonce_bytes).expect("CSPRNG failure is catastrophic");
152
153 let cipher = Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid");
154 let nonce = aes_gcm::Nonce::from(nonce_bytes);
155
156 let ciphertext = cipher
157 .encrypt(&nonce, plaintext)
158 .expect("AES-GCM encryption cannot fail with valid inputs");
159
160 let mut result = Vec::with_capacity(NONCE_LENGTH + ciphertext.len());
162 result.extend_from_slice(&nonce_bytes);
163 result.extend_from_slice(&ciphertext);
164 result
165}
166
167pub fn decrypt_field(key: &FieldKey, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
176 const MIN_SIZE: usize = NONCE_LENGTH + TAG_LENGTH;
178 if ciphertext.len() < MIN_SIZE {
179 return Err(CryptoError::DecryptionError);
180 }
181
182 let (nonce_bytes, encrypted) = ciphertext.split_at(NONCE_LENGTH);
183 let nonce_array: [u8; NONCE_LENGTH] = nonce_bytes
184 .try_into()
185 .expect("split_at guarantees correct length");
186 let nonce = aes_gcm::Nonce::from(nonce_array);
187
188 let cipher = Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid");
189
190 cipher
191 .decrypt(&nonce, encrypted)
192 .map_err(|_| CryptoError::DecryptionError)
193}
194
195pub const TOKEN_LENGTH: usize = 32;
201
202#[derive(Clone, Copy, PartialEq, Eq, Hash)]
215pub struct Token([u8; TOKEN_LENGTH]);
216
217impl Token {
218 pub fn from_bytes(bytes: [u8; TOKEN_LENGTH]) -> Self {
220 Self(bytes)
221 }
222
223 pub fn as_bytes(&self) -> &[u8; TOKEN_LENGTH] {
225 &self.0
226 }
227
228 pub fn to_hex(&self) -> String {
230 let mut s = String::with_capacity(TOKEN_LENGTH * 2);
231 for byte in &self.0 {
232 use std::fmt::Write;
233 write!(s, "{byte:02x}").expect("formatting cannot fail");
234 }
235 s
236 }
237}
238
239impl std::fmt::Debug for Token {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 write!(f, "Token({}...)", &self.to_hex()[..16])
242 }
243}
244
245impl std::fmt::Display for Token {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 write!(f, "{}", self.to_hex())
248 }
249}
250
251pub fn tokenize(key: &FieldKey, value: &[u8]) -> Token {
280 let mut hasher = Sha256::new();
283 hasher.update(key.0);
284 hasher.update(value);
285 let hash: [u8; 32] = hasher.finalize().into();
286
287 Token(hash)
288}
289
290pub fn matches_token(key: &FieldKey, value: &[u8], token: &Token) -> bool {
295 tokenize(key, value) == *token
296}
297
298#[derive(Clone)]
308pub struct ReversibleToken {
309 pub token: Token,
311 pub encrypted: Vec<u8>,
313}
314
315impl ReversibleToken {
316 pub fn create(key: &FieldKey, value: &[u8]) -> Self {
321 let token = tokenize(key, value);
322 let encrypted = encrypt_field(key, value);
323 Self { token, encrypted }
324 }
325
326 pub fn reveal(&self, key: &FieldKey) -> Result<Vec<u8>, CryptoError> {
332 decrypt_field(key, &self.encrypted)
333 }
334}
335
336impl std::fmt::Debug for ReversibleToken {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 f.debug_struct("ReversibleToken")
339 .field("token", &self.token)
340 .field("encrypted_len", &self.encrypted.len())
341 .finish()
342 }
343}
344
345#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn field_key_derivation_is_deterministic() {
355 let parent = EncryptionKey::generate();
356 let key1 = FieldKey::derive(&parent, "patient_ssn");
357 let key2 = FieldKey::derive(&parent, "patient_ssn");
358
359 assert_eq!(key1.as_bytes(), key2.as_bytes());
360 }
361
362 #[test]
363 fn different_field_names_produce_different_keys() {
364 let parent = EncryptionKey::generate();
365 let ssn_key = FieldKey::derive(&parent, "ssn");
366 let dob_key = FieldKey::derive(&parent, "dob");
367
368 assert_ne!(ssn_key.as_bytes(), dob_key.as_bytes());
369 }
370
371 #[test]
372 fn encrypt_decrypt_roundtrip() {
373 let parent = EncryptionKey::generate();
374 let key = FieldKey::derive(&parent, "test_field");
375
376 let plaintext = b"sensitive data";
377 let ciphertext = encrypt_field(&key, plaintext);
378 let decrypted = decrypt_field(&key, &ciphertext).unwrap();
379
380 assert_eq!(decrypted, plaintext);
381 }
382
383 #[test]
384 fn randomized_encryption_produces_different_ciphertext() {
385 let parent = EncryptionKey::generate();
386 let key = FieldKey::derive(&parent, "test_field");
387
388 let plaintext = b"sensitive data";
389 let ct1 = encrypt_field(&key, plaintext);
390 let ct2 = encrypt_field(&key, plaintext);
391
392 assert_ne!(ct1, ct2);
394
395 assert_eq!(decrypt_field(&key, &ct1).unwrap(), plaintext);
397 assert_eq!(decrypt_field(&key, &ct2).unwrap(), plaintext);
398 }
399
400 #[test]
401 fn wrong_key_fails_decryption() {
402 let parent1 = EncryptionKey::generate();
403 let parent2 = EncryptionKey::generate();
404 let key1 = FieldKey::derive(&parent1, "field");
405 let key2 = FieldKey::derive(&parent2, "field");
406
407 let ciphertext = encrypt_field(&key1, b"secret");
408 let result = decrypt_field(&key2, &ciphertext);
409
410 assert!(result.is_err());
411 }
412
413 #[test]
414 fn tokenization_is_deterministic() {
415 let parent = EncryptionKey::generate();
416 let key = FieldKey::derive(&parent, "ssn");
417
418 let token1 = tokenize(&key, b"123-45-6789");
419 let token2 = tokenize(&key, b"123-45-6789");
420
421 assert_eq!(token1, token2);
422 }
423
424 #[test]
425 fn different_values_produce_different_tokens() {
426 let parent = EncryptionKey::generate();
427 let key = FieldKey::derive(&parent, "ssn");
428
429 let token1 = tokenize(&key, b"123-45-6789");
430 let token2 = tokenize(&key, b"987-65-4321");
431
432 assert_ne!(token1, token2);
433 }
434
435 #[test]
436 fn different_keys_produce_different_tokens() {
437 let parent = EncryptionKey::generate();
438 let key1 = FieldKey::derive(&parent, "field1");
439 let key2 = FieldKey::derive(&parent, "field2");
440
441 let token1 = tokenize(&key1, b"same value");
442 let token2 = tokenize(&key2, b"same value");
443
444 assert_ne!(token1, token2);
445 }
446
447 #[test]
448 fn matches_token_works() {
449 let parent = EncryptionKey::generate();
450 let key = FieldKey::derive(&parent, "field");
451
452 let value = b"test value";
453 let token = tokenize(&key, value);
454
455 assert!(matches_token(&key, value, &token));
456 assert!(!matches_token(&key, b"other value", &token));
457 }
458
459 #[test]
460 fn reversible_token_roundtrip() {
461 let parent = EncryptionKey::generate();
462 let key = FieldKey::derive(&parent, "field");
463
464 let value = b"original value";
465 let rt = ReversibleToken::create(&key, value);
466
467 assert_eq!(rt.token, tokenize(&key, value));
469
470 let revealed = rt.reveal(&key).unwrap();
472 assert_eq!(revealed, value);
473 }
474
475 #[test]
476 fn token_hex_formatting() {
477 let token = Token::from_bytes([0xab; 32]);
478 let hex = token.to_hex();
479
480 assert_eq!(hex.len(), 64);
481 assert!(hex.chars().all(|c| c == 'a' || c == 'b'));
482 }
483
484 #[test]
485 fn empty_plaintext_encrypts() {
486 let parent = EncryptionKey::generate();
487 let key = FieldKey::derive(&parent, "field");
488
489 let encrypted = encrypt_field(&key, b"");
490 let decrypted = decrypt_field(&key, &encrypted).unwrap();
491
492 assert!(decrypted.is_empty());
493 }
494
495 use proptest::prelude::*;
500
501 proptest! {
502 #[test]
504 fn prop_field_key_derivation_deterministic(field_name in "\\PC{1,100}") {
505 let parent = EncryptionKey::generate();
506 let key1 = FieldKey::derive(&parent, &field_name);
507 let key2 = FieldKey::derive(&parent, &field_name);
508
509 prop_assert_eq!(key1.as_bytes(), key2.as_bytes());
510 }
511
512 #[test]
514 fn prop_different_fields_different_keys(
515 fields in ("\\PC{1,100}", "\\PC{1,100}").prop_filter("fields must differ", |(f1, f2)| f1 != f2),
516 ) {
517 let (field1, field2) = fields;
518 let parent = EncryptionKey::generate();
519 let key1 = FieldKey::derive(&parent, &field1);
520 let key2 = FieldKey::derive(&parent, &field2);
521
522 prop_assert_ne!(key1.as_bytes(), key2.as_bytes());
523 }
524
525 #[test]
527 fn prop_field_encrypt_decrypt_roundtrip(
528 plaintext in prop::collection::vec(any::<u8>(), 0..10000)
529 ) {
530 let parent = EncryptionKey::generate();
531 let key = FieldKey::derive(&parent, "test_field");
532
533 let encrypted = encrypt_field(&key, &plaintext);
534 let decrypted = decrypt_field(&key, &encrypted)
535 .expect("decryption should succeed");
536
537 prop_assert_eq!(decrypted, plaintext);
538 }
539
540 #[test]
542 fn prop_randomized_encryption_differs(
543 plaintext in prop::collection::vec(any::<u8>(), 1..1000)
544 ) {
545 let parent = EncryptionKey::generate();
546 let key = FieldKey::derive(&parent, "field");
547
548 let ct1 = encrypt_field(&key, &plaintext);
549 let ct2 = encrypt_field(&key, &plaintext);
550
551 let decrypted1 = decrypt_field(&key, &ct1).unwrap();
553 let decrypted2 = decrypt_field(&key, &ct2).unwrap();
554 prop_assert_eq!(&decrypted1[..], &plaintext[..]);
555 prop_assert_eq!(&decrypted2[..], &plaintext[..]);
556
557 prop_assert_ne!(ct1, ct2);
559 }
560
561 #[test]
563 fn prop_tokenization_deterministic(
564 plaintext in prop::collection::vec(any::<u8>(), 1..1000)
565 ) {
566 let parent = EncryptionKey::generate();
567 let key = FieldKey::derive(&parent, "field");
568
569 let token1 = tokenize(&key, &plaintext);
570 let token2 = tokenize(&key, &plaintext);
571
572 prop_assert_eq!(token1, token2);
573 }
574
575 #[test]
577 fn prop_different_values_different_tokens(
578 values in (prop::collection::vec(any::<u8>(), 1..1000), prop::collection::vec(any::<u8>(), 1..1000))
579 .prop_filter("values must differ", |(v1, v2)| v1 != v2),
580 ) {
581 let (value1, value2) = values;
582 let parent = EncryptionKey::generate();
583 let key = FieldKey::derive(&parent, "field");
584
585 let token1 = tokenize(&key, &value1);
586 let token2 = tokenize(&key, &value2);
587
588 prop_assert_ne!(token1, token2);
589 }
590
591 #[test]
593 fn prop_matches_token_consistent(
594 value in prop::collection::vec(any::<u8>(), 1..1000)
595 ) {
596 let parent = EncryptionKey::generate();
597 let key = FieldKey::derive(&parent, "field");
598
599 let token = tokenize(&key, &value);
600
601 prop_assert!(matches_token(&key, &value, &token));
603 }
604
605 #[test]
607 fn prop_matches_token_rejects_different(
608 values in (prop::collection::vec(any::<u8>(), 1..1000), prop::collection::vec(any::<u8>(), 1..1000))
609 .prop_filter("values must differ", |(v1, v2)| v1 != v2),
610 ) {
611 let (value1, value2) = values;
612 let parent = EncryptionKey::generate();
613 let key = FieldKey::derive(&parent, "field");
614
615 let token1 = tokenize(&key, &value1);
616
617 prop_assert!(!matches_token(&key, &value2, &token1));
619 }
620
621 #[test]
623 fn prop_reversible_token_roundtrip(
624 value in prop::collection::vec(any::<u8>(), 1..1000)
625 ) {
626 let parent = EncryptionKey::generate();
627 let key = FieldKey::derive(&parent, "field");
628
629 let rt = ReversibleToken::create(&key, &value);
630
631 prop_assert_eq!(rt.token, tokenize(&key, &value));
633
634 let revealed = rt.reveal(&key).expect("reveal should succeed");
636 prop_assert_eq!(revealed, value);
637 }
638
639 #[test]
641 fn prop_field_key_serialization(
642 plaintext in prop::collection::vec(any::<u8>(), 1..1000)
643 ) {
644 let parent = EncryptionKey::generate();
645 let original = FieldKey::derive(&parent, "field");
646
647 let bytes = original.as_bytes();
648 let restored = FieldKey::from_bytes(bytes);
649
650 let token1 = tokenize(&original, &plaintext);
652 let token2 = tokenize(&restored, &plaintext);
653 prop_assert_eq!(token1, token2);
654
655 let encrypted = encrypt_field(&original, &plaintext);
657 let decrypted = decrypt_field(&restored, &encrypted)
658 .expect("restored key should decrypt");
659 prop_assert_eq!(decrypted, plaintext);
660 }
661
662 #[test]
664 fn prop_wrong_key_fails_decryption(
665 plaintext in prop::collection::vec(any::<u8>(), 1..1000)
666 ) {
667 let parent1 = EncryptionKey::generate();
668 let parent2 = EncryptionKey::generate();
669 let key1 = FieldKey::derive(&parent1, "field");
670 let key2 = FieldKey::derive(&parent2, "field");
671
672 let encrypted = encrypt_field(&key1, &plaintext);
673 let result = decrypt_field(&key2, &encrypted);
674
675 prop_assert!(result.is_err(), "wrong key must fail decryption");
676 }
677 }
678
679 use test_case::test_case;
684
685 #[test_case("ssn"; "social security number")]
686 #[test_case("patient_id"; "patient identifier")]
687 #[test_case("email_address"; "email")]
688 #[test_case(""; "empty field name")]
689 #[test_case("a"; "single char")]
690 #[test_case("very_long_field_name_with_many_underscores_and_characters_to_test_edge_cases"; "long name")]
691 fn field_key_derivation_various_names(field_name: &str) {
692 let parent = EncryptionKey::generate();
693 let key1 = FieldKey::derive(&parent, field_name);
694 let key2 = FieldKey::derive(&parent, field_name);
695
696 assert_eq!(key1.as_bytes(), key2.as_bytes());
697 }
698
699 #[test]
700 fn different_parent_keys_produce_different_field_keys() {
701 let parent1 = EncryptionKey::generate();
702 let parent2 = EncryptionKey::generate();
703
704 let key1 = FieldKey::derive(&parent1, "same_field");
705 let key2 = FieldKey::derive(&parent2, "same_field");
706
707 assert_ne!(key1.as_bytes(), key2.as_bytes());
708 }
709
710 #[test]
711 fn token_hex_length() {
712 let token = Token::from_bytes([0u8; 32]);
713 let hex = token.to_hex();
714
715 assert_eq!(hex.len(), 64); assert!(hex.chars().all(|c| c.is_ascii_hexdigit()));
717 }
718
719 #[test]
720 fn large_plaintext_field_encryption() {
721 let parent = EncryptionKey::generate();
722 let key = FieldKey::derive(&parent, "large_field");
723
724 let plaintext = vec![0xCC; 10 * 1024 * 1024];
726
727 let encrypted = encrypt_field(&key, &plaintext);
728 let decrypted = decrypt_field(&key, &encrypted).unwrap();
729
730 assert_eq!(decrypted, plaintext);
731 }
732
733 #[test]
734 fn corrupted_field_ciphertext_fails() {
735 let parent = EncryptionKey::generate();
736 let key = FieldKey::derive(&parent, "field");
737
738 let encrypted = encrypt_field(&key, b"sensitive data");
739 let mut corrupted = encrypted.clone();
740
741 if !corrupted.is_empty() {
743 corrupted[0] ^= 0x01;
744 }
745
746 let result = decrypt_field(&key, &corrupted);
747 assert!(result.is_err(), "corrupted field ciphertext must fail");
748 }
749
750 #[test]
751 fn token_from_empty_value() {
752 let parent = EncryptionKey::generate();
753 let key = FieldKey::derive(&parent, "field");
754
755 let token = tokenize(&key, b"");
756
757 assert_eq!(token.as_bytes().len(), 32);
759
760 assert!(matches_token(&key, b"", &token));
762 assert!(!matches_token(&key, b"non-empty", &token));
763 }
764
765 #[test]
766 fn reversible_token_wrong_key_fails() {
767 let parent1 = EncryptionKey::generate();
768 let parent2 = EncryptionKey::generate();
769 let key1 = FieldKey::derive(&parent1, "field");
770 let key2 = FieldKey::derive(&parent2, "field");
771
772 let value = b"secret value";
773 let rt = ReversibleToken::create(&key1, value);
774
775 let result = rt.reveal(&key2);
777 assert!(result.is_err(), "wrong key must fail to reveal token");
778 }
779
780 #[test]
781 fn field_encryption_preserves_binary_data() {
782 let parent = EncryptionKey::generate();
783 let key = FieldKey::derive(&parent, "binary_field");
784
785 let binary_data: Vec<u8> = (0..=255).collect();
787
788 let encrypted = encrypt_field(&key, &binary_data);
789 let decrypted = decrypt_field(&key, &encrypted).unwrap();
790
791 assert_eq!(decrypted, binary_data);
792 }
793
794 #[test]
795 fn tokenization_collision_resistance() {
796 let parent = EncryptionKey::generate();
797 let key = FieldKey::derive(&parent, "field");
798
799 let mut tokens = std::collections::HashSet::new();
801
802 for i in 0..1000 {
803 let value = format!("value_{i}");
804 let token = tokenize(&key, value.as_bytes());
805 assert!(
806 tokens.insert(token),
807 "token collision detected for different values"
808 );
809 }
810 }
811
812 #[test]
813 fn field_key_bytes_roundtrip() {
814 let parent = EncryptionKey::generate();
815 let original = FieldKey::derive(&parent, "test");
816
817 let bytes = original.as_bytes();
818 let restored = FieldKey::from_bytes(bytes);
819
820 let test_value = b"test value";
822 let token1 = tokenize(&original, test_value);
823 let token2 = tokenize(&restored, test_value);
824
825 assert_eq!(token1, token2);
826 }
827}