#[cfg(test)]
mod tests;
pub(crate) mod id_indexer;
use crate::{Error, Migrator, PgEventId};
use async_trait::async_trait;
use disintegrate::{Event, EventListener, StreamItem, StreamQuery};
use disintegrate_serde::Serde;
use futures::future::join_all;
use futures::{try_join, Future, StreamExt};
use sqlx::{Postgres, Row, Transaction};
use std::error::Error as StdError;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::event_store::PgEventStore;
pub struct PgEventListener<E, S>
where
E: Event + Clone,
S: Serde<E> + Send + Sync,
{
executors: Vec<Box<dyn EventListenerExecutor<E> + Send + Sync>>,
event_store: PgEventStore<E, S>,
intialize: bool,
shutdown_token: CancellationToken,
}
impl<E, S> PgEventListener<E, S>
where
E: Event + Clone + Send + Sync + 'static,
S: Serde<E> + Clone + Send + Sync + 'static,
{
pub fn builder(event_store: PgEventStore<E, S>) -> Self {
Self {
event_store,
executors: vec![],
shutdown_token: CancellationToken::new(),
intialize: true,
}
}
pub fn uninitialized(mut self) -> Self {
self.intialize = false;
self
}
pub fn register_listener<QE, L>(
mut self,
event_listener: L,
config: PgEventListenerConfig<impl Retry<L::Error> + Send + Sync + Clone + 'static>,
) -> Self
where
L: EventListener<PgEventId, QE> + 'static,
QE: TryFrom<E> + Into<E> + Event + Send + Sync + Clone + 'static,
<QE as TryFrom<E>>::Error: StdError + Send + Sync,
L::Error: Send + Sync + 'static,
{
self.executors.push(Box::new(PgEventListerExecutor::new(
self.event_store.clone(),
event_listener,
self.shutdown_token.clone(),
config,
)));
self
}
pub async fn start(self) -> Result<(), Error> {
if self.intialize {
Migrator::new(self.event_store.clone())
.init_listener()
.await?;
}
let mut handles = vec![];
let mut wakers = vec![];
for executor in self.executors {
executor.init().await?;
let (waker, task) = executor.run();
if let Some(waker) = waker {
wakers.push(waker);
}
handles.push(task);
}
if !wakers.is_empty() {
let pool = self.event_store.pool.clone();
let shutdown = self.shutdown_token.clone();
let watch_new_events = tokio::spawn(async move {
loop {
let mut listener = sqlx::postgres::PgListener::connect_with(&pool).await?;
listener.listen("new_events").await?;
loop {
tokio::select! {
msg = listener.try_recv() => {
match msg {
Ok(Some(notification)) => {
for waker in &wakers {
waker.wake(notification.payload());
}
},
Ok(None) => {},
Err(err @ sqlx::Error::PoolClosed) => return Err(Error::Database(err)),
Err(_) => break,
}
}
_ = shutdown.cancelled() => return Ok::<(), Error>(()),
}
}
}
});
handles.push(watch_new_events);
}
join_all(handles).await;
Ok(())
}
pub async fn start_with_shutdown<F: Future<Output = ()> + Send + 'static>(
self,
shutdown: F,
) -> Result<(), Error> {
let shutdown_token = self.shutdown_token.clone();
let shutdown_handle = async move {
shutdown.await;
shutdown_token.cancel();
Ok::<(), Error>(())
};
try_join!(self.start(), shutdown_handle).map(|_| ())
}
}
#[derive(Debug)]
pub enum PgEventListenerErrorKind<HE> {
InitTransaction { source: Error },
AcquireLock { source: Error },
FetchNextEvent {
source: Error,
last_processed_event_id: PgEventId,
},
Handler {
source: HE,
last_processed_event_id: PgEventId,
},
ReleaseLock {
source: Error,
last_processed_event_id: PgEventId,
},
}
#[derive(Debug)]
pub struct PgEventListenerError<HE> {
pub kind: PgEventListenerErrorKind<HE>,
pub listener_id: String,
}
pub enum RetryAction {
Abort,
Wait { duration: Duration },
}
pub trait Retry<HE> {
fn retry(&self, error: PgEventListenerError<HE>, attempts: usize) -> RetryAction;
}
#[derive(Clone, Copy, Default)]
pub struct AbortRetry;
impl<HE> Retry<HE> for AbortRetry {
fn retry(&self, _error: PgEventListenerError<HE>, _attempts: usize) -> RetryAction {
RetryAction::Abort
}
}
impl<HE, T: Fn(PgEventListenerError<HE>, usize) -> RetryAction> Retry<HE> for T {
fn retry(&self, error: PgEventListenerError<HE>, attempts: usize) -> RetryAction {
self(error, attempts)
}
}
pub struct PgEventListenerConfig<R> {
poll: Duration,
fetch_size: usize,
notifier_enabled: bool,
retry: R,
}
impl<R> Clone for PgEventListenerConfig<R>
where
R: Clone,
{
fn clone(&self) -> Self {
Self {
poll: self.poll,
fetch_size: self.fetch_size,
notifier_enabled: self.notifier_enabled,
retry: self.retry.clone(),
}
}
}
impl PgEventListenerConfig<AbortRetry> {
pub fn poller(poll: Duration) -> PgEventListenerConfig<AbortRetry> {
PgEventListenerConfig {
poll,
fetch_size: usize::MAX,
notifier_enabled: false,
retry: AbortRetry,
}
}
}
impl<R> PgEventListenerConfig<R> {
pub fn fetch_size(mut self, fetch_size: usize) -> Self {
self.fetch_size = fetch_size;
self
}
pub fn with_notifier(mut self) -> Self {
self.notifier_enabled = true;
self
}
pub fn with_retry<R1>(self, retry: R1) -> PgEventListenerConfig<R1> {
PgEventListenerConfig {
retry,
poll: self.poll,
fetch_size: self.fetch_size,
notifier_enabled: self.notifier_enabled,
}
}
}
enum ListenerExecutionControl {
Continue,
Stop,
}
#[async_trait]
trait EventListenerExecutor<E: Event + Clone> {
async fn init(&self) -> Result<(), Error>;
fn run(&self) -> (Option<ExecutorWaker<E>>, JoinHandle<Result<(), Error>>);
}
struct PgEventListerExecutor<L, QE, E, S, R>
where
QE: TryFrom<E> + Event + Send + Sync + Clone,
<QE as TryFrom<E>>::Error: Send + Sync,
E: Event + Clone + Sync + Send,
S: Serde<E> + Clone + Send + Sync,
L: EventListener<PgEventId, QE>,
R: Retry<L::Error>,
L::Error: Send + Sync + 'static,
{
event_store: PgEventStore<E, S>,
event_handler: Arc<L>,
config: PgEventListenerConfig<R>,
wake_channel: (watch::Sender<bool>, watch::Receiver<bool>),
shutdown_token: CancellationToken,
_event_store_events: PhantomData<E>,
_event_listener_events: PhantomData<QE>,
}
impl<L, QE, E, S, R> PgEventListerExecutor<L, QE, E, S, R>
where
E: Event + Clone + Sync + Send + 'static,
S: Serde<E> + Clone + Send + Sync + 'static,
QE: TryFrom<E> + Event + 'static + Send + Sync + Clone,
<QE as TryFrom<E>>::Error: StdError + 'static + Send + Sync,
L: EventListener<PgEventId, QE> + 'static,
R: Retry<L::Error> + Send + Sync + 'static,
L::Error: Send + Sync + 'static,
{
pub fn new(
event_store: PgEventStore<E, S>,
event_handler: L,
shutdown_token: CancellationToken,
config: PgEventListenerConfig<R>,
) -> Self {
Self {
event_store,
event_handler: Arc::new(event_handler),
config,
wake_channel: watch::channel(true),
shutdown_token,
_event_store_events: PhantomData,
_event_listener_events: PhantomData,
}
}
async fn acquire_listener(
&self,
tx: &mut Transaction<'_, Postgres>,
) -> Result<Option<PgEventId>, sqlx::Error> {
Ok(sqlx::query("SELECT last_processed_event_id FROM event_listener WHERE id = $1 FOR UPDATE SKIP LOCKED")
.bind(self.event_handler.id())
.fetch_optional(&mut **tx)
.await?
.map(|r| r.get(0)))
}
async fn release_listener(
&self,
result: Result<PgEventId, PgEventListenerError<L::Error>>,
mut tx: Transaction<'_, Postgres>,
) -> Result<(), PgEventListenerError<L::Error>> {
let last_processed_event_id = match result {
Ok(last_processed_event_id) => last_processed_event_id,
Err(PgEventListenerError {
kind:
PgEventListenerErrorKind::FetchNextEvent {
last_processed_event_id,
..
}
| PgEventListenerErrorKind::Handler {
last_processed_event_id,
..
},
..
}) => last_processed_event_id,
Err(e) => return Err(e),
};
sqlx::query(
"UPDATE event_listener SET last_processed_event_id = $1, updated_at = now() WHERE id = $2",
)
.bind(last_processed_event_id)
.bind(self.event_handler.id())
.execute(&mut *tx)
.await.map_err(|e| PgEventListenerError::<L::Error>{
kind: PgEventListenerErrorKind::ReleaseLock {
source: e.into(),
last_processed_event_id
},
listener_id: self.event_handler.id().to_string(),
})?;
tx.commit()
.await
.map_err(|e| PgEventListenerError::<L::Error> {
kind: PgEventListenerErrorKind::ReleaseLock {
source: e.into(),
last_processed_event_id,
},
listener_id: self.event_handler.id().to_string(),
})?;
result.map(|_| ())
}
async fn handle_events_from(
&self,
mut last_processed_event_id: PgEventId,
tx: &mut Transaction<'_, Postgres>,
) -> Result<PgEventId, PgEventListenerError<L::Error>> {
let query = self
.event_handler
.query()
.clone()
.change_origin(last_processed_event_id);
let mut stream = self
.event_store
.stream_with(&mut **tx, &query)
.take(self.config.fetch_size);
while let Some(item) = stream.next().await {
let item = item.map_err(|e| PgEventListenerError::<L::Error> {
kind: PgEventListenerErrorKind::FetchNextEvent {
source: e,
last_processed_event_id,
},
listener_id: self.event_handler.id().to_string(),
})?;
let event_id = item.id();
match item {
StreamItem::End(_) => {
last_processed_event_id = event_id;
break;
}
StreamItem::Event(event) => {
self.event_handler
.handle(event)
.await
.map_err(|e| PgEventListenerError {
kind: PgEventListenerErrorKind::Handler {
source: e,
last_processed_event_id,
},
listener_id: self.event_handler.id().to_string(),
})?;
last_processed_event_id = event_id;
}
}
if self.shutdown_token.is_cancelled() {
break;
}
}
Ok(last_processed_event_id)
}
pub async fn try_execute(&self) -> Result<(), PgEventListenerError<L::Error>> {
let mut tx = self
.event_store
.pool
.begin()
.await
.map_err(|e| PgEventListenerError {
kind: PgEventListenerErrorKind::InitTransaction { source: e.into() },
listener_id: self.event_handler.id().to_string(),
})?;
let Some(last_processed_id) =
self.acquire_listener(&mut tx)
.await
.map_err(|e| PgEventListenerError {
kind: PgEventListenerErrorKind::AcquireLock { source: e.into() },
listener_id: self.event_handler.id().to_string(),
})?
else {
return Ok(()); };
let result = self.handle_events_from(last_processed_id, &mut tx).await;
self.release_listener(result, tx).await
}
async fn execute(&self) -> ListenerExecutionControl {
let mut attempts = 0;
loop {
match self.try_execute().await {
Ok(_) => break ListenerExecutionControl::Continue,
Err(err) => match self.config.retry.retry(err, attempts) {
RetryAction::Abort => break ListenerExecutionControl::Stop,
RetryAction::Wait { duration } => {
attempts += 1;
tokio::time::sleep(duration).await;
}
},
}
}
}
pub fn spawn_task(self) -> JoinHandle<Result<(), Error>> {
let shutdown = self.shutdown_token.clone();
let mut poll = tokio::time::interval(self.config.poll);
poll.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut wake_tx = self.wake_channel.1.clone();
tokio::spawn(async move {
loop {
let outcome = tokio::select! {
Ok(()) = wake_tx.changed() => self.execute().await,
_ = poll.tick() => self.execute().await,
_ = shutdown.cancelled() => return Ok::<(), Error>(()),
};
match outcome {
ListenerExecutionControl::Continue => {}
ListenerExecutionControl::Stop => break,
}
}
Ok(())
})
}
}
#[async_trait]
impl<L, QE, E, S, R> EventListenerExecutor<E> for PgEventListerExecutor<L, QE, E, S, R>
where
E: Event + Clone + Sync + Send + 'static,
S: Serde<E> + Clone + Send + Sync + 'static,
QE: TryFrom<E> + Into<E> + Event + 'static + Send + Sync + Clone,
<QE as TryFrom<E>>::Error: StdError + 'static + Send + Sync,
L: EventListener<PgEventId, QE> + 'static,
R: Retry<L::Error> + Clone + Send + Sync + 'static,
L::Error: Send + Sync + 'static,
{
async fn init(&self) -> Result<(), Error> {
let mut tx = self.event_store.pool.begin().await?;
sqlx::query("INSERT INTO event_listener (id, last_processed_event_id) VALUES ($1, 0) ON CONFLICT (id) DO NOTHING")
.bind(self.event_handler.id())
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
fn run(&self) -> (Option<ExecutorWaker<E>>, JoinHandle<Result<(), Error>>) {
let waker = if self.config.notifier_enabled {
Some(ExecutorWaker {
wake_tx: self.wake_channel.0.clone(),
query: self.event_handler.query().cast().clone(),
})
} else {
None
};
(waker, self.clone().spawn_task())
}
}
impl<L, QE, E, S, R> Clone for PgEventListerExecutor<L, QE, E, S, R>
where
QE: TryFrom<E> + Event + Send + Sync + Clone,
<QE as TryFrom<E>>::Error: Send + Sync,
E: Event + Clone + Sync + Send,
S: Serde<E> + Clone + Send + Sync,
L: EventListener<PgEventId, QE>,
R: Retry<L::Error> + Clone,
L::Error: Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
event_store: self.event_store.clone(),
event_handler: Arc::clone(&self.event_handler),
config: self.config.clone(),
wake_channel: self.wake_channel.clone(),
shutdown_token: self.shutdown_token.clone(),
_event_store_events: PhantomData,
_event_listener_events: PhantomData,
}
}
}
struct ExecutorWaker<E: Event + Clone> {
wake_tx: watch::Sender<bool>,
query: StreamQuery<PgEventId, E>,
}
impl<E: Event + Clone> ExecutorWaker<E> {
fn wake(&self, event: &str) {
if self.query.matches_event(event) {
self.wake_tx.send_replace(true);
}
}
}