use std::time::Duration;
use mire::lease::{AcquireOutcome, checkpoint, try_acquire};
use mire::{
Aggregate, EventData, EventStore, HandledEvent, ProjectionRunner, TransactionalEventHandler,
};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
mod common;
async fn store() -> Option<EventStore> {
common::store().await
}
fn fence_of(o: AcquireOutcome) -> i64 {
match o {
AcquireOutcome::Acquired { fence_token } => fence_token,
AcquireOutcome::Held => panic!("expected acquire"),
}
}
#[tokio::test]
async fn fenced_transactional_checkpoint_rolls_back_writes() {
let Some(store) = store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let pool = store.pool();
let sub = format!("tx-fence-{}", Uuid::new_v4());
let rm = format!("rm_{}", Uuid::new_v4().simple());
sqlx::raw_sql(sqlx::AssertSqlSafe(format!(
"CREATE TABLE IF NOT EXISTS {rm} (k TEXT PRIMARY KEY, v BIGINT NOT NULL)"
)))
.execute(pool)
.await
.unwrap();
let stale = fence_of(
try_acquire(pool, &sub, "A", Duration::from_millis(100))
.await
.unwrap(),
);
tokio::time::sleep(Duration::from_millis(150)).await;
let live = fence_of(
try_acquire(pool, &sub, "B", Duration::from_secs(30))
.await
.unwrap(),
);
assert!(live > stale);
let mut tx = pool.begin().await.unwrap();
sqlx::query(sqlx::AssertSqlSafe(format!(
"INSERT INTO {rm} (k, v) VALUES ('x', 7)"
)))
.execute(&mut *tx)
.await
.unwrap();
let ok = checkpoint(&mut *tx, &sub, stale, 10, 5).await.unwrap();
assert!(!ok, "stale fence must be rejected");
tx.rollback().await.unwrap();
let n: i64 = sqlx::query_scalar(sqlx::AssertSqlSafe(format!("SELECT count(*) FROM {rm}")))
.fetch_one(pool)
.await
.unwrap();
assert_eq!(n, 0, "a fenced worker's read-model write must roll back");
let mut tx = pool.begin().await.unwrap();
sqlx::query(sqlx::AssertSqlSafe(format!(
"INSERT INTO {rm} (k, v) VALUES ('x', 7)"
)))
.execute(&mut *tx)
.await
.unwrap();
let ok = checkpoint(&mut *tx, &sub, live, 10, 5).await.unwrap();
assert!(ok, "the live leader's checkpoint must apply");
tx.commit().await.unwrap();
let n: i64 = sqlx::query_scalar(sqlx::AssertSqlSafe(format!("SELECT count(*) FROM {rm}")))
.fetch_one(pool)
.await
.unwrap();
assert_eq!(n, 1, "the live leader's write commits with the checkpoint");
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
enum CounterEvent {
Incremented { by: i64 },
}
impl EventData for CounterEvent {
fn event_type(&self) -> &'static str {
"counter.incremented"
}
}
#[derive(Default)]
struct Counter;
impl Aggregate for Counter {
type Event = CounterEvent;
fn stream_category() -> &'static str {
"txcounter"
}
fn apply(&mut self, _: &CounterEvent) {}
}
struct TxTotals {
table: String,
}
impl TransactionalEventHandler for TxTotals {
type Aggregate = Counter;
async fn handle(
&self,
event: HandledEvent<CounterEvent>,
conn: &mut sqlx::PgConnection,
) -> anyhow::Result<()> {
let CounterEvent::Incremented { by } = event.event;
sqlx::query(sqlx::AssertSqlSafe(format!(
"INSERT INTO {0} (stream_id, total) VALUES ($1, $2) \
ON CONFLICT (stream_id) DO UPDATE SET total = {0}.total + EXCLUDED.total",
self.table
)))
.bind(event.stream_id())
.bind(by)
.execute(&mut *conn)
.await?;
Ok(())
}
}
#[tokio::test]
async fn subscribe_transactional_builds_read_model() {
let Some(store) = store().await else {
eprintln!("skipping: DATABASE_URL not set");
return;
};
let pool = store.pool().clone();
let suffix = Uuid::new_v4().simple().to_string();
let table = format!("txtotals_{suffix}");
sqlx::raw_sql(sqlx::AssertSqlSafe(format!(
"CREATE TABLE IF NOT EXISTS {table} (stream_id TEXT PRIMARY KEY, total BIGINT NOT NULL)"
)))
.execute(&pool)
.await
.unwrap();
let id = Uuid::new_v4().to_string();
let mut counter = store.load_or_default::<Counter>(&id).await.unwrap();
counter.record(CounterEvent::Incremented { by: 4 });
counter.record(CounterEvent::Incremented { by: 6 });
store.save(&mut counter).await.unwrap();
let runner = ProjectionRunner::builder(store.clone())
.poll_interval(Duration::from_millis(20))
.subscribe_transactional(
format!("txtotals-{suffix}"),
TxTotals {
table: table.clone(),
},
)
.build();
let token = CancellationToken::new();
let handle = {
let token = token.clone();
tokio::spawn(async move { runner.run(token).await })
};
let mut total = None;
for _ in 0..400 {
total = sqlx::query_scalar::<_, i64>(sqlx::AssertSqlSafe(format!(
"SELECT total FROM {table} WHERE stream_id = $1"
)))
.bind(&counter.stream_id)
.fetch_optional(&pool)
.await
.unwrap();
if total.is_some() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
token.cancel();
handle.await.unwrap().unwrap();
assert_eq!(total, Some(10));
}