rustrails-record 0.1.2

ORM layer (ActiveRecord equivalent)
Documentation
use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};

use crate::base::{Record, RecordError};

/// Counter cache helpers for record types.
pub trait CounterCache: Record {
    /// Returns the available counter cache columns as `(column, association)` pairs.
    fn counter_cache_columns() -> &'static [(&'static str, &'static str)];

    /// Increments a counter cache column for the record identified by `id`.
    async fn increment_counter(
        column: &str,
        id: i64,
        db: &DatabaseConnection,
    ) -> Result<(), RecordError> {
        update_counter::<Self>(column, id, 1, db).await
    }

    /// Decrements a counter cache column for the record identified by `id`.
    async fn decrement_counter(
        column: &str,
        id: i64,
        db: &DatabaseConnection,
    ) -> Result<(), RecordError> {
        update_counter::<Self>(column, id, -1, db).await
    }
}

async fn update_counter<T: CounterCache>(
    column: &str,
    id: i64,
    delta: i64,
    db: &DatabaseConnection,
) -> Result<(), RecordError> {
    validate_counter_column::<T>(column)?;

    let sql = format!(
        "UPDATE {table} SET {column} = COALESCE({column}, 0) + ? WHERE {primary_key} = ?",
        table = T::table_name(),
        primary_key = T::primary_key_name(),
    );
    let result = db
        .execute_raw(Statement::from_sql_and_values(
            db.get_database_backend(),
            sql,
            [delta.into(), id.into()],
        ))
        .await?;

    if result.rows_affected() == 0 {
        return Err(RecordError::NotFound);
    }

    Ok(())
}

fn validate_counter_column<T: CounterCache>(column: &str) -> Result<(), RecordError> {
    if T::counter_cache_columns()
        .iter()
        .any(|(candidate, _)| *candidate == column)
    {
        Ok(())
    } else {
        Err(RecordError::Invalid(format!(
            "unknown counter cache column: {column}"
        )))
    }
}

#[cfg(test)]
mod tests {
    use sea_orm::{
        ActiveModelTrait, ActiveValue::NotSet, ActiveValue::Set, ConnectionTrait, Database,
        EntityTrait, Schema,
    };

    use super::CounterCache;
    use crate::base::{Record, RecordState};

    mod counter_record {
        use sea_orm::entity::prelude::*;

        #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
        #[sea_orm(table_name = "counter_records")]
        pub struct Model {
            #[sea_orm(primary_key)]
            pub id: i32,
            pub comments_count: i32,
            pub views_count: Option<i32>,
        }

        #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
        pub enum Relation {}

        impl ActiveModelBehavior for ActiveModel {}
    }

    const COUNTER_COLUMNS: [(&str, &str); 2] =
        [("comments_count", "comments"), ("views_count", "views")];

    #[derive(Clone, Debug, Default, PartialEq, Eq)]
    struct CounterRecord {
        id: Option<i64>,
        comments_count: i32,
        views_count: Option<i32>,
        state: RecordState,
    }

    impl Record for CounterRecord {
        type Entity = counter_record::Entity;

        fn table_name() -> &'static str {
            "counter_records"
        }

        fn id(&self) -> Option<i64> {
            self.id
        }

        fn record_state(&self) -> RecordState {
            self.state
        }

        fn set_record_state(&mut self, state: RecordState) {
            self.state = state;
        }

        fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
            Self {
                id: Some(i64::from(model.id)),
                comments_count: model.comments_count,
                views_count: model.views_count,
                state: RecordState::Persisted,
            }
        }

        fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
            counter_record::ActiveModel {
                id: match self.id.and_then(|value| i32::try_from(value).ok()) {
                    Some(value) => Set(value),
                    None => NotSet,
                },
                comments_count: Set(self.comments_count),
                views_count: Set(self.views_count),
            }
        }
    }

    impl CounterCache for CounterRecord {
        fn counter_cache_columns() -> &'static [(&'static str, &'static str)] {
            &COUNTER_COLUMNS
        }
    }

    async fn setup_db() -> sea_orm::DatabaseConnection {
        let db = Database::connect("sqlite::memory:")
            .await
            .expect("in-memory sqlite connection should succeed");
        let backend = db.get_database_backend();
        let schema = Schema::new(backend);
        db.execute(&schema.create_table_from_entity(counter_record::Entity))
            .await
            .expect("counter_records table should be created");
        counter_record::ActiveModel {
            comments_count: Set(2),
            ..Default::default()
        }
        .insert(&db)
        .await
        .expect("seed row should insert");
        db
    }

    #[tokio::test]
    async fn increment_counter_updates_column() {
        let db = setup_db().await;

        CounterRecord::increment_counter("comments_count", 1, &db)
            .await
            .expect("increment should succeed");

        let row = counter_record::Entity::find_by_id(1)
            .one(&db)
            .await
            .expect("query should succeed")
            .expect("row should exist");
        assert_eq!(row.comments_count, 3);
    }

    #[tokio::test]
    async fn decrement_counter_updates_column() {
        let db = setup_db().await;

        CounterRecord::decrement_counter("comments_count", 1, &db)
            .await
            .expect("decrement should succeed");

        let row = counter_record::Entity::find_by_id(1)
            .one(&db)
            .await
            .expect("query should succeed")
            .expect("row should exist");
        assert_eq!(row.comments_count, 1);
    }

    #[tokio::test]
    async fn unknown_counter_column_is_rejected() {
        let db = setup_db().await;

        let error = CounterRecord::increment_counter("likes_count", 1, &db)
            .await
            .expect_err("unknown counter should fail");

        assert!(
            matches!(error, crate::RecordError::Invalid(message) if message.contains("likes_count"))
        );
    }

    #[tokio::test]
    async fn increment_counter_returns_not_found_for_missing_id() {
        let db = setup_db().await;

        let error = CounterRecord::increment_counter("comments_count", 404, &db)
            .await
            .expect_err("missing row should fail");

        assert!(matches!(error, crate::RecordError::NotFound));
    }

    #[tokio::test]
    async fn decrement_counter_returns_not_found_for_missing_id() {
        let db = setup_db().await;

        let error = CounterRecord::decrement_counter("comments_count", 404, &db)
            .await
            .expect_err("missing row should fail");

        assert!(matches!(error, crate::RecordError::NotFound));
    }

    #[tokio::test]
    async fn repeated_increments_accumulate_on_existing_counter() {
        let db = setup_db().await;

        CounterRecord::increment_counter("comments_count", 1, &db)
            .await
            .expect("first increment should succeed");
        CounterRecord::increment_counter("comments_count", 1, &db)
            .await
            .expect("second increment should succeed");

        let row = counter_record::Entity::find_by_id(1)
            .one(&db)
            .await
            .expect("query should succeed")
            .expect("row should exist");
        assert_eq!(row.comments_count, 4);
    }

    #[tokio::test]
    async fn decrement_counter_can_cross_zero() {
        let db = setup_db().await;

        CounterRecord::decrement_counter("comments_count", 1, &db)
            .await
            .expect("first decrement should succeed");
        CounterRecord::decrement_counter("comments_count", 1, &db)
            .await
            .expect("second decrement should succeed");
        CounterRecord::decrement_counter("comments_count", 1, &db)
            .await
            .expect("third decrement should succeed");

        let row = counter_record::Entity::find_by_id(1)
            .one(&db)
            .await
            .expect("query should succeed")
            .expect("row should exist");
        assert_eq!(row.comments_count, -1);
    }

    #[tokio::test]
    async fn increment_counter_coalesces_null_counters_to_zero() {
        let db = setup_db().await;

        CounterRecord::increment_counter("views_count", 1, &db)
            .await
            .expect("increment should succeed for null counter");

        let row = counter_record::Entity::find_by_id(1)
            .one(&db)
            .await
            .expect("query should succeed")
            .expect("row should exist");
        assert_eq!(row.views_count, Some(1));
    }
}