rustrails-record 0.1.2

ORM layer (ActiveRecord equivalent)
Documentation
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, Iterable};
use sha2::Sha256;

use crate::{Record, RecordError, querying::AsyncQuerying};
use rustrails_support::{database, runtime};

type HmacSha256 = Hmac<Sha256>;

/// Adds tamper-proof signed identifiers to persisted records.
#[allow(private_bounds)]
pub trait SignedId: Record {
    /// Returns a signed, tamper-proof token encoding this record's ID.
    fn signed_id(&self) -> Result<String, RecordError> {
        let id = self.id().ok_or(RecordError::NotSaved)?;
        sign_id(id, &Self::signing_secret())
    }

    /// Finds a record by its signed ID token.
    async fn find_signed(token: &str, db: &DatabaseConnection) -> Result<Self, RecordError>
    where
        Self: AsyncQuerying,
        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
    {
        let id = verify_signed_id(token, &Self::signing_secret())?;
        Self::find(id, db).await
    }

    /// Finds a record by its signed ID token using the thread-local database connection.
    fn find_signed_sync(token: &str) -> Result<Self, RecordError>
    where
        Self: AsyncQuerying,
        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
    {
        database::with_db(|db| runtime::block_on(Self::find_signed(token, db)))
    }

    /// Returns the secret key used for signing.
    fn signing_secret() -> String {
        std::env::var("RUSTRAILS_SECRET_KEY_BASE")
            .unwrap_or_else(|_| "development-secret".to_owned())
    }
}

impl<T: Record> SignedId for T {}

fn sign_id(id: i64, secret: &str) -> Result<String, RecordError> {
    let id_bytes = id.to_string().into_bytes();
    let signature = sign_hmac(secret.as_bytes(), &id_bytes)?;
    let mut payload = Vec::with_capacity(id_bytes.len() + 1 + signature.len());
    payload.extend_from_slice(&id_bytes);
    payload.push(b'.');
    payload.extend_from_slice(&signature);
    Ok(URL_SAFE_NO_PAD.encode(payload))
}

fn verify_signed_id(token: &str, secret: &str) -> Result<i64, RecordError> {
    let decoded = URL_SAFE_NO_PAD
        .decode(token)
        .map_err(|_| invalid_signed_id_error())?;
    let separator = decoded
        .iter()
        .position(|byte| *byte == b'.')
        .ok_or_else(invalid_signed_id_error)?;
    let id_bytes = &decoded[..separator];
    let signature = decoded
        .get(separator + 1..)
        .filter(|bytes| !bytes.is_empty())
        .ok_or_else(invalid_signed_id_error)?;

    verify_hmac(secret.as_bytes(), id_bytes, signature)?;

    let id = std::str::from_utf8(id_bytes)
        .map_err(|_| invalid_signed_id_error())?
        .parse::<i64>()
        .map_err(|_| invalid_signed_id_error())?;
    Ok(id)
}

fn sign_hmac(secret: &[u8], payload: &[u8]) -> Result<Vec<u8>, RecordError> {
    let mut mac = HmacSha256::new_from_slice(secret)
        .map_err(|_| RecordError::Invalid("invalid signing secret".to_owned()))?;
    mac.update(payload);
    Ok(mac.finalize().into_bytes().to_vec())
}

fn verify_hmac(secret: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), RecordError> {
    let mut mac = HmacSha256::new_from_slice(secret)
        .map_err(|_| RecordError::Invalid("invalid signing secret".to_owned()))?;
    mac.update(payload);
    mac.verify_slice(signature)
        .map_err(|_| invalid_signed_id_error())
}

fn invalid_signed_id_error() -> RecordError {
    RecordError::Invalid("invalid signed id".to_owned())
}

#[cfg(test)]
mod tests {
    use super::SignedId;
    use crate::{
        Querying, RecordError,
        base::test_support::{TestUser, seed_users, test_user},
    };
    use rustrails_support::{database, runtime};
    use sea_orm::{ConnectionTrait, Schema};

    fn run_sync_test(seed: bool, test: impl FnOnce() + Send + 'static) {
        std::thread::spawn(move || {
            let _runtime = runtime::init_runtime();
            database::establish("sqlite::memory:")
                .expect("sqlite in-memory connection should succeed");
            runtime::block_on(async {
                let db = database::db();
                let schema = Schema::new(db.get_database_backend());
                db.execute(&schema.create_table_from_entity(test_user::Entity))
                    .await
                    .expect("test_users table should be created");
                if seed {
                    seed_users(&db).await;
                }
            });
            test();
        })
        .join()
        .expect("signed-id sync test should not panic");
    }

    #[tokio::test]
    async fn signed_id_generates_non_empty_token_for_persisted_record() {
        let db = crate::base::test_support::setup_db().await;
        let user = seed_users(&db).await.remove(0);

        let token = user.signed_id().expect("signed_id should succeed");

        assert!(!token.is_empty());
    }

    #[tokio::test]
    async fn find_signed_recovers_original_record() {
        let db = crate::base::test_support::setup_db().await;
        let original = seed_users(&db).await.remove(1);
        let token = original.signed_id().expect("signed_id should succeed");

        let found = TestUser::find_signed(&token, &db)
            .await
            .expect("find_signed should succeed");

        assert_eq!(found, original);
    }

    #[tokio::test]
    async fn find_signed_rejects_tampered_tokens() {
        let db = crate::base::test_support::setup_db().await;
        let user = seed_users(&db).await.remove(0);
        let mut token = user.signed_id().expect("signed_id should succeed");
        token.push('x');

        let error = TestUser::find_signed(&token, &db)
            .await
            .expect_err("tampered token should be rejected");

        assert!(matches!(error, RecordError::Invalid(_)));
    }

    #[test]
    fn find_signed_sync_recovers_original_record() {
        run_sync_test(true, || {
            let original = TestUser::find_sync(2).expect("seeded record should exist");
            let token = original.signed_id().expect("signed_id should succeed");

            let found =
                TestUser::find_signed_sync(&token).expect("find_signed_sync should succeed");

            assert_eq!(found, original);
        });
    }
}