1use chrono::{DateTime, Utc};
3use serde::{de::Error as _, Deserialize, Serialize};
4use sqlx::{postgres::PgListener, PgPool};
5use tokio::{
6 sync::{
7 broadcast::{self, error::RecvError},
8 RwLock,
9 },
10 task,
11};
12use tracing::instrument;
13#[cfg(feature = "otel")]
14use tracing_opentelemetry::OpenTelemetrySpanExt;
15
16use std::{
17 collections::HashMap,
18 sync::{
19 atomic::{AtomicBool, Ordering},
20 Arc,
21 },
22};
23
24use crate::{
25 balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
26};
27
28pub struct EventSubscriberOpts {
30 pub close_on_lag: bool,
31 pub buffer: usize,
32 pub after_id: Option<SqlxLedgerEventId>,
33 pub batch_size: i64,
34}
35impl Default for EventSubscriberOpts {
36 fn default() -> Self {
37 Self {
38 close_on_lag: false,
39 buffer: 100,
40 after_id: None,
41 batch_size: 1000,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct EventSubscriber {
49 buffer: usize,
50 closed: Arc<AtomicBool>,
51 #[allow(clippy::type_complexity)]
52 balance_receivers:
53 Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
54 journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
55 all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
56}
57
58impl EventSubscriber {
59 pub(crate) async fn connect(
60 pool: &PgPool,
61 EventSubscriberOpts {
62 close_on_lag,
63 buffer,
64 after_id: start_id,
65 batch_size,
66 }: EventSubscriberOpts,
67 ) -> Result<Self, SqlxLedgerError> {
68 let closed = Arc::new(AtomicBool::new(false));
69 let mut incoming = subscribe(
70 pool.clone(),
71 Arc::clone(&closed),
72 buffer,
73 start_id,
74 batch_size,
75 )
76 .await?;
77 let all = Arc::new(incoming.resubscribe());
78 let balance_receivers = Arc::new(RwLock::new(HashMap::<
79 (JournalId, AccountId),
80 broadcast::Sender<SqlxLedgerEvent>,
81 >::new()));
82 let journal_receivers = Arc::new(RwLock::new(HashMap::<
83 JournalId,
84 broadcast::Sender<SqlxLedgerEvent>,
85 >::new()));
86 let inner_balance_receivers = Arc::clone(&balance_receivers);
87 let inner_journal_receivers = Arc::clone(&journal_receivers);
88 let inner_closed = Arc::clone(&closed);
89 tokio::spawn(async move {
90 loop {
91 match incoming.recv().await {
92 Ok(event) => {
93 let journal_id = event.journal_id();
94 if let Some(journal_receivers) =
95 inner_journal_receivers.read().await.get(&journal_id)
96 {
97 let _ = journal_receivers.send(event.clone());
98 }
99 if let Some(account_id) = event.account_id() {
100 let receivers = inner_balance_receivers.read().await;
101 if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
102 let _ = receiver.send(event);
103 }
104 }
105 }
106 Err(RecvError::Lagged(_)) => {
107 if close_on_lag {
108 inner_closed.store(true, Ordering::SeqCst);
109 inner_balance_receivers.write().await.clear();
110 inner_journal_receivers.write().await.clear();
111 }
112 }
113 Err(RecvError::Closed) => {
114 tracing::warn!("Event subscriber closed");
115 inner_closed.store(true, Ordering::SeqCst);
116 inner_balance_receivers.write().await.clear();
117 inner_journal_receivers.write().await.clear();
118 break;
119 }
120 }
121 }
122 });
123 Ok(Self {
124 buffer,
125 closed,
126 balance_receivers,
127 journal_receivers,
128 all,
129 })
130 }
131
132 pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
133 let recv = self.all.resubscribe();
134 if self.closed.load(Ordering::SeqCst) {
135 return Err(SqlxLedgerError::EventSubscriberClosed);
136 }
137 Ok(recv)
138 }
139
140 pub async fn account_balance(
141 &self,
142 journal_id: JournalId,
143 account_id: AccountId,
144 ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
145 let mut listeners = self.balance_receivers.write().await;
146 let mut ret = None;
147 let sender = listeners
148 .entry((journal_id, account_id))
149 .or_insert_with(|| {
150 let (sender, recv) = broadcast::channel(self.buffer);
151 ret = Some(recv);
152 sender
153 });
154 let ret = ret.unwrap_or_else(|| sender.subscribe());
155 if self.closed.load(Ordering::SeqCst) {
156 listeners.remove(&(journal_id, account_id));
157 return Err(SqlxLedgerError::EventSubscriberClosed);
158 }
159 Ok(ret)
160 }
161
162 pub async fn journal(
163 &self,
164 journal_id: JournalId,
165 ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
166 let mut listeners = self.journal_receivers.write().await;
167 let mut ret = None;
168 let sender = listeners.entry(journal_id).or_insert_with(|| {
169 let (sender, recv) = broadcast::channel(self.buffer);
170 ret = Some(recv);
171 sender
172 });
173 let ret = ret.unwrap_or_else(|| sender.subscribe());
174 if self.closed.load(Ordering::SeqCst) {
175 listeners.remove(&journal_id);
176 return Err(SqlxLedgerError::EventSubscriberClosed);
177 }
178 Ok(ret)
179 }
180}
181
182#[derive(
183 sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
184)]
185#[serde(transparent)]
186#[sqlx(transparent)]
187pub struct SqlxLedgerEventId(i64);
188impl SqlxLedgerEventId {
189 pub const BEGIN: Self = Self(0);
190}
191
192impl From<i64> for SqlxLedgerEventId {
193 fn from(value: i64) -> Self {
194 Self(value)
195 }
196}
197
198#[derive(Debug, Clone, Deserialize)]
200#[serde(try_from = "EventRaw")]
201pub struct SqlxLedgerEvent {
202 pub id: SqlxLedgerEventId,
203 pub data: SqlxLedgerEventData,
204 pub r#type: SqlxLedgerEventType,
205 pub recorded_at: DateTime<Utc>,
206 #[cfg(feature = "otel")]
207 pub otel_context: opentelemetry::Context,
208}
209
210impl SqlxLedgerEvent {
211 #[cfg(feature = "otel")]
212 fn record_otel_context(&mut self) {
213 self.otel_context = tracing::Span::current().context();
214 }
215
216 #[cfg(not(feature = "otel"))]
217 fn record_otel_context(&mut self) {}
218}
219
220impl SqlxLedgerEvent {
221 pub fn journal_id(&self) -> JournalId {
222 match &self.data {
223 SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
224 SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
225 SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
226 }
227 }
228
229 pub fn account_id(&self) -> Option<AccountId> {
230 match &self.data {
231 SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
232 _ => None,
233 }
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239#[allow(clippy::large_enum_variant)]
240pub enum SqlxLedgerEventData {
241 BalanceUpdated(BalanceDetails),
242 TransactionCreated(Transaction),
243 TransactionUpdated(Transaction),
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248pub enum SqlxLedgerEventType {
249 BalanceUpdated,
250 TransactionCreated,
251 TransactionUpdated,
252}
253
254pub(crate) async fn subscribe(
255 pool: PgPool,
256 closed: Arc<AtomicBool>,
257 buffer: usize,
258 after_id: Option<SqlxLedgerEventId>,
259 batch_size: i64,
260) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
261 let mut listener = PgListener::connect_with(&pool).await?;
262 listener.listen("sqlx_ledger_events").await?;
263 let (snd, recv) = broadcast::channel(buffer);
264 let mut reload = after_id.is_some();
265 task::spawn(async move {
266 let mut num_errors: u32 = 0;
267 let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
268 loop {
269 if reload {
270 let batch_result = async {
273 loop {
274 let rows = sqlx::query!(
275 r#"SELECT json_build_object(
276 'id', id,
277 'type', type,
278 'data', data,
279 'recorded_at', recorded_at
280 ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id LIMIT $2"#,
281 last_id.0,
282 batch_size
283 )
284 .fetch_all(&pool)
285 .await?;
286
287 let is_last_batch = (rows.len() as i64) < batch_size;
288
289 for row in rows {
290 let event: Result<SqlxLedgerEvent, _> =
291 serde_json::from_value(row.payload);
292 if sqlx_ledger_notification_received(event, &snd, &mut last_id, true)
293 .is_err()
294 {
295 return Err::<(), SqlxLedgerError>(
296 SqlxLedgerError::EventSubscriberClosed,
297 );
298 }
299 }
300
301 if is_last_batch {
302 break;
303 }
304 }
305 Ok(())
306 }
307 .await;
308
309 match batch_result {
310 Ok(()) => {
311 num_errors = 0;
312 reload = false;
313 }
314 Err(SqlxLedgerError::EventSubscriberClosed) => {
315 closed.store(true, Ordering::SeqCst);
316 break;
317 }
318 Err(e) => {
319 num_errors += 1;
320 let delay = backoff_delay(num_errors);
321 tracing::error!(
322 "Error fetching events (attempt {}): {}. Retrying in {:?}",
323 num_errors,
324 e,
325 delay
326 );
327 tokio::time::sleep(delay).await;
328 continue;
329 }
330 }
331 }
332 if closed.load(Ordering::Relaxed) {
333 break;
334 }
335 loop {
336 match listener.recv().await {
337 Ok(notification) => {
338 let event: Result<SqlxLedgerEvent, _> =
339 serde_json::from_str(notification.payload());
340 if let Err(e) = &event {
341 if e.to_string().contains("data field missing") {
342 reload = true;
343 break;
344 }
345 }
346 match sqlx_ledger_notification_received(event, &snd, &mut last_id, reload) {
347 Ok(false) => {
348 reload = true;
349 break;
350 }
351 Ok(_) => num_errors = 0,
352 Err(_) => {
353 closed.store(true, Ordering::SeqCst);
354 break;
355 }
356 }
357 }
358 Err(e) => {
359 num_errors += 1;
364 let delay = backoff_delay(num_errors);
365 tracing::warn!(
366 "PgListener recv error (attempt {}): {}. Retrying in {:?}",
367 num_errors,
368 e,
369 delay
370 );
371 tokio::time::sleep(delay).await;
372 reload = true;
373 break;
374 }
375 }
376 }
377 }
378 let _ = listener.unlisten("sqlx_ledger_events").await;
379 });
380 Ok(recv)
381}
382
383fn backoff_delay(num_errors: u32) -> std::time::Duration {
384 std::time::Duration::from_secs(1u64 << num_errors.min(5))
385}
386
387#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
388fn sqlx_ledger_notification_received(
389 event: Result<SqlxLedgerEvent, serde_json::Error>,
390 sender: &broadcast::Sender<SqlxLedgerEvent>,
391 last_id: &mut SqlxLedgerEventId,
392 ignore_gap: bool,
393) -> Result<bool, SqlxLedgerError> {
394 let mut event = event?;
395 event.record_otel_context();
396 let id = event.id;
397 if id <= *last_id {
398 return Ok(true);
399 }
400 if !ignore_gap && last_id.0 + 1 != id.0 {
401 return Ok(false);
402 }
403 sender.send(event)?;
404 *last_id = id;
405 Ok(true)
406}
407
408#[derive(Deserialize)]
409struct EventRaw {
410 id: SqlxLedgerEventId,
411 #[serde(default)]
412 data: Option<serde_json::Value>,
413 r#type: SqlxLedgerEventType,
414 recorded_at: DateTime<Utc>,
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use chrono::Utc;
421 use rust_decimal::Decimal;
422 use tokio::sync::broadcast;
423
424 use crate::balance::BalanceDetails;
425 use crate::{CorrelationId, Currency, EntryId, TransactionId, TxTemplateId};
426
427 fn make_balance_event(id: i64) -> SqlxLedgerEvent {
428 let now = Utc::now();
429 let journal_id = JournalId::new();
430 let account_id = AccountId::new();
431 let entry_id = EntryId::new();
432 let currency: Currency = "USD".parse().unwrap();
433
434 SqlxLedgerEvent {
435 id: SqlxLedgerEventId(id),
436 data: SqlxLedgerEventData::BalanceUpdated(BalanceDetails {
437 journal_id,
438 account_id,
439 entry_id,
440 currency,
441 settled_dr_balance: Decimal::ZERO,
442 settled_cr_balance: Decimal::ZERO,
443 settled_entry_id: entry_id,
444 settled_modified_at: now,
445 pending_dr_balance: Decimal::ZERO,
446 pending_cr_balance: Decimal::ZERO,
447 pending_entry_id: entry_id,
448 pending_modified_at: now,
449 encumbered_dr_balance: Decimal::ZERO,
450 encumbered_cr_balance: Decimal::ZERO,
451 encumbered_entry_id: entry_id,
452 encumbered_modified_at: now,
453 version: 1,
454 modified_at: now,
455 created_at: now,
456 }),
457 r#type: SqlxLedgerEventType::BalanceUpdated,
458 recorded_at: now,
459 }
460 }
461
462 fn make_transaction_event(id: i64) -> SqlxLedgerEvent {
463 let now = Utc::now();
464 SqlxLedgerEvent {
465 id: SqlxLedgerEventId(id),
466 data: SqlxLedgerEventData::TransactionCreated(Transaction {
467 id: TransactionId::new(),
468 version: 1,
469 journal_id: JournalId::new(),
470 tx_template_id: TxTemplateId::new(),
471 effective: now.date_naive(),
472 correlation_id: CorrelationId::new(),
473 external_id: "test-ext".to_string(),
474 description: None,
475 metadata_json: None,
476 created_at: now,
477 modified_at: now,
478 }),
479 r#type: SqlxLedgerEventType::TransactionCreated,
480 recorded_at: now,
481 }
482 }
483
484 #[test]
485 fn notification_received_sends_event_and_updates_last_id() {
486 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
487 let mut last_id = SqlxLedgerEventId(0);
488 let event = make_balance_event(1);
489
490 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
491 assert!(result.is_ok());
492 assert_eq!(result.unwrap(), true);
493 assert_eq!(last_id, SqlxLedgerEventId(1));
494
495 let received = recv.try_recv().unwrap();
496 assert_eq!(received.id, SqlxLedgerEventId(1));
497 }
498
499 #[test]
500 fn notification_received_skips_duplicate_id() {
501 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
502 let mut last_id = SqlxLedgerEventId(5);
503
504 let event = make_balance_event(5);
505 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
506 assert!(result.is_ok());
507 assert_eq!(result.unwrap(), true);
508 assert_eq!(last_id, SqlxLedgerEventId(5));
509
510 assert!(recv.try_recv().is_err());
511 }
512
513 #[test]
514 fn notification_received_skips_older_id() {
515 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
516 let mut last_id = SqlxLedgerEventId(10);
517
518 let event = make_balance_event(3);
519 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
520 assert!(result.is_ok());
521 assert_eq!(result.unwrap(), true);
522 assert_eq!(last_id, SqlxLedgerEventId(10));
523
524 assert!(recv.try_recv().is_err());
525 }
526
527 #[test]
528 fn notification_received_detects_gap_when_ignore_gap_false() {
529 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
530 let mut last_id = SqlxLedgerEventId(1);
531
532 let event = make_balance_event(5);
533 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
534 assert!(result.is_ok());
535 assert_eq!(result.unwrap(), false);
536 assert_eq!(last_id, SqlxLedgerEventId(1));
537
538 assert!(recv.try_recv().is_err());
539 }
540
541 #[test]
542 fn notification_received_ignores_gap_when_flag_set() {
543 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
544 let mut last_id = SqlxLedgerEventId(1);
545
546 let event = make_balance_event(5);
547 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, true);
548 assert!(result.is_ok());
549 assert_eq!(result.unwrap(), true);
550 assert_eq!(last_id, SqlxLedgerEventId(5));
551
552 let received = recv.try_recv().unwrap();
553 assert_eq!(received.id, SqlxLedgerEventId(5));
554 }
555
556 #[test]
557 fn notification_received_propagates_deserialization_error() {
558 let (sender, _recv) = broadcast::channel::<SqlxLedgerEvent>(16);
559 let mut last_id = SqlxLedgerEventId(0);
560
561 let deser_err: Result<SqlxLedgerEvent, _> = serde_json::from_str::<SqlxLedgerEvent>("{}");
562 assert!(deser_err.is_err());
563
564 let result = sqlx_ledger_notification_received(deser_err, &sender, &mut last_id, false);
565 assert!(result.is_err());
566 assert_eq!(last_id, SqlxLedgerEventId(0));
567 }
568
569 #[test]
570 fn notification_received_errors_when_no_receivers() {
571 let (sender, recv) = broadcast::channel::<SqlxLedgerEvent>(16);
573 drop(recv);
574 let mut last_id = SqlxLedgerEventId(0);
575
576 let event = make_balance_event(1);
577 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
578 assert!(result.is_err());
579 }
580
581 #[test]
582 fn notification_received_sequential_ids() {
583 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
584 let mut last_id = SqlxLedgerEventId(0);
585
586 for i in 1..=5 {
587 let event = make_balance_event(i);
588 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
589 assert!(result.is_ok());
590 assert_eq!(result.unwrap(), true);
591 assert_eq!(last_id, SqlxLedgerEventId(i));
592 }
593
594 for i in 1..=5 {
595 let received = recv.try_recv().unwrap();
596 assert_eq!(received.id, SqlxLedgerEventId(i));
597 }
598 }
599
600 #[test]
601 fn notification_received_handles_transaction_event() {
602 let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
603 let mut last_id = SqlxLedgerEventId(0);
604
605 let event = make_transaction_event(1);
606 let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
607 assert!(result.is_ok());
608 assert_eq!(result.unwrap(), true);
609 assert_eq!(last_id, SqlxLedgerEventId(1));
610
611 let received = recv.try_recv().unwrap();
612 assert!(matches!(
613 received.r#type,
614 SqlxLedgerEventType::TransactionCreated
615 ));
616 }
617
618 #[test]
619 fn notification_received_data_field_missing_error_string() {
620 let raw_json =
623 r#"{"id": 1, "type": "BalanceUpdated", "recorded_at": "2024-01-01T00:00:00Z"}"#;
624 let result: Result<SqlxLedgerEvent, _> = serde_json::from_str(raw_json);
625 assert!(result.is_err());
626 let err_msg = result.unwrap_err().to_string();
627 assert!(
628 err_msg.contains("data field missing"),
629 "Expected 'data field missing' in error: {err_msg}"
630 );
631 }
632
633 #[test]
634 fn backoff_delay_caps_at_32_seconds() {
635 use std::time::Duration;
636 assert_eq!(backoff_delay(1), Duration::from_secs(2));
639 assert_eq!(backoff_delay(2), Duration::from_secs(4));
640 assert_eq!(backoff_delay(3), Duration::from_secs(8));
641 assert_eq!(backoff_delay(4), Duration::from_secs(16));
642 assert_eq!(backoff_delay(5), Duration::from_secs(32));
643 assert_eq!(backoff_delay(6), Duration::from_secs(32));
645 assert_eq!(backoff_delay(10), Duration::from_secs(32));
646 assert_eq!(backoff_delay(100), Duration::from_secs(32));
647 }
648}
649
650impl TryFrom<EventRaw> for SqlxLedgerEvent {
651 type Error = serde_json::Error;
652
653 fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
654 let data_value = value
655 .data
656 .ok_or_else(|| serde_json::Error::custom("data field missing"))?;
657
658 let data = match value.r#type {
659 SqlxLedgerEventType::BalanceUpdated => {
660 SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(data_value)?)
661 }
662 SqlxLedgerEventType::TransactionCreated => {
663 SqlxLedgerEventData::TransactionCreated(serde_json::from_value(data_value)?)
664 }
665 SqlxLedgerEventType::TransactionUpdated => {
666 SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(data_value)?)
667 }
668 };
669
670 Ok(SqlxLedgerEvent {
671 id: value.id,
672 data,
673 r#type: value.r#type,
674 recorded_at: value.recorded_at,
675 #[cfg(feature = "otel")]
676 otel_context: tracing::Span::current().context(),
677 })
678 }
679}