use serde::{Serialize, de::DeserializeOwned};
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use crate::out::event::*;
use crate::{config::*, handle::OwnedTaskHandle, sequence::EventSequence};
pub struct CacheHandle<P>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
{
highest_known_sequence: Arc<AtomicU64>,
persistent_event_receiver: Option<broadcast::Receiver<Arc<PersistentOutboxEvent<P>>>>,
backfill_request:
mpsc::UnboundedSender<(EventSequence, mpsc::Sender<Arc<PersistentOutboxEvent<P>>>)>,
backfill_buffer_size: usize,
}
impl<P> CacheHandle<P>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
{
pub fn latest_known_persisted(&self) -> EventSequence {
EventSequence::from(self.highest_known_sequence.load(Ordering::Relaxed))
}
pub fn persistent_event_stream(&mut self) -> BroadcastStream<Arc<PersistentOutboxEvent<P>>> {
BroadcastStream::new(
self.persistent_event_receiver
.take()
.expect("receiver already taken"),
)
}
pub fn request_old_persistent_events(
&self,
start_after: EventSequence,
) -> ReceiverStream<Arc<PersistentOutboxEvent<P>>> {
let (tx, rx) = mpsc::channel(self.backfill_buffer_size);
let _ = self.backfill_request.send((start_after, tx));
ReceiverStream::new(rx)
}
}
#[derive(Debug)]
pub struct PersistentOutboxEventCache<P, Tables>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
{
highest_known_sequence: Arc<AtomicU64>,
persistent_event_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
backfill_request_send:
mpsc::UnboundedSender<(EventSequence, mpsc::Sender<Arc<PersistentOutboxEvent<P>>>)>,
backfill_buffer_size: usize,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
_cache_loop_handle: OwnedTaskHandle,
_phantom: std::marker::PhantomData<Tables>,
}
impl<P, Tables> PersistentOutboxEventCache<P, Tables>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
Tables: crate::tables::MailboxTables,
{
pub fn handle(&self) -> CacheHandle<P> {
CacheHandle {
highest_known_sequence: self.highest_known_sequence.clone(),
persistent_event_receiver: Some(self.persistent_event_sender.subscribe()),
backfill_request: self.backfill_request_send.clone(),
backfill_buffer_size: self.backfill_buffer_size,
}
}
pub fn cache_fill_sender(&self) -> broadcast::Sender<Arc<PersistentOutboxEvent<P>>> {
self.cache_fill_sender.clone()
}
pub async fn init(
pool: &sqlx::PgPool,
config: &MailboxConfig,
persistent_notification_rx: mpsc::Receiver<sqlx::postgres::PgNotification>,
) -> Result<Self, sqlx::Error> {
let (backfill_send, backfill_recv) = mpsc::unbounded_channel();
let (cache_fill_send, cache_fill_recv) = broadcast::channel(config.event_buffer_size);
let (persistent_event_sender, _) = broadcast::channel(config.event_buffer_size);
let highest_known_sequence = Arc::new(AtomicU64::from(
Tables::highest_known_persistent_sequence(pool).await?,
));
let cache_loop_handle = Self::spawn_cache_loop(
pool,
config,
persistent_event_sender.clone(),
highest_known_sequence.clone(),
backfill_recv,
cache_fill_recv,
cache_fill_send.clone(),
persistent_notification_rx,
)
.await?;
let ret = Self {
highest_known_sequence,
backfill_request_send: backfill_send,
persistent_event_sender,
backfill_buffer_size: config.event_buffer_size,
cache_fill_sender: cache_fill_send,
_cache_loop_handle: cache_loop_handle,
_phantom: std::marker::PhantomData,
};
Ok(ret)
}
fn insert_into_cache_and_maybe_broadcast(
cache: im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
event: Arc<PersistentOutboxEvent<P>>,
highest_known_sequence: &AtomicU64,
persistent_event_sender: &broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
mut last_broadcast_sequence: EventSequence,
cache_size: usize,
) -> (
im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
EventSequence,
) {
use std::ops::Bound;
let sequence = event.sequence;
let highest_known = highest_known_sequence.load(Ordering::Relaxed);
let threshold = highest_known.saturating_sub(cache_size as u64);
if u64::from(sequence) <= threshold {
return (cache, last_broadcast_sequence);
}
highest_known_sequence.fetch_max(u64::from(sequence), Ordering::AcqRel);
let cache = cache.alter(|existing| existing.or(Some(event)), sequence);
for (seq, evt) in cache.range((Bound::Excluded(last_broadcast_sequence), Bound::Unbounded))
{
if *seq != last_broadcast_sequence.next() {
break;
}
if persistent_event_sender.send(evt.clone()).is_err() {
break;
}
last_broadcast_sequence = *seq;
}
(cache, last_broadcast_sequence)
}
async fn handle_backfill_request(
pool: sqlx::PgPool,
start_after: EventSequence,
sender: mpsc::Sender<Arc<PersistentOutboxEvent<P>>>,
cache_snapshot: im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
highest: EventSequence,
buffer_size: usize,
) {
use std::ops::Bound;
let mut current_sequence = start_after;
while current_sequence < highest {
let next_needed = current_sequence.next();
if cache_snapshot.contains_key(&next_needed) {
break;
}
match Tables::load_next_page::<P>(&pool, current_sequence, buffer_size).await {
Ok(events) if events.is_empty() => break,
Ok(events) => {
for event in events {
let seq = event.sequence;
let event = Arc::new(event);
let _ = cache_fill_sender.send(event.clone());
if sender.send(event).await.is_err() {
return;
}
current_sequence = seq;
}
}
Err(_) => break,
}
}
for (_, event) in
cache_snapshot.range((Bound::Excluded(current_sequence), Bound::Unbounded))
{
if sender.send(event.clone()).await.is_err() {
return;
}
}
}
async fn fetch_event_by_sequence(
pool: sqlx::PgPool,
sequence: EventSequence,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
) {
let start_after = EventSequence::from(u64::from(sequence).saturating_sub(1));
if let Ok(events) = Tables::load_next_page::<P>(&pool, start_after, 1).await
&& let Some(event) = events.into_iter().next()
&& event.sequence == sequence
{
let _ = cache_fill_sender.send(Arc::new(event));
}
}
fn handle_notification(
pool: &sqlx::PgPool,
payload: &str,
cache: &im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
cache_fill_sender: &broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
) -> Option<PersistentOutboxEvent<P>> {
#[derive(serde::Deserialize)]
struct NotificationHeader {
sequence: EventSequence,
#[serde(default)]
payload_omitted: bool,
}
let header: NotificationHeader = serde_json::from_str(payload).ok()?;
if header.payload_omitted {
if cache.contains_key(&header.sequence) {
return None;
}
tokio::spawn(Self::fetch_event_by_sequence(
pool.clone(),
header.sequence,
cache_fill_sender.clone(),
));
None
} else {
serde_json::from_str(payload).ok()
}
}
#[allow(clippy::too_many_arguments)]
async fn spawn_cache_loop(
pool: &sqlx::PgPool,
config: &MailboxConfig,
persistent_event_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
highest_known_sequence: Arc<AtomicU64>,
mut backfill_request: mpsc::UnboundedReceiver<(
EventSequence,
mpsc::Sender<Arc<PersistentOutboxEvent<P>>>,
)>,
mut cache_fill_receiver: broadcast::Receiver<Arc<PersistentOutboxEvent<P>>>,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
mut notification_receiver: mpsc::Receiver<sqlx::postgres::PgNotification>,
) -> Result<OwnedTaskHandle, sqlx::Error> {
let pool = pool.clone();
let cache_size = config.event_cache_size;
let high_water = cache_size * (100 + config.event_cache_trim_percent as usize) / 100;
let low_water = cache_size * (100 - config.event_cache_trim_percent as usize) / 100;
let initial_sequence = EventSequence::from(highest_known_sequence.load(Ordering::Relaxed));
let handle = tokio::spawn(async move {
let mut persistent_cache: im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>> =
im::OrdMap::new();
let mut last_broadcast_sequence = initial_sequence;
loop {
tokio::select! {
biased;
result = backfill_request.recv() => {
match result {
Some((start_after, sender)) => {
let cache_snapshot = persistent_cache.clone();
let highest = EventSequence::from(
highest_known_sequence.load(Ordering::Relaxed)
);
tokio::spawn(Self::handle_backfill_request(
pool.clone(),
start_after,
sender,
cache_snapshot,
cache_fill_sender.clone(),
highest,
cache_size,
));
}
None => {
break;
}
}
continue;
}
result = cache_fill_receiver.recv() => {
match result {
Ok(event) => {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
event,
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
while let Ok(event) = cache_fill_receiver.try_recv() {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
event,
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
}
}
Err(broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
result = notification_receiver.recv() => {
match result {
Some(notification) => {
if let Some(event) = Self::handle_notification(
&pool,
notification.payload(),
&persistent_cache,
&cache_fill_sender,
) {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
Arc::new(event),
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
}
while let Ok(notification) = notification_receiver.try_recv() {
if let Some(event) = Self::handle_notification(
&pool,
notification.payload(),
&persistent_cache,
&cache_fill_sender,
) {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
Arc::new(event),
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
}
}
}
None => {
break;
}
}
}
}
if persistent_cache.len() > high_water {
let to_remove = persistent_cache.len() - low_water;
if let Some((&split_key, _)) = persistent_cache.iter().nth(to_remove) {
let (_, right) = persistent_cache.split(&split_key);
persistent_cache = right;
}
}
}
});
Ok(OwnedTaskHandle::new(handle))
}
}