1use anyhow::{anyhow, Result};
16use chrono::Utc;
17use log::{debug, error, info, trace, warn};
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::RwLock;
22use tokio::time::{interval, sleep};
23
24use super::connection::ReplicationConnection;
25use super::decoder::PgOutputDecoder;
26use super::protocol::BackendMessage;
27use super::types::{StandbyStatusUpdate, WalMessage};
28use super::PostgresSourceConfig;
29use drasi_core::models::{Element, ElementMetadata, ElementReference, SourceChange};
30use drasi_lib::channels::{ComponentEventSender, ComponentStatus, SourceEvent, SourceEventWrapper};
31use drasi_lib::sources::base::SourceBase;
32
33pub struct ReplicationStream {
34 config: PostgresSourceConfig,
35 source_id: String,
36 connection: Option<ReplicationConnection>,
37 decoder: PgOutputDecoder,
38 dispatchers: Arc<
39 RwLock<
40 Vec<Box<dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync>>,
41 >,
42 >,
43 #[allow(dead_code)]
44 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
45 status: Arc<RwLock<ComponentStatus>>,
46 current_lsn: u64,
47 last_feedback_time: std::time::Instant,
48 pending_transaction: Option<Vec<SourceChange>>,
49 relations: HashMap<u32, RelationMapping>,
50 table_primary_keys: Arc<RwLock<HashMap<String, Vec<String>>>>,
51}
52
53struct RelationMapping {
54 #[allow(dead_code)]
55 table_name: String,
56 #[allow(dead_code)]
57 schema_name: String,
58 label: String,
59}
60
61impl ReplicationStream {
62 pub fn new(
63 config: PostgresSourceConfig,
64 source_id: String,
65 dispatchers: Arc<
66 RwLock<
67 Vec<
68 Box<
69 dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync,
70 >,
71 >,
72 >,
73 >,
74 status_tx: Arc<RwLock<Option<ComponentEventSender>>>,
75 status: Arc<RwLock<ComponentStatus>>,
76 ) -> Self {
77 Self {
78 config,
79 source_id,
80 connection: None,
81 decoder: PgOutputDecoder::new(),
82 dispatchers,
83 status_tx,
84 status,
85 current_lsn: 0,
86 last_feedback_time: std::time::Instant::now(),
87 pending_transaction: None,
88 relations: HashMap::new(),
89 table_primary_keys: Arc::new(RwLock::new(HashMap::new())),
90 }
91 }
92
93 pub async fn run(&mut self) -> Result<()> {
98 info!("Starting replication stream for source {}", self.source_id);
99
100 self.connect_and_setup().await?;
102
103 let mut keepalive_interval = interval(Duration::from_secs(10));
105
106 loop {
107 {
109 let status = self.status.read().await;
110 if *status == ComponentStatus::Stopping || *status == ComponentStatus::Stopped {
111 info!("Received stop signal, shutting down replication");
112 break;
113 }
114 }
115
116 tokio::select! {
117 result = self.read_next_message() => {
119 match result {
120 Ok(Some(msg)) => {
121 if let Err(e) = self.handle_message(msg).await {
122 error!("Error handling message: {e}");
123 if let Err(e) = self.recover_connection().await {
125 error!("Failed to recover connection: {e}");
126 return Err(e);
127 }
128 }
129 }
130 Ok(None) => {
131 }
133 Err(e) => {
134 error!("Error reading message: {e}");
135 if let Err(e) = self.recover_connection().await {
137 error!("Failed to recover connection: {e}");
138 return Err(e);
139 }
140 }
141 }
142 }
143
144 _ = keepalive_interval.tick() => {
146 if let Err(e) = self.send_feedback(false).await {
147 warn!("Failed to send keepalive: {e}");
148 }
149 }
150 }
151 }
152
153 self.shutdown().await?;
155 Ok(())
156 }
157
158 async fn connect_and_setup(&mut self) -> Result<()> {
159 info!("Connecting to PostgreSQL for replication");
160
161 let mut conn = ReplicationConnection::connect(
163 &self.config.host,
164 self.config.port,
165 &self.config.database,
166 &self.config.user,
167 &self.config.password,
168 )
169 .await?;
170
171 let system_info = conn.identify_system().await?;
173 info!("Connected to PostgreSQL system: {system_info:?}");
174
175 let slot_info = conn
177 .create_replication_slot(&self.config.slot_name, false)
178 .await?;
179 info!("Using replication slot: {slot_info:?}");
180
181 if !slot_info.consistent_point.is_empty() && slot_info.consistent_point != "0/0" {
183 self.current_lsn = parse_lsn(&slot_info.consistent_point)?;
184 } else {
185 self.current_lsn = 0;
187 }
188
189 let mut options = HashMap::new();
191 options.insert("proto_version".to_string(), "1".to_string());
192 options.insert(
193 "publication_names".to_string(),
194 self.config.publication_name.clone(),
195 );
196
197 conn.start_replication(&self.config.slot_name, Some(self.current_lsn), options)
199 .await?;
200
201 self.connection = Some(conn);
202 info!("Replication started from LSN: {:x}", self.current_lsn);
203
204 Ok(())
205 }
206
207 async fn read_next_message(&mut self) -> Result<Option<BackendMessage>> {
208 if let Some(conn) = &mut self.connection {
209 match tokio::time::timeout(Duration::from_millis(100), conn.read_replication_message())
211 .await
212 {
213 Ok(Ok(msg)) => Ok(Some(msg)),
214 Ok(Err(e)) => Err(e),
215 Err(_) => Ok(None), }
217 } else {
218 Err(anyhow!("No connection available"))
219 }
220 }
221
222 async fn handle_message(&mut self, msg: BackendMessage) -> Result<()> {
223 match msg {
224 BackendMessage::CopyData(data) => {
225 self.handle_copy_data(&data).await?;
226 }
227 BackendMessage::PrimaryKeepaliveMessage {
228 wal_end,
229 timestamp: _,
230 reply,
231 } => {
232 self.current_lsn = wal_end;
233 if reply == 1 {
234 self.send_feedback(true).await?;
235 }
236 }
237 BackendMessage::ErrorResponse(err) => {
238 error!("Server error: {}", err.message);
239 return Err(anyhow!("Server error: {}", err.message));
240 }
241 _ => {
242 trace!("Ignoring message: {msg:?}");
243 }
244 }
245 Ok(())
246 }
247
248 async fn handle_copy_data(&mut self, data: &[u8]) -> Result<()> {
249 if data.is_empty() {
250 return Ok(());
251 }
252
253 let msg_type = data[0];
255
256 match msg_type {
257 b'w' => {
258 self.handle_xlog_data(&data[1..]).await?;
260 }
261 b'k' => {
262 self.handle_keepalive(&data[1..]).await?;
264 }
265 _ => {
266 warn!("Unknown copy data message type: 0x{msg_type:02x}");
267 }
268 }
269
270 Ok(())
271 }
272
273 async fn handle_xlog_data(&mut self, data: &[u8]) -> Result<()> {
274 if data.len() < 24 {
275 return Err(anyhow!("XLogData message too short: {} bytes", data.len()));
276 }
277
278 let _start_lsn = u64::from_be_bytes([
280 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
281 ]);
282 let end_lsn = u64::from_be_bytes([
283 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
284 ]);
285 let _timestamp = i64::from_be_bytes([
286 data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
287 ]);
288
289 self.current_lsn = end_lsn;
291
292 let wal_data = &data[24..];
294
295 if !wal_data.is_empty() {
297 let msg_type = wal_data[0];
298 debug!(
299 "Attempting to decode WAL message type: {} ({}), data length: {}",
300 msg_type as char,
301 msg_type,
302 wal_data.len()
303 );
304 }
305
306 match self.decoder.decode_message(wal_data) {
307 Ok(Some(wal_msg)) => {
308 self.process_wal_message(wal_msg).await?;
309 }
310 Ok(None) => {
311 }
313 Err(e) => {
314 if !wal_data.is_empty() {
316 debug!(
317 "Failed to decode WAL message type {} ({}): {}, data length: {}",
318 wal_data[0] as char,
319 wal_data[0],
320 e,
321 wal_data.len()
322 );
323 }
324 }
326 }
327
328 if self.last_feedback_time.elapsed() > Duration::from_secs(5) {
330 self.send_feedback(false).await?;
331 }
332
333 Ok(())
334 }
335
336 async fn handle_keepalive(&mut self, data: &[u8]) -> Result<()> {
337 if data.len() < 17 {
338 return Err(anyhow!("Keepalive message too short"));
339 }
340
341 let wal_end = u64::from_be_bytes([
342 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
343 ]);
344 let reply = data[16];
345
346 self.current_lsn = wal_end;
347
348 if reply == 1 {
349 self.send_feedback(true).await?;
350 }
351
352 Ok(())
353 }
354
355 async fn process_wal_message(&mut self, msg: WalMessage) -> Result<()> {
356 match msg {
357 WalMessage::Begin(_) => {
358 self.pending_transaction = Some(Vec::new());
360 }
361 WalMessage::Commit(tx_info) => {
362 if let Some(changes) = self.pending_transaction.take() {
364 for change in changes {
366 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
368 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
369
370 let wrapper = SourceEventWrapper::with_profiling(
371 self.source_id.clone(),
372 SourceEvent::Change(change),
373 chrono::Utc::now(),
374 profiling,
375 );
376
377 if let Err(e) = SourceBase::dispatch_from_task(
379 self.dispatchers.clone(),
380 wrapper.clone(),
381 &self.source_id,
382 )
383 .await
384 {
385 debug!(
386 "[{}] Failed to dispatch change (no subscribers): {}",
387 self.source_id, e
388 );
389 }
390 }
391 debug!(
392 "Committed transaction {} with LSN {:x}",
393 tx_info.xid, tx_info.commit_lsn
394 );
395 }
396 }
397 WalMessage::Relation(relation) => {
398 let label = relation.name.clone();
401 self.relations.insert(
402 relation.id,
403 RelationMapping {
404 table_name: relation.name.clone(),
405 schema_name: relation.namespace.clone(),
406 label,
407 },
408 );
409
410 }
413 WalMessage::Insert { relation_id, tuple } => {
414 if let Some(change) = self.convert_insert(relation_id, tuple).await? {
415 if let Some(tx) = &mut self.pending_transaction {
416 tx.push(change);
417 } else {
418 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
420 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
421
422 let wrapper = SourceEventWrapper::with_profiling(
423 self.source_id.clone(),
424 SourceEvent::Change(change),
425 chrono::Utc::now(),
426 profiling,
427 );
428
429 if let Err(e) = SourceBase::dispatch_from_task(
431 self.dispatchers.clone(),
432 wrapper.clone(),
433 &self.source_id,
434 )
435 .await
436 {
437 debug!(
438 "[{}] Failed to dispatch change (no subscribers): {}",
439 self.source_id, e
440 );
441 }
442 }
443 }
444 }
445 WalMessage::Update {
446 relation_id,
447 old_tuple,
448 new_tuple,
449 } => {
450 if let Some(change) = self
451 .convert_update(relation_id, old_tuple, new_tuple)
452 .await?
453 {
454 if let Some(tx) = &mut self.pending_transaction {
455 tx.push(change);
456 } else {
457 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
458 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
459
460 let wrapper = SourceEventWrapper::with_profiling(
461 self.source_id.clone(),
462 SourceEvent::Change(change),
463 chrono::Utc::now(),
464 profiling,
465 );
466
467 if let Err(e) = SourceBase::dispatch_from_task(
469 self.dispatchers.clone(),
470 wrapper.clone(),
471 &self.source_id,
472 )
473 .await
474 {
475 debug!(
476 "[{}] Failed to dispatch change (no subscribers): {}",
477 self.source_id, e
478 );
479 }
480 }
481 }
482 }
483 WalMessage::Delete {
484 relation_id,
485 old_tuple,
486 } => {
487 if let Some(change) = self.convert_delete(relation_id, old_tuple).await? {
488 if let Some(tx) = &mut self.pending_transaction {
489 tx.push(change);
490 } else {
491 let mut profiling = drasi_lib::profiling::ProfilingMetadata::new();
492 profiling.source_send_ns = Some(drasi_lib::profiling::timestamp_ns());
493
494 let wrapper = SourceEventWrapper::with_profiling(
495 self.source_id.clone(),
496 SourceEvent::Change(change),
497 chrono::Utc::now(),
498 profiling,
499 );
500
501 if let Err(e) = SourceBase::dispatch_from_task(
503 self.dispatchers.clone(),
504 wrapper.clone(),
505 &self.source_id,
506 )
507 .await
508 {
509 debug!(
510 "[{}] Failed to dispatch change (no subscribers): {}",
511 self.source_id, e
512 );
513 }
514 }
515 }
516 }
517 WalMessage::Truncate { relation_ids } => {
518 warn!("Truncate not yet implemented for relations: {relation_ids:?}");
519 }
520 }
521 Ok(())
522 }
523
524 async fn convert_insert(
525 &self,
526 relation_id: u32,
527 tuple: Vec<super::types::PostgresValue>,
528 ) -> Result<Option<SourceChange>> {
529 let relation = self
531 .decoder
532 .get_relation(relation_id)
533 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
534
535 let mapping = self
536 .relations
537 .get(&relation_id)
538 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
539
540 let mut properties = drasi_core::models::ElementPropertyMap::new();
542 for (i, value) in tuple.iter().enumerate() {
543 if let Some(column) = relation.columns.get(i) {
544 let json_value = value.to_json();
545 if !json_value.is_null() {
546 properties.insert(
547 &column.name,
548 drasi_lib::sources::manager::convert_json_to_element_value(&json_value)?,
549 );
550 }
551 }
552 }
553
554 let element_id = self.generate_element_id(relation, &tuple).await?;
556
557 let element = Element::Node {
559 metadata: ElementMetadata {
560 reference: ElementReference::new(&self.source_id, &element_id),
561 labels: Arc::from([Arc::from(mapping.label.as_str())]),
562 effective_from: Utc::now().timestamp_millis() as u64,
563 },
564 properties,
565 };
566
567 Ok(Some(SourceChange::Insert { element }))
568 }
569
570 async fn convert_update(
571 &self,
572 relation_id: u32,
573 old_tuple: Option<Vec<super::types::PostgresValue>>,
574 new_tuple: Vec<super::types::PostgresValue>,
575 ) -> Result<Option<SourceChange>> {
576 let relation = self
577 .decoder
578 .get_relation(relation_id)
579 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
580
581 let mapping = self
582 .relations
583 .get(&relation_id)
584 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
585
586 let element_id = self.generate_element_id(relation, &new_tuple).await?;
588
589 let Some(_old_tuple) = old_tuple else {
591 warn!("UPDATE without old tuple for relation {relation_id}, treating as INSERT");
592 return self.convert_insert(relation_id, new_tuple).await;
593 };
594
595 let mut after_properties = drasi_core::models::ElementPropertyMap::new();
598
599 for (i, column) in relation.columns.iter().enumerate() {
601 if let Some(value) = new_tuple.get(i) {
602 let json_value = value.to_json();
603 if !json_value.is_null() {
604 after_properties.insert(
605 &column.name,
606 drasi_lib::sources::manager::convert_json_to_element_value(&json_value)?,
607 );
608 }
609 }
610 }
611
612 let after_element = Element::Node {
617 metadata: ElementMetadata {
618 reference: ElementReference::new(&self.source_id, &element_id),
619 labels: Arc::from([Arc::from(mapping.label.as_str())]),
620 effective_from: Utc::now().timestamp_millis() as u64,
621 },
622 properties: after_properties,
623 };
624
625 Ok(Some(SourceChange::Update {
626 element: after_element,
627 }))
628 }
629
630 async fn convert_delete(
631 &self,
632 relation_id: u32,
633 old_tuple: Vec<super::types::PostgresValue>,
634 ) -> Result<Option<SourceChange>> {
635 let relation = self
636 .decoder
637 .get_relation(relation_id)
638 .ok_or_else(|| anyhow!("Unknown relation {relation_id}"))?;
639
640 let mapping = self
641 .relations
642 .get(&relation_id)
643 .ok_or_else(|| anyhow!("No mapping for relation {relation_id}"))?;
644
645 let element_id = self.generate_element_id(relation, &old_tuple).await?;
646
647 Ok(Some(SourceChange::Delete {
648 metadata: ElementMetadata {
649 reference: ElementReference::new(&self.source_id, &element_id),
650 labels: Arc::from([Arc::from(mapping.label.as_str())]),
651 effective_from: Utc::now().timestamp_millis() as u64,
652 },
653 }))
654 }
655
656 async fn generate_element_id(
668 &self,
669 relation: &super::types::RelationInfo,
670 tuple: &[super::types::PostgresValue],
671 ) -> Result<String> {
672 let table_name = if relation.namespace == "public" {
674 relation.name.clone()
675 } else {
676 format!("{}.{}", relation.namespace, relation.name)
677 };
678
679 let primary_keys = self.table_primary_keys.read().await;
681 let pk_columns = primary_keys.get(&table_name);
682
683 let configured_keys = self
685 .config
686 .table_keys
687 .iter()
688 .find(|tk| tk.table == table_name)
689 .map(|tk| &tk.key_columns);
690
691 let key_columns = configured_keys.or(pk_columns);
693
694 if let Some(keys) = key_columns {
695 let mut key_parts = Vec::new();
696
697 for (i, column) in relation.columns.iter().enumerate() {
698 if keys.contains(&column.name) {
699 if let Some(value) = tuple.get(i) {
700 let json_val = value.to_json();
701 if !json_val.is_null() {
702 let val_str = json_val.to_string();
704 let cleaned = val_str.trim_matches('"');
705 key_parts.push(cleaned.to_string());
706 }
707 }
708 }
709 }
710
711 if !key_parts.is_empty() {
712 return Ok(format!("{}:{}", table_name, key_parts.join("_")));
714 }
715 }
716
717 warn!("No primary key value found for table '{table_name}'. Consider adding 'table_keys' configuration.");
719 Ok(format!("{}:{}", table_name, uuid::Uuid::new_v4()))
721 }
722
723 async fn send_feedback(&mut self, reply_requested: bool) -> Result<()> {
724 if let Some(conn) = &mut self.connection {
725 let status = StandbyStatusUpdate {
726 write_lsn: self.current_lsn,
727 flush_lsn: self.current_lsn,
728 apply_lsn: self.current_lsn,
729 reply_requested,
730 };
731
732 conn.send_standby_status(status).await?;
733 self.last_feedback_time = std::time::Instant::now();
734 trace!("Sent feedback with LSN: {:x}", self.current_lsn);
735 }
736
737 Ok(())
738 }
739
740 #[allow(dead_code)]
741 async fn check_stop_signal(&self) -> bool {
742 let status = self.status.read().await;
743 *status == ComponentStatus::Stopping || *status == ComponentStatus::Stopped
744 }
745
746 async fn recover_connection(&mut self) -> Result<()> {
747 warn!("Attempting to recover connection");
748
749 if let Some(conn) = self.connection.take() {
751 let _ = conn.close().await;
752 }
753
754 sleep(Duration::from_secs(5)).await;
756
757 self.connect_and_setup().await?;
759
760 info!("Connection recovered successfully");
761 Ok(())
762 }
763
764 async fn shutdown(&mut self) -> Result<()> {
765 info!("Shutting down replication stream");
766
767 let _ = self.send_feedback(false).await;
769
770 if let Some(conn) = self.connection.take() {
772 conn.close().await?;
773 }
774
775 Ok(())
776 }
777}
778
779fn parse_lsn(lsn_str: &str) -> Result<u64> {
780 let parts: Vec<&str> = lsn_str.split('/').collect();
781 if parts.len() != 2 {
782 return Err(anyhow!("Invalid LSN format: {lsn_str}"));
783 }
784
785 let high = u64::from_str_radix(parts[0], 16)?;
786 let low = u64::from_str_radix(parts[1], 16)?;
787
788 Ok((high << 32) | low)
789}