rustrails-record 0.1.2

ORM layer (ActiveRecord equivalent)
Documentation
use std::collections::HashMap;

use rustrails_support::encryption::{EncryptorError, MessageEncryptor, MessageVerifier};

/// Metadata describing an encrypted attribute.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EncryptedFieldConfig {
    /// The field name.
    pub field: String,
    /// Whether equality-query support should emit a deterministic blind index.
    pub deterministic: bool,
}

impl EncryptedFieldConfig {
    /// Creates encryption metadata for `field`.
    #[must_use]
    pub fn new(field: &str) -> Self {
        Self {
            field: field.to_owned(),
            deterministic: false,
        }
    }

    /// Enables deterministic equality-query support.
    #[must_use]
    pub fn deterministic(mut self) -> Self {
        self.deterministic = true;
        self
    }
}

/// Declares an encrypted attribute.
#[must_use]
pub fn encrypts(field: &str) -> EncryptedFieldConfig {
    EncryptedFieldConfig::new(field)
}

/// Stored encrypted data plus optional deterministic blind index.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StoredEncryptedValue {
    /// Encrypted ciphertext envelope.
    pub ciphertext: String,
    /// Identifier for the key used to encrypt the ciphertext.
    pub key_id: String,
    /// Deterministic equality token when enabled.
    pub blind_index: Option<String>,
}

/// Errors returned by encrypted-attribute helpers.
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum EncryptionError {
    /// The configured key id does not exist.
    #[error("encryption key not found: {0}")]
    MissingKey(String),
    /// Encryption failed.
    #[error("encryption failed: {0}")]
    Encrypt(#[from] EncryptorError),
    /// Decryption failed.
    #[error("decryption failed: {0}")]
    Decrypt(EncryptorError),
    /// Generated plaintext was not valid UTF-8.
    #[error("decrypted plaintext is not valid utf-8")]
    InvalidUtf8,
}

/// Keyring used for encryption, decryption, and key rotation.
#[derive(Debug, Clone)]
pub struct EncryptionKeyRing {
    active_key_id: String,
    keys: HashMap<String, [u8; 32]>,
}

impl EncryptionKeyRing {
    /// Creates a keyring with an active key identifier.
    #[must_use]
    pub fn new(active_key_id: &str, keys: HashMap<String, [u8; 32]>) -> Self {
        Self {
            active_key_id: active_key_id.to_owned(),
            keys,
        }
    }

    /// Encrypts `plaintext` according to the provided field configuration.
    pub fn encrypt_value(
        &self,
        config: &EncryptedFieldConfig,
        plaintext: &str,
    ) -> Result<StoredEncryptedValue, EncryptionError> {
        let encryptor = self.encryptor(&self.active_key_id)?;
        let ciphertext = encryptor.encrypt_and_sign(plaintext.as_bytes())?;
        let blind_index = if config.deterministic {
            Some(self.blind_index_for(&self.active_key_id, plaintext)?)
        } else {
            None
        };

        Ok(StoredEncryptedValue {
            ciphertext,
            key_id: self.active_key_id.clone(),
            blind_index,
        })
    }

    /// Decrypts an encrypted value, trying the recorded key first and then rotated keys.
    pub fn decrypt_value(&self, value: &StoredEncryptedValue) -> Result<String, EncryptionError> {
        let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
        key_ids.sort();
        if let Some(index) = key_ids.iter().position(|key_id| key_id == &value.key_id) {
            let key_id = key_ids.remove(index);
            key_ids.insert(0, key_id);
        } else {
            return Err(EncryptionError::MissingKey(value.key_id.clone()));
        }

        let mut last_error = None;
        for key_id in key_ids {
            let encryptor = self.encryptor(&key_id)?;
            match encryptor.decrypt_and_verify(&value.ciphertext) {
                Ok(bytes) => {
                    return String::from_utf8(bytes).map_err(|_| EncryptionError::InvalidUtf8);
                }
                Err(error) => last_error = Some(error),
            }
        }

        match last_error {
            Some(error) => Err(EncryptionError::Decrypt(error)),
            None => Err(EncryptionError::MissingKey(value.key_id.clone())),
        }
    }

    /// Returns equality tokens for every configured key, newest key first.
    pub fn equality_tokens(&self, plaintext: &str) -> Result<Vec<String>, EncryptionError> {
        let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
        key_ids.sort();
        if let Some(index) = key_ids
            .iter()
            .position(|key_id| key_id == &self.active_key_id)
        {
            let key_id = key_ids.remove(index);
            key_ids.insert(0, key_id);
        }

        key_ids
            .into_iter()
            .map(|key_id| self.blind_index_for(&key_id, plaintext))
            .collect()
    }

    fn encryptor(&self, key_id: &str) -> Result<MessageEncryptor, EncryptionError> {
        let secret = self
            .keys
            .get(key_id)
            .ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
        MessageEncryptor::new(secret).map_err(EncryptionError::Encrypt)
    }

    fn blind_index_for(&self, key_id: &str, plaintext: &str) -> Result<String, EncryptionError> {
        let secret = self
            .keys
            .get(key_id)
            .ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
        Ok(MessageVerifier::new(secret).generate(plaintext.as_bytes()))
    }
}

/// Trait implemented by records that declare encrypted attributes.
pub trait EncryptedAttribute {
    /// Returns encrypted-attribute metadata for the record type.
    fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
        &[]
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;
    use std::sync::LazyLock;

    use super::{
        EncryptedAttribute, EncryptedFieldConfig, EncryptionError, EncryptionKeyRing, encrypts,
    };

    struct UserRecord;

    impl EncryptedAttribute for UserRecord {
        fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
            static CONFIGS: LazyLock<Vec<EncryptedFieldConfig>> =
                LazyLock::new(|| vec![encrypts("email"), encrypts("ssn").deterministic()]);
            CONFIGS.as_slice()
        }
    }

    fn keyring() -> EncryptionKeyRing {
        EncryptionKeyRing::new(
            "new",
            HashMap::from([
                ("new".to_owned(), [1_u8; 32]),
                ("old".to_owned(), [2_u8; 32]),
            ]),
        )
    }

    #[test]
    fn encrypts_builder_enables_deterministic_mode() {
        let config = encrypts("email").deterministic();
        assert!(config.deterministic);
        assert_eq!(config.field, "email");
    }

    #[test]
    fn encrypt_value_round_trips_plaintext() {
        let stored = keyring()
            .encrypt_value(&encrypts("email"), "alice@example.com")
            .expect("encryption should succeed");
        let plaintext = keyring()
            .decrypt_value(&stored)
            .expect("decryption should succeed");
        assert_eq!(plaintext, "alice@example.com");
    }

    #[test]
    fn deterministic_fields_emit_blind_indexes() {
        let stored = keyring()
            .encrypt_value(&encrypts("ssn").deterministic(), "123-45-6789")
            .expect("encryption should succeed");
        assert!(stored.blind_index.is_some());
    }

    #[test]
    fn encrypted_attribute_metadata_preserves_field_order_and_flags() {
        assert_eq!(
            UserRecord::encrypted_attributes(),
            &[encrypts("email"), encrypts("ssn").deterministic()]
        );
    }

    #[test]
    fn non_deterministic_fields_do_not_emit_blind_indexes() {
        let stored = keyring()
            .encrypt_value(&encrypts("email"), "alice@example.com")
            .expect("encryption should succeed");

        assert_eq!(stored.blind_index, None);
    }

    #[test]
    fn deterministic_encryptions_keep_blind_index_stable_but_change_ciphertext() {
        let config = encrypts("ssn").deterministic();
        let first = keyring()
            .encrypt_value(&config, "123-45-6789")
            .expect("encryption should succeed");
        let second = keyring()
            .encrypt_value(&config, "123-45-6789")
            .expect("encryption should succeed");

        assert_eq!(first.blind_index, second.blind_index);
        assert_ne!(first.ciphertext, second.ciphertext);
    }

    #[test]
    fn equality_tokens_order_active_key_before_rotated_keys() {
        let plaintext = "123-45-6789";
        let active = keyring()
            .encrypt_value(&encrypts("ssn").deterministic(), plaintext)
            .expect("encryption should succeed");
        let old_ring =
            EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
        let rotated = old_ring
            .encrypt_value(&encrypts("ssn").deterministic(), plaintext)
            .expect("encryption should succeed");

        let tokens = keyring()
            .equality_tokens(plaintext)
            .expect("tokens should generate");

        assert_eq!(
            tokens,
            vec![
                active
                    .blind_index
                    .expect("deterministic field should emit a blind index"),
                rotated
                    .blind_index
                    .expect("deterministic field should emit a blind index"),
            ]
        );
    }

    #[test]
    fn encrypt_value_returns_missing_key_when_active_key_is_unconfigured() {
        let keyring =
            EncryptionKeyRing::new("missing", HashMap::from([("old".to_owned(), [2_u8; 32])]));

        assert_eq!(
            keyring.encrypt_value(&encrypts("email"), "alice@example.com"),
            Err(EncryptionError::MissingKey("missing".to_owned()))
        );
    }

    #[test]
    fn tampered_ciphertext_returns_decrypt_error() {
        let mut stored = keyring()
            .encrypt_value(&encrypts("email"), "alice@example.com")
            .expect("encryption should succeed");
        stored.ciphertext.push_str("tampered");

        assert!(matches!(
            keyring().decrypt_value(&stored),
            Err(EncryptionError::Decrypt(_))
        ));
    }

    #[test]
    fn ciphertext_is_non_empty_and_does_not_echo_plaintext() {
        let plaintext = "alice@example.com";
        let stored = keyring()
            .encrypt_value(&encrypts("email"), plaintext)
            .expect("encryption should succeed");

        assert!(!stored.ciphertext.is_empty());
        assert_ne!(stored.ciphertext, plaintext);
        assert_eq!(stored.key_id, "new");
    }

    #[test]
    fn equality_tokens_are_stable_for_same_plaintext() {
        let first = keyring()
            .equality_tokens("123-45-6789")
            .expect("tokens should generate");
        let second = keyring()
            .equality_tokens("123-45-6789")
            .expect("tokens should generate");
        assert_eq!(first, second);
    }

    #[test]
    fn equality_tokens_change_for_different_plaintext() {
        let first = keyring()
            .equality_tokens("alpha")
            .expect("tokens should generate");
        let second = keyring()
            .equality_tokens("beta")
            .expect("tokens should generate");
        assert_ne!(first, second);
    }

    #[test]
    fn rotated_keys_can_decrypt_old_ciphertext() {
        let old_ring =
            EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
        let stored = old_ring
            .encrypt_value(&encrypts("email"), "legacy@example.com")
            .expect("encryption should succeed");

        let plaintext = keyring()
            .decrypt_value(&stored)
            .expect("rotated keys should decrypt legacy values");
        assert_eq!(plaintext, "legacy@example.com");
    }

    #[test]
    fn missing_key_returns_error() {
        let stored = super::StoredEncryptedValue {
            ciphertext: "ciphertext".to_owned(),
            key_id: "missing".to_owned(),
            blind_index: None,
        };

        assert_eq!(
            keyring().decrypt_value(&stored),
            Err(EncryptionError::MissingKey("missing".to_owned()))
        );
    }

    #[test]
    fn trait_exposes_declared_encrypted_attributes() {
        assert_eq!(UserRecord::encrypted_attributes().len(), 2);
        assert!(UserRecord::encrypted_attributes()[1].deterministic);
    }
}