use std::collections::{BTreeMap, HashMap};
use sqlx::{PgPool, Row};
use crate::{EventStore, EventStoreError, RecordedEvent};
type Cursor = (u64, i64);
fn just_before(c: Cursor) -> Cursor {
(c.0, c.1 - 1)
}
fn cursor_of(e: &RecordedEvent) -> Cursor {
(e.transaction_id, e.global_position)
}
pub struct Subscription {
id: String,
store: EventStore,
db: PgPool,
poll_cursor: Cursor,
batch_size: i64,
session_start: Cursor,
last_delivered: HashMap<String, i64>,
held: BTreeMap<Cursor, RecordedEvent>,
last_batch_cursors: Vec<Cursor>,
}
impl Subscription {
pub async fn create(
store: EventStore,
db: PgPool,
id: impl Into<String>,
batch_size: i64,
) -> Result<Self, EventStoreError> {
let id = id.into();
sqlx::query(
"INSERT INTO es_subscriptions (subscription_id, last_position)
VALUES ($1, 0)
ON CONFLICT (subscription_id) DO NOTHING",
)
.bind(&id)
.execute(&db)
.await?;
let row = sqlx::query(
"SELECT last_position, last_transaction_id::text AS last_transaction_id
FROM es_subscriptions WHERE subscription_id = $1",
)
.bind(&id)
.fetch_one(&db)
.await?;
let last_position: i64 = row.get("last_position");
let last_transaction_id: String = row.get("last_transaction_id");
let cursor: Cursor = (last_transaction_id.parse().unwrap_or(0), last_position);
Ok(Self {
id,
store,
db,
poll_cursor: cursor,
batch_size,
session_start: cursor,
last_delivered: HashMap::new(),
held: BTreeMap::new(),
last_batch_cursors: Vec::new(),
})
}
pub async fn poll(&mut self) -> Result<Vec<RecordedEvent>, EventStoreError> {
let raw = self
.store
.read_all_after(self.poll_cursor.0, self.poll_cursor.1, self.batch_size)
.await?;
self.gate(raw).await
}
pub async fn poll_category(
&mut self,
category: &str,
) -> Result<Vec<RecordedEvent>, EventStoreError> {
let raw = self
.store
.read_category_after(
category,
self.poll_cursor.0,
self.poll_cursor.1,
self.batch_size,
)
.await?;
self.gate(raw).await
}
async fn gate(
&mut self,
raw: Vec<RecordedEvent>,
) -> Result<Vec<RecordedEvent>, EventStoreError> {
if let Some(last) = raw.last() {
self.poll_cursor = cursor_of(last);
}
let mut batch = raw;
batch.sort_by_key(|e| e.global_position);
self.init_unknown_streams(&batch).await?;
let mut out: Vec<RecordedEvent> = Vec::with_capacity(batch.len());
for event in batch {
let last = *self
.last_delivered
.get(&event.stream_id)
.expect("init_unknown_streams populated every batch stream");
if event.stream_version == last + 1 {
self.deliver(event, &mut out);
} else if event.stream_version > last + 1 {
self.held.insert(cursor_of(&event), event);
} else {
return Err(EventStoreError::OrderingViolation {
subscription_id: self.id.clone(),
stream_id: event.stream_id.clone(),
version: event.stream_version,
last_delivered: last,
});
}
}
self.last_batch_cursors = out.iter().map(cursor_of).collect();
Ok(out)
}
fn deliver(&mut self, event: RecordedEvent, out: &mut Vec<RecordedEvent>) {
let stream = event.stream_id.clone();
let mut version = event.stream_version;
out.push(event);
loop {
let next = self
.held
.iter()
.find(|(_, e)| e.stream_id == stream && e.stream_version == version + 1)
.map(|(k, _)| *k);
let Some(key) = next else {
break;
};
let e = self.held.remove(&key).expect("key just found");
version = e.stream_version;
out.push(e);
}
self.last_delivered.insert(stream, version);
}
async fn init_unknown_streams(
&mut self,
batch: &[RecordedEvent],
) -> Result<(), EventStoreError> {
let mut unknown: Vec<&str> = batch
.iter()
.map(|e| e.stream_id.as_str())
.filter(|s| !self.last_delivered.contains_key(*s))
.collect();
unknown.sort_unstable();
unknown.dedup();
if unknown.is_empty() {
return Ok(());
}
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT stream_id, max(stream_version)
FROM es_events
WHERE stream_id = ANY($1)
AND ((transaction_id = $2::text::xid8 AND global_position <= $3)
OR transaction_id < $2::text::xid8)
GROUP BY stream_id",
)
.bind(unknown.iter().map(|s| s.to_string()).collect::<Vec<_>>())
.bind(self.session_start.0.to_string())
.bind(self.session_start.1)
.fetch_all(&self.db)
.await?;
for s in &unknown {
self.last_delivered.insert((*s).to_string(), 0);
}
for (stream_id, max_version) in rows {
self.last_delivered.insert(stream_id, max_version);
}
Ok(())
}
pub fn checkpoint_cursor_after(&self, handled: usize) -> Cursor {
let mut safe = self.poll_cursor;
if let Some((first_held, _)) = self.held.iter().next() {
safe = safe.min(just_before(*first_held));
}
if let Some(min_rest) = self.last_batch_cursors.get(handled..).and_then(|rest| {
rest.iter().min().copied()
}) {
safe = safe.min(just_before(min_rest));
}
safe
}
pub async fn checkpoint(&self) -> Result<(), EventStoreError> {
let (txid, pos) = self.checkpoint_cursor_after(self.last_batch_cursors.len());
sqlx::query(
"UPDATE es_subscriptions
SET last_position = $1, last_transaction_id = $2::text::xid8, updated_at = now()
WHERE subscription_id = $3 AND fence_token = 0",
)
.bind(pos)
.bind(txid.to_string())
.bind(&self.id)
.execute(&self.db)
.await?;
Ok(())
}
pub async fn reset(&mut self) -> Result<(), EventStoreError> {
self.poll_cursor = (0, 0);
self.session_start = (0, 0);
self.last_delivered.clear();
self.held.clear();
self.last_batch_cursors.clear();
sqlx::query(
"UPDATE es_subscriptions
SET last_position = 0, last_transaction_id = '0'::xid8, updated_at = now()
WHERE subscription_id = $1 AND fence_token = 0",
)
.bind(&self.id)
.execute(&self.db)
.await?;
Ok(())
}
pub fn position(&self) -> i64 {
self.poll_cursor.1
}
pub fn held_count(&self) -> usize {
self.held.len()
}
}