1use std::{
5 ops::Bound,
6 sync::{
7 Arc,
8 atomic::{AtomicBool, Ordering},
9 },
10 thread::{self, JoinHandle},
11 time::Duration,
12};
13
14use reifydb_core::{
15 CommitVersion, EncodedKey, Result,
16 interface::{
17 Cdc, CdcChange, CdcConsumerId, CdcQueryTransaction, CommandTransaction, Engine as EngineInterface, Key,
18 KeyKind, MultiVersionCommandTransaction,
19 },
20 key::{CdcConsumerKey, EncodableKey},
21};
22use reifydb_engine::StandardEngine;
23use reifydb_sub_api::Priority;
24use tracing::{debug, error};
25
26use crate::{CdcCheckpoint, CdcConsume, CdcConsumer};
27
28#[derive(Debug, Clone)]
30pub struct PollConsumerConfig {
31 pub consumer_id: CdcConsumerId,
33 pub poll_interval: Duration,
35 pub priority: Priority,
37 pub max_batch_size: Option<u64>,
39}
40
41impl PollConsumerConfig {
42 pub fn new(consumer_id: CdcConsumerId, poll_interval: Duration, max_batch_size: Option<u64>) -> Self {
43 Self {
44 consumer_id,
45 poll_interval,
46 priority: Priority::Normal,
47 max_batch_size,
48 }
49 }
50
51 pub fn with_priority(mut self, priority: Priority) -> Self {
52 self.priority = priority;
53 self
54 }
55}
56
57pub struct PollConsumer<F: CdcConsume> {
58 engine: Option<StandardEngine>,
59 consumer: Option<Box<F>>,
60 config: PollConsumerConfig,
61 state: Arc<ConsumerState>,
62 worker: Option<JoinHandle<()>>,
63}
64
65struct ConsumerState {
66 consumer_key: EncodedKey,
67 running: AtomicBool,
68}
69
70impl<C: CdcConsume> PollConsumer<C> {
71 pub fn new(config: PollConsumerConfig, engine: StandardEngine, consume: C) -> Self {
72 let consumer_key = CdcConsumerKey {
73 consumer: config.consumer_id.clone(),
74 }
75 .encode();
76
77 Self {
78 engine: Some(engine),
79 consumer: Some(Box::new(consume)),
80 config,
81 state: Arc::new(ConsumerState {
82 consumer_key,
83 running: AtomicBool::new(false),
84 }),
85 worker: None,
86 }
87 }
88
89 fn consume_batch(
90 state: &ConsumerState,
91 engine: &StandardEngine,
92 consumer: &C,
93 max_batch_size: Option<u64>,
94 ) -> Result<Option<(CommitVersion, u64)>> {
95 let current_version = engine.current_version()?;
98 let target_version = CommitVersion(current_version.0.saturating_sub(1));
99
100 let _ = engine.try_wait_for_watermark(target_version, Duration::from_millis(50));
102
103 let safe_version = engine.done_until();
105
106 let mut transaction = engine.begin_command()?;
107
108 let checkpoint = CdcCheckpoint::fetch(&mut transaction, &state.consumer_key)?;
109
110 if safe_version <= checkpoint {
112 transaction.rollback()?;
113 return Ok(None);
114 }
115
116 let transactions = fetch_cdcs_until(&mut transaction, checkpoint, safe_version, max_batch_size)?;
118 if transactions.is_empty() {
119 transaction.rollback()?;
120 return Ok(None);
121 }
122
123 let latest_version = transactions.iter().map(|tx| tx.version).max().unwrap_or(checkpoint);
124
125 let relevant_transactions = transactions
128 .into_iter()
129 .filter(|tx| {
130 tx.changes.iter().any(|change| match &change.change {
131 CdcChange::Insert {
132 key,
133 ..
134 }
135 | CdcChange::Update {
136 key,
137 ..
138 }
139 | CdcChange::Delete {
140 key,
141 ..
142 } => {
143 if let Some(kind) = Key::kind(key) {
144 matches!(
145 kind,
146 KeyKind::Row
147 | KeyKind::Flow | KeyKind::FlowNode
148 | KeyKind::FlowNodeByFlow | KeyKind::FlowEdge
149 | KeyKind::FlowEdgeByFlow | KeyKind::NamespaceFlow
150 )
151 } else {
152 false
153 }
154 }
155 })
156 })
157 .collect::<Vec<_>>();
158
159 if !relevant_transactions.is_empty() {
160 consumer.consume(&mut transaction, relevant_transactions)?;
161 }
162
163 CdcCheckpoint::persist(&mut transaction, &state.consumer_key, latest_version)?;
164 let current_version = transaction.commit()?;
165
166 let lag = current_version.0.saturating_sub(latest_version.0);
167
168 Ok(Some((latest_version, lag)))
169 }
170
171 fn polling_loop(
172 config: &PollConsumerConfig,
173 engine: StandardEngine,
174 consumer: Box<C>,
175 state: Arc<ConsumerState>,
176 ) {
177 debug!("[Consumer {:?}] Started polling with interval {:?}", config.consumer_id, config.poll_interval);
178
179 while state.running.load(Ordering::Acquire) {
180 match Self::consume_batch(&state, &engine, &consumer, config.max_batch_size) {
181 Ok(Some((processed_version, lag))) => {
182 debug!("processed {} with lag {}", processed_version, lag);
183 thread::sleep(config.poll_interval);
184 }
185 Ok(None) => {
186 thread::sleep(config.poll_interval);
188 }
189 Err(error) => {
190 error!("[Consumer {:?}] Error consuming events: {}", config.consumer_id, error);
191 thread::sleep(config.poll_interval);
193 }
194 }
195 }
196
197 debug!("[Consumer {:?}] Stopped", config.consumer_id);
198 }
199}
200
201impl<F: CdcConsume> CdcConsumer for PollConsumer<F> {
202 fn start(&mut self) -> Result<()> {
203 assert!(self.worker.is_none(), "start() can only be called once");
204
205 if self.state.running.swap(true, Ordering::AcqRel) {
206 return Ok(());
207 }
208
209 let engine = self.engine.take().expect("engine already consumed");
210
211 let consumer = self.consumer.take().expect("consumer already consumed");
212
213 let state = Arc::clone(&self.state);
214 let config = self.config.clone();
215
216 self.worker = Some(thread::spawn(move || {
217 Self::polling_loop(&config, engine, consumer, state);
218 }));
219
220 Ok(())
221 }
222
223 fn stop(&mut self) -> Result<()> {
224 if !self.state.running.swap(false, Ordering::AcqRel) {
225 return Ok(());
226 }
227
228 if let Some(worker) = self.worker.take() {
229 worker.join().expect("Failed to join consumer thread");
230 }
231
232 Ok(())
233 }
234
235 fn is_running(&self) -> bool {
236 self.state.running.load(Ordering::Acquire)
237 }
238}
239
240fn fetch_cdcs_until(
241 txn: &mut impl CommandTransaction,
242 since_version: CommitVersion,
243 until_version: CommitVersion,
244 max_batch_size: Option<u64>,
245) -> Result<Vec<Cdc>> {
246 let upper_bound = match max_batch_size {
247 Some(size) => {
248 let batch_limit = CommitVersion(since_version.0.saturating_add(size));
249 Bound::Included(batch_limit.min(until_version))
250 }
251 None => Bound::Included(until_version),
252 };
253 txn.with_cdc_query(|cdc| Ok(cdc.range(Bound::Excluded(since_version), upper_bound)?.collect::<Vec<_>>()))
254}