1use anyhow::{anyhow, Result};
16use chrono::Utc;
17use log::{debug, error, info, trace, warn};
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::RwLock;
22use tokio::time::{interval, sleep};
23
24use super::connection::ReplicationConnection;
25use super::decoder::PgOutputDecoder;
26use super::protocol::BackendMessage;
27use super::types::{StandbyStatusUpdate, WalMessage};
28use super::PostgresSourceConfig;
29use drasi_core::models::{Element, ElementMetadata, ElementReference, SourceChange};
30use drasi_lib::channels::{ComponentEventSender, ComponentStatus, SourceEvent, SourceEventWrapper};
31use drasi_lib::sources::base::SourceBase;
32
33pub struct ReplicationStream {
34 config: PostgresSourceConfig,
35 source_id: String,
36 connection: Option<ReplicationConnection>,
37 decoder: PgOutputDecoder,
38 dispatchers: Arc<
39 RwLock<
40 Vec<Box<dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync>>,
41 >,
42 >,
43 #[allow(dead_code)]
44 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
45 status: Arc<RwLock<ComponentStatus>>,
46 current_lsn: u64,
47 last_feedback_time: std::time::Instant,
48 pending_transaction: Option<Vec<SourceChange>>,
49 relations: HashMap<u32, RelationMapping>,
50 table_primary_keys: Arc<RwLock<HashMap<String, Vec<String>>>>,
51}
52
53struct RelationMapping {
54 #[allow(dead_code)]
55 table_name: String,
56 #[allow(dead_code)]
57 schema_name: String,
58 label: String,
59}
60
61impl ReplicationStream {
62 pub fn new(
63 config: PostgresSourceConfig,
64 source_id: String,
65 dispatchers: Arc<
66 RwLock<
67 Vec<
68 Box<
69 dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync,
70 >,
71 >,
72 >,
73 >,
74 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
75 status: Arc<RwLock<ComponentStatus>>,
76 ) -> Self {
77 Self {
78 config,
79 source_id,
80 connection: None,
81 decoder: PgOutputDecoder::new(),
82 dispatchers,
83 status_tx,
84 status,
85 current_lsn: 0,
86 last_feedback_time: std::time::Instant::now(),
87 pending_transaction: None,
88 relations: HashMap::new(),
89 table_primary_keys: Arc::new(RwLock::new(HashMap::new())),
90 }
91 }
92
93 pub async fn run(&mut self) -> Result<()> {
98 info!("Starting replication stream for source {}", self.source_id);
99
100 self.connect_and_setup().await?;
102
103 let mut keepalive_interval = interval(Duration::from_secs(10));
105
106 loop {
107 {
109 let status = self.status.read().await;
110 if *status == ComponentStatus::Stopping || *status == ComponentStatus::Stopped {
111 info!("Received stop signal, shutting down replication");
112 break;
113 }
114 }
115
116 tokio::select! {
117 result = self.read_next_message() => {
119 match result {
120 Ok(Some(msg)) => {
121 if let Err(e) = self.handle_message(msg).await {
122 error!("Error handling message: {e}");
123 if let Err(e) = self.recover_connection().await {
125 error!("Failed to recover connection: {e}");
126 return Err(e);
127 }
128 }
129 }
130 Ok(None) => {
131 }
133 Err(e) => {
134 error!("Error reading message: {e}");
135 if let Err(e) = self.recover_connection().await {
137 error!("Failed to recover connection: {e}");
138 return Err(e);
139 }
140 }
141 }
142 }
143
144 _ = keepalive_interval.tick() => {
146 if let Err(e) = self.send_feedback(false).await {
147 warn!("Failed to send keepalive: {e}");
148 }
149 }
150 }
151 }
152
153 self.shutdown().await?;
155 Ok(())
156 }
157
158 async fn connect_and_setup(&mut self) -> Result<()> {
159 info!("Connecting to PostgreSQL for replication");
160
161 let mut conn = ReplicationConnection::connect(
163 &self.config.host,
164 self.config.port,
165 &self.config.database,
166 &self.config.user,
167 &self.config.password,
168 )
169 .await?;
170
171 let system_info = conn.identify_system().await?;
173 info!("Connected to PostgreSQL system: {system_info:?}");
174
175 let slot_info = conn
177 .create_replication_slot(&self.config.slot_name, false)
178 .await?;
179 info!("Using replication slot: {slot_info:?}");
180
181 if !slot_info.consistent_point.is_empty() && slot_info.consistent_point != "0/0" {
183 self.current_lsn = parse_lsn(&slot_info.consistent_point)?;
184 } else {
185 self.current_lsn = 0;
187 }
188
189 let mut options = HashMap::new();
191 options.insert("proto_version".to_string(), "1".to_string());
192 options.insert(
193 "publication_names".to_string(),
194 self.config.publication_name.clone(),
195 );
196
197 conn.start_replication(&self.config.slot_name, Some(self.current_lsn), options)
199 .await?;
200
201 self.connection = Some(conn);
202 info!("Replication started from LSN: {:x}", self.current_lsn);
203
204 Ok(())
205 }
206
207 async fn read_next_message(&mut self) -> Result<Option<BackendMessage>> {
208 if let Some(conn) = &mut self.connection {
209 match tokio::time::timeout(Duration::from_millis(100), conn.read_replication_message())
211 .await
212 {
213 Ok(Ok(msg)) => Ok(Some(msg)),
214 Ok(Err(e)) => Err(e),
215 Err(_) => Ok(None), }
217 } else {
218 Err(anyhow!("No connection available"))
219 }
220 }
221
222 async fn handle_message(&mut self, msg: BackendMessage) -> Result<()> {
223 match msg {
224 BackendMessage::CopyData(data) => {
225 self.handle_copy_data(&data).await?;
226 }
227 BackendMessage::PrimaryKeepaliveMessage {
228 wal_end,
229 timestamp: _,
230 reply,
231 } => {
232 self.current_lsn = wal_end;
233 if reply == 1 {
234 self.send_feedback(true).await?;
235 }
236 }
237 BackendMessage::ErrorResponse(err) => {
238 error!("Server error: {}", err.message);
239 return Err(anyhow!("Server error: {}", err.message));
240 }
241 _ => {
242 trace!("Ignoring message: {msg:?}");
243 }
244 }
245 Ok(())
246 }
247
248 async fn handle_copy_data(&mut self, data: &[u8]) -> Result<()> {
249 if data.is_empty() {
250 return Ok(());
251 }
252
253 let msg_type = data[0];
255
256 match msg_type {
257 b'w' => {
258 self.handle_xlog_data(&data[1..]).await?;
260 }
261 b'k' => {
262 self.handle_keepalive(&data[1..]).await?;
264 }
265 _ => {
266 warn!("Unknown copy data message type: 0x{msg_type:02x}");
267 }
268 }
269
270 Ok(())
271 }
272
273 async fn handle_xlog_data(&mut self, data: &[u8]) -> Result<()> {
274 if data.len() < 24 {
275 return Err(anyhow!("XLogData message too short: {} bytes", data.len()));
276 }
277
278 let _start_lsn = u64::from_be_bytes([
280 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
281 ]);
282 let end_lsn = u64::from_be_bytes([
283 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
284 ]);
285 let _timestamp = i64::from_be_bytes([
286 data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
287 ]);
288
289 self.current_lsn = end_lsn;
291
292 let wal_data = &data[24..];
294
295 if !wal_data.is_empty() {
297 let msg_type = wal_data[0];
298 debug!(
299 "Attempting to decode WAL message type: {} ({}), data length: {}",
300 msg_type as char,
301 msg_type,
302 wal_data.len()
303 );
304 }
305
306 match self.decoder.decode_message(wal_data) {
307 Ok(Some(wal_msg)) => {
308 self.process_wal_message(wal_msg).await?;
309 }
310 Ok(None) => {
311 }
313 Err(e) => {
314 if !wal_data.is_empty() {
316 debug!(
317 "Failed to decode WAL message type {} ({}): {}, data length: {}",
318 wal_data[0] as char,
319 wal_data[0],
320 e,
321 wal_data.len()
322 );
323 }
324 }
326 }
327
328 if self.last_feedback_time.elapsed() > Duration::from_secs(5) {
330 self.send_feedback(false).await?;
331 }
332
333 Ok(())
334 }
335
336 async fn handle_keepalive(&mut self, data: &[u8]) -> Result<()> {
337 if data.len() < 17 {
338 return Err(anyhow!("Keepalive message too short"));
339 }
340
341 let wal_end = u64::from_be_bytes([
342 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
343 ]);
344 let reply = data[16];
345
346 self.current_lsn = wal_end;
347
348 if reply == 1 {
349 self.send_feedback(true).await?;
350 }
351
352 Ok(())
353 }
354
355 async fn process_wal_message(&mut self, msg: WalMessage) -> Result<()> {
356 match msg {
357 WalMessage::Begin(_) => {
358 self.pending_transaction = Some(Vec::new());
360 }
361 WalMessage::Commit(tx_info) => {
362 if let Some(changes) = self.pending_transaction.take() {
364 for change in changes {
366 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
368 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
369
370 let wrapper = SourceEventWrapper::with_profiling(
371 self.source_id.clone(),
372 SourceEvent::Change(change),
373 chrono::Utc::now(),
374 profiling,
375 );
376
377 if let Err(e) = SourceBase::dispatch_from_task(
379 self.dispatchers.clone(),
380 wrapper.clone(),
381 &self.source_id,
382 )
383 .await
384 {
385 debug!(
386 "[{}] Failed to dispatch change (no subscribers): {}",
387 self.source_id, e
388 );
389 }
390 }
391 debug!(
392 "Committed transaction {} with LSN {:x}",
393 tx_info.xid, tx_info.commit_lsn
394 );
395 }
396 }
397 WalMessage::Relation(relation) => {
398 let label = relation.name.clone();
401 self.relations.insert(
402 relation.id,
403 RelationMapping {
404 table_name: relation.name.clone(),
405 schema_name: relation.namespace.clone(),
406 label,
407 },
408 );
409
410 }
413 WalMessage::Insert { relation_id, tuple } => {
414 if let Some(change) = self.convert_insert(relation_id, tuple).await? {
415 if let Some(tx) = &mut self.pending_transaction {
416 tx.push(change);
417 } else {
418 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
420 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
421
422 let wrapper = SourceEventWrapper::with_profiling(
423 self.source_id.clone(),
424 SourceEvent::Change(change),
425 chrono::Utc::now(),
426 profiling,
427 );
428
429 if let Err(e) = SourceBase::dispatch_from_task(
431 self.dispatchers.clone(),
432 wrapper.clone(),
433 &self.source_id,
434 )
435 .await
436 {
437 debug!(
438 "[{}] Failed to dispatch change (no subscribers): {}",
439 self.source_id, e
440 );
441 }
442 }
443 }
444 }
445 WalMessage::Update {
446 relation_id,
447 old_tuple,
448 new_tuple,
449 } => {
450 if let Some(change) = self
451 .convert_update(relation_id, old_tuple, new_tuple)
452 .await?
453 {
454 if let Some(tx) = &mut self.pending_transaction {
455 tx.push(change);
456 } else {
457 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
458 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
459
460 let wrapper = SourceEventWrapper::with_profiling(
461 self.source_id.clone(),
462 SourceEvent::Change(change),
463 chrono::Utc::now(),
464 profiling,
465 );
466
467 if let Err(e) = SourceBase::dispatch_from_task(
469 self.dispatchers.clone(),
470 wrapper.clone(),
471 &self.source_id,
472 )
473 .await
474 {
475 debug!(
476 "[{}] Failed to dispatch change (no subscribers): {}",
477 self.source_id, e
478 );
479 }
480 }
481 }
482 }
483 WalMessage::Delete {
484 relation_id,
485 old_tuple,
486 } => {
487 if let Some(change) = self.convert_delete(relation_id, old_tuple).await? {
488 if let Some(tx) = &mut self.pending_transaction {
489 tx.push(change);
490 } else {
491 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
492 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
493
494 let wrapper = SourceEventWrapper::with_profiling(
495 self.source_id.clone(),
496 SourceEvent::Change(change),
497 chrono::Utc::now(),
498 profiling,
499 );
500
501 if let Err(e) = SourceBase::dispatch_from_task(
503 self.dispatchers.clone(),
504 wrapper.clone(),
505 &self.source_id,
506 )
507 .await
508 {
509 debug!(
510 "[{}] Failed to dispatch change (no subscribers): {}",
511 self.source_id, e
512 );
513 }
514 }
515 }
516 }
517 WalMessage::Truncate { relation_ids } => {
518 warn!("Truncate not yet implemented for relations: {relation_ids:?}");
519 }
520 }
521 Ok(())
522 }
523
524 async fn convert_insert(
525 &self,
526 relation_id: u32,
527 tuple: Vec<super::types::PostgresValue>,
528 ) -> Result<Option<SourceChange>> {
529 let relation = self
531 .decoder
532 .get_relation(relation_id)
533 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
534
535 let mapping = self
536 .relations
537 .get(&relation_id)
538 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
539
540 let mut properties = drasi_core::models::ElementPropertyMap::new();
542 for (i, value) in tuple.iter().enumerate() {
543 if let Some(column) = relation.columns.get(i) {
544 let json_value = value.to_json();
545 if !json_value.is_null() {
546 properties.insert(
547 &column.name,
548 drasi_lib::sources::manager::convert_json_to_element_value(&json_value)?,
549 );
550 }
551 }
552 }
553
554 let element_id = self.generate_element_id(relation, &tuple).await?;
556
557 let element = Element::Node {
559 metadata: ElementMetadata {
560 reference: ElementReference::new(&self.source_id, &element_id),
561 labels: Arc::from([Arc::from(mapping.label.as_str())]),
562 effective_from: Utc::now().timestamp_millis() as u64,
563 },
564 properties,
565 };
566
567 Ok(Some(SourceChange::Insert { element }))
568 }
569
570 async fn convert_update(
571 &self,
572 relation_id: u32,
573 old_tuple: Option<Vec<super::types::PostgresValue>>,
574 new_tuple: Vec<super::types::PostgresValue>,
575 ) -> Result<Option<SourceChange>> {
576 let relation = self
577 .decoder
578 .get_relation(relation_id)
579 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
580
581 let mapping = self
582 .relations
583 .get(&relation_id)
584 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
585
586 let element_id = self.generate_element_id(relation, &new_tuple).await?;
588
589 if old_tuple.is_none() {
590 warn!("UPDATE without old tuple for relation {relation_id}, preserving UPDATE");
591 }
592
593 let mut after_properties = drasi_core::models::ElementPropertyMap::new();
596
597 for (i, column) in relation.columns.iter().enumerate() {
599 if let Some(value) = new_tuple.get(i) {
600 let json_value = value.to_json();
601 if !json_value.is_null() {
602 after_properties.insert(
603 &column.name,
604 drasi_lib::sources::manager::convert_json_to_element_value(&json_value)?,
605 );
606 }
607 }
608 }
609
610 let after_element = Element::Node {
611 metadata: ElementMetadata {
612 reference: ElementReference::new(&self.source_id, &element_id),
613 labels: Arc::from([Arc::from(mapping.label.as_str())]),
614 effective_from: Utc::now().timestamp_millis() as u64,
615 },
616 properties: after_properties,
617 };
618
619 Ok(Some(SourceChange::Update {
620 element: after_element,
621 }))
622 }
623
624 async fn convert_delete(
625 &self,
626 relation_id: u32,
627 old_tuple: Vec<super::types::PostgresValue>,
628 ) -> Result<Option<SourceChange>> {
629 let relation = self
630 .decoder
631 .get_relation(relation_id)
632 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
633
634 let mapping = self
635 .relations
636 .get(&relation_id)
637 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
638
639 let element_id = self.generate_element_id(relation, &old_tuple).await?;
640
641 Ok(Some(SourceChange::Delete {
642 metadata: ElementMetadata {
643 reference: ElementReference::new(&self.source_id, &element_id),
644 labels: Arc::from([Arc::from(mapping.label.as_str())]),
645 effective_from: Utc::now().timestamp_millis() as u64,
646 },
647 }))
648 }
649
650 async fn generate_element_id(
662 &self,
663 relation: &super::types::RelationInfo,
664 tuple: &[super::types::PostgresValue],
665 ) -> Result<String> {
666 let table_name = if relation.namespace == "public" {
668 relation.name.clone()
669 } else {
670 format!("{}.{}", relation.namespace, relation.name)
671 };
672
673 let primary_keys = self.table_primary_keys.read().await;
675 let pk_columns = primary_keys.get(&table_name);
676
677 let configured_keys = self
679 .config
680 .table_keys
681 .iter()
682 .find(|tk| tk.table == table_name)
683 .map(|tk| &tk.key_columns);
684
685 let key_columns = configured_keys.or(pk_columns);
687
688 if let Some(keys) = key_columns {
689 let mut key_parts = Vec::new();
690
691 for (i, column) in relation.columns.iter().enumerate() {
692 if keys.contains(&column.name) {
693 if let Some(value) = tuple.get(i) {
694 let json_val = value.to_json();
695 if !json_val.is_null() {
696 let val_str = json_val.to_string();
698 let cleaned = val_str.trim_matches('"');
699 key_parts.push(cleaned.to_string());
700 }
701 }
702 }
703 }
704
705 if !key_parts.is_empty() {
706 return Ok(format!("{}:{}", table_name, key_parts.join("_")));
708 }
709 }
710
711 warn!("No primary key value found for table '{table_name}'. Consider adding 'table_keys' configuration.");
713 Ok(format!("{}:{}", table_name, uuid::Uuid::new_v4()))
715 }
716
717 async fn send_feedback(&mut self, reply_requested: bool) -> Result<()> {
718 if let Some(conn) = &mut self.connection {
719 let status = StandbyStatusUpdate {
720 write_lsn: self.current_lsn,
721 flush_lsn: self.current_lsn,
722 apply_lsn: self.current_lsn,
723 reply_requested,
724 };
725
726 conn.send_standby_status(status).await?;
727 self.last_feedback_time = std::time::Instant::now();
728 trace!("Sent feedback with LSN: {:x}", self.current_lsn);
729 }
730
731 Ok(())
732 }
733
734 #[allow(dead_code)]
735 async fn check_stop_signal(&self) -> bool {
736 let status = self.status.read().await;
737 *status == ComponentStatus::Stopping || *status == ComponentStatus::Stopped
738 }
739
740 async fn recover_connection(&mut self) -> Result<()> {
741 warn!("Attempting to recover connection");
742
743 if let Some(conn) = self.connection.take() {
745 let _ = conn.close().await;
746 }
747
748 sleep(Duration::from_secs(5)).await;
750
751 self.connect_and_setup().await?;
753
754 info!("Connection recovered successfully");
755 Ok(())
756 }
757
758 async fn shutdown(&mut self) -> Result<()> {
759 info!("Shutting down replication stream");
760
761 let _ = self.send_feedback(false).await;
763
764 if let Some(conn) = self.connection.take() {
766 conn.close().await?;
767 }
768
769 Ok(())
770 }
771}
772
773fn parse_lsn(lsn_str: &str) -> Result<u64> {
774 let parts: Vec<&str> = lsn_str.split('/').collect();
775 if parts.len() != 2 {
776 return Err(anyhow!("Invalid LSN format: {lsn_str}"));
777 }
778
779 let high = u64::from_str_radix(parts[0], 16)?;
780 let low = u64::from_str_radix(parts[1], 16)?;
781
782 Ok((high << 32) | low)
783}
784
785#[cfg(test)]
786mod tests {
787 use chrono::Utc;
788 use drasi_core::models::validate_effective_from;
789
790 #[test]
793 fn effective_from_uses_milliseconds() {
794 let effective_from = Utc::now().timestamp_millis() as u64;
795 assert!(
796 validate_effective_from(effective_from).is_ok(),
797 "Postgres CDC effective_from ({effective_from}) should be in millisecond range"
798 );
799 }
800
801 #[test]
803 fn effective_from_rejects_nanoseconds_pattern() {
804 let bad_effective_from = Utc::now().timestamp_nanos_opt().unwrap() as u64;
805 assert!(
806 validate_effective_from(bad_effective_from).is_err(),
807 "Nanosecond timestamp ({bad_effective_from}) should be rejected"
808 );
809 }
810}