1use anyhow::{anyhow, Result};
16use chrono::Utc;
17use log::{debug, error, info, trace, warn};
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::RwLock;
22use tokio::time::{interval, sleep};
23
24use super::connection::ReplicationConnection;
25use super::decoder::PgOutputDecoder;
26use super::protocol::BackendMessage;
27use super::types::{StandbyStatusUpdate, WalMessage};
28use super::PostgresSourceConfig;
29use drasi_core::models::{Element, ElementMetadata, ElementReference, SourceChange};
30use drasi_lib::channels::{ComponentStatus, SourceEvent, SourceEventWrapper};
31use drasi_lib::component_graph::ComponentStatusHandle;
32use drasi_lib::sources::base::SourceBase;
33
34pub struct ReplicationStream {
35 config: PostgresSourceConfig,
36 source_id: String,
37 connection: Option<ReplicationConnection>,
38 decoder: PgOutputDecoder,
39 dispatchers: Arc<
40 RwLock<
41 Vec<Box<dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync>>,
42 >,
43 >,
44 #[allow(dead_code)]
45 status_handle: ComponentStatusHandle,
46 current_lsn: u64,
47 last_feedback_time: std::time::Instant,
48 pending_transaction: Option<Vec<SourceChange>>,
49 relations: HashMap<u32, RelationMapping>,
50 table_primary_keys: Arc<RwLock<HashMap<String, Vec<String>>>>,
51}
52
53struct RelationMapping {
54 #[allow(dead_code)]
55 table_name: String,
56 #[allow(dead_code)]
57 schema_name: String,
58 label: String,
59}
60
61impl ReplicationStream {
62 pub fn new(
63 config: PostgresSourceConfig,
64 source_id: String,
65 dispatchers: Arc<
66 RwLock<
67 Vec<
68 Box<
69 dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync,
70 >,
71 >,
72 >,
73 >,
74 status_handle: ComponentStatusHandle,
75 ) -> Self {
76 Self {
77 config,
78 source_id,
79 connection: None,
80 decoder: PgOutputDecoder::new(),
81 dispatchers,
82 status_handle,
83 current_lsn: 0,
84 last_feedback_time: std::time::Instant::now(),
85 pending_transaction: None,
86 relations: HashMap::new(),
87 table_primary_keys: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90
91 pub async fn run(&mut self) -> Result<()> {
96 info!("Starting replication stream for source {}", self.source_id);
97
98 self.connect_and_setup().await?;
100
101 let mut keepalive_interval = interval(Duration::from_secs(10));
103
104 loop {
105 {
107 let status = self.status_handle.get_status().await;
108 if status == ComponentStatus::Stopping || status == ComponentStatus::Stopped {
109 info!("Received stop signal, shutting down replication");
110 break;
111 }
112 }
113
114 tokio::select! {
115 result = self.read_next_message() => {
117 match result {
118 Ok(Some(msg)) => {
119 if let Err(e) = self.handle_message(msg).await {
120 error!("Error handling message: {e}");
121 if let Err(e) = self.recover_connection().await {
123 error!("Failed to recover connection: {e}");
124 return Err(e);
125 }
126 }
127 }
128 Ok(None) => {
129 }
131 Err(e) => {
132 error!("Error reading message: {e}");
133 if let Err(e) = self.recover_connection().await {
135 error!("Failed to recover connection: {e}");
136 return Err(e);
137 }
138 }
139 }
140 }
141
142 _ = keepalive_interval.tick() => {
144 if let Err(e) = self.send_feedback(false).await {
145 warn!("Failed to send keepalive: {e}");
146 }
147 }
148 }
149 }
150
151 self.shutdown().await?;
153 Ok(())
154 }
155
156 async fn connect_and_setup(&mut self) -> Result<()> {
157 info!("Connecting to PostgreSQL for replication");
158
159 let mut conn = ReplicationConnection::connect(
161 &self.config.host,
162 self.config.port,
163 &self.config.database,
164 &self.config.user,
165 &self.config.password,
166 )
167 .await?;
168
169 let system_info = conn.identify_system().await?;
171 info!("Connected to PostgreSQL system: {system_info:?}");
172
173 let slot_info = conn
175 .create_replication_slot(&self.config.slot_name, false)
176 .await?;
177 info!("Using replication slot: {slot_info:?}");
178
179 if !slot_info.consistent_point.is_empty() && slot_info.consistent_point != "0/0" {
181 self.current_lsn = parse_lsn(&slot_info.consistent_point)?;
182 } else {
183 self.current_lsn = 0;
185 }
186
187 let mut options = HashMap::new();
189 options.insert("proto_version".to_string(), "1".to_string());
190 options.insert(
191 "publication_names".to_string(),
192 self.config.publication_name.clone(),
193 );
194
195 conn.start_replication(&self.config.slot_name, Some(self.current_lsn), options)
197 .await?;
198
199 self.connection = Some(conn);
200 info!("Replication started from LSN: {:x}", self.current_lsn);
201
202 Ok(())
203 }
204
205 async fn read_next_message(&mut self) -> Result<Option<BackendMessage>> {
206 if let Some(conn) = &mut self.connection {
207 match tokio::time::timeout(Duration::from_millis(100), conn.read_replication_message())
209 .await
210 {
211 Ok(Ok(msg)) => Ok(Some(msg)),
212 Ok(Err(e)) => Err(e),
213 Err(_) => Ok(None), }
215 } else {
216 Err(anyhow!("No connection available"))
217 }
218 }
219
220 async fn handle_message(&mut self, msg: BackendMessage) -> Result<()> {
221 match msg {
222 BackendMessage::CopyData(data) => {
223 self.handle_copy_data(&data).await?;
224 }
225 BackendMessage::PrimaryKeepaliveMessage {
226 wal_end,
227 timestamp: _,
228 reply,
229 } => {
230 self.current_lsn = wal_end;
231 if reply == 1 {
232 self.send_feedback(true).await?;
233 }
234 }
235 BackendMessage::ErrorResponse(err) => {
236 error!("Server error: {}", err.message);
237 return Err(anyhow!("Server error: {}", err.message));
238 }
239 _ => {
240 trace!("Ignoring message: {msg:?}");
241 }
242 }
243 Ok(())
244 }
245
246 async fn handle_copy_data(&mut self, data: &[u8]) -> Result<()> {
247 if data.is_empty() {
248 return Ok(());
249 }
250
251 let msg_type = data[0];
253
254 match msg_type {
255 b'w' => {
256 self.handle_xlog_data(&data[1..]).await?;
258 }
259 b'k' => {
260 self.handle_keepalive(&data[1..]).await?;
262 }
263 _ => {
264 warn!("Unknown copy data message type: 0x{msg_type:02x}");
265 }
266 }
267
268 Ok(())
269 }
270
271 async fn handle_xlog_data(&mut self, data: &[u8]) -> Result<()> {
272 if data.len() < 24 {
273 return Err(anyhow!("XLogData message too short: {} bytes", data.len()));
274 }
275
276 let _start_lsn = u64::from_be_bytes([
278 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
279 ]);
280 let end_lsn = u64::from_be_bytes([
281 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
282 ]);
283 let _timestamp = i64::from_be_bytes([
284 data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
285 ]);
286
287 self.current_lsn = end_lsn;
289
290 let wal_data = &data[24..];
292
293 if !wal_data.is_empty() {
295 let msg_type = wal_data[0];
296 debug!(
297 "Attempting to decode WAL message type: {} ({}), data length: {}",
298 msg_type as char,
299 msg_type,
300 wal_data.len()
301 );
302 }
303
304 match self.decoder.decode_message(wal_data) {
305 Ok(Some(wal_msg)) => {
306 self.process_wal_message(wal_msg).await?;
307 }
308 Ok(None) => {
309 }
311 Err(e) => {
312 if !wal_data.is_empty() {
314 debug!(
315 "Failed to decode WAL message type {} ({}): {}, data length: {}",
316 wal_data[0] as char,
317 wal_data[0],
318 e,
319 wal_data.len()
320 );
321 }
322 }
324 }
325
326 if self.last_feedback_time.elapsed() > Duration::from_secs(5) {
328 self.send_feedback(false).await?;
329 }
330
331 Ok(())
332 }
333
334 async fn handle_keepalive(&mut self, data: &[u8]) -> Result<()> {
335 if data.len() < 17 {
336 return Err(anyhow!("Keepalive message too short"));
337 }
338
339 let wal_end = u64::from_be_bytes([
340 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
341 ]);
342 let reply = data[16];
343
344 self.current_lsn = wal_end;
345
346 if reply == 1 {
347 self.send_feedback(true).await?;
348 }
349
350 Ok(())
351 }
352
353 async fn process_wal_message(&mut self, msg: WalMessage) -> Result<()> {
354 match msg {
355 WalMessage::Begin(_) => {
356 self.pending_transaction = Some(Vec::new());
358 }
359 WalMessage::Commit(tx_info) => {
360 if let Some(changes) = self.pending_transaction.take() {
362 for change in changes {
364 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
366 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
367
368 let wrapper = SourceEventWrapper::with_profiling(
369 self.source_id.clone(),
370 SourceEvent::Change(change),
371 chrono::Utc::now(),
372 profiling,
373 );
374
375 if let Err(e) = SourceBase::dispatch_from_task(
377 self.dispatchers.clone(),
378 wrapper.clone(),
379 &self.source_id,
380 )
381 .await
382 {
383 debug!(
384 "[{}] Failed to dispatch change (no subscribers): {}",
385 self.source_id, e
386 );
387 }
388 }
389 debug!(
390 "Committed transaction {} with LSN {:x}",
391 tx_info.xid, tx_info.commit_lsn
392 );
393 }
394 }
395 WalMessage::Relation(relation) => {
396 let label = relation.name.clone();
399 self.relations.insert(
400 relation.id,
401 RelationMapping {
402 table_name: relation.name.clone(),
403 schema_name: relation.namespace.clone(),
404 label,
405 },
406 );
407
408 }
411 WalMessage::Insert { relation_id, tuple } => {
412 if let Some(change) = self.convert_insert(relation_id, tuple).await? {
413 if let Some(tx) = &mut self.pending_transaction {
414 tx.push(change);
415 } else {
416 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
418 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
419
420 let wrapper = SourceEventWrapper::with_profiling(
421 self.source_id.clone(),
422 SourceEvent::Change(change),
423 chrono::Utc::now(),
424 profiling,
425 );
426
427 if let Err(e) = SourceBase::dispatch_from_task(
429 self.dispatchers.clone(),
430 wrapper.clone(),
431 &self.source_id,
432 )
433 .await
434 {
435 debug!(
436 "[{}] Failed to dispatch change (no subscribers): {}",
437 self.source_id, e
438 );
439 }
440 }
441 }
442 }
443 WalMessage::Update {
444 relation_id,
445 old_tuple,
446 new_tuple,
447 } => {
448 if let Some(change) = self
449 .convert_update(relation_id, old_tuple, new_tuple)
450 .await?
451 {
452 if let Some(tx) = &mut self.pending_transaction {
453 tx.push(change);
454 } else {
455 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
456 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
457
458 let wrapper = SourceEventWrapper::with_profiling(
459 self.source_id.clone(),
460 SourceEvent::Change(change),
461 chrono::Utc::now(),
462 profiling,
463 );
464
465 if let Err(e) = SourceBase::dispatch_from_task(
467 self.dispatchers.clone(),
468 wrapper.clone(),
469 &self.source_id,
470 )
471 .await
472 {
473 debug!(
474 "[{}] Failed to dispatch change (no subscribers): {}",
475 self.source_id, e
476 );
477 }
478 }
479 }
480 }
481 WalMessage::Delete {
482 relation_id,
483 old_tuple,
484 } => {
485 if let Some(change) = self.convert_delete(relation_id, old_tuple).await? {
486 if let Some(tx) = &mut self.pending_transaction {
487 tx.push(change);
488 } else {
489 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
490 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
491
492 let wrapper = SourceEventWrapper::with_profiling(
493 self.source_id.clone(),
494 SourceEvent::Change(change),
495 chrono::Utc::now(),
496 profiling,
497 );
498
499 if let Err(e) = SourceBase::dispatch_from_task(
501 self.dispatchers.clone(),
502 wrapper.clone(),
503 &self.source_id,
504 )
505 .await
506 {
507 debug!(
508 "[{}] Failed to dispatch change (no subscribers): {}",
509 self.source_id, e
510 );
511 }
512 }
513 }
514 }
515 WalMessage::Truncate { relation_ids } => {
516 warn!("Truncate not yet implemented for relations: {relation_ids:?}");
517 }
518 }
519 Ok(())
520 }
521
522 async fn convert_insert(
523 &self,
524 relation_id: u32,
525 tuple: Vec<super::types::PostgresValue>,
526 ) -> Result<Option<SourceChange>> {
527 let relation = self
529 .decoder
530 .get_relation(relation_id)
531 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
532
533 let mapping = self
534 .relations
535 .get(&relation_id)
536 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
537
538 let mut properties = drasi_core::models::ElementPropertyMap::new();
540 for (i, value) in tuple.iter().enumerate() {
541 if let Some(column) = relation.columns.get(i) {
542 let json_value = value.to_json();
543 if !json_value.is_null() {
544 properties.insert(
545 &column.name,
546 drasi_lib::sources::manager::convert_json_to_element_value(&json_value),
547 );
548 }
549 }
550 }
551
552 let element_id = self.generate_element_id(relation, &tuple).await?;
554
555 let element = Element::Node {
557 metadata: ElementMetadata {
558 reference: ElementReference::new(&self.source_id, &element_id),
559 labels: Arc::from([Arc::from(mapping.label.as_str())]),
560 effective_from: Utc::now().timestamp_millis() as u64,
561 },
562 properties,
563 };
564
565 Ok(Some(SourceChange::Insert { element }))
566 }
567
568 async fn convert_update(
569 &self,
570 relation_id: u32,
571 old_tuple: Option<Vec<super::types::PostgresValue>>,
572 new_tuple: Vec<super::types::PostgresValue>,
573 ) -> Result<Option<SourceChange>> {
574 let relation = self
575 .decoder
576 .get_relation(relation_id)
577 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
578
579 let mapping = self
580 .relations
581 .get(&relation_id)
582 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
583
584 let element_id = self.generate_element_id(relation, &new_tuple).await?;
586
587 if old_tuple.is_none() {
588 warn!("UPDATE without old tuple for relation {relation_id}, preserving UPDATE");
589 }
590
591 let mut after_properties = drasi_core::models::ElementPropertyMap::new();
594
595 for (i, column) in relation.columns.iter().enumerate() {
597 if let Some(value) = new_tuple.get(i) {
598 let json_value = value.to_json();
599 if !json_value.is_null() {
600 after_properties.insert(
601 &column.name,
602 drasi_lib::sources::manager::convert_json_to_element_value(&json_value),
603 );
604 }
605 }
606 }
607
608 let after_element = Element::Node {
609 metadata: ElementMetadata {
610 reference: ElementReference::new(&self.source_id, &element_id),
611 labels: Arc::from([Arc::from(mapping.label.as_str())]),
612 effective_from: Utc::now().timestamp_millis() as u64,
613 },
614 properties: after_properties,
615 };
616
617 Ok(Some(SourceChange::Update {
618 element: after_element,
619 }))
620 }
621
622 async fn convert_delete(
623 &self,
624 relation_id: u32,
625 old_tuple: Vec<super::types::PostgresValue>,
626 ) -> Result<Option<SourceChange>> {
627 let relation = self
628 .decoder
629 .get_relation(relation_id)
630 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
631
632 let mapping = self
633 .relations
634 .get(&relation_id)
635 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
636
637 let element_id = self.generate_element_id(relation, &old_tuple).await?;
638
639 Ok(Some(SourceChange::Delete {
640 metadata: ElementMetadata {
641 reference: ElementReference::new(&self.source_id, &element_id),
642 labels: Arc::from([Arc::from(mapping.label.as_str())]),
643 effective_from: Utc::now().timestamp_millis() as u64,
644 },
645 }))
646 }
647
648 async fn generate_element_id(
660 &self,
661 relation: &super::types::RelationInfo,
662 tuple: &[super::types::PostgresValue],
663 ) -> Result<String> {
664 let table_name = if relation.namespace == "public" {
666 relation.name.clone()
667 } else {
668 format!("{}.{}", relation.namespace, relation.name)
669 };
670
671 let primary_keys = self.table_primary_keys.read().await;
673 let pk_columns = primary_keys.get(&table_name);
674
675 let configured_keys = self
677 .config
678 .table_keys
679 .iter()
680 .find(|tk| tk.table == table_name)
681 .map(|tk| &tk.key_columns);
682
683 let key_columns = configured_keys.or(pk_columns);
685
686 if let Some(keys) = key_columns {
687 let mut key_parts = Vec::new();
688
689 for (i, column) in relation.columns.iter().enumerate() {
690 if keys.contains(&column.name) {
691 if let Some(value) = tuple.get(i) {
692 let json_val = value.to_json();
693 if !json_val.is_null() {
694 let val_str = json_val.to_string();
696 let cleaned = val_str.trim_matches('"');
697 key_parts.push(cleaned.to_string());
698 }
699 }
700 }
701 }
702
703 if !key_parts.is_empty() {
704 return Ok(format!("{}:{}", table_name, key_parts.join("_")));
706 }
707 }
708
709 warn!("No primary key value found for table '{table_name}'. Consider adding 'table_keys' configuration.");
711 Ok(format!("{}:{}", table_name, uuid::Uuid::new_v4()))
713 }
714
715 async fn send_feedback(&mut self, reply_requested: bool) -> Result<()> {
716 if let Some(conn) = &mut self.connection {
717 let status = StandbyStatusUpdate {
718 write_lsn: self.current_lsn,
719 flush_lsn: self.current_lsn,
720 apply_lsn: self.current_lsn,
721 reply_requested,
722 };
723
724 conn.send_standby_status(status).await?;
725 self.last_feedback_time = std::time::Instant::now();
726 trace!("Sent feedback with LSN: {:x}", self.current_lsn);
727 }
728
729 Ok(())
730 }
731
732 #[allow(dead_code)]
733 async fn check_stop_signal(&self) -> bool {
734 let status = self.status_handle.get_status().await;
735 status == ComponentStatus::Stopping || status == ComponentStatus::Stopped
736 }
737
738 async fn recover_connection(&mut self) -> Result<()> {
739 warn!("Attempting to recover connection");
740
741 if let Some(conn) = self.connection.take() {
743 let _ = conn.close().await;
744 }
745
746 sleep(Duration::from_secs(5)).await;
748
749 self.connect_and_setup().await?;
751
752 info!("Connection recovered successfully");
753 Ok(())
754 }
755
756 async fn shutdown(&mut self) -> Result<()> {
757 info!("Shutting down replication stream");
758
759 let _ = self.send_feedback(false).await;
761
762 if let Some(conn) = self.connection.take() {
764 conn.close().await?;
765 }
766
767 Ok(())
768 }
769}
770
771fn parse_lsn(lsn_str: &str) -> Result<u64> {
772 let parts: Vec<&str> = lsn_str.split('/').collect();
773 if parts.len() != 2 {
774 return Err(anyhow!("Invalid LSN format: {lsn_str}"));
775 }
776
777 let high = u64::from_str_radix(parts[0], 16)?;
778 let low = u64::from_str_radix(parts[1], 16)?;
779
780 Ok((high << 32) | low)
781}
782
783#[cfg(test)]
784mod tests {
785 use chrono::Utc;
786 use drasi_core::models::validate_effective_from;
787
788 #[test]
791 fn effective_from_uses_milliseconds() {
792 let effective_from = Utc::now().timestamp_millis() as u64;
793 assert!(
794 validate_effective_from(effective_from).is_ok(),
795 "Postgres CDC effective_from ({effective_from}) should be in millisecond range"
796 );
797 }
798
799 #[test]
801 fn effective_from_rejects_nanoseconds_pattern() {
802 let bad_effective_from = Utc::now().timestamp_nanos_opt().unwrap() as u64;
803 assert!(
804 validate_effective_from(bad_effective_from).is_err(),
805 "Nanosecond timestamp ({bad_effective_from}) should be rejected"
806 );
807 }
808}