rustrails-record 0.1.2

ORM layer (ActiveRecord equivalent)
Documentation
use std::collections::HashSet;

use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};

/// Metadata describing a generated secure token.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SecureTokenConfig {
    /// The token field name.
    pub field: String,
    /// The generated token length.
    pub length: usize,
}

impl SecureTokenConfig {
    /// Creates token metadata for `field`.
    #[must_use]
    pub fn new(field: &str) -> Self {
        Self {
            field: field.to_owned(),
            length: 24,
        }
    }

    /// Overrides the generated token length.
    #[must_use]
    pub fn length(mut self, length: usize) -> Self {
        self.length = length.max(1);
        self
    }
}

/// Errors returned by secure-token helpers.
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum SecureTokenError {
    /// The requested token field is unknown.
    #[error("unknown secure token field: {0}")]
    UnknownField(String),
    /// Token generation exhausted uniqueness attempts.
    #[error("could not generate a unique token for {0}")]
    ExhaustedAttempts(String),
    /// A provided token collides with an existing token.
    #[error("token for {0} must be unique")]
    DuplicateToken(String),
}

/// Declares a secure token for `field`.
#[must_use]
pub fn has_secure_token(field: &str) -> SecureTokenConfig {
    SecureTokenConfig::new(field)
}

/// Trait implemented by records that expose secure-token fields.
pub trait SecureToken {
    /// Returns secure-token metadata for the record type.
    fn secure_token_configurations() -> &'static [SecureTokenConfig];
    /// Reads the current token value for `field`.
    fn get_secure_token(&self, field: &str) -> Option<&str>;
    /// Stores a token value for `field`.
    fn set_secure_token(&mut self, field: &str, token: String);

    /// Generates a cryptographically random token string.
    fn generate_token() -> String {
        generate_token_with_length(32)
    }

    /// Ensures every declared secure-token field has a unique value.
    fn ensure_secure_tokens(
        &mut self,
        existing_tokens: &HashSet<String>,
    ) -> Result<(), SecureTokenError> {
        let mut reserved = existing_tokens.clone();
        for config in Self::secure_token_configurations() {
            match self.get_secure_token(&config.field) {
                Some(token) if reserved.contains(token) => {
                    return Err(SecureTokenError::DuplicateToken(config.field.clone()));
                }
                Some(token) => {
                    reserved.insert(token.to_owned());
                }
                None => {
                    let token = generate_unique_token(config.length, &reserved, &config.field)?;
                    reserved.insert(token.clone());
                    self.set_secure_token(&config.field, token);
                }
            }
        }
        Ok(())
    }

    /// Replaces the token stored in `field` and returns the new value.
    fn regenerate_token(
        &mut self,
        field: &str,
        existing_tokens: &HashSet<String>,
    ) -> Result<String, SecureTokenError> {
        let config = Self::secure_token_configurations()
            .iter()
            .find(|config| config.field == field)
            .ok_or_else(|| SecureTokenError::UnknownField(field.to_owned()))?;

        let mut reserved = existing_tokens.clone();
        if let Some(current) = self.get_secure_token(field) {
            reserved.remove(current);
        }

        let token = generate_unique_token(config.length, &reserved, field)?;
        self.set_secure_token(field, token.clone());
        Ok(token)
    }
}

fn generate_unique_token(
    length: usize,
    existing_tokens: &HashSet<String>,
    field: &str,
) -> Result<String, SecureTokenError> {
    for _ in 0..32 {
        let token = generate_token_with_length(length);
        if !existing_tokens.contains(&token) {
            return Ok(token);
        }
    }

    Err(SecureTokenError::ExhaustedAttempts(field.to_owned()))
}

fn generate_token_with_length(length: usize) -> String {
    let mut token = String::new();
    while token.len() < length {
        let bytes: [u8; 24] = rand::random();
        token.push_str(&URL_SAFE_NO_PAD.encode(bytes));
    }
    token.truncate(length);
    token
}

#[cfg(test)]
mod tests {
    use std::collections::HashSet;
    use std::sync::LazyLock;

    use super::{SecureToken, SecureTokenConfig, SecureTokenError, has_secure_token};

    #[derive(Debug, Default)]
    struct ApiKeyRecord {
        token: Option<String>,
        recovery_token: Option<String>,
    }

    static TOKEN_CONFIGS: LazyLock<Vec<SecureTokenConfig>> = LazyLock::new(|| {
        vec![
            has_secure_token("token"),
            has_secure_token("recovery_token").length(12),
        ]
    });

    impl SecureToken for ApiKeyRecord {
        fn secure_token_configurations() -> &'static [SecureTokenConfig] {
            TOKEN_CONFIGS.as_slice()
        }

        fn get_secure_token(&self, field: &str) -> Option<&str> {
            match field {
                "token" => self.token.as_deref(),
                "recovery_token" => self.recovery_token.as_deref(),
                _ => None,
            }
        }

        fn set_secure_token(&mut self, field: &str, token: String) {
            match field {
                "token" => self.token = Some(token),
                "recovery_token" => self.recovery_token = Some(token),
                _ => {}
            }
        }
    }

    #[test]
    fn ensure_secure_tokens_generates_missing_tokens() {
        let mut record = ApiKeyRecord::default();
        record
            .ensure_secure_tokens(&HashSet::new())
            .expect("tokens should be generated");

        assert!(record.token.is_some());
        assert!(record.recovery_token.is_some());
        assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
    }

    #[test]
    fn ensure_secure_tokens_preserves_unique_existing_tokens() {
        let mut record = ApiKeyRecord {
            token: Some("existing-token".to_owned()),
            recovery_token: None,
        };
        record
            .ensure_secure_tokens(&HashSet::new())
            .expect("existing token should be preserved");

        assert_eq!(record.token.as_deref(), Some("existing-token"));
        assert!(record.recovery_token.is_some());
    }

    #[test]
    fn ensure_secure_tokens_rejects_duplicate_existing_tokens() {
        let mut record = ApiKeyRecord {
            token: Some("taken".to_owned()),
            recovery_token: None,
        };
        let existing = HashSet::from(["taken".to_owned()]);

        assert_eq!(
            record.ensure_secure_tokens(&existing),
            Err(SecureTokenError::DuplicateToken("token".to_owned()))
        );
    }

    #[test]
    fn regenerate_token_replaces_existing_value() {
        let mut record = ApiKeyRecord {
            token: Some("current".to_owned()),
            recovery_token: None,
        };
        let token = record
            .regenerate_token("token", &HashSet::new())
            .expect("token should regenerate");

        assert_eq!(record.token.as_deref(), Some(token.as_str()));
        assert_ne!(token, "current");
    }

    #[test]
    fn regenerate_token_rejects_unknown_fields() {
        let mut record = ApiKeyRecord::default();
        assert_eq!(
            record.regenerate_token("missing", &HashSet::new()),
            Err(SecureTokenError::UnknownField("missing".to_owned()))
        );
    }

    #[test]
    fn metadata_builder_preserves_length_overrides() {
        let config = has_secure_token("auth_token").length(10);
        assert_eq!(config.field, "auth_token");
        assert_eq!(config.length, 10);
    }
    #[test]
    fn generate_token_returns_minimum_length() {
        let token = <ApiKeyRecord as SecureToken>::generate_token();

        assert!(token.len() >= 32);
    }

    #[test]
    fn ensure_secure_tokens_still_honors_length_overrides() {
        let mut record = ApiKeyRecord::default();
        record
            .ensure_secure_tokens(&HashSet::new())
            .expect("tokens should be generated");

        assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
    }
}