mire 0.2.0

A small, generic PostgreSQL event-sourcing library: append-only event streams, aggregates with optimistic concurrency, and subscription-based projections (requires tokio + sqlx)
Documentation
//! C2/LEASE-1: the transactional handler API gives effectively-exactly-once
//! read models. A handler's writes share the fenced checkpoint's transaction,
//! so a superseded (fenced) worker's writes roll back instead of corrupting
//! the read model.
//!
//! Gated on `DATABASE_URL`.

use std::time::Duration;

use mire::lease::{AcquireOutcome, checkpoint, try_acquire};
use mire::{
    Aggregate, EventData, EventStore, HandledEvent, ProjectionRunner, TransactionalEventHandler,
};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;

mod common;

async fn store() -> Option<EventStore> {
    common::store().await
}

fn fence_of(o: AcquireOutcome) -> i64 {
    match o {
        AcquireOutcome::Acquired { fence_token } => fence_token,
        AcquireOutcome::Held => panic!("expected acquire"),
    }
}

/// The core mechanism: a fenced checkpoint in the same transaction as a
/// read-model write rolls the write back; a valid one commits both. This is
/// exactly what the runner's transactional dispatch does.
#[tokio::test]
async fn fenced_transactional_checkpoint_rolls_back_writes() {
    let Some(store) = store().await else {
        eprintln!("skipping: DATABASE_URL not set");
        return;
    };
    let pool = store.pool();
    let sub = format!("tx-fence-{}", Uuid::new_v4());
    let rm = format!("rm_{}", Uuid::new_v4().simple());
    sqlx::raw_sql(sqlx::AssertSqlSafe(format!(
        "CREATE TABLE IF NOT EXISTS {rm} (k TEXT PRIMARY KEY, v BIGINT NOT NULL)"
    )))
    .execute(pool)
    .await
    .unwrap();

    // A acquires (fence 1); it then stalls. B steals (fence 2), now published
    // to es_subscriptions — so A's fence 1 is stale.
    let stale = fence_of(
        try_acquire(pool, &sub, "A", Duration::from_millis(100))
            .await
            .unwrap(),
    );
    tokio::time::sleep(Duration::from_millis(150)).await;
    let live = fence_of(
        try_acquire(pool, &sub, "B", Duration::from_secs(30))
            .await
            .unwrap(),
    );
    assert!(live > stale);

    // Stale worker A: write + fenced checkpoint in one tx → fenced → rollback.
    let mut tx = pool.begin().await.unwrap();
    sqlx::query(sqlx::AssertSqlSafe(format!(
        "INSERT INTO {rm} (k, v) VALUES ('x', 7)"
    )))
    .execute(&mut *tx)
    .await
    .unwrap();
    let ok = checkpoint(&mut *tx, &sub, stale, 10, 5).await.unwrap();
    assert!(!ok, "stale fence must be rejected");
    tx.rollback().await.unwrap();

    let n: i64 = sqlx::query_scalar(sqlx::AssertSqlSafe(format!("SELECT count(*) FROM {rm}")))
        .fetch_one(pool)
        .await
        .unwrap();
    assert_eq!(n, 0, "a fenced worker's read-model write must roll back");

    // Live worker B: write + valid checkpoint in one tx → commit.
    let mut tx = pool.begin().await.unwrap();
    sqlx::query(sqlx::AssertSqlSafe(format!(
        "INSERT INTO {rm} (k, v) VALUES ('x', 7)"
    )))
    .execute(&mut *tx)
    .await
    .unwrap();
    let ok = checkpoint(&mut *tx, &sub, live, 10, 5).await.unwrap();
    assert!(ok, "the live leader's checkpoint must apply");
    tx.commit().await.unwrap();

    let n: i64 = sqlx::query_scalar(sqlx::AssertSqlSafe(format!("SELECT count(*) FROM {rm}")))
        .fetch_one(pool)
        .await
        .unwrap();
    assert_eq!(n, 1, "the live leader's write commits with the checkpoint");
}

// --- runner happy-path: subscribe_transactional builds a read model --------

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
enum CounterEvent {
    Incremented { by: i64 },
}
impl EventData for CounterEvent {
    fn event_type(&self) -> &'static str {
        "counter.incremented"
    }
}

#[derive(Default)]
struct Counter;
impl Aggregate for Counter {
    type Event = CounterEvent;
    fn stream_category() -> &'static str {
        "txcounter"
    }
    fn apply(&mut self, _: &CounterEvent) {}
}

struct TxTotals {
    table: String,
}
impl TransactionalEventHandler for TxTotals {
    type Aggregate = Counter;
    async fn handle(
        &self,
        event: HandledEvent<CounterEvent>,
        conn: &mut sqlx::PgConnection,
    ) -> anyhow::Result<()> {
        let CounterEvent::Incremented { by } = event.event;
        sqlx::query(sqlx::AssertSqlSafe(format!(
            "INSERT INTO {0} (stream_id, total) VALUES ($1, $2) \
             ON CONFLICT (stream_id) DO UPDATE SET total = {0}.total + EXCLUDED.total",
            self.table
        )))
        .bind(event.stream_id())
        .bind(by)
        .execute(&mut *conn)
        .await?;
        Ok(())
    }
}

#[tokio::test]
async fn subscribe_transactional_builds_read_model() {
    let Some(store) = store().await else {
        eprintln!("skipping: DATABASE_URL not set");
        return;
    };
    let pool = store.pool().clone();
    let suffix = Uuid::new_v4().simple().to_string();
    let table = format!("txtotals_{suffix}");
    sqlx::raw_sql(sqlx::AssertSqlSafe(format!(
        "CREATE TABLE IF NOT EXISTS {table} (stream_id TEXT PRIMARY KEY, total BIGINT NOT NULL)"
    )))
    .execute(&pool)
    .await
    .unwrap();

    let id = Uuid::new_v4().to_string();
    let mut counter = store.load_or_default::<Counter>(&id).await.unwrap();
    counter.record(CounterEvent::Incremented { by: 4 });
    counter.record(CounterEvent::Incremented { by: 6 });
    store.save(&mut counter).await.unwrap();

    let runner = ProjectionRunner::builder(store.clone())
        .poll_interval(Duration::from_millis(20))
        .subscribe_transactional(
            format!("txtotals-{suffix}"),
            TxTotals {
                table: table.clone(),
            },
        )
        .build();

    let token = CancellationToken::new();
    let handle = {
        let token = token.clone();
        tokio::spawn(async move { runner.run(token).await })
    };

    let mut total = None;
    for _ in 0..400 {
        total = sqlx::query_scalar::<_, i64>(sqlx::AssertSqlSafe(format!(
            "SELECT total FROM {table} WHERE stream_id = $1"
        )))
        .bind(&counter.stream_id)
        .fetch_optional(&pool)
        .await
        .unwrap();
        if total.is_some() {
            break;
        }
        tokio::time::sleep(Duration::from_millis(20)).await;
    }
    token.cancel();
    handle.await.unwrap().unwrap();

    assert_eq!(total, Some(10));
}