use std::{
ops::Bound,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
thread::{self, JoinHandle},
time::Duration,
};
use reifydb_core::{
CommitVersion, EncodedKey, Result,
interface::{
Cdc, CdcChange, CdcConsumerId, CdcQueryTransaction, CommandTransaction, Engine as EngineInterface, Key,
KeyKind, MultiVersionCommandTransaction,
},
key::{CdcConsumerKey, EncodableKey},
};
use reifydb_engine::StandardEngine;
use reifydb_sub_api::Priority;
use tracing::{debug, error};
use crate::{CdcCheckpoint, CdcConsume, CdcConsumer};
#[derive(Debug, Clone)]
pub struct PollConsumerConfig {
pub consumer_id: CdcConsumerId,
pub poll_interval: Duration,
pub priority: Priority,
pub max_batch_size: Option<u64>,
}
impl PollConsumerConfig {
pub fn new(consumer_id: CdcConsumerId, poll_interval: Duration, max_batch_size: Option<u64>) -> Self {
Self {
consumer_id,
poll_interval,
priority: Priority::Normal,
max_batch_size,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
}
pub struct PollConsumer<F: CdcConsume> {
engine: Option<StandardEngine>,
consumer: Option<Box<F>>,
config: PollConsumerConfig,
state: Arc<ConsumerState>,
worker: Option<JoinHandle<()>>,
}
struct ConsumerState {
consumer_key: EncodedKey,
running: AtomicBool,
}
impl<C: CdcConsume> PollConsumer<C> {
pub fn new(config: PollConsumerConfig, engine: StandardEngine, consume: C) -> Self {
let consumer_key = CdcConsumerKey {
consumer: config.consumer_id.clone(),
}
.encode();
Self {
engine: Some(engine),
consumer: Some(Box::new(consume)),
config,
state: Arc::new(ConsumerState {
consumer_key,
running: AtomicBool::new(false),
}),
worker: None,
}
}
fn consume_batch(
state: &ConsumerState,
engine: &StandardEngine,
consumer: &C,
max_batch_size: Option<u64>,
) -> Result<Option<(CommitVersion, u64)>> {
let current_version = engine.current_version()?;
let target_version = CommitVersion(current_version.0.saturating_sub(1));
let _ = engine.try_wait_for_watermark(target_version, Duration::from_millis(50));
let safe_version = engine.done_until();
let mut transaction = engine.begin_command()?;
let checkpoint = CdcCheckpoint::fetch(&mut transaction, &state.consumer_key)?;
if safe_version <= checkpoint {
transaction.rollback()?;
return Ok(None);
}
let transactions = fetch_cdcs_until(&mut transaction, checkpoint, safe_version, max_batch_size)?;
if transactions.is_empty() {
transaction.rollback()?;
return Ok(None);
}
let latest_version = transactions.iter().map(|tx| tx.version).max().unwrap_or(checkpoint);
let relevant_transactions = transactions
.into_iter()
.filter(|tx| {
tx.changes.iter().any(|change| match &change.change {
CdcChange::Insert {
key,
..
}
| CdcChange::Update {
key,
..
}
| CdcChange::Delete {
key,
..
} => {
if let Some(kind) = Key::kind(key) {
matches!(
kind,
KeyKind::Row
| KeyKind::Flow | KeyKind::FlowNode
| KeyKind::FlowNodeByFlow | KeyKind::FlowEdge
| KeyKind::FlowEdgeByFlow | KeyKind::NamespaceFlow
)
} else {
false
}
}
})
})
.collect::<Vec<_>>();
if !relevant_transactions.is_empty() {
consumer.consume(&mut transaction, relevant_transactions)?;
}
CdcCheckpoint::persist(&mut transaction, &state.consumer_key, latest_version)?;
let current_version = transaction.commit()?;
let lag = current_version.0.saturating_sub(latest_version.0);
Ok(Some((latest_version, lag)))
}
fn polling_loop(
config: &PollConsumerConfig,
engine: StandardEngine,
consumer: Box<C>,
state: Arc<ConsumerState>,
) {
debug!("[Consumer {:?}] Started polling with interval {:?}", config.consumer_id, config.poll_interval);
while state.running.load(Ordering::Acquire) {
match Self::consume_batch(&state, &engine, &consumer, config.max_batch_size) {
Ok(Some((processed_version, lag))) => {
debug!("processed {} with lag {}", processed_version, lag);
thread::sleep(config.poll_interval);
}
Ok(None) => {
thread::sleep(config.poll_interval);
}
Err(error) => {
error!("[Consumer {:?}] Error consuming events: {}", config.consumer_id, error);
thread::sleep(config.poll_interval);
}
}
}
debug!("[Consumer {:?}] Stopped", config.consumer_id);
}
}
impl<F: CdcConsume> CdcConsumer for PollConsumer<F> {
fn start(&mut self) -> Result<()> {
assert!(self.worker.is_none(), "start() can only be called once");
if self.state.running.swap(true, Ordering::AcqRel) {
return Ok(());
}
let engine = self.engine.take().expect("engine already consumed");
let consumer = self.consumer.take().expect("consumer already consumed");
let state = Arc::clone(&self.state);
let config = self.config.clone();
self.worker = Some(thread::spawn(move || {
Self::polling_loop(&config, engine, consumer, state);
}));
Ok(())
}
fn stop(&mut self) -> Result<()> {
if !self.state.running.swap(false, Ordering::AcqRel) {
return Ok(());
}
if let Some(worker) = self.worker.take() {
worker.join().expect("Failed to join consumer thread");
}
Ok(())
}
fn is_running(&self) -> bool {
self.state.running.load(Ordering::Acquire)
}
}
fn fetch_cdcs_until(
txn: &mut impl CommandTransaction,
since_version: CommitVersion,
until_version: CommitVersion,
max_batch_size: Option<u64>,
) -> Result<Vec<Cdc>> {
let upper_bound = match max_batch_size {
Some(size) => {
let batch_limit = CommitVersion(since_version.0.saturating_add(size));
Bound::Included(batch_limit.min(until_version))
}
None => Bound::Included(until_version),
};
txn.with_cdc_query(|cdc| Ok(cdc.range(Bound::Excluded(since_version), upper_bound)?.collect::<Vec<_>>()))
}