use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
use crate::base::{Record, RecordError};
pub trait CounterCache: Record {
fn counter_cache_columns() -> &'static [(&'static str, &'static str)];
async fn increment_counter(
column: &str,
id: i64,
db: &DatabaseConnection,
) -> Result<(), RecordError> {
update_counter::<Self>(column, id, 1, db).await
}
async fn decrement_counter(
column: &str,
id: i64,
db: &DatabaseConnection,
) -> Result<(), RecordError> {
update_counter::<Self>(column, id, -1, db).await
}
}
async fn update_counter<T: CounterCache>(
column: &str,
id: i64,
delta: i64,
db: &DatabaseConnection,
) -> Result<(), RecordError> {
validate_counter_column::<T>(column)?;
let sql = format!(
"UPDATE {table} SET {column} = COALESCE({column}, 0) + ? WHERE {primary_key} = ?",
table = T::table_name(),
primary_key = T::primary_key_name(),
);
let result = db
.execute_raw(Statement::from_sql_and_values(
db.get_database_backend(),
sql,
[delta.into(), id.into()],
))
.await?;
if result.rows_affected() == 0 {
return Err(RecordError::NotFound);
}
Ok(())
}
fn validate_counter_column<T: CounterCache>(column: &str) -> Result<(), RecordError> {
if T::counter_cache_columns()
.iter()
.any(|(candidate, _)| *candidate == column)
{
Ok(())
} else {
Err(RecordError::Invalid(format!(
"unknown counter cache column: {column}"
)))
}
}
#[cfg(test)]
mod tests {
use sea_orm::{
ActiveModelTrait, ActiveValue::NotSet, ActiveValue::Set, ConnectionTrait, Database,
EntityTrait, Schema,
};
use super::CounterCache;
use crate::base::{Record, RecordState};
mod counter_record {
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "counter_records")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub comments_count: i32,
pub views_count: Option<i32>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
}
const COUNTER_COLUMNS: [(&str, &str); 2] =
[("comments_count", "comments"), ("views_count", "views")];
#[derive(Clone, Debug, Default, PartialEq, Eq)]
struct CounterRecord {
id: Option<i64>,
comments_count: i32,
views_count: Option<i32>,
state: RecordState,
}
impl Record for CounterRecord {
type Entity = counter_record::Entity;
fn table_name() -> &'static str {
"counter_records"
}
fn id(&self) -> Option<i64> {
self.id
}
fn record_state(&self) -> RecordState {
self.state
}
fn set_record_state(&mut self, state: RecordState) {
self.state = state;
}
fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
Self {
id: Some(i64::from(model.id)),
comments_count: model.comments_count,
views_count: model.views_count,
state: RecordState::Persisted,
}
}
fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
counter_record::ActiveModel {
id: match self.id.and_then(|value| i32::try_from(value).ok()) {
Some(value) => Set(value),
None => NotSet,
},
comments_count: Set(self.comments_count),
views_count: Set(self.views_count),
}
}
}
impl CounterCache for CounterRecord {
fn counter_cache_columns() -> &'static [(&'static str, &'static str)] {
&COUNTER_COLUMNS
}
}
async fn setup_db() -> sea_orm::DatabaseConnection {
let db = Database::connect("sqlite::memory:")
.await
.expect("in-memory sqlite connection should succeed");
let backend = db.get_database_backend();
let schema = Schema::new(backend);
db.execute(&schema.create_table_from_entity(counter_record::Entity))
.await
.expect("counter_records table should be created");
counter_record::ActiveModel {
comments_count: Set(2),
..Default::default()
}
.insert(&db)
.await
.expect("seed row should insert");
db
}
#[tokio::test]
async fn increment_counter_updates_column() {
let db = setup_db().await;
CounterRecord::increment_counter("comments_count", 1, &db)
.await
.expect("increment should succeed");
let row = counter_record::Entity::find_by_id(1)
.one(&db)
.await
.expect("query should succeed")
.expect("row should exist");
assert_eq!(row.comments_count, 3);
}
#[tokio::test]
async fn decrement_counter_updates_column() {
let db = setup_db().await;
CounterRecord::decrement_counter("comments_count", 1, &db)
.await
.expect("decrement should succeed");
let row = counter_record::Entity::find_by_id(1)
.one(&db)
.await
.expect("query should succeed")
.expect("row should exist");
assert_eq!(row.comments_count, 1);
}
#[tokio::test]
async fn unknown_counter_column_is_rejected() {
let db = setup_db().await;
let error = CounterRecord::increment_counter("likes_count", 1, &db)
.await
.expect_err("unknown counter should fail");
assert!(
matches!(error, crate::RecordError::Invalid(message) if message.contains("likes_count"))
);
}
#[tokio::test]
async fn increment_counter_returns_not_found_for_missing_id() {
let db = setup_db().await;
let error = CounterRecord::increment_counter("comments_count", 404, &db)
.await
.expect_err("missing row should fail");
assert!(matches!(error, crate::RecordError::NotFound));
}
#[tokio::test]
async fn decrement_counter_returns_not_found_for_missing_id() {
let db = setup_db().await;
let error = CounterRecord::decrement_counter("comments_count", 404, &db)
.await
.expect_err("missing row should fail");
assert!(matches!(error, crate::RecordError::NotFound));
}
#[tokio::test]
async fn repeated_increments_accumulate_on_existing_counter() {
let db = setup_db().await;
CounterRecord::increment_counter("comments_count", 1, &db)
.await
.expect("first increment should succeed");
CounterRecord::increment_counter("comments_count", 1, &db)
.await
.expect("second increment should succeed");
let row = counter_record::Entity::find_by_id(1)
.one(&db)
.await
.expect("query should succeed")
.expect("row should exist");
assert_eq!(row.comments_count, 4);
}
#[tokio::test]
async fn decrement_counter_can_cross_zero() {
let db = setup_db().await;
CounterRecord::decrement_counter("comments_count", 1, &db)
.await
.expect("first decrement should succeed");
CounterRecord::decrement_counter("comments_count", 1, &db)
.await
.expect("second decrement should succeed");
CounterRecord::decrement_counter("comments_count", 1, &db)
.await
.expect("third decrement should succeed");
let row = counter_record::Entity::find_by_id(1)
.one(&db)
.await
.expect("query should succeed")
.expect("row should exist");
assert_eq!(row.comments_count, -1);
}
#[tokio::test]
async fn increment_counter_coalesces_null_counters_to_zero() {
let db = setup_db().await;
CounterRecord::increment_counter("views_count", 1, &db)
.await
.expect("increment should succeed for null counter");
let row = counter_record::Entity::find_by_id(1)
.one(&db)
.await
.expect("query should succeed")
.expect("row should exist");
assert_eq!(row.views_count, Some(1));
}
}