use super::coordinator::coordinate_event_channel;
use crate::consumer::{AggregateConsumerStore, RootConsumerStore};
use crate::error::{Error, Result};
use crate::TimesourceEventPayload;
use futures::stream::{BoxStream, StreamExt};
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::fmt::Debug;
use std::time::Duration;
use timesource_core::event::Persisted;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
use super::super::listener::Listener;
use super::super::store::ConsumerStore;
pub struct TransientConsumerBuilder<'a> {
event_buffer_capacity: usize,
notification_buffer_capacity: usize,
polling_freq: Duration,
dsn: &'a str,
}
impl<'a> TransientConsumerBuilder<'a> {
pub fn new(dsn: &'a str) -> Self {
Self {
event_buffer_capacity: 100,
notification_buffer_capacity: 100,
polling_freq: Duration::from_secs(60),
dsn,
}
}
pub fn with_backup_polling_frequency(mut self, duration: Duration) -> Self {
self.polling_freq = duration;
self
}
pub fn with_event_buffer_capacity(mut self, capacity: usize) -> Self {
self.event_buffer_capacity = capacity;
self
}
pub fn with_notification_buffer_capacity(mut self, capacity: usize) -> Self {
self.notification_buffer_capacity = capacity;
self
}
pub async fn aggregate_build<Event>(
&self,
aggregate_type_name: &'a str,
) -> Result<TransientConsumer<AggregateConsumerStore<Event>>>
where
Event: TimesourceEventPayload + 'static + Send + Sync + Debug,
{
let pool = self.pool().await?;
let aggregate_type_id = self.aggregate_type_id(aggregate_type_name, &pool).await?;
let store = AggregateConsumerStore::new_anonymous(aggregate_type_id, pool.clone());
let listener = Listener::new(
self.notification_buffer_capacity,
aggregate_type_id.to_string().into(),
pool,
);
Ok(TransientConsumer {
event_buffer_capacity: self.event_buffer_capacity,
listener,
polling_freq: self.polling_freq,
store,
})
}
pub async fn aggregate_root_build<Event>(
&self,
aggregate_type_name: &'a str,
root_id: Uuid,
) -> Result<TransientConsumer<RootConsumerStore<Event>>>
where
Event: TimesourceEventPayload + 'static + Send + Sync + Debug,
{
let pool = self.pool().await?;
let aggregate_type_id = self.aggregate_type_id(aggregate_type_name, &pool).await?;
let store = RootConsumerStore::new_anonymous(root_id, aggregate_type_id, pool.clone());
let channel = format!("{}::{}", aggregate_type_id, root_id);
let listener = Listener::new(self.notification_buffer_capacity, channel.into(), pool);
Ok(TransientConsumer {
event_buffer_capacity: self.event_buffer_capacity,
listener,
polling_freq: self.polling_freq,
store,
})
}
async fn pool(&self) -> Result<PgPool> {
Ok(PgPoolOptions::new().connect(self.dsn).await?)
}
async fn aggregate_type_id(&self, aggregate_type_name: &str, pool: &PgPool) -> Result<i32> {
Ok(
sqlx::query_file_scalar!("queries/aggregate_type/id.sql", aggregate_type_name)
.fetch_one(pool)
.await?
.ok_or_else(|| Error::InvalidData("Unable to get aggregate type id".into()))?,
)
}
}
pub struct TransientConsumer<Store>
where
Store: ConsumerStore,
{
event_buffer_capacity: usize,
polling_freq: Duration,
listener: Listener,
store: Store,
}
impl<Store> TransientConsumer<Store>
where
Store: ConsumerStore,
{
pub async fn resume(
&self,
) -> Result<BoxStream<'_, Result<Persisted<<Store as ConsumerStore>::Event>>>> {
let (buffer_tx, buffer_rx) = mpsc::channel(self.event_buffer_capacity);
let pg_notification_rx = self.listener.listen().await?;
let polling_freq = self.polling_freq;
let store = self.store.clone();
tokio::spawn(async move {
let task = tokio::spawn(coordinate_event_channel(
polling_freq,
pg_notification_rx,
store,
buffer_tx.clone(),
));
if let Err(error) = task.await {
error!(?error, "Consumer task crashed. Unable to recover");
}
});
Ok(ReceiverStream::new(buffer_rx).boxed())
}
}