use chrono::{DateTime, Utc};
use serde::{de::Error as _, Deserialize, Serialize};
use sqlx::{postgres::PgListener, PgPool};
use tokio::{
sync::{
broadcast::{self, error::RecvError},
RwLock,
},
task,
};
use tracing::instrument;
#[cfg(feature = "otel")]
use tracing_opentelemetry::OpenTelemetrySpanExt;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use crate::{
balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
};
pub struct EventSubscriberOpts {
pub close_on_lag: bool,
pub buffer: usize,
pub after_id: Option<SqlxLedgerEventId>,
pub batch_size: i64,
}
impl Default for EventSubscriberOpts {
fn default() -> Self {
Self {
close_on_lag: false,
buffer: 100,
after_id: None,
batch_size: 1000,
}
}
}
#[derive(Debug, Clone)]
pub struct EventSubscriber {
buffer: usize,
closed: Arc<AtomicBool>,
#[allow(clippy::type_complexity)]
balance_receivers:
Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
}
impl EventSubscriber {
pub(crate) async fn connect(
pool: &PgPool,
EventSubscriberOpts {
close_on_lag,
buffer,
after_id: start_id,
batch_size,
}: EventSubscriberOpts,
) -> Result<Self, SqlxLedgerError> {
let closed = Arc::new(AtomicBool::new(false));
let mut incoming = subscribe(
pool.clone(),
Arc::clone(&closed),
buffer,
start_id,
batch_size,
)
.await?;
let all = Arc::new(incoming.resubscribe());
let balance_receivers = Arc::new(RwLock::new(HashMap::<
(JournalId, AccountId),
broadcast::Sender<SqlxLedgerEvent>,
>::new()));
let journal_receivers = Arc::new(RwLock::new(HashMap::<
JournalId,
broadcast::Sender<SqlxLedgerEvent>,
>::new()));
let inner_balance_receivers = Arc::clone(&balance_receivers);
let inner_journal_receivers = Arc::clone(&journal_receivers);
let inner_closed = Arc::clone(&closed);
tokio::spawn(async move {
loop {
match incoming.recv().await {
Ok(event) => {
let journal_id = event.journal_id();
if let Some(journal_receivers) =
inner_journal_receivers.read().await.get(&journal_id)
{
let _ = journal_receivers.send(event.clone());
}
if let Some(account_id) = event.account_id() {
let receivers = inner_balance_receivers.read().await;
if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
let _ = receiver.send(event);
}
}
}
Err(RecvError::Lagged(_)) => {
if close_on_lag {
inner_closed.store(true, Ordering::SeqCst);
inner_balance_receivers.write().await.clear();
inner_journal_receivers.write().await.clear();
}
}
Err(RecvError::Closed) => {
tracing::warn!("Event subscriber closed");
inner_closed.store(true, Ordering::SeqCst);
inner_balance_receivers.write().await.clear();
inner_journal_receivers.write().await.clear();
break;
}
}
}
});
Ok(Self {
buffer,
closed,
balance_receivers,
journal_receivers,
all,
})
}
pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
let recv = self.all.resubscribe();
if self.closed.load(Ordering::SeqCst) {
return Err(SqlxLedgerError::EventSubscriberClosed);
}
Ok(recv)
}
pub async fn account_balance(
&self,
journal_id: JournalId,
account_id: AccountId,
) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
let mut listeners = self.balance_receivers.write().await;
let mut ret = None;
let sender = listeners
.entry((journal_id, account_id))
.or_insert_with(|| {
let (sender, recv) = broadcast::channel(self.buffer);
ret = Some(recv);
sender
});
let ret = ret.unwrap_or_else(|| sender.subscribe());
if self.closed.load(Ordering::SeqCst) {
listeners.remove(&(journal_id, account_id));
return Err(SqlxLedgerError::EventSubscriberClosed);
}
Ok(ret)
}
pub async fn journal(
&self,
journal_id: JournalId,
) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
let mut listeners = self.journal_receivers.write().await;
let mut ret = None;
let sender = listeners.entry(journal_id).or_insert_with(|| {
let (sender, recv) = broadcast::channel(self.buffer);
ret = Some(recv);
sender
});
let ret = ret.unwrap_or_else(|| sender.subscribe());
if self.closed.load(Ordering::SeqCst) {
listeners.remove(&journal_id);
return Err(SqlxLedgerError::EventSubscriberClosed);
}
Ok(ret)
}
}
#[derive(
sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct SqlxLedgerEventId(i64);
impl SqlxLedgerEventId {
pub const BEGIN: Self = Self(0);
}
impl From<i64> for SqlxLedgerEventId {
fn from(value: i64) -> Self {
Self(value)
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(try_from = "EventRaw")]
pub struct SqlxLedgerEvent {
pub id: SqlxLedgerEventId,
pub data: SqlxLedgerEventData,
pub r#type: SqlxLedgerEventType,
pub recorded_at: DateTime<Utc>,
#[cfg(feature = "otel")]
pub otel_context: opentelemetry::Context,
}
impl SqlxLedgerEvent {
#[cfg(feature = "otel")]
fn record_otel_context(&mut self) {
self.otel_context = tracing::Span::current().context();
}
#[cfg(not(feature = "otel"))]
fn record_otel_context(&mut self) {}
}
impl SqlxLedgerEvent {
pub fn journal_id(&self) -> JournalId {
match &self.data {
SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
}
}
pub fn account_id(&self) -> Option<AccountId> {
match &self.data {
SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum SqlxLedgerEventData {
BalanceUpdated(BalanceDetails),
TransactionCreated(Transaction),
TransactionUpdated(Transaction),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SqlxLedgerEventType {
BalanceUpdated,
TransactionCreated,
TransactionUpdated,
}
pub(crate) async fn subscribe(
pool: PgPool,
closed: Arc<AtomicBool>,
buffer: usize,
after_id: Option<SqlxLedgerEventId>,
batch_size: i64,
) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
let mut listener = PgListener::connect_with(&pool).await?;
listener.listen("sqlx_ledger_events").await?;
let (snd, recv) = broadcast::channel(buffer);
let mut reload = after_id.is_some();
task::spawn(async move {
let mut num_errors: u32 = 0;
let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
loop {
if reload {
let batch_result = async {
loop {
let rows = sqlx::query!(
r#"SELECT json_build_object(
'id', id,
'type', type,
'data', data,
'recorded_at', recorded_at
) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id LIMIT $2"#,
last_id.0,
batch_size
)
.fetch_all(&pool)
.await?;
let is_last_batch = (rows.len() as i64) < batch_size;
for row in rows {
let event: Result<SqlxLedgerEvent, _> =
serde_json::from_value(row.payload);
if sqlx_ledger_notification_received(event, &snd, &mut last_id, true)
.is_err()
{
return Err::<(), SqlxLedgerError>(
SqlxLedgerError::EventSubscriberClosed,
);
}
}
if is_last_batch {
break;
}
}
Ok(())
}
.await;
match batch_result {
Ok(()) => {
num_errors = 0;
reload = false;
}
Err(SqlxLedgerError::EventSubscriberClosed) => {
closed.store(true, Ordering::SeqCst);
break;
}
Err(e) => {
num_errors += 1;
let delay = backoff_delay(num_errors);
tracing::error!(
"Error fetching events (attempt {}): {}. Retrying in {:?}",
num_errors,
e,
delay
);
tokio::time::sleep(delay).await;
continue;
}
}
}
if closed.load(Ordering::Relaxed) {
break;
}
loop {
match listener.recv().await {
Ok(notification) => {
let event: Result<SqlxLedgerEvent, _> =
serde_json::from_str(notification.payload());
if let Err(e) = &event {
if e.to_string().contains("data field missing") {
reload = true;
break;
}
}
match sqlx_ledger_notification_received(event, &snd, &mut last_id, reload) {
Ok(false) => {
reload = true;
break;
}
Ok(_) => num_errors = 0,
Err(_) => {
closed.store(true, Ordering::SeqCst);
break;
}
}
}
Err(e) => {
num_errors += 1;
let delay = backoff_delay(num_errors);
tracing::warn!(
"PgListener recv error (attempt {}): {}. Retrying in {:?}",
num_errors,
e,
delay
);
tokio::time::sleep(delay).await;
reload = true;
break;
}
}
}
}
let _ = listener.unlisten("sqlx_ledger_events").await;
});
Ok(recv)
}
fn backoff_delay(num_errors: u32) -> std::time::Duration {
std::time::Duration::from_secs(1u64 << num_errors.min(5))
}
#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
fn sqlx_ledger_notification_received(
event: Result<SqlxLedgerEvent, serde_json::Error>,
sender: &broadcast::Sender<SqlxLedgerEvent>,
last_id: &mut SqlxLedgerEventId,
ignore_gap: bool,
) -> Result<bool, SqlxLedgerError> {
let mut event = event?;
event.record_otel_context();
let id = event.id;
if id <= *last_id {
return Ok(true);
}
if !ignore_gap && last_id.0 + 1 != id.0 {
return Ok(false);
}
sender.send(event)?;
*last_id = id;
Ok(true)
}
#[derive(Deserialize)]
struct EventRaw {
id: SqlxLedgerEventId,
#[serde(default)]
data: Option<serde_json::Value>,
r#type: SqlxLedgerEventType,
recorded_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use rust_decimal::Decimal;
use tokio::sync::broadcast;
use crate::balance::BalanceDetails;
use crate::{CorrelationId, Currency, EntryId, TransactionId, TxTemplateId};
fn make_balance_event(id: i64) -> SqlxLedgerEvent {
let now = Utc::now();
let journal_id = JournalId::new();
let account_id = AccountId::new();
let entry_id = EntryId::new();
let currency: Currency = "USD".parse().unwrap();
SqlxLedgerEvent {
id: SqlxLedgerEventId(id),
data: SqlxLedgerEventData::BalanceUpdated(BalanceDetails {
journal_id,
account_id,
entry_id,
currency,
settled_dr_balance: Decimal::ZERO,
settled_cr_balance: Decimal::ZERO,
settled_entry_id: entry_id,
settled_modified_at: now,
pending_dr_balance: Decimal::ZERO,
pending_cr_balance: Decimal::ZERO,
pending_entry_id: entry_id,
pending_modified_at: now,
encumbered_dr_balance: Decimal::ZERO,
encumbered_cr_balance: Decimal::ZERO,
encumbered_entry_id: entry_id,
encumbered_modified_at: now,
version: 1,
modified_at: now,
created_at: now,
}),
r#type: SqlxLedgerEventType::BalanceUpdated,
recorded_at: now,
}
}
fn make_transaction_event(id: i64) -> SqlxLedgerEvent {
let now = Utc::now();
SqlxLedgerEvent {
id: SqlxLedgerEventId(id),
data: SqlxLedgerEventData::TransactionCreated(Transaction {
id: TransactionId::new(),
version: 1,
journal_id: JournalId::new(),
tx_template_id: TxTemplateId::new(),
effective: now.date_naive(),
correlation_id: CorrelationId::new(),
external_id: "test-ext".to_string(),
description: None,
metadata_json: None,
created_at: now,
modified_at: now,
}),
r#type: SqlxLedgerEventType::TransactionCreated,
recorded_at: now,
}
}
#[test]
fn notification_received_sends_event_and_updates_last_id() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(0);
let event = make_balance_event(1);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
assert_eq!(last_id, SqlxLedgerEventId(1));
let received = recv.try_recv().unwrap();
assert_eq!(received.id, SqlxLedgerEventId(1));
}
#[test]
fn notification_received_skips_duplicate_id() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(5);
let event = make_balance_event(5);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
assert_eq!(last_id, SqlxLedgerEventId(5));
assert!(recv.try_recv().is_err());
}
#[test]
fn notification_received_skips_older_id() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(10);
let event = make_balance_event(3);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
assert_eq!(last_id, SqlxLedgerEventId(10));
assert!(recv.try_recv().is_err());
}
#[test]
fn notification_received_detects_gap_when_ignore_gap_false() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(1);
let event = make_balance_event(5);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), false);
assert_eq!(last_id, SqlxLedgerEventId(1));
assert!(recv.try_recv().is_err());
}
#[test]
fn notification_received_ignores_gap_when_flag_set() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(1);
let event = make_balance_event(5);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, true);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
assert_eq!(last_id, SqlxLedgerEventId(5));
let received = recv.try_recv().unwrap();
assert_eq!(received.id, SqlxLedgerEventId(5));
}
#[test]
fn notification_received_propagates_deserialization_error() {
let (sender, _recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(0);
let deser_err: Result<SqlxLedgerEvent, _> = serde_json::from_str::<SqlxLedgerEvent>("{}");
assert!(deser_err.is_err());
let result = sqlx_ledger_notification_received(deser_err, &sender, &mut last_id, false);
assert!(result.is_err());
assert_eq!(last_id, SqlxLedgerEventId(0));
}
#[test]
fn notification_received_errors_when_no_receivers() {
let (sender, recv) = broadcast::channel::<SqlxLedgerEvent>(16);
drop(recv);
let mut last_id = SqlxLedgerEventId(0);
let event = make_balance_event(1);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_err());
}
#[test]
fn notification_received_sequential_ids() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(0);
for i in 1..=5 {
let event = make_balance_event(i);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
assert_eq!(last_id, SqlxLedgerEventId(i));
}
for i in 1..=5 {
let received = recv.try_recv().unwrap();
assert_eq!(received.id, SqlxLedgerEventId(i));
}
}
#[test]
fn notification_received_handles_transaction_event() {
let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
let mut last_id = SqlxLedgerEventId(0);
let event = make_transaction_event(1);
let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
assert_eq!(last_id, SqlxLedgerEventId(1));
let received = recv.try_recv().unwrap();
assert!(matches!(
received.r#type,
SqlxLedgerEventType::TransactionCreated
));
}
#[test]
fn notification_received_data_field_missing_error_string() {
let raw_json =
r#"{"id": 1, "type": "BalanceUpdated", "recorded_at": "2024-01-01T00:00:00Z"}"#;
let result: Result<SqlxLedgerEvent, _> = serde_json::from_str(raw_json);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("data field missing"),
"Expected 'data field missing' in error: {err_msg}"
);
}
#[test]
fn backoff_delay_caps_at_32_seconds() {
use std::time::Duration;
assert_eq!(backoff_delay(1), Duration::from_secs(2));
assert_eq!(backoff_delay(2), Duration::from_secs(4));
assert_eq!(backoff_delay(3), Duration::from_secs(8));
assert_eq!(backoff_delay(4), Duration::from_secs(16));
assert_eq!(backoff_delay(5), Duration::from_secs(32));
assert_eq!(backoff_delay(6), Duration::from_secs(32));
assert_eq!(backoff_delay(10), Duration::from_secs(32));
assert_eq!(backoff_delay(100), Duration::from_secs(32));
}
}
impl TryFrom<EventRaw> for SqlxLedgerEvent {
type Error = serde_json::Error;
fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
let data_value = value
.data
.ok_or_else(|| serde_json::Error::custom("data field missing"))?;
let data = match value.r#type {
SqlxLedgerEventType::BalanceUpdated => {
SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(data_value)?)
}
SqlxLedgerEventType::TransactionCreated => {
SqlxLedgerEventData::TransactionCreated(serde_json::from_value(data_value)?)
}
SqlxLedgerEventType::TransactionUpdated => {
SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(data_value)?)
}
};
Ok(SqlxLedgerEvent {
id: value.id,
data,
r#type: value.r#type,
recorded_at: value.recorded_at,
#[cfg(feature = "otel")]
otel_context: tracing::Span::current().context(),
})
}
}