use std::sync::Arc;
use arrow::array::{ArrayRef, Int64Array, RecordBatch};
use arrow_schema::SchemaRef;
use async_stream::try_stream;
use chrono::DateTime;
use futures::StreamExt;
use crate::source::mutable::MutableTableRegistry;
use crate::store::mutable::definition::MutableTableId;
use crate::tenant::TenantId;
use crate::trigger::broker::TriggerBroker;
use crate::trigger::error::TriggerError;
use crate::trigger::ids::SubscriptionId;
use crate::trigger::offset::Offset;
use crate::trigger::predicate::Predicate;
use crate::trigger::subscription::{DeliveredBatch, Subscription};
use crate::trigger::topic::{TopicDefinition, OFFSET_COLUMN, PRODUCED_AT_COLUMN, ROW_INDEX_COLUMN};
pub struct Subscriber {
broker: Arc<dyn TriggerBroker>,
mutable: Arc<MutableTableRegistry>,
}
impl Subscriber {
pub fn new(broker: Arc<dyn TriggerBroker>, mutable: Arc<MutableTableRegistry>) -> Self {
Self { broker, mutable }
}
pub async fn subscribe(
&self,
topic: &TopicDefinition,
predicate: Predicate,
from_offset: Option<Offset>,
) -> Result<Subscription, TriggerError> {
let tenant = self.mutable.binding().current_tenant();
self.subscribe_scoped(topic, tenant, predicate, from_offset)
.await
}
pub async fn subscribe_scoped(
&self,
topic: &TopicDefinition,
tenant: Option<TenantId>,
predicate: Predicate,
from_offset: Option<Offset>,
) -> Result<Subscription, TriggerError> {
let replay_delivered = self
.drain_replay(topic, tenant, &predicate, from_offset)
.await?;
let last_replayed = replay_delivered.iter().map(|d| d.offset.value()).max();
let mut live = self
.broker
.subscribe(topic.id, predicate, from_offset)
.await?;
let stream = try_stream! {
let mut last_yielded = last_replayed;
for delivered in replay_delivered {
yield delivered;
}
while let Some(item) = live.next().await {
let delivered = item?;
if last_yielded.is_none_or(|seen| delivered.offset.value() > seen) {
last_yielded = Some(delivered.offset.value());
yield delivered;
}
}
};
Ok(Subscription::new(SubscriptionId::new(), Box::pin(stream)))
}
pub async fn replay_only(
&self,
topic: &TopicDefinition,
predicate: Predicate,
from_offset: Option<Offset>,
) -> Result<Vec<DeliveredBatch>, TriggerError> {
let tenant = self.mutable.binding().current_tenant();
self.drain_replay(topic, tenant, &predicate, from_offset)
.await
}
pub async fn replay_only_scoped(
&self,
topic: &TopicDefinition,
tenant: Option<TenantId>,
predicate: Predicate,
from_offset: Option<Offset>,
) -> Result<Vec<DeliveredBatch>, TriggerError> {
self.drain_replay(topic, tenant, &predicate, from_offset)
.await
}
async fn drain_replay(
&self,
topic: &TopicDefinition,
tenant: Option<TenantId>,
predicate: &Predicate,
from_offset: Option<Offset>,
) -> Result<Vec<DeliveredBatch>, TriggerError> {
let backing_id = MutableTableId::new(topic.backing_table_name())
.map_err(|e| TriggerError::Catalog(e.to_string()))?;
let user_schema = Arc::clone(&topic.schema);
let replay_batches = match from_offset {
Some(off) => {
let scan_after_value = (off.value() as i64).saturating_sub(1);
let mut stream = self
.mutable
.scan_after_for_tenant(&backing_id, scan_after_value, tenant)
.await
.map_err(TriggerError::BackingTable)?;
let mut batches: Vec<RecordBatch> = Vec::new();
while let Some(b) = stream.next().await {
batches.push(b.map_err(TriggerError::BackingTable)?);
}
batches
}
None => Vec::new(),
};
let replay_events = group_replay_batches(&replay_batches, &user_schema)?;
let mut delivered: Vec<DeliveredBatch> = Vec::with_capacity(replay_events.len());
for event in replay_events {
if let Some(filtered) = predicate.evaluate(&event.batch)? {
delivered.push(DeliveredBatch {
offset: event.offset,
produced_at: event.produced_at,
batch: filtered,
});
}
}
Ok(delivered)
}
}
struct ReplayEvent {
offset: Offset,
produced_at: chrono::DateTime<chrono::Utc>,
batch: RecordBatch,
}
fn group_replay_batches(
batches: &[RecordBatch],
user_schema: &SchemaRef,
) -> Result<Vec<ReplayEvent>, TriggerError> {
let mut events: Vec<ReplayEvent> = Vec::new();
let user_field_count = user_schema.fields().len();
for batch in batches {
let offset_idx = batch
.schema()
.index_of(OFFSET_COLUMN)
.map_err(|_| TriggerError::Catalog("backing table missing _offset".into()))?;
let row_idx_idx = batch
.schema()
.index_of(ROW_INDEX_COLUMN)
.map_err(|_| TriggerError::Catalog("backing table missing _row_idx".into()))?;
let produced_idx = batch
.schema()
.index_of(PRODUCED_AT_COLUMN)
.map_err(|_| TriggerError::Catalog("backing table missing _produced_at".into()))?;
let offsets = batch
.column(offset_idx)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| TriggerError::Catalog("_offset column must be Int64".into()))?;
let _row_indices = batch
.column(row_idx_idx)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| TriggerError::Catalog("_row_idx column must be Int64".into()))?;
let produced = batch
.column(produced_idx)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| TriggerError::Catalog("_produced_at column must be Int64".into()))?;
let mut user_indices: Vec<usize> = Vec::with_capacity(user_field_count);
for f in user_schema.fields() {
let i = batch.schema().index_of(f.name()).map_err(|_| {
TriggerError::Catalog(format!("backing table missing topic column '{}'", f.name()))
})?;
user_indices.push(i);
}
let mut start = 0usize;
while start < batch.num_rows() {
let off = offsets.value(start);
let mut end = start + 1;
while end < batch.num_rows() && offsets.value(end) == off {
end += 1;
}
let slice_len = end - start;
let produced_at_micros = produced.value(start);
let produced_at =
DateTime::from_timestamp_micros(produced_at_micros).ok_or_else(|| {
TriggerError::Catalog(format!(
"_produced_at out of range: {produced_at_micros}"
))
})?;
let columns: Vec<ArrayRef> = user_indices
.iter()
.map(|&i| batch.column(i).slice(start, slice_len))
.collect();
let event_batch = RecordBatch::try_new(Arc::clone(user_schema), columns)
.map_err(|e| TriggerError::Catalog(e.to_string()))?;
events.push(ReplayEvent {
offset: Offset::new(off as u64, produced_at),
produced_at,
batch: event_batch,
});
start = end;
}
}
Ok(events)
}