1use anyhow::{anyhow, Result};
16use bytes::Bytes;
17use chrono::Utc;
18use log::{debug, error, info, trace, warn};
19use std::collections::HashMap;
20use std::sync::atomic::Ordering;
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::{oneshot, RwLock};
24use tokio::time::{interval, sleep};
25
26use super::connection::{parse_lsn, ReplicationConnection};
27use super::decoder::PgOutputDecoder;
28use super::protocol::BackendMessage;
29use super::types::{StandbyStatusUpdate, WalMessage};
30use super::{PostgresSourceConfig, ReplayState};
31use drasi_core::models::{Element, ElementMetadata, ElementReference, SourceChange};
32use drasi_lib::channels::{ComponentStatus, SourceEvent, SourceEventWrapper};
33use drasi_lib::component_graph::ComponentStatusHandle;
34use drasi_lib::sources::base::SourceBase;
35
36pub struct ReplicationStream {
37 config: PostgresSourceConfig,
38 source_id: String,
39 connection: Option<ReplicationConnection>,
40 decoder: PgOutputDecoder,
41 #[allow(dead_code)]
42 status_handle: ComponentStatusHandle,
43 base: SourceBase,
44 replay_state: Arc<ReplayState>,
45 read_lsn: u64,
46 start_lsn: Option<u64>,
47 last_feedback_time: std::time::Instant,
48 pending_transaction: Option<Vec<(SourceChange, u64)>>,
49 relations: HashMap<u32, RelationMapping>,
50 table_primary_keys: Arc<RwLock<HashMap<String, Vec<String>>>>,
51}
52
53struct RelationMapping {
54 #[allow(dead_code)]
55 table_name: String,
56 #[allow(dead_code)]
57 schema_name: String,
58 label: String,
59}
60
61impl ReplicationStream {
62 pub(crate) fn new(
63 config: PostgresSourceConfig,
64 source_id: String,
65 status_handle: ComponentStatusHandle,
66 base: SourceBase,
67 replay_state: Arc<ReplayState>,
68 start_lsn: Option<u64>,
69 ) -> Self {
70 Self {
71 config,
72 source_id,
73 connection: None,
74 decoder: PgOutputDecoder::new(),
75 status_handle,
76 base,
77 replay_state,
78 read_lsn: 0,
79 start_lsn,
80 last_feedback_time: std::time::Instant::now(),
81 pending_transaction: None,
82 relations: HashMap::new(),
83 table_primary_keys: Arc::new(RwLock::new(HashMap::new())),
84 }
85 }
86
87 pub async fn run(
92 &mut self,
93 ready_tx: Option<oneshot::Sender<std::result::Result<(), String>>>,
94 ) -> Result<()> {
95 info!("Starting replication stream for source {}", self.source_id);
96
97 if let Err(error) = self.connect_and_setup().await {
99 if let Some(tx) = ready_tx {
100 let _ = tx.send(Err(format!("{error:#}")));
101 }
102 return Err(error);
103 }
104 if let Some(tx) = ready_tx {
105 let _ = tx.send(Ok(()));
106 }
107
108 let mut keepalive_interval = interval(Duration::from_secs(10));
110
111 loop {
112 {
114 let status = self.status_handle.get_status().await;
115 if status == ComponentStatus::Stopping || status == ComponentStatus::Stopped {
116 info!("Received stop signal, shutting down replication");
117 break;
118 }
119 }
120
121 tokio::select! {
122 result = self.read_next_message() => {
124 match result {
125 Ok(Some(msg)) => {
126 if let Err(e) = self.handle_message(msg).await {
127 error!("Error handling message: {e}");
128 if let Err(e) = self.recover_connection().await {
130 error!("Failed to recover connection: {e}");
131 return Err(e);
132 }
133 }
134 }
135 Ok(None) => {
136 }
138 Err(e) => {
139 error!("Error reading message: {e}");
140 if let Err(e) = self.recover_connection().await {
142 error!("Failed to recover connection: {e}");
143 return Err(e);
144 }
145 }
146 }
147 }
148
149 _ = keepalive_interval.tick() => {
151 if let Err(e) = self.send_feedback(false).await {
152 warn!("Failed to send keepalive: {e}");
153 }
154 }
155 }
156 }
157
158 self.shutdown().await?;
160 Ok(())
161 }
162
163 async fn connect_and_setup(&mut self) -> Result<()> {
164 info!("Connecting to PostgreSQL for replication");
165
166 let mut conn = ReplicationConnection::connect(
168 &self.config.host,
169 self.config.port,
170 &self.config.database,
171 &self.config.user,
172 &self.config.password,
173 )
174 .await?;
175
176 let system_info = conn.identify_system().await?;
178 info!("Connected to PostgreSQL system: {system_info:?}");
179
180 let slot_info = conn
182 .create_replication_slot(&self.config.slot_name, false)
183 .await?;
184 info!("Using replication slot: {slot_info:?}");
185
186 let slot_lsn =
189 if !slot_info.consistent_point.is_empty() && slot_info.consistent_point != "0/0" {
190 parse_lsn(&slot_info.consistent_point)?
191 } else {
192 0
193 };
194 self.read_lsn = self.start_lsn.unwrap_or(slot_lsn);
195 self.replay_state
196 .read_lsn
197 .store(self.read_lsn, Ordering::Release);
198
199 let mut options = HashMap::new();
201 options.insert("proto_version".to_string(), "1".to_string());
202 options.insert(
203 "publication_names".to_string(),
204 self.config.publication_name.clone(),
205 );
206
207 conn.start_replication(&self.config.slot_name, Some(self.read_lsn), options)
209 .await?;
210
211 self.connection = Some(conn);
212 info!(
213 "Replication started from read LSN {:x} (slot watermark {:x})",
214 self.read_lsn, slot_lsn
215 );
216
217 Ok(())
218 }
219
220 async fn read_next_message(&mut self) -> Result<Option<BackendMessage>> {
221 if let Some(conn) = &mut self.connection {
222 match tokio::time::timeout(Duration::from_millis(100), conn.read_replication_message())
224 .await
225 {
226 Ok(Ok(msg)) => Ok(Some(msg)),
227 Ok(Err(e)) => Err(e),
228 Err(_) => Ok(None), }
230 } else {
231 Err(anyhow!("No connection available"))
232 }
233 }
234
235 async fn handle_message(&mut self, msg: BackendMessage) -> Result<()> {
236 match msg {
237 BackendMessage::CopyData(data) => {
238 self.handle_copy_data(&data).await?;
239 }
240 BackendMessage::PrimaryKeepaliveMessage {
241 wal_end,
242 timestamp: _,
243 reply,
244 } => {
245 self.read_lsn = wal_end;
246 self.replay_state
247 .read_lsn
248 .store(self.read_lsn, Ordering::Release);
249 if reply == 1 {
250 self.send_feedback(true).await?;
251 }
252 }
253 BackendMessage::ErrorResponse(err) => {
254 error!("Server error: {}", err.message);
255 return Err(anyhow!("Server error: {}", err.message));
256 }
257 _ => {
258 trace!("Ignoring message: {msg:?}");
259 }
260 }
261 Ok(())
262 }
263
264 async fn handle_copy_data(&mut self, data: &[u8]) -> Result<()> {
265 if data.is_empty() {
266 return Ok(());
267 }
268
269 let msg_type = data[0];
271
272 match msg_type {
273 b'w' => {
274 self.handle_xlog_data(&data[1..]).await?;
276 }
277 b'k' => {
278 self.handle_keepalive(&data[1..]).await?;
280 }
281 _ => {
282 warn!("Unknown copy data message type: 0x{msg_type:02x}");
283 }
284 }
285
286 Ok(())
287 }
288
289 async fn handle_xlog_data(&mut self, data: &[u8]) -> Result<()> {
290 if data.len() < 24 {
291 return Err(anyhow!("XLogData message too short: {} bytes", data.len()));
292 }
293
294 let _start_lsn = u64::from_be_bytes([
296 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
297 ]);
298 let end_lsn = u64::from_be_bytes([
299 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
300 ]);
301 let _timestamp = i64::from_be_bytes([
302 data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
303 ]);
304
305 self.read_lsn = end_lsn;
307 self.replay_state
308 .read_lsn
309 .store(self.read_lsn, Ordering::Release);
310
311 let wal_data = &data[24..];
313
314 if !wal_data.is_empty() {
316 let msg_type = wal_data[0];
317 debug!(
318 "Attempting to decode WAL message type: {} ({}), data length: {}",
319 msg_type as char,
320 msg_type,
321 wal_data.len()
322 );
323 }
324
325 match self.decoder.decode_message(wal_data) {
326 Ok(Some(wal_msg)) => {
327 self.process_wal_message(wal_msg).await?;
328 }
329 Ok(None) => {
330 }
332 Err(e) => {
333 if !wal_data.is_empty() {
335 debug!(
336 "Failed to decode WAL message type {} ({}): {}, data length: {}",
337 wal_data[0] as char,
338 wal_data[0],
339 e,
340 wal_data.len()
341 );
342 }
343 }
345 }
346
347 if self.last_feedback_time.elapsed() > Duration::from_secs(5) {
349 self.send_feedback(false).await?;
350 }
351
352 Ok(())
353 }
354
355 async fn handle_keepalive(&mut self, data: &[u8]) -> Result<()> {
356 if data.len() < 17 {
357 return Err(anyhow!("Keepalive message too short"));
358 }
359
360 let wal_end = u64::from_be_bytes([
361 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
362 ]);
363 let reply = data[16];
364
365 self.read_lsn = wal_end;
366 self.replay_state
367 .read_lsn
368 .store(self.read_lsn, Ordering::Release);
369
370 if reply == 1 {
371 self.send_feedback(true).await?;
372 }
373
374 Ok(())
375 }
376
377 async fn process_wal_message(&mut self, msg: WalMessage) -> Result<()> {
378 match msg {
379 WalMessage::Begin(_) => {
380 self.pending_transaction = Some(Vec::new());
382 }
383 WalMessage::Commit(tx_info) => {
384 if let Some(changes) = self.pending_transaction.take() {
387 for (change, _) in changes {
388 self.dispatch_change(change, tx_info.commit_lsn).await;
389 }
390 debug!(
391 "Committed transaction {} with LSN {:x}",
392 tx_info.xid, tx_info.commit_lsn
393 );
394 }
395 }
396 WalMessage::Relation(relation) => {
397 let label = relation.name.clone();
400 self.relations.insert(
401 relation.id,
402 RelationMapping {
403 table_name: relation.name.clone(),
404 schema_name: relation.namespace.clone(),
405 label,
406 },
407 );
408
409 }
412 WalMessage::Insert { relation_id, tuple } => {
413 if let Some(change) = self.convert_insert(relation_id, tuple).await? {
414 if let Some(tx) = &mut self.pending_transaction {
415 tx.push((change, self.read_lsn));
416 } else {
417 self.dispatch_change(change, self.read_lsn).await;
418 }
419 }
420 }
421 WalMessage::Update {
422 relation_id,
423 old_tuple,
424 new_tuple,
425 } => {
426 if let Some(change) = self
427 .convert_update(relation_id, old_tuple, new_tuple)
428 .await?
429 {
430 if let Some(tx) = &mut self.pending_transaction {
431 tx.push((change, self.read_lsn));
432 } else {
433 self.dispatch_change(change, self.read_lsn).await;
434 }
435 }
436 }
437 WalMessage::Delete {
438 relation_id,
439 old_tuple,
440 } => {
441 if let Some(change) = self.convert_delete(relation_id, old_tuple).await? {
442 if let Some(tx) = &mut self.pending_transaction {
443 tx.push((change, self.read_lsn));
444 } else {
445 self.dispatch_change(change, self.read_lsn).await;
446 }
447 }
448 }
449 WalMessage::Truncate { relation_ids } => {
450 warn!("Truncate not yet implemented for relations: {relation_ids:?}");
451 }
452 }
453 Ok(())
454 }
455
456 async fn convert_insert(
457 &self,
458 relation_id: u32,
459 tuple: Vec<super::types::PostgresValue>,
460 ) -> Result<Option<SourceChange>> {
461 let relation = self
463 .decoder
464 .get_relation(relation_id)
465 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
466
467 let mapping = self
468 .relations
469 .get(&relation_id)
470 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
471
472 let mut properties = drasi_core::models::ElementPropertyMap::new();
474 for (i, value) in tuple.iter().enumerate() {
475 if let Some(column) = relation.columns.get(i) {
476 let json_value = value.to_json();
477 if !json_value.is_null() {
478 properties.insert(
479 &column.name,
480 drasi_lib::sources::manager::convert_json_to_element_value(&json_value),
481 );
482 }
483 }
484 }
485
486 let element_id = self.generate_element_id(relation, &tuple).await?;
488
489 let element = Element::Node {
491 metadata: ElementMetadata {
492 reference: ElementReference::new(&self.source_id, &element_id),
493 labels: Arc::from([Arc::from(mapping.label.as_str())]),
494 effective_from: Utc::now().timestamp_millis() as u64,
495 },
496 properties,
497 };
498
499 Ok(Some(SourceChange::Insert { element }))
500 }
501
502 async fn convert_update(
503 &self,
504 relation_id: u32,
505 old_tuple: Option<Vec<super::types::PostgresValue>>,
506 new_tuple: Vec<super::types::PostgresValue>,
507 ) -> Result<Option<SourceChange>> {
508 let relation = self
509 .decoder
510 .get_relation(relation_id)
511 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
512
513 let mapping = self
514 .relations
515 .get(&relation_id)
516 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
517
518 let element_id = self.generate_element_id(relation, &new_tuple).await?;
520
521 if old_tuple.is_none() {
522 warn!("UPDATE without old tuple for relation {relation_id}, preserving UPDATE");
523 }
524
525 let mut after_properties = drasi_core::models::ElementPropertyMap::new();
528
529 for (i, column) in relation.columns.iter().enumerate() {
531 if let Some(value) = new_tuple.get(i) {
532 let json_value = value.to_json();
533 if !json_value.is_null() {
534 after_properties.insert(
535 &column.name,
536 drasi_lib::sources::manager::convert_json_to_element_value(&json_value),
537 );
538 }
539 }
540 }
541
542 let after_element = Element::Node {
543 metadata: ElementMetadata {
544 reference: ElementReference::new(&self.source_id, &element_id),
545 labels: Arc::from([Arc::from(mapping.label.as_str())]),
546 effective_from: Utc::now().timestamp_millis() as u64,
547 },
548 properties: after_properties,
549 };
550
551 Ok(Some(SourceChange::Update {
552 element: after_element,
553 }))
554 }
555
556 async fn convert_delete(
557 &self,
558 relation_id: u32,
559 old_tuple: Vec<super::types::PostgresValue>,
560 ) -> Result<Option<SourceChange>> {
561 let relation = self
562 .decoder
563 .get_relation(relation_id)
564 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
565
566 let mapping = self
567 .relations
568 .get(&relation_id)
569 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
570
571 let element_id = self.generate_element_id(relation, &old_tuple).await?;
572
573 Ok(Some(SourceChange::Delete {
574 metadata: ElementMetadata {
575 reference: ElementReference::new(&self.source_id, &element_id),
576 labels: Arc::from([Arc::from(mapping.label.as_str())]),
577 effective_from: Utc::now().timestamp_millis() as u64,
578 },
579 }))
580 }
581
582 async fn generate_element_id(
594 &self,
595 relation: &super::types::RelationInfo,
596 tuple: &[super::types::PostgresValue],
597 ) -> Result<String> {
598 let table_name = if relation.namespace == "public" {
600 relation.name.clone()
601 } else {
602 format!("{}.{}", relation.namespace, relation.name)
603 };
604
605 let primary_keys = self.table_primary_keys.read().await;
607 let pk_columns = primary_keys.get(&table_name);
608
609 let configured_keys = self
611 .config
612 .table_keys
613 .iter()
614 .find(|tk| tk.table == table_name)
615 .map(|tk| &tk.key_columns);
616
617 let key_columns = configured_keys.or(pk_columns);
619
620 if let Some(keys) = key_columns {
621 let mut key_parts = Vec::new();
622
623 for (i, column) in relation.columns.iter().enumerate() {
624 if keys.contains(&column.name) {
625 if let Some(value) = tuple.get(i) {
626 let json_val = value.to_json();
627 if !json_val.is_null() {
628 let val_str = json_val.to_string();
630 let cleaned = val_str.trim_matches('"');
631 key_parts.push(cleaned.to_string());
632 }
633 }
634 }
635 }
636
637 if !key_parts.is_empty() {
638 return Ok(format!("{}:{}", table_name, key_parts.join("_")));
640 }
641 }
642
643 warn!("No primary key value found for table '{table_name}'. Consider adding 'table_keys' configuration.");
645 Ok(format!("{}:{}", table_name, uuid::Uuid::new_v4()))
647 }
648
649 async fn send_feedback(&mut self, reply_requested: bool) -> Result<()> {
650 if let Some(conn) = &mut self.connection {
651 let confirmed_lsn = match self.base.compute_confirmed_source_position().await {
658 Some(bytes) if bytes.len() == 8 => {
659 let arr: [u8; 8] = bytes[..8].try_into().expect("length already checked");
660 u64::from_be_bytes(arr)
661 }
662 Some(bytes) => {
663 warn!(
664 "[{}] Confirmed source position has unexpected length {} (expected 8); \
665 not advancing flush_lsn",
666 self.source_id,
667 bytes.len()
668 );
669 0
670 }
671 None => 0, };
673
674 let fence = self.replay_state.effective_flush_fence();
678 let (effective_lsn, was_clamped) = if fence < u64::MAX && confirmed_lsn > fence {
679 (fence, true)
680 } else {
681 (confirmed_lsn, false)
682 };
683
684 let status = StandbyStatusUpdate {
685 write_lsn: self.read_lsn,
686 flush_lsn: effective_lsn,
687 apply_lsn: effective_lsn,
688 reply_requested,
689 };
690
691 conn.send_standby_status(status).await?;
692 self.last_feedback_time = std::time::Instant::now();
693
694 if !was_clamped && effective_lsn > 0 {
700 if let Some(confirmed_seq) = self.base.compute_confirmed_position().await {
701 self.base.prune_position_map(confirmed_seq).await;
702 }
703 }
704
705 trace!(
706 "[{}] Sent feedback: write_lsn={:x}, flush_lsn={:x}{}",
707 self.source_id,
708 self.read_lsn,
709 effective_lsn,
710 if was_clamped { " (fenced)" } else { "" }
711 );
712 }
713
714 Ok(())
715 }
716
717 async fn dispatch_change(&self, change: SourceChange, lsn: u64) {
722 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
723 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
724
725 let mut wrapper = SourceEventWrapper::with_profiling(
726 self.source_id.clone(),
727 SourceEvent::Change(change),
728 chrono::Utc::now(),
729 profiling,
730 );
731
732 wrapper.set_source_position(Bytes::from(lsn.to_be_bytes().to_vec()));
734
735 if let Err(e) = self.base.dispatch_event(wrapper).await {
737 debug!(
738 "[{}] Failed to dispatch change (no subscribers): {}",
739 self.source_id, e
740 );
741 }
742 }
743
744 #[allow(dead_code)]
745 async fn check_stop_signal(&self) -> bool {
746 let status = self.status_handle.get_status().await;
747 status == ComponentStatus::Stopping || status == ComponentStatus::Stopped
748 }
749
750 async fn recover_connection(&mut self) -> Result<()> {
751 warn!("Attempting to recover connection");
752
753 if let Some(conn) = self.connection.take() {
755 let _ = conn.close().await;
756 }
757
758 sleep(Duration::from_secs(5)).await;
760
761 self.connect_and_setup().await?;
763
764 info!("Connection recovered successfully");
765 Ok(())
766 }
767
768 async fn shutdown(&mut self) -> Result<()> {
769 info!("Shutting down replication stream");
770
771 let _ = self.send_feedback(false).await;
773
774 if let Some(conn) = self.connection.take() {
776 conn.close().await?;
777 }
778
779 Ok(())
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use chrono::Utc;
786 use drasi_core::models::validate_effective_from;
787
788 #[test]
791 fn effective_from_uses_milliseconds() {
792 let effective_from = Utc::now().timestamp_millis() as u64;
793 assert!(
794 validate_effective_from(effective_from).is_ok(),
795 "Postgres CDC effective_from ({effective_from}) should be in millisecond range"
796 );
797 }
798
799 #[test]
801 fn effective_from_rejects_nanoseconds_pattern() {
802 let bad_effective_from = Utc::now().timestamp_nanos_opt().unwrap() as u64;
803 assert!(
804 validate_effective_from(bad_effective_from).is_err(),
805 "Nanosecond timestamp ({bad_effective_from}) should be rejected"
806 );
807 }
808}