use std::time::Duration;
use mire::{Aggregate, EventData, EventStore, EventStoreError, Snapshot};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
mod common;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
enum CounterEvent {
Incremented { by: i64 },
}
impl EventData for CounterEvent {
fn event_type(&self) -> &'static str {
match self {
CounterEvent::Incremented { .. } => "counter.incremented",
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
struct Counter {
total: i64,
}
impl Aggregate for Counter {
type Event = CounterEvent;
fn stream_category() -> &'static str {
"counter"
}
fn apply(&mut self, event: &CounterEvent) {
match event {
CounterEvent::Incremented { by } => self.total += by,
}
}
}
impl Snapshot for Counter {
const SNAPSHOT_VERSION: i32 = 1;
const SNAPSHOT_FREQUENCY: i64 = 2;
}
async fn maybe_store() -> Option<EventStore> {
common::store().await
}
async fn seed_counter(store: &EventStore, id: &str, increments: &[i64]) {
let mut counter = store.load_or_default::<Counter>(id).await.unwrap();
for by in increments {
counter.record(CounterEvent::Incremented { by: *by });
}
store.save(&mut counter).await.unwrap();
}
#[tokio::test]
async fn load_for_update_blocks_concurrent_lockers() {
let Some(store) = maybe_store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let id = Uuid::new_v4().to_string();
seed_counter(&store, &id, &[3, 7]).await;
let mut a = store.begin_transaction().await.unwrap();
let root_a = a.load_for_update::<Counter>(&id).await.unwrap();
assert!(root_a.is_some(), "A should observe the seeded stream");
let id_for_b = id.clone();
let store_for_b = store.clone();
let racer = tokio::spawn(async move {
let mut b = store_for_b.begin_transaction().await.unwrap();
sqlx::query("SET LOCAL lock_timeout = '300ms'")
.execute(&mut **b.tx())
.await
.unwrap();
let result = b.load_for_update::<Counter>(&id_for_b).await;
let _ = b.rollback().await;
result
});
let outcome = racer.await.unwrap();
let err = outcome
.err()
.unwrap_or_else(|| panic!("expected B to be blocked by A's lock; instead succeeded"));
let EventStoreError::Database(sqlx_err) = err else {
panic!("expected EventStoreError::Database for blocked B, got: {err}");
};
let sqlstate = sqlx_err
.as_database_error()
.and_then(|d| d.code())
.map(|c| c.into_owned());
assert_eq!(
sqlstate.as_deref(),
Some("55P03"),
"expected SQLSTATE 55P03 (lock_not_available); got: {sqlx_err}",
);
a.rollback().await.unwrap();
}
#[tokio::test]
async fn load_for_update_on_missing_stream_takes_no_lock() {
let Some(store) = maybe_store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let id = Uuid::new_v4().to_string();
let stream_id = format!("counter-{id}");
let mut scope = store.begin_transaction().await.unwrap();
let result = scope.load_for_update::<Counter>(&id).await.unwrap();
assert!(
result.is_none(),
"load_for_update on missing stream must return None"
);
let pid: i32 = sqlx::query_scalar("SELECT pg_backend_pid()")
.fetch_one(&mut **scope.tx())
.await
.unwrap();
let tuple_locks: i64 = sqlx::query_scalar(
"SELECT COUNT(*)::bigint FROM pg_locks
WHERE pid = $1 AND locktype = 'tuple'",
)
.bind(pid)
.fetch_one(store.pool())
.await
.unwrap();
assert_eq!(
tuple_locks, 0,
"load_for_update on a missing stream held {tuple_locks} tuple lock(s) for pid {pid}",
);
let store_clone = store.clone();
let stream_id_clone = stream_id.clone();
let inserter = tokio::time::timeout(Duration::from_secs(1), async move {
sqlx::query(
"INSERT INTO es_streams (stream_id, stream_category, stream_version)
VALUES ($1, $2, 0)",
)
.bind(&stream_id_clone)
.bind("counter")
.execute(store_clone.pool())
.await
})
.await;
inserter
.expect("inserter blocked — load_for_update is over-locking")
.expect("inserter SQL failed");
scope.rollback().await.unwrap();
}
#[tokio::test]
async fn load_and_load_for_update_agree_without_contention() {
let Some(store) = maybe_store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let id = Uuid::new_v4().to_string();
seed_counter(&store, &id, &[4, 6, 1]).await;
let expected_total: i64 = 11;
let expected_version: i64 = 3;
let mut scope_a = store.begin_transaction().await.unwrap();
let via_load = scope_a
.load::<Counter>(&id)
.await
.unwrap()
.expect("scope_a should see the seeded stream");
scope_a.rollback().await.unwrap();
let mut scope_b = store.begin_transaction().await.unwrap();
let via_locked = scope_b
.load_for_update::<Counter>(&id)
.await
.unwrap()
.expect("scope_b should see the seeded stream");
scope_b.rollback().await.unwrap();
assert_eq!(via_load.state.total, expected_total);
assert_eq!(via_load.version, expected_version);
assert_eq!(via_locked.state.total, expected_total);
assert_eq!(via_locked.version, expected_version);
}
#[tokio::test]
async fn load_and_load_for_update_return_none_on_missing_stream() {
let Some(store) = maybe_store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let id = Uuid::new_v4().to_string();
let mut scope = store.begin_transaction().await.unwrap();
assert!(scope.load::<Counter>(&id).await.unwrap().is_none());
assert!(
scope
.load_for_update::<Counter>(&id)
.await
.unwrap()
.is_none()
);
scope.rollback().await.unwrap();
}
#[tokio::test]
async fn rollback_releases_load_for_update_lock() {
let Some(store) = maybe_store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let id = Uuid::new_v4().to_string();
seed_counter(&store, &id, &[5]).await;
let mut a = store.begin_transaction().await.unwrap();
let _ = a.load_for_update::<Counter>(&id).await.unwrap();
a.rollback().await.unwrap();
let mut b = store.begin_transaction().await.unwrap();
sqlx::query("SET LOCAL lock_timeout = '300ms'")
.execute(&mut **b.tx())
.await
.unwrap();
let root_b = b
.load_for_update::<Counter>(&id)
.await
.expect("B's lock should succeed after A's rollback")
.expect("stream should exist");
assert_eq!(root_b.state.total, 5);
b.rollback().await.unwrap();
}
#[tokio::test]
async fn load_snapshotted_inside_scope_matches_off_scope() {
let Some(store) = maybe_store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let id = Uuid::new_v4().to_string();
let mut counter = store.load_or_default::<Counter>(&id).await.unwrap();
counter.record(CounterEvent::Incremented { by: 10 });
counter.record(CounterEvent::Incremented { by: 20 });
store.save_snapshotting(&mut counter).await.unwrap();
counter.record(CounterEvent::Incremented { by: 5 });
store.save_snapshotting(&mut counter).await.unwrap();
let off_scope = store
.load_snapshotted::<Counter>(&id)
.await
.unwrap()
.expect("off-scope snapshot load");
let mut scope = store.begin_transaction().await.unwrap();
let in_scope = scope
.load_snapshotted::<Counter>(&id)
.await
.unwrap()
.expect("in-scope snapshot load");
scope.rollback().await.unwrap();
assert_eq!(off_scope.state.total, in_scope.state.total);
assert_eq!(off_scope.version, in_scope.version);
assert_eq!(off_scope.stream_id, in_scope.stream_id);
assert_eq!(off_scope.state.total, 35);
assert_eq!(off_scope.version, 3);
}