use serde::{Serialize, de::DeserializeOwned};
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use crate::out::event::*;
use crate::{
config::*,
handle::{OwnedTaskHandle, spawn_supervised},
sequence::EventSequence,
};
pub struct CacheHandle<P>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
{
highest_known_sequence: Arc<AtomicU64>,
persistent_event_receiver: Option<broadcast::Receiver<Arc<PersistentOutboxEvent<P>>>>,
backfill_request:
mpsc::UnboundedSender<(EventSequence, mpsc::Sender<Arc<PersistentOutboxEvent<P>>>)>,
backfill_buffer_size: usize,
}
impl<P> CacheHandle<P>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
{
pub fn latest_known_persisted(&self) -> EventSequence {
EventSequence::from(self.highest_known_sequence.load(Ordering::Relaxed))
}
pub fn persistent_event_stream(&mut self) -> BroadcastStream<Arc<PersistentOutboxEvent<P>>> {
BroadcastStream::new(
self.persistent_event_receiver
.take()
.expect("receiver already taken"),
)
}
pub fn request_old_persistent_events(
&self,
start_after: EventSequence,
) -> ReceiverStream<Arc<PersistentOutboxEvent<P>>> {
let (tx, rx) = mpsc::channel(self.backfill_buffer_size);
let _ = self.backfill_request.send((start_after, tx));
ReceiverStream::new(rx)
}
}
#[derive(Debug)]
pub struct PersistentOutboxEventCache<P, Tables>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
{
highest_known_sequence: Arc<AtomicU64>,
persistent_event_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
backfill_request_send:
mpsc::UnboundedSender<(EventSequence, mpsc::Sender<Arc<PersistentOutboxEvent<P>>>)>,
backfill_buffer_size: usize,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
_cache_loop_handle: OwnedTaskHandle,
_phantom: std::marker::PhantomData<Tables>,
}
impl<P, Tables> PersistentOutboxEventCache<P, Tables>
where
P: Serialize + DeserializeOwned + Send + Sync + 'static,
Tables: crate::tables::MailboxTables,
{
pub fn handle(&self) -> CacheHandle<P> {
CacheHandle {
highest_known_sequence: self.highest_known_sequence.clone(),
persistent_event_receiver: Some(self.persistent_event_sender.subscribe()),
backfill_request: self.backfill_request_send.clone(),
backfill_buffer_size: self.backfill_buffer_size,
}
}
pub fn cache_fill_sender(&self) -> broadcast::Sender<Arc<PersistentOutboxEvent<P>>> {
self.cache_fill_sender.clone()
}
pub async fn init(
pool: &sqlx::PgPool,
config: &MailboxConfig,
persistent_notification_rx: mpsc::Receiver<sqlx::postgres::PgNotification>,
) -> Result<Self, sqlx::Error> {
let (backfill_send, backfill_recv) = mpsc::unbounded_channel();
let (cache_fill_send, cache_fill_recv) = broadcast::channel(config.event_buffer_size);
let (persistent_event_sender, _) = broadcast::channel(config.event_buffer_size);
let highest_known_sequence = Arc::new(AtomicU64::from(
Tables::highest_known_persistent_sequence(pool).await?,
));
let cache_loop_handle = Self::spawn_cache_loop(
pool,
config,
persistent_event_sender.clone(),
highest_known_sequence.clone(),
backfill_recv,
cache_fill_recv,
cache_fill_send.clone(),
persistent_notification_rx,
)
.await?;
let ret = Self {
highest_known_sequence,
backfill_request_send: backfill_send,
persistent_event_sender,
backfill_buffer_size: config.event_buffer_size,
cache_fill_sender: cache_fill_send,
_cache_loop_handle: cache_loop_handle,
_phantom: std::marker::PhantomData,
};
Ok(ret)
}
fn insert_into_cache_and_maybe_broadcast(
cache: im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
event: Arc<PersistentOutboxEvent<P>>,
highest_known_sequence: &AtomicU64,
persistent_event_sender: &broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
mut last_broadcast_sequence: EventSequence,
cache_size: usize,
) -> (
im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
EventSequence,
) {
use std::ops::Bound;
let sequence = event.sequence;
let highest_known = highest_known_sequence.load(Ordering::Relaxed);
let threshold = highest_known.saturating_sub(cache_size as u64);
if u64::from(sequence) <= threshold {
return (cache, last_broadcast_sequence);
}
highest_known_sequence.fetch_max(u64::from(sequence), Ordering::AcqRel);
let cache = cache.alter(|existing| existing.or(Some(event)), sequence);
for (seq, evt) in cache.range((Bound::Excluded(last_broadcast_sequence), Bound::Unbounded))
{
if *seq != last_broadcast_sequence.next() {
record_sequence_gap(
u64::from(last_broadcast_sequence),
u64::from(*seq),
highest_known_sequence.load(Ordering::Relaxed),
);
break;
}
if persistent_event_sender.send(evt.clone()).is_err() {
record_no_receivers(u64::from(*seq));
break;
}
last_broadcast_sequence = *seq;
}
(cache, last_broadcast_sequence)
}
async fn fill_gap(
pool: sqlx::PgPool,
from_sequence: EventSequence,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
buffer_size: usize,
) {
if let Ok(events) = Tables::load_next_page::<P>(&pool, from_sequence, buffer_size).await {
for event in events {
let _ = cache_fill_sender.send(Arc::new(event));
}
}
}
async fn handle_backfill_request(
pool: sqlx::PgPool,
start_after: EventSequence,
sender: mpsc::Sender<Arc<PersistentOutboxEvent<P>>>,
cache_snapshot: im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
highest: EventSequence,
buffer_size: usize,
) {
use std::ops::Bound;
let mut current_sequence = start_after;
while current_sequence < highest {
let next_needed = current_sequence.next();
if cache_snapshot.contains_key(&next_needed) {
break;
}
match Tables::load_next_page::<P>(&pool, current_sequence, buffer_size).await {
Ok(events) if events.is_empty() => break,
Ok(events) => {
for event in events {
let seq = event.sequence;
let event = Arc::new(event);
let _ = cache_fill_sender.send(event.clone());
if sender.send(event).await.is_err() {
return;
}
current_sequence = seq;
}
}
Err(e) => {
record_backfill_failed(&e, u64::from(current_sequence));
break;
}
}
}
for (_, event) in
cache_snapshot.range((Bound::Excluded(current_sequence), Bound::Unbounded))
{
if sender.send(event.clone()).await.is_err() {
return;
}
}
}
async fn fetch_event_by_sequence(
pool: sqlx::PgPool,
sequence: EventSequence,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
) {
let start_after = EventSequence::from(u64::from(sequence).saturating_sub(1));
if let Ok(events) = Tables::load_next_page::<P>(&pool, start_after, 1).await
&& let Some(event) = events.into_iter().next()
&& event.sequence == sequence
{
let _ = cache_fill_sender.send(Arc::new(event));
}
}
fn handle_notification(
pool: &sqlx::PgPool,
payload: &str,
cache: &im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>>,
cache_fill_sender: &broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
) -> Option<PersistentOutboxEvent<P>> {
#[derive(serde::Deserialize)]
struct NotificationHeader {
sequence: EventSequence,
#[serde(default)]
payload_omitted: bool,
}
let header: NotificationHeader = serde_json::from_str(payload).ok()?;
if header.payload_omitted {
if cache.contains_key(&header.sequence) {
return None;
}
tokio::spawn(Self::fetch_event_by_sequence(
pool.clone(),
header.sequence,
cache_fill_sender.clone(),
));
None
} else {
serde_json::from_str(payload).ok()
}
}
#[allow(clippy::too_many_arguments)]
async fn spawn_cache_loop(
pool: &sqlx::PgPool,
config: &MailboxConfig,
persistent_event_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
highest_known_sequence: Arc<AtomicU64>,
mut backfill_request: mpsc::UnboundedReceiver<(
EventSequence,
mpsc::Sender<Arc<PersistentOutboxEvent<P>>>,
)>,
mut cache_fill_receiver: broadcast::Receiver<Arc<PersistentOutboxEvent<P>>>,
cache_fill_sender: broadcast::Sender<Arc<PersistentOutboxEvent<P>>>,
mut notification_receiver: mpsc::Receiver<sqlx::postgres::PgNotification>,
) -> Result<OwnedTaskHandle, sqlx::Error> {
let pool = pool.clone();
let cache_size = config.event_cache_size;
let high_water = cache_size * (100 + config.event_cache_trim_percent as usize) / 100;
let low_water = cache_size * (100 - config.event_cache_trim_percent as usize) / 100;
let initial_sequence = EventSequence::from(highest_known_sequence.load(Ordering::Relaxed));
let handle = spawn_supervised("obix::persistent_cache_loop", async move {
let mut persistent_cache: im::OrdMap<EventSequence, Arc<PersistentOutboxEvent<P>>> =
im::OrdMap::new();
let mut last_broadcast_sequence = initial_sequence;
let mut gap_fill_in_progress_for: Option<(EventSequence, std::time::Instant)> = None;
loop {
tokio::select! {
biased;
result = backfill_request.recv() => {
match result {
Some((start_after, sender)) => {
let cache_snapshot = persistent_cache.clone();
let highest = EventSequence::from(
highest_known_sequence.load(Ordering::Relaxed)
);
tokio::spawn(Self::handle_backfill_request(
pool.clone(),
start_after,
sender,
cache_snapshot,
cache_fill_sender.clone(),
highest,
cache_size,
));
}
None => {
record_backfill_channel_closed();
break;
}
}
continue;
}
result = cache_fill_receiver.recv() => {
match result {
Ok(event) => {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
event,
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
while let Ok(event) = cache_fill_receiver.try_recv() {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
event,
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
record_cache_fill_lagged(
n,
u64::from(last_broadcast_sequence),
highest_known_sequence.load(Ordering::Relaxed),
);
continue;
}
Err(broadcast::error::RecvError::Closed) => {
record_cache_fill_closed();
break;
}
}
}
result = notification_receiver.recv() => {
match result {
Some(notification) => {
if let Some(event) = Self::handle_notification(
&pool,
notification.payload(),
&persistent_cache,
&cache_fill_sender,
) {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
Arc::new(event),
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
}
while let Ok(notification) = notification_receiver.try_recv() {
if let Some(event) = Self::handle_notification(
&pool,
notification.payload(),
&persistent_cache,
&cache_fill_sender,
) {
(persistent_cache, last_broadcast_sequence) =
Self::insert_into_cache_and_maybe_broadcast(
persistent_cache,
Arc::new(event),
&highest_known_sequence,
&persistent_event_sender,
last_broadcast_sequence,
cache_size,
);
}
}
}
None => {
record_notification_channel_closed();
break;
}
}
}
}
let next_needed = last_broadcast_sequence.next();
let highest = highest_known_sequence.load(Ordering::Relaxed);
if u64::from(next_needed) <= highest && !persistent_cache.contains_key(&next_needed)
{
let should_fill = match gap_fill_in_progress_for {
Some((seq, started)) if seq == last_broadcast_sequence => {
started.elapsed() > std::time::Duration::from_secs(1)
}
_ => true,
};
if should_fill {
gap_fill_in_progress_for =
Some((last_broadcast_sequence, std::time::Instant::now()));
tokio::spawn(Self::fill_gap(
pool.clone(),
last_broadcast_sequence,
cache_fill_sender.clone(),
cache_size,
));
}
} else {
gap_fill_in_progress_for = None;
}
if persistent_cache.len() > high_water {
let to_remove = persistent_cache.len() - low_water;
if let Some((&split_key, _)) = persistent_cache.iter().nth(to_remove) {
let (_, right) = persistent_cache.split(&split_key);
persistent_cache = right;
}
}
}
});
Ok(OwnedTaskHandle::new(handle))
}
}
#[tracing::instrument(name = "obix.persistent_cache.sequence_gap", level = "warn")]
fn record_sequence_gap(last_broadcast_sequence: u64, next_in_cache: u64, highest_known: u64) {}
#[tracing::instrument(name = "obix.persistent_cache.no_receivers", level = "warn")]
fn record_no_receivers(sequence: u64) {}
#[tracing::instrument(
name = "obix.persistent_cache.backfill_failed",
level = "warn",
skip_all,
fields(error = %error, current_sequence = current_sequence),
)]
fn record_backfill_failed(error: &sqlx::Error, current_sequence: u64) {}
#[tracing::instrument(
name = "obix.persistent_cache.backfill_channel_closed",
level = "error",
fields(otel.status_code = "ERROR"),
)]
fn record_backfill_channel_closed() {}
#[tracing::instrument(
name = "obix.persistent_cache.cache_fill_lagged",
level = "error",
fields(otel.status_code = "ERROR"),
)]
fn record_cache_fill_lagged(dropped: u64, last_broadcast_sequence: u64, highest_known: u64) {}
#[tracing::instrument(
name = "obix.persistent_cache.cache_fill_closed",
level = "error",
fields(otel.status_code = "ERROR"),
)]
fn record_cache_fill_closed() {}
#[tracing::instrument(
name = "obix.persistent_cache.notification_channel_closed",
level = "error",
fields(otel.status_code = "ERROR"),
)]
fn record_notification_channel_closed() {}