use async_trait::async_trait;
use klauthed_core::time::Timestamp;
use sqlx::AnyPool;
use sqlx::Row;
use crate::error::DataError;
use crate::outbox::{Outbox, OutboxEntry, OutboxId};
#[derive(Clone)]
pub struct SqlOutbox {
pool: AnyPool,
table: String,
}
impl SqlOutbox {
pub const DEFAULT_TABLE: &'static str = "outbox";
pub const CREATE_TABLE_SQL: &'static str = "\
CREATE TABLE IF NOT EXISTS outbox (
id TEXT NOT NULL PRIMARY KEY,
aggregate_type TEXT NOT NULL,
aggregate_id TEXT NOT NULL,
event_type TEXT NOT NULL,
sequence BIGINT NOT NULL,
payload TEXT NOT NULL,
occurred_at TEXT NOT NULL,
published INTEGER NOT NULL DEFAULT 0,
published_at TEXT
)";
pub fn new(pool: AnyPool) -> Self {
Self { pool, table: Self::DEFAULT_TABLE.to_owned() }
}
pub fn pool(&self) -> &AnyPool {
&self.pool
}
pub async fn ensure_schema(&self) -> Result<(), DataError> {
sqlx::query(Self::CREATE_TABLE_SQL).execute(&self.pool).await?;
Ok(())
}
fn select_prefix(&self) -> String {
format!(
"SELECT id, aggregate_type, aggregate_id, event_type, sequence, \
payload, occurred_at, published, published_at FROM {}",
self.table
)
}
#[cfg(feature = "postgres")]
pub async fn fetch_unpublished_skip_locked(
&self,
limit: usize,
) -> Result<Vec<OutboxEntry>, DataError> {
let sql = format!(
"{prefix} WHERE published = 0 ORDER BY sequence ASC LIMIT {limit} FOR UPDATE SKIP LOCKED",
prefix = self.select_prefix(),
limit = limit as i64,
);
let rows = sqlx::query(sqlx::AssertSqlSafe(&*sql)).fetch_all(&self.pool).await?;
rows.iter().map(row_to_entry).collect()
}
}
fn row_to_entry(row: &sqlx::any::AnyRow) -> Result<OutboxEntry, DataError> {
let id_str: String = row.try_get("id")?;
let id: OutboxId = id_str
.parse()
.map_err(|e| DataError::Outbox(format!("invalid outbox id '{id_str}': {e}")))?;
let payload_str: String = row.try_get("payload")?;
let payload: serde_json::Value = serde_json::from_str(&payload_str)
.map_err(|e| DataError::Outbox(format!("invalid stored payload json: {e}")))?;
let occurred_at_str: String = row.try_get("occurred_at")?;
let occurred_at = parse_timestamp(&occurred_at_str)?;
let published_at_str: Option<String> = row.try_get("published_at")?;
let published_at = match published_at_str {
Some(s) => Some(parse_timestamp(&s)?),
None => None,
};
let sequence: i64 = row.try_get("sequence")?;
let published: i64 = row.try_get("published")?;
Ok(OutboxEntry {
id,
aggregate_type: row.try_get("aggregate_type")?,
aggregate_id: row.try_get("aggregate_id")?,
event_type: row.try_get("event_type")?,
sequence: sequence as u64,
payload,
occurred_at,
published: published != 0,
published_at,
})
}
fn parse_timestamp(s: &str) -> Result<Timestamp, DataError> {
serde_json::from_value(serde_json::Value::String(s.to_owned()))
.map_err(|e| DataError::Outbox(format!("invalid stored timestamp '{s}': {e}")))
}
#[async_trait]
impl Outbox for SqlOutbox {
async fn enqueue(&self, entries: Vec<OutboxEntry>) -> Result<(), DataError> {
if entries.is_empty() {
return Ok(());
}
let sql = format!(
"INSERT INTO {} \
(id, aggregate_type, aggregate_id, event_type, sequence, payload, occurred_at, published, published_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.table
);
let mut tx = self.pool.begin().await?;
for entry in entries {
let payload = serde_json::to_string(&entry.payload).map_err(|e| {
DataError::Outbox(format!("failed to serialize outbox payload: {e}"))
})?;
let published_at = entry.published_at.map(|t| t.to_rfc3339());
sqlx::query(sqlx::AssertSqlSafe(&*sql))
.bind(entry.id.to_string())
.bind(entry.aggregate_type)
.bind(entry.aggregate_id)
.bind(entry.event_type)
.bind(entry.sequence as i64)
.bind(payload)
.bind(entry.occurred_at.to_rfc3339())
.bind(i64::from(entry.published))
.bind(published_at)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
async fn fetch_unpublished(&self, limit: usize) -> Result<Vec<OutboxEntry>, DataError> {
let sql = format!(
"{prefix} WHERE published = 0 ORDER BY sequence ASC LIMIT {limit}",
prefix = self.select_prefix(),
limit = limit as i64,
);
let rows = sqlx::query(sqlx::AssertSqlSafe(&*sql)).fetch_all(&self.pool).await?;
rows.iter().map(row_to_entry).collect()
}
async fn mark_published(&self, ids: &[OutboxId]) -> Result<(), DataError> {
if ids.is_empty() {
return Ok(());
}
let now = Timestamp::now().to_rfc3339();
let sql = format!(
"UPDATE {} SET published = 1, published_at = ? WHERE id = ? AND published = 0",
self.table
);
let mut tx = self.pool.begin().await?;
for id in ids {
sqlx::query(sqlx::AssertSqlSafe(&*sql))
.bind(now.clone())
.bind(id.to_string())
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
}
#[cfg(all(test, feature = "sqlite"))]
mod tests {
use super::*;
use klauthed_core::domain::{DomainEvent, EventEnvelope};
use klauthed_core::id::Id;
use serde::Serialize;
use std::borrow::Cow;
#[derive(Debug, Serialize)]
struct Opened {
owner: String,
}
impl DomainEvent for Opened {
fn event_type(&self) -> &'static str {
"account.opened"
}
}
fn entry(seq: u64) -> OutboxEntry {
let envelope = EventEnvelope {
event_id: Id::new(),
event_type: Cow::Borrowed("account.opened"),
aggregate_id: "acct-1".to_owned(),
aggregate_type: Cow::Borrowed("account"),
sequence: seq,
occurred_at: Timestamp::from_unix_millis(1_000 + seq as i64),
payload: Opened { owner: format!("owner-{seq}") },
};
OutboxEntry::from_envelope(&envelope).unwrap()
}
async fn memory_outbox() -> SqlOutbox {
sqlx::any::install_default_drivers();
let pool = sqlx::pool::PoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("connect in-memory sqlite");
let outbox = SqlOutbox::new(pool);
outbox.ensure_schema().await.expect("ensure schema");
outbox
}
#[tokio::test]
async fn ensure_schema_is_idempotent() {
let outbox = memory_outbox().await;
outbox.ensure_schema().await.unwrap();
assert!(outbox.fetch_unpublished(10).await.unwrap().is_empty());
}
#[tokio::test]
async fn enqueue_fetch_mark_round_trip_over_any_sqlite() {
let outbox = memory_outbox().await;
let e1 = entry(1);
let e2 = entry(2);
let (id1, id2) = (e1.id, e2.id);
outbox.enqueue(vec![e1.clone(), e2.clone()]).await.unwrap();
let unpublished = outbox.fetch_unpublished(10).await.unwrap();
assert_eq!(unpublished.len(), 2);
assert_eq!(unpublished[0], e1);
assert_eq!(unpublished[1].id, id2);
assert_eq!(unpublished[0].payload["owner"], "owner-1");
assert!(!unpublished[0].published);
outbox.mark_published(&[id1]).await.unwrap();
let remaining = outbox.fetch_unpublished(10).await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].id, id2);
outbox.mark_published(&[id2]).await.unwrap();
assert!(outbox.fetch_unpublished(10).await.unwrap().is_empty());
}
#[tokio::test]
async fn fetch_honors_limit_and_sequence_order() {
let outbox = memory_outbox().await;
let entries: Vec<_> = (1..=5).map(entry).collect();
outbox.enqueue(entries).await.unwrap();
let two = outbox.fetch_unpublished(2).await.unwrap();
assert_eq!(two.len(), 2);
assert_eq!(two[0].sequence, 1);
assert_eq!(two[1].sequence, 2);
assert_eq!(outbox.fetch_unpublished(100).await.unwrap().len(), 5);
}
#[tokio::test]
async fn marking_published_stores_published_at() {
let outbox = memory_outbox().await;
let e = entry(1);
let id = e.id;
outbox.enqueue(vec![e]).await.unwrap();
outbox.mark_published(&[id]).await.unwrap();
outbox.mark_published(&[id]).await.unwrap();
assert!(outbox.fetch_unpublished(10).await.unwrap().is_empty());
}
#[tokio::test]
async fn empty_batches_are_noops() {
let outbox = memory_outbox().await;
outbox.enqueue(vec![]).await.unwrap();
outbox.mark_published(&[]).await.unwrap();
assert!(outbox.fetch_unpublished(10).await.unwrap().is_empty());
}
}