use std::collections::HashMap;
use rustrails_support::encryption::{EncryptorError, MessageEncryptor, MessageVerifier};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EncryptedFieldConfig {
pub field: String,
pub deterministic: bool,
}
impl EncryptedFieldConfig {
#[must_use]
pub fn new(field: &str) -> Self {
Self {
field: field.to_owned(),
deterministic: false,
}
}
#[must_use]
pub fn deterministic(mut self) -> Self {
self.deterministic = true;
self
}
}
#[must_use]
pub fn encrypts(field: &str) -> EncryptedFieldConfig {
EncryptedFieldConfig::new(field)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StoredEncryptedValue {
pub ciphertext: String,
pub key_id: String,
pub blind_index: Option<String>,
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum EncryptionError {
#[error("encryption key not found: {0}")]
MissingKey(String),
#[error("encryption failed: {0}")]
Encrypt(#[from] EncryptorError),
#[error("decryption failed: {0}")]
Decrypt(EncryptorError),
#[error("decrypted plaintext is not valid utf-8")]
InvalidUtf8,
}
#[derive(Debug, Clone)]
pub struct EncryptionKeyRing {
active_key_id: String,
keys: HashMap<String, [u8; 32]>,
}
impl EncryptionKeyRing {
#[must_use]
pub fn new(active_key_id: &str, keys: HashMap<String, [u8; 32]>) -> Self {
Self {
active_key_id: active_key_id.to_owned(),
keys,
}
}
pub fn encrypt_value(
&self,
config: &EncryptedFieldConfig,
plaintext: &str,
) -> Result<StoredEncryptedValue, EncryptionError> {
let encryptor = self.encryptor(&self.active_key_id)?;
let ciphertext = encryptor.encrypt_and_sign(plaintext.as_bytes())?;
let blind_index = if config.deterministic {
Some(self.blind_index_for(&self.active_key_id, plaintext)?)
} else {
None
};
Ok(StoredEncryptedValue {
ciphertext,
key_id: self.active_key_id.clone(),
blind_index,
})
}
pub fn decrypt_value(&self, value: &StoredEncryptedValue) -> Result<String, EncryptionError> {
let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
key_ids.sort();
if let Some(index) = key_ids.iter().position(|key_id| key_id == &value.key_id) {
let key_id = key_ids.remove(index);
key_ids.insert(0, key_id);
} else {
return Err(EncryptionError::MissingKey(value.key_id.clone()));
}
let mut last_error = None;
for key_id in key_ids {
let encryptor = self.encryptor(&key_id)?;
match encryptor.decrypt_and_verify(&value.ciphertext) {
Ok(bytes) => {
return String::from_utf8(bytes).map_err(|_| EncryptionError::InvalidUtf8);
}
Err(error) => last_error = Some(error),
}
}
match last_error {
Some(error) => Err(EncryptionError::Decrypt(error)),
None => Err(EncryptionError::MissingKey(value.key_id.clone())),
}
}
pub fn equality_tokens(&self, plaintext: &str) -> Result<Vec<String>, EncryptionError> {
let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
key_ids.sort();
if let Some(index) = key_ids
.iter()
.position(|key_id| key_id == &self.active_key_id)
{
let key_id = key_ids.remove(index);
key_ids.insert(0, key_id);
}
key_ids
.into_iter()
.map(|key_id| self.blind_index_for(&key_id, plaintext))
.collect()
}
fn encryptor(&self, key_id: &str) -> Result<MessageEncryptor, EncryptionError> {
let secret = self
.keys
.get(key_id)
.ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
MessageEncryptor::new(secret).map_err(EncryptionError::Encrypt)
}
fn blind_index_for(&self, key_id: &str, plaintext: &str) -> Result<String, EncryptionError> {
let secret = self
.keys
.get(key_id)
.ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
Ok(MessageVerifier::new(secret).generate(plaintext.as_bytes()))
}
}
pub trait EncryptedAttribute {
fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
&[]
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::LazyLock;
use super::{
EncryptedAttribute, EncryptedFieldConfig, EncryptionError, EncryptionKeyRing, encrypts,
};
struct UserRecord;
impl EncryptedAttribute for UserRecord {
fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
static CONFIGS: LazyLock<Vec<EncryptedFieldConfig>> =
LazyLock::new(|| vec![encrypts("email"), encrypts("ssn").deterministic()]);
CONFIGS.as_slice()
}
}
fn keyring() -> EncryptionKeyRing {
EncryptionKeyRing::new(
"new",
HashMap::from([
("new".to_owned(), [1_u8; 32]),
("old".to_owned(), [2_u8; 32]),
]),
)
}
#[test]
fn encrypts_builder_enables_deterministic_mode() {
let config = encrypts("email").deterministic();
assert!(config.deterministic);
assert_eq!(config.field, "email");
}
#[test]
fn encrypt_value_round_trips_plaintext() {
let stored = keyring()
.encrypt_value(&encrypts("email"), "alice@example.com")
.expect("encryption should succeed");
let plaintext = keyring()
.decrypt_value(&stored)
.expect("decryption should succeed");
assert_eq!(plaintext, "alice@example.com");
}
#[test]
fn deterministic_fields_emit_blind_indexes() {
let stored = keyring()
.encrypt_value(&encrypts("ssn").deterministic(), "123-45-6789")
.expect("encryption should succeed");
assert!(stored.blind_index.is_some());
}
#[test]
fn encrypted_attribute_metadata_preserves_field_order_and_flags() {
assert_eq!(
UserRecord::encrypted_attributes(),
&[encrypts("email"), encrypts("ssn").deterministic()]
);
}
#[test]
fn non_deterministic_fields_do_not_emit_blind_indexes() {
let stored = keyring()
.encrypt_value(&encrypts("email"), "alice@example.com")
.expect("encryption should succeed");
assert_eq!(stored.blind_index, None);
}
#[test]
fn deterministic_encryptions_keep_blind_index_stable_but_change_ciphertext() {
let config = encrypts("ssn").deterministic();
let first = keyring()
.encrypt_value(&config, "123-45-6789")
.expect("encryption should succeed");
let second = keyring()
.encrypt_value(&config, "123-45-6789")
.expect("encryption should succeed");
assert_eq!(first.blind_index, second.blind_index);
assert_ne!(first.ciphertext, second.ciphertext);
}
#[test]
fn equality_tokens_order_active_key_before_rotated_keys() {
let plaintext = "123-45-6789";
let active = keyring()
.encrypt_value(&encrypts("ssn").deterministic(), plaintext)
.expect("encryption should succeed");
let old_ring =
EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
let rotated = old_ring
.encrypt_value(&encrypts("ssn").deterministic(), plaintext)
.expect("encryption should succeed");
let tokens = keyring()
.equality_tokens(plaintext)
.expect("tokens should generate");
assert_eq!(
tokens,
vec![
active
.blind_index
.expect("deterministic field should emit a blind index"),
rotated
.blind_index
.expect("deterministic field should emit a blind index"),
]
);
}
#[test]
fn encrypt_value_returns_missing_key_when_active_key_is_unconfigured() {
let keyring =
EncryptionKeyRing::new("missing", HashMap::from([("old".to_owned(), [2_u8; 32])]));
assert_eq!(
keyring.encrypt_value(&encrypts("email"), "alice@example.com"),
Err(EncryptionError::MissingKey("missing".to_owned()))
);
}
#[test]
fn tampered_ciphertext_returns_decrypt_error() {
let mut stored = keyring()
.encrypt_value(&encrypts("email"), "alice@example.com")
.expect("encryption should succeed");
stored.ciphertext.push_str("tampered");
assert!(matches!(
keyring().decrypt_value(&stored),
Err(EncryptionError::Decrypt(_))
));
}
#[test]
fn ciphertext_is_non_empty_and_does_not_echo_plaintext() {
let plaintext = "alice@example.com";
let stored = keyring()
.encrypt_value(&encrypts("email"), plaintext)
.expect("encryption should succeed");
assert!(!stored.ciphertext.is_empty());
assert_ne!(stored.ciphertext, plaintext);
assert_eq!(stored.key_id, "new");
}
#[test]
fn equality_tokens_are_stable_for_same_plaintext() {
let first = keyring()
.equality_tokens("123-45-6789")
.expect("tokens should generate");
let second = keyring()
.equality_tokens("123-45-6789")
.expect("tokens should generate");
assert_eq!(first, second);
}
#[test]
fn equality_tokens_change_for_different_plaintext() {
let first = keyring()
.equality_tokens("alpha")
.expect("tokens should generate");
let second = keyring()
.equality_tokens("beta")
.expect("tokens should generate");
assert_ne!(first, second);
}
#[test]
fn rotated_keys_can_decrypt_old_ciphertext() {
let old_ring =
EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
let stored = old_ring
.encrypt_value(&encrypts("email"), "legacy@example.com")
.expect("encryption should succeed");
let plaintext = keyring()
.decrypt_value(&stored)
.expect("rotated keys should decrypt legacy values");
assert_eq!(plaintext, "legacy@example.com");
}
#[test]
fn missing_key_returns_error() {
let stored = super::StoredEncryptedValue {
ciphertext: "ciphertext".to_owned(),
key_id: "missing".to_owned(),
blind_index: None,
};
assert_eq!(
keyring().decrypt_value(&stored),
Err(EncryptionError::MissingKey("missing".to_owned()))
);
}
#[test]
fn trait_exposes_declared_encrypted_attributes() {
assert_eq!(UserRecord::encrypted_attributes().len(), 2);
assert!(UserRecord::encrypted_attributes()[1].deterministic);
}
}