1#![allow(unexpected_cfgs)]
16
17pub mod config;
173pub mod connection;
174pub mod decoder;
175pub mod descriptor;
176pub mod protocol;
177pub mod scram;
178pub mod stream;
179pub mod types;
180
181pub use config::{PostgresSourceConfig, SslMode, TableKeyConfig};
182
183use anyhow::{anyhow, Result};
184use async_trait::async_trait;
185use bytes::Bytes;
186use drasi_lib::schema::{
187 normalize_table_label, NodeSchema, PropertySchema, PropertyType, SourceSchema,
188};
189use log::{debug, error, info};
190use postgres_native_tls::MakeTlsConnector;
191use std::collections::HashMap;
192use std::sync::atomic::{AtomicU64, Ordering};
193use std::sync::Arc;
194use tokio::sync::{oneshot, Mutex};
195
196use drasi_lib::channels::{ComponentStatus, DispatchMode, *};
197use drasi_lib::sources::base::{SourceBase, SourceBaseParams};
198use drasi_lib::{Source, SourceError};
199use tracing::Instrument;
200
201pub(crate) struct ReplayState {
208 pub(crate) read_lsn: AtomicU64,
209 pub(crate) flush_fence_lsn: AtomicU64,
216 pub(crate) fence_set_epoch_secs: AtomicU64,
219}
220
221const FENCE_TIMEOUT_SECS: u64 = 60;
228
229impl Default for ReplayState {
230 fn default() -> Self {
231 Self {
232 read_lsn: AtomicU64::new(0),
233 flush_fence_lsn: AtomicU64::new(u64::MAX),
234 fence_set_epoch_secs: AtomicU64::new(0),
235 }
236 }
237}
238
239impl ReplayState {
240 fn current_read_lsn(&self) -> u64 {
241 self.read_lsn.load(Ordering::Acquire)
242 }
243
244 fn set_flush_fence(&self, lsn: u64) {
247 use std::time::{SystemTime, UNIX_EPOCH};
248 let now_secs = SystemTime::now()
249 .duration_since(UNIX_EPOCH)
250 .unwrap_or_default()
251 .as_secs();
252 self.flush_fence_lsn.store(lsn, Ordering::Release);
253 self.fence_set_epoch_secs.store(now_secs, Ordering::Release);
254 }
255
256 fn clear_flush_fence(&self) {
258 self.flush_fence_lsn.store(u64::MAX, Ordering::Release);
259 }
260
261 fn effective_flush_fence(&self) -> u64 {
264 let fence = self.flush_fence_lsn.load(Ordering::Acquire);
265 if fence == u64::MAX {
266 return u64::MAX;
267 }
268 use std::time::{SystemTime, UNIX_EPOCH};
269 let now_secs = SystemTime::now()
270 .duration_since(UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs();
273 let set_secs = self.fence_set_epoch_secs.load(Ordering::Acquire);
274 if now_secs.saturating_sub(set_secs) > FENCE_TIMEOUT_SECS {
275 self.flush_fence_lsn.store(u64::MAX, Ordering::Release);
277 u64::MAX
278 } else {
279 fence
280 }
281 }
282}
283
284pub struct PostgresReplicationSource {
295 base: SourceBase,
297 config: PostgresSourceConfig,
299 cached_schema: Arc<std::sync::RwLock<Option<SourceSchema>>>,
301 replay_state: Arc<ReplayState>,
303 subscribe_lock: Mutex<()>,
305}
306
307fn postgres_type_to_property_type(data_type: &str) -> Option<PropertyType> {
308 match data_type {
309 "smallint" | "integer" | "bigint" => Some(PropertyType::Integer),
310 "real" | "double precision" | "numeric" | "decimal" => Some(PropertyType::Float),
311 "boolean" => Some(PropertyType::Boolean),
312 "timestamp without time zone"
313 | "timestamp with time zone"
314 | "date"
315 | "time without time zone"
316 | "time with time zone" => Some(PropertyType::Timestamp),
317 "json" | "jsonb" => Some(PropertyType::Json),
318 "character" | "character varying" | "text" | "uuid" | "bytea" => Some(PropertyType::String),
319 _ => None,
320 }
321}
322
323async fn introspect_postgres_schema(config: &PostgresSourceConfig) -> Result<Option<SourceSchema>> {
324 if config.tables.is_empty() {
325 return Ok(None);
326 }
327
328 let mut pg_config = tokio_postgres::Config::new();
329 pg_config.host(&config.host);
330 pg_config.port(config.port);
331 pg_config.dbname(&config.database);
332 pg_config.user(&config.user);
333 if !config.password.is_empty() {
334 pg_config.password(&config.password);
335 }
336
337 let client = match config.ssl_mode {
338 SslMode::Require => {
339 pg_config.ssl_mode(tokio_postgres::config::SslMode::Require);
340 let tls_connector = native_tls::TlsConnector::builder()
341 .danger_accept_invalid_hostnames(false)
342 .danger_accept_invalid_certs(false)
343 .build()
344 .map_err(|e| anyhow!("Failed to create TLS connector: {e}"))?;
345 let connector = MakeTlsConnector::new(tls_connector);
346
347 debug!("Schema introspection: connecting with SSL (require)");
348 let (client, connection) = pg_config.connect(connector).await?;
349 tokio::spawn(async move {
350 if let Err(e) = connection.await {
351 log::warn!("PostgreSQL schema introspection connection closed: {e}");
352 }
353 });
354 client
355 }
356 SslMode::Prefer => {
357 let tls_connector = native_tls::TlsConnector::builder()
359 .danger_accept_invalid_hostnames(false)
360 .danger_accept_invalid_certs(false)
361 .build()
362 .map_err(|e| anyhow!("Failed to create TLS connector: {e}"))?;
363 let connector = MakeTlsConnector::new(tls_connector);
364
365 pg_config.ssl_mode(tokio_postgres::config::SslMode::Prefer);
366 debug!("Schema introspection: connecting with SSL (prefer)");
367 let (client, connection) = pg_config.connect(connector).await?;
368 tokio::spawn(async move {
369 if let Err(e) = connection.await {
370 log::warn!("PostgreSQL schema introspection connection closed: {e}");
371 }
372 });
373 client
374 }
375 SslMode::Disable => {
376 debug!("Schema introspection: connecting without SSL");
377 let (client, connection) = pg_config.connect(tokio_postgres::NoTls).await?;
378 tokio::spawn(async move {
379 if let Err(e) = connection.await {
380 log::warn!("PostgreSQL schema introspection connection closed: {e}");
381 }
382 });
383 client
384 }
385 };
386
387 let mut nodes = Vec::new();
388
389 for table in &config.tables {
390 let (schema_name, table_name) = table
391 .split_once('.')
392 .map(|(schema, name)| (schema.to_string(), name.to_string()))
393 .unwrap_or_else(|| ("public".to_string(), table.to_string()));
394
395 let rows = client
396 .query(
397 "SELECT column_name, data_type \
398 FROM information_schema.columns \
399 WHERE table_schema = $1 AND table_name = $2 \
400 ORDER BY ordinal_position",
401 &[&schema_name, &table_name],
402 )
403 .await?;
404
405 let properties = rows
406 .into_iter()
407 .map(|row| PropertySchema {
408 name: row.get::<_, String>(0),
409 data_type: postgres_type_to_property_type(&row.get::<_, String>(1)),
410 description: None,
411 })
412 .collect();
413
414 nodes.push(NodeSchema {
415 label: normalize_table_label(&table_name),
416 properties,
417 });
418 }
419
420 Ok(Some(SourceSchema {
421 nodes,
422 relations: Vec::new(),
423 }))
424}
425
426impl PostgresReplicationSource {
427 pub fn builder(id: impl Into<String>) -> PostgresSourceBuilder {
446 PostgresSourceBuilder::new(id)
447 }
448
449 pub fn new(id: impl Into<String>, config: PostgresSourceConfig) -> Result<Self> {
491 let id = id.into();
492 let params = SourceBaseParams::new(id);
493 Ok(Self {
494 base: SourceBase::new(params)?,
495 config,
496 cached_schema: Arc::new(std::sync::RwLock::new(None)),
497 replay_state: Arc::new(ReplayState::default()),
498 subscribe_lock: Mutex::new(()),
499 })
500 }
501
502 pub fn with_dispatch(
507 id: impl Into<String>,
508 config: PostgresSourceConfig,
509 dispatch_mode: Option<DispatchMode>,
510 dispatch_buffer_capacity: Option<usize>,
511 ) -> Result<Self> {
512 let id = id.into();
513 let mut params = SourceBaseParams::new(id);
514 if let Some(mode) = dispatch_mode {
515 params = params.with_dispatch_mode(mode);
516 }
517 if let Some(capacity) = dispatch_buffer_capacity {
518 params = params.with_dispatch_buffer_capacity(capacity);
519 }
520 Ok(Self {
521 base: SourceBase::new(params)?,
522 config,
523 cached_schema: Arc::new(std::sync::RwLock::new(None)),
524 replay_state: Arc::new(ReplayState::default()),
525 subscribe_lock: Mutex::new(()),
526 })
527 }
528}
529
530impl PostgresReplicationSource {
531 async fn spawn_replication_task(&self, start_lsn: Option<u64>) -> Result<()> {
535 let config = self.config.clone();
536 let source_id = self.base.id.clone();
537 let reporter = self.base.status_handle();
538 let base = self.base.clone_shared();
539 let replay_state = self.replay_state.clone();
540 let (ready_tx, ready_rx) = oneshot::channel::<std::result::Result<(), String>>();
541
542 let instance_id = self
543 .base
544 .context()
545 .await
546 .map(|c| c.instance_id)
547 .unwrap_or_default();
548
549 let source_id_for_span = source_id.clone();
550 let span = tracing::info_span!(
551 "postgres_replication_task",
552 instance_id = %instance_id,
553 component_id = %source_id_for_span,
554 component_type = "source",
555 start_lsn = ?start_lsn
556 );
557
558 let task = tokio::spawn(
559 async move {
560 info!("Starting replication for source {source_id}");
561
562 let mut stream = stream::ReplicationStream::new(
563 config,
564 source_id.clone(),
565 reporter.clone(),
566 base,
567 replay_state,
568 start_lsn,
569 );
570
571 if let Err(e) = stream.run(Some(ready_tx)).await {
572 error!("Replication task failed for {source_id}: {e}");
573 reporter
574 .set_status(
575 ComponentStatus::Error,
576 Some(format!("Replication failed: {e}")),
577 )
578 .await;
579 }
580 }
581 .instrument(span),
582 );
583
584 *self.base.task_handle.write().await = Some(task);
585
586 match ready_rx.await {
587 Ok(Ok(())) => Ok(()),
588 Ok(Err(message)) => {
589 let _ = self.base.task_handle.write().await.take();
590 Err(anyhow!(
591 "Failed to establish PostgreSQL replication: {message}"
592 ))
593 }
594 Err(_) => {
595 let _ = self.base.task_handle.write().await.take();
596 Err(anyhow!(
597 "PostgreSQL replication task exited before confirming startup"
598 ))
599 }
600 }
601 }
602
603 async fn abort_replication_task(&self) {
604 if let Some(task) = self.base.task_handle.write().await.take() {
605 task.abort();
606 let _ = task.await;
607 }
608 }
609
610 async fn pause_replication_for_restart(&self, start_lsn: u64) {
611 info!(
612 "Pausing PostgreSQL source '{}' before replay from requested LSN {:x}",
613 self.base.id, start_lsn
614 );
615
616 self.base
617 .set_status(
618 ComponentStatus::Starting,
619 Some(format!(
620 "Rewinding PostgreSQL replication to LSN {start_lsn:x}"
621 )),
622 )
623 .await;
624
625 self.abort_replication_task().await;
626
627 self.base.clear_sequence_position_map().await;
633
634 self.replay_state.set_flush_fence(start_lsn);
639 }
640
641 async fn resume_replication_from(&self, start_lsn: u64) -> Result<()> {
642 self.spawn_replication_task(Some(start_lsn)).await?;
643
644 self.base
645 .set_status(
646 ComponentStatus::Running,
647 Some(format!(
648 "PostgreSQL replication resumed from LSN {start_lsn:x}"
649 )),
650 )
651 .await;
652
653 Ok(())
654 }
655
656 async fn restart_replication_from(&self, start_lsn: u64) -> Result<()> {
657 info!(
658 "Restarting PostgreSQL source '{}' from requested LSN {:x}",
659 self.base.id, start_lsn
660 );
661
662 self.pause_replication_for_restart(start_lsn).await;
663 self.resume_replication_from(start_lsn).await
664 }
665
666 async fn get_earliest_available_lsn(&self) -> Result<u64> {
669 let mut conn = connection::ReplicationConnection::connect(
670 &self.config.host,
671 self.config.port,
672 &self.config.database,
673 &self.config.user,
674 &self.config.password,
675 )
676 .await?;
677
678 let _ = conn.identify_system().await?;
679 let slot_info = conn
680 .get_replication_slot_info(&self.config.slot_name)
681 .await?;
682 let _ = conn.close().await;
683
684 let lsn_str = slot_info
687 .restart_lsn
688 .as_deref()
689 .unwrap_or(&slot_info.consistent_point);
690
691 if lsn_str.is_empty() || lsn_str == "0/0" {
692 Ok(0)
693 } else {
694 connection::parse_lsn(lsn_str)
695 }
696 }
697}
698
699#[async_trait]
700impl Source for PostgresReplicationSource {
701 fn id(&self) -> &str {
702 &self.base.id
703 }
704
705 fn type_name(&self) -> &str {
706 "postgres"
707 }
708
709 fn properties(&self) -> HashMap<String, serde_json::Value> {
710 use crate::descriptor::PostgresSourceConfigDto;
711
712 self.base
713 .properties_or_serialize(&PostgresSourceConfigDto::from(&self.config))
714 }
715
716 fn auto_start(&self) -> bool {
717 self.base.get_auto_start()
718 }
719
720 fn describe_schema(&self) -> Option<SourceSchema> {
721 self.cached_schema
722 .read()
723 .ok()
724 .and_then(|schema| schema.clone())
725 .or_else(|| {
726 if self.config.tables.is_empty() {
727 None
728 } else {
729 Some(SourceSchema {
730 nodes: self
731 .config
732 .tables
733 .iter()
734 .map(|table| NodeSchema::new(normalize_table_label(table)))
735 .collect(),
736 relations: Vec::new(),
737 })
738 }
739 })
740 }
741
742 async fn start(&self) -> Result<()> {
743 if self.base.get_status().await == ComponentStatus::Running {
744 return Ok(());
745 }
746
747 self.base.set_status(ComponentStatus::Starting, None).await;
748 info!("Starting PostgreSQL replication source: {}", self.base.id);
749
750 match introspect_postgres_schema(&self.config).await {
751 Ok(Some(schema)) => {
752 if let Ok(mut cached) = self.cached_schema.write() {
753 *cached = Some(schema);
754 }
755 }
756 Ok(None) => {}
757 Err(e) => {
758 log::warn!(
759 "Failed to introspect PostgreSQL schema for '{}': {e}",
760 self.base.id
761 );
762 }
763 }
764
765 self.spawn_replication_task(None).await?;
766 self.base
767 .set_status(
768 ComponentStatus::Running,
769 Some("PostgreSQL replication started".to_string()),
770 )
771 .await;
772
773 Ok(())
774 }
775
776 async fn stop(&self) -> Result<()> {
777 if self.base.get_status().await != ComponentStatus::Running {
778 return Ok(());
779 }
780
781 info!("Stopping PostgreSQL replication source: {}", self.base.id);
782
783 self.base.set_status(ComponentStatus::Stopping, None).await;
784
785 self.abort_replication_task().await;
786
787 if let Ok(mut cached) = self.cached_schema.write() {
789 *cached = None;
790 }
791
792 self.base
793 .set_status(
794 ComponentStatus::Stopped,
795 Some("PostgreSQL replication stopped".to_string()),
796 )
797 .await;
798
799 Ok(())
800 }
801
802 async fn status(&self) -> ComponentStatus {
803 self.base.get_status().await
804 }
805
806 async fn subscribe(
807 &self,
808 settings: drasi_lib::config::SourceSubscriptionSettings,
809 ) -> Result<SubscriptionResponse> {
810 let _guard = self.subscribe_lock.lock().await;
813
814 let mut restart_from = None;
815 let mut pause_before_subscribe = false;
816
817 if let Some(ref resume_bytes) = settings.resume_from {
818 let resume_lsn = if resume_bytes.len() == 8 {
820 let bytes: [u8; 8] = resume_bytes[..8]
821 .try_into()
822 .expect("length already verified as 8");
823 u64::from_be_bytes(bytes)
824 } else {
825 return Err(anyhow!(
826 "Invalid resume_from position: expected 8 bytes, got {}",
827 resume_bytes.len()
828 ));
829 };
830
831 let earliest_available = self.get_earliest_available_lsn().await?;
832 if resume_lsn < earliest_available {
833 return Err(SourceError::PositionUnavailable {
834 source_id: self.base.id.clone(),
835 requested: resume_bytes.clone(),
836 earliest_available: Some(Bytes::from(
837 earliest_available.to_be_bytes().to_vec(),
838 )),
839 }
840 .into());
841 }
842
843 let read_lsn = self.replay_state.current_read_lsn();
844 let is_running = self.base.get_status().await == ComponentStatus::Running;
845
846 if !is_running || read_lsn == 0 || resume_lsn < read_lsn {
847 restart_from = Some(resume_lsn);
848 pause_before_subscribe = is_running;
849 }
850 }
851
852 if let Some(start_lsn) = restart_from.filter(|_| pause_before_subscribe) {
853 self.pause_replication_for_restart(start_lsn).await;
857 }
858
859 let response = match self
860 .base
861 .subscribe_with_bootstrap(&settings, "PostgreSQL")
862 .await
863 {
864 Ok(response) => response,
865 Err(err) => {
866 if pause_before_subscribe {
867 self.base
868 .set_status(
869 ComponentStatus::Error,
870 Some(format!("Failed to register replay subscription: {err}")),
871 )
872 .await;
873 }
874 return Err(err);
875 }
876 };
877
878 if let Some(start_lsn) = restart_from {
879 if pause_before_subscribe {
880 self.resume_replication_from(start_lsn).await?;
881 } else {
882 self.restart_replication_from(start_lsn).await?;
883 }
884 }
885
886 Ok(response)
887 }
888
889 fn as_any(&self) -> &dyn std::any::Any {
890 self
891 }
892
893 async fn initialize(&self, context: drasi_lib::context::SourceRuntimeContext) {
894 self.base.initialize(context).await;
895 self.base
897 .set_position_comparator(drasi_lib::sources::ByteLexPositionComparator)
898 .await;
899 }
900
901 async fn remove_position_handle(&self, query_id: &str) {
902 self.base.remove_position_handle(query_id).await;
903 }
904
905 async fn on_subscriptions_complete(&self) {
906 self.replay_state.clear_flush_fence();
909 }
910
911 async fn set_bootstrap_provider(
912 &self,
913 provider: Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>,
914 ) {
915 self.base.set_bootstrap_provider(provider).await;
916 }
917}
918
919pub struct PostgresSourceBuilder {
942 id: String,
943 host: String,
944 port: u16,
945 database: String,
946 user: String,
947 password: String,
948 tables: Vec<String>,
949 slot_name: String,
950 publication_name: String,
951 ssl_mode: SslMode,
952 table_keys: Vec<TableKeyConfig>,
953 dispatch_mode: Option<DispatchMode>,
954 dispatch_buffer_capacity: Option<usize>,
955 bootstrap_provider: Option<Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>>,
956 auto_start: bool,
957}
958
959impl PostgresSourceBuilder {
960 pub fn new(id: impl Into<String>) -> Self {
962 Self {
963 id: id.into(),
964 host: "localhost".to_string(),
965 port: 5432,
966 database: String::new(),
967 user: String::new(),
968 password: String::new(),
969 tables: Vec::new(),
970 slot_name: "drasi_slot".to_string(),
971 publication_name: "drasi_publication".to_string(),
972 ssl_mode: SslMode::default(),
973 table_keys: Vec::new(),
974 dispatch_mode: None,
975 dispatch_buffer_capacity: None,
976 bootstrap_provider: None,
977 auto_start: true,
978 }
979 }
980
981 pub fn with_host(mut self, host: impl Into<String>) -> Self {
983 self.host = host.into();
984 self
985 }
986
987 pub fn with_port(mut self, port: u16) -> Self {
989 self.port = port;
990 self
991 }
992
993 pub fn with_database(mut self, database: impl Into<String>) -> Self {
995 self.database = database.into();
996 self
997 }
998
999 pub fn with_user(mut self, user: impl Into<String>) -> Self {
1001 self.user = user.into();
1002 self
1003 }
1004
1005 pub fn with_password(mut self, password: impl Into<String>) -> Self {
1007 self.password = password.into();
1008 self
1009 }
1010
1011 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
1013 self.tables = tables;
1014 self
1015 }
1016
1017 pub fn add_table(mut self, table: impl Into<String>) -> Self {
1019 self.tables.push(table.into());
1020 self
1021 }
1022
1023 pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
1025 self.slot_name = slot_name.into();
1026 self
1027 }
1028
1029 pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
1031 self.publication_name = publication_name.into();
1032 self
1033 }
1034
1035 pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
1037 self.ssl_mode = ssl_mode;
1038 self
1039 }
1040
1041 pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
1043 self.table_keys = table_keys;
1044 self
1045 }
1046
1047 pub fn add_table_key(mut self, table_key: TableKeyConfig) -> Self {
1049 self.table_keys.push(table_key);
1050 self
1051 }
1052
1053 pub fn with_dispatch_mode(mut self, mode: DispatchMode) -> Self {
1055 self.dispatch_mode = Some(mode);
1056 self
1057 }
1058
1059 pub fn with_dispatch_buffer_capacity(mut self, capacity: usize) -> Self {
1061 self.dispatch_buffer_capacity = Some(capacity);
1062 self
1063 }
1064
1065 pub fn with_bootstrap_provider(
1067 mut self,
1068 provider: impl drasi_lib::bootstrap::BootstrapProvider + 'static,
1069 ) -> Self {
1070 self.bootstrap_provider = Some(Box::new(provider));
1071 self
1072 }
1073
1074 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
1079 self.auto_start = auto_start;
1080 self
1081 }
1082
1083 pub fn with_config(mut self, config: PostgresSourceConfig) -> Self {
1085 self.host = config.host;
1086 self.port = config.port;
1087 self.database = config.database;
1088 self.user = config.user;
1089 self.password = config.password;
1090 self.tables = config.tables;
1091 self.slot_name = config.slot_name;
1092 self.publication_name = config.publication_name;
1093 self.ssl_mode = config.ssl_mode;
1094 self.table_keys = config.table_keys;
1095 self
1096 }
1097
1098 pub fn build(self) -> Result<PostgresReplicationSource> {
1104 let config = PostgresSourceConfig {
1105 host: self.host,
1106 port: self.port,
1107 database: self.database,
1108 user: self.user,
1109 password: self.password,
1110 tables: self.tables,
1111 slot_name: self.slot_name,
1112 publication_name: self.publication_name,
1113 ssl_mode: self.ssl_mode,
1114 table_keys: self.table_keys,
1115 };
1116
1117 let mut params = SourceBaseParams::new(&self.id).with_auto_start(self.auto_start);
1118 if let Some(mode) = self.dispatch_mode {
1119 params = params.with_dispatch_mode(mode);
1120 }
1121 if let Some(capacity) = self.dispatch_buffer_capacity {
1122 params = params.with_dispatch_buffer_capacity(capacity);
1123 }
1124 if let Some(provider) = self.bootstrap_provider {
1125 params = params.with_bootstrap_provider(provider);
1126 }
1127
1128 Ok(PostgresReplicationSource {
1129 base: SourceBase::new(params)?,
1130 config,
1131 cached_schema: Arc::new(std::sync::RwLock::new(None)),
1132 replay_state: Arc::new(ReplayState::default()),
1133 subscribe_lock: Mutex::new(()),
1134 })
1135 }
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140 use super::*;
1141
1142 mod construction {
1143 use super::*;
1144
1145 #[test]
1146 fn test_builder_with_valid_config() {
1147 let source = PostgresSourceBuilder::new("test-source")
1148 .with_database("testdb")
1149 .with_user("testuser")
1150 .build();
1151 assert!(source.is_ok());
1152 }
1153
1154 #[test]
1155 fn test_builder_with_custom_config() {
1156 let source = PostgresSourceBuilder::new("pg-source")
1157 .with_host("192.168.1.100")
1158 .with_port(5433)
1159 .with_database("production")
1160 .with_user("admin")
1161 .with_password("secret")
1162 .build()
1163 .unwrap();
1164 assert_eq!(source.id(), "pg-source");
1165 }
1166
1167 #[test]
1168 fn test_with_dispatch_creates_source() {
1169 let config = PostgresSourceConfig {
1170 host: "localhost".to_string(),
1171 port: 5432,
1172 database: "testdb".to_string(),
1173 user: "testuser".to_string(),
1174 password: String::new(),
1175 tables: Vec::new(),
1176 slot_name: "drasi_slot".to_string(),
1177 publication_name: "drasi_publication".to_string(),
1178 ssl_mode: SslMode::default(),
1179 table_keys: Vec::new(),
1180 };
1181 let source = PostgresReplicationSource::with_dispatch(
1182 "dispatch-source",
1183 config,
1184 Some(DispatchMode::Channel),
1185 Some(2000),
1186 );
1187 assert!(source.is_ok());
1188 assert_eq!(source.unwrap().id(), "dispatch-source");
1189 }
1190 }
1191
1192 mod properties {
1193 use super::*;
1194
1195 #[test]
1196 fn test_id_returns_correct_value() {
1197 let source = PostgresSourceBuilder::new("my-pg-source")
1198 .with_database("db")
1199 .with_user("user")
1200 .build()
1201 .unwrap();
1202 assert_eq!(source.id(), "my-pg-source");
1203 }
1204
1205 #[test]
1206 fn test_type_name_returns_postgres() {
1207 let source = PostgresSourceBuilder::new("test")
1208 .with_database("db")
1209 .with_user("user")
1210 .build()
1211 .unwrap();
1212 assert_eq!(source.type_name(), "postgres");
1213 }
1214
1215 #[test]
1216 fn test_properties_contains_connection_info() {
1217 let source = PostgresSourceBuilder::new("test")
1218 .with_host("db.example.com")
1219 .with_port(5433)
1220 .with_database("mydb")
1221 .with_user("app_user")
1222 .with_password("secret")
1223 .with_tables(vec!["users".to_string()])
1224 .build()
1225 .unwrap();
1226 let props = source.properties();
1227
1228 assert_eq!(
1229 props.get("host"),
1230 Some(&serde_json::Value::String("db.example.com".to_string()))
1231 );
1232 assert_eq!(
1233 props.get("port"),
1234 Some(&serde_json::Value::Number(5433.into()))
1235 );
1236 assert_eq!(
1237 props.get("database"),
1238 Some(&serde_json::Value::String("mydb".to_string()))
1239 );
1240 assert_eq!(
1241 props.get("user"),
1242 Some(&serde_json::Value::String("app_user".to_string()))
1243 );
1244 }
1245
1246 #[test]
1247 fn test_properties_includes_password() {
1248 let source = PostgresSourceBuilder::new("test")
1249 .with_database("db")
1250 .with_user("user")
1251 .with_password("super_secret_password")
1252 .build()
1253 .unwrap();
1254 let props = source.properties();
1255
1256 assert_eq!(
1258 props.get("password"),
1259 Some(&serde_json::Value::String(
1260 "super_secret_password".to_string()
1261 ))
1262 );
1263 }
1264
1265 #[test]
1266 fn test_properties_includes_tables() {
1267 let source = PostgresSourceBuilder::new("test")
1268 .with_database("db")
1269 .with_user("user")
1270 .with_tables(vec!["users".to_string(), "orders".to_string()])
1271 .build()
1272 .unwrap();
1273 let props = source.properties();
1274
1275 let tables = props.get("tables").unwrap().as_array().unwrap();
1276 assert_eq!(tables.len(), 2);
1277 assert_eq!(tables[0], "users");
1278 assert_eq!(tables[1], "orders");
1279 }
1280
1281 #[test]
1282 fn test_describe_schema_falls_back_to_configured_tables() {
1283 let source = PostgresSourceBuilder::new("test")
1284 .with_database("db")
1285 .with_user("user")
1286 .with_tables(vec!["public.users".to_string(), "orders".to_string()])
1287 .build()
1288 .unwrap();
1289
1290 let schema = source
1291 .describe_schema()
1292 .expect("configured postgres tables should produce fallback schema");
1293
1294 assert_eq!(schema.nodes.len(), 2);
1295 assert!(schema.nodes.iter().any(|node| node.label == "users"));
1296 assert!(schema.nodes.iter().any(|node| node.label == "orders"));
1297 }
1298
1299 #[test]
1300 fn test_postgres_type_to_property_type_integer() {
1301 assert_eq!(
1302 postgres_type_to_property_type("integer"),
1303 Some(PropertyType::Integer)
1304 );
1305 assert_eq!(
1306 postgres_type_to_property_type("bigint"),
1307 Some(PropertyType::Integer)
1308 );
1309 assert_eq!(
1310 postgres_type_to_property_type("smallint"),
1311 Some(PropertyType::Integer)
1312 );
1313 }
1314
1315 #[test]
1316 fn test_postgres_type_to_property_type_float() {
1317 assert_eq!(
1318 postgres_type_to_property_type("double precision"),
1319 Some(PropertyType::Float)
1320 );
1321 assert_eq!(
1322 postgres_type_to_property_type("real"),
1323 Some(PropertyType::Float)
1324 );
1325 assert_eq!(
1326 postgres_type_to_property_type("numeric"),
1327 Some(PropertyType::Float)
1328 );
1329 assert_eq!(
1330 postgres_type_to_property_type("decimal"),
1331 Some(PropertyType::Float)
1332 );
1333 }
1334
1335 #[test]
1336 fn test_postgres_type_to_property_type_boolean() {
1337 assert_eq!(
1338 postgres_type_to_property_type("boolean"),
1339 Some(PropertyType::Boolean)
1340 );
1341 }
1342
1343 #[test]
1344 fn test_postgres_type_to_property_type_timestamp() {
1345 assert_eq!(
1346 postgres_type_to_property_type("timestamp with time zone"),
1347 Some(PropertyType::Timestamp)
1348 );
1349 assert_eq!(
1350 postgres_type_to_property_type("timestamp without time zone"),
1351 Some(PropertyType::Timestamp)
1352 );
1353 assert_eq!(
1354 postgres_type_to_property_type("date"),
1355 Some(PropertyType::Timestamp)
1356 );
1357 }
1358
1359 #[test]
1360 fn test_postgres_type_to_property_type_json() {
1361 assert_eq!(
1362 postgres_type_to_property_type("json"),
1363 Some(PropertyType::Json)
1364 );
1365 assert_eq!(
1366 postgres_type_to_property_type("jsonb"),
1367 Some(PropertyType::Json)
1368 );
1369 }
1370
1371 #[test]
1372 fn test_postgres_type_to_property_type_string() {
1373 assert_eq!(
1374 postgres_type_to_property_type("character varying"),
1375 Some(PropertyType::String)
1376 );
1377 assert_eq!(
1378 postgres_type_to_property_type("text"),
1379 Some(PropertyType::String)
1380 );
1381 assert_eq!(
1382 postgres_type_to_property_type("uuid"),
1383 Some(PropertyType::String)
1384 );
1385 }
1386
1387 #[test]
1388 fn test_postgres_type_to_property_type_unknown_returns_none() {
1389 assert_eq!(postgres_type_to_property_type("point"), None);
1390 assert_eq!(postgres_type_to_property_type("polygon"), None);
1391 assert_eq!(postgres_type_to_property_type("cidr"), None);
1392 }
1393 }
1394
1395 mod lifecycle {
1396 use super::*;
1397
1398 struct TestSecretResolver;
1400
1401 #[async_trait::async_trait]
1402 impl drasi_plugin_sdk::resolver::ValueResolver for TestSecretResolver {
1403 async fn resolve_to_string(
1404 &self,
1405 value: &drasi_plugin_sdk::ConfigValue<String>,
1406 ) -> Result<String, drasi_plugin_sdk::resolver::ResolverError> {
1407 match value {
1408 drasi_plugin_sdk::ConfigValue::Secret { name } => {
1409 Ok(format!("resolved-secret-{name}"))
1410 }
1411 _ => Err(drasi_plugin_sdk::resolver::ResolverError::WrongResolverType),
1412 }
1413 }
1414 }
1415
1416 fn ensure_test_secret_resolver() {
1417 drasi_plugin_sdk::resolver::register_secret_resolver(std::sync::Arc::new(
1418 TestSecretResolver,
1419 ));
1420 }
1421
1422 #[tokio::test]
1423 async fn test_descriptor_preserves_secret_envelope() {
1424 use crate::descriptor::PostgresSourceDescriptor;
1425 use drasi_lib::sources::Source;
1426 use drasi_plugin_sdk::descriptor::SourcePluginDescriptor;
1427
1428 ensure_test_secret_resolver();
1429
1430 let config_json = serde_json::json!({
1431 "host": "db.example.com",
1432 "port": 5432,
1433 "database": "mydb",
1434 "user": "app_user",
1435 "password": {
1436 "kind": "Secret",
1437 "name": "db-password"
1438 },
1439 "tables": ["users"],
1440 "slotName": "drasi_slot",
1441 "publicationName": "drasi_pub"
1442 });
1443
1444 let descriptor = PostgresSourceDescriptor;
1445 let source = descriptor
1446 .create_source("pg-secret-test", &config_json, true)
1447 .await
1448 .expect("descriptor should create source");
1449
1450 let props = source.properties();
1451
1452 let password = props.get("password").expect("password must be present");
1454 assert!(
1455 password.is_object(),
1456 "password should be Secret envelope, got: {password}"
1457 );
1458 assert_eq!(
1459 password.get("kind").and_then(|v| v.as_str()),
1460 Some("Secret"),
1461 "envelope kind must be Secret"
1462 );
1463 assert_eq!(
1464 password.get("name").and_then(|v| v.as_str()),
1465 Some("db-password"),
1466 "secret name must be preserved"
1467 );
1468
1469 let props_str = serde_json::to_string(&props).unwrap();
1471 assert!(
1472 !props_str.contains("resolved-secret-db-password"),
1473 "resolved secret must not appear in properties"
1474 );
1475
1476 assert!(
1478 props.contains_key("slotName"),
1479 "expected camelCase 'slotName', got keys: {:?}",
1480 props.keys().collect::<Vec<_>>()
1481 );
1482 assert!(
1483 props.contains_key("publicationName"),
1484 "expected camelCase 'publicationName'"
1485 );
1486 }
1487
1488 #[tokio::test]
1489 async fn test_initial_status_is_stopped() {
1490 let source = PostgresSourceBuilder::new("test")
1491 .with_database("db")
1492 .with_user("user")
1493 .build()
1494 .unwrap();
1495 assert_eq!(source.status().await, ComponentStatus::Stopped);
1496 }
1497
1498 #[test]
1499 fn test_builder_fallback_produces_camel_case() {
1500 use drasi_lib::sources::Source;
1501
1502 let source = PostgresSourceBuilder::new("pg-fallback")
1503 .with_host("myhost.example.com")
1504 .with_port(5433)
1505 .with_database("mydb")
1506 .with_user("admin")
1507 .with_password("secret123")
1508 .with_ssl_mode(SslMode::Require)
1509 .with_slot_name("custom_slot")
1510 .with_publication_name("custom_pub")
1511 .build()
1512 .unwrap();
1513
1514 let props = source.properties();
1515
1516 assert!(
1518 props.contains_key("slotName"),
1519 "expected camelCase 'slotName', got keys: {:?}",
1520 props.keys().collect::<Vec<_>>()
1521 );
1522 assert!(
1523 props.contains_key("publicationName"),
1524 "expected camelCase 'publicationName'"
1525 );
1526 assert!(
1527 props.contains_key("sslMode"),
1528 "expected camelCase 'sslMode'"
1529 );
1530
1531 assert!(
1533 !props.contains_key("slot_name"),
1534 "should not have snake_case 'slot_name'"
1535 );
1536 assert!(
1537 !props.contains_key("publication_name"),
1538 "should not have snake_case 'publication_name'"
1539 );
1540
1541 assert_eq!(
1543 props.get("host").and_then(|v| v.as_str()),
1544 Some("myhost.example.com")
1545 );
1546 assert_eq!(props.get("port").and_then(|v| v.as_u64()), Some(5433));
1547 assert_eq!(props.get("database").and_then(|v| v.as_str()), Some("mydb"));
1548 assert_eq!(
1549 props.get("password").and_then(|v| v.as_str()),
1550 Some("secret123")
1551 );
1552 }
1553
1554 #[tokio::test]
1555 async fn test_pause_replication_for_restart_aborts_existing_task() {
1556 let source = PostgresSourceBuilder::new("test")
1557 .with_database("db")
1558 .with_user("user")
1559 .build()
1560 .unwrap();
1561
1562 source.base.set_status(ComponentStatus::Running, None).await;
1563
1564 let task = tokio::spawn(async {
1565 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1566 });
1567 *source.base.task_handle.write().await = Some(task);
1568
1569 source.pause_replication_for_restart(42).await;
1570
1571 assert!(source.base.task_handle.read().await.is_none());
1572 assert_eq!(source.status().await, ComponentStatus::Starting);
1573 }
1574
1575 #[test]
1576 fn test_supports_replay_returns_true() {
1577 let source = PostgresSourceBuilder::new("test")
1578 .with_database("db")
1579 .with_user("user")
1580 .build()
1581 .unwrap();
1582 assert!(source.supports_replay());
1583 }
1584 }
1585
1586 mod subscribe {
1587 use super::*;
1588 use drasi_lib::config::SourceSubscriptionSettings;
1589 use std::collections::HashSet;
1590
1591 #[tokio::test]
1592 async fn test_malformed_resume_from_rejected() {
1593 let source = PostgresSourceBuilder::new("test-source")
1594 .with_database("testdb")
1595 .with_user("testuser")
1596 .build()
1597 .unwrap();
1598
1599 let bad_position = bytes::Bytes::from(vec![0u8; 4]);
1601 let settings = SourceSubscriptionSettings {
1602 source_id: "test-source".to_string(),
1603 query_id: "q-bad-position".to_string(),
1604 enable_bootstrap: false,
1605 nodes: HashSet::new(),
1606 relations: HashSet::new(),
1607 resume_from: Some(bad_position),
1608 request_position_handle: false,
1609 last_sequence: None,
1610 };
1611
1612 let result = source.subscribe(settings).await;
1613 assert!(result.is_err());
1614 let err_msg = format!("{}", result.err().unwrap());
1615 assert!(
1616 err_msg.contains("expected 8 bytes"),
1617 "Error should mention expected byte length, got: {err_msg}"
1618 );
1619 }
1620 }
1621
1622 mod builder {
1623 use super::*;
1624
1625 #[test]
1626 fn test_postgres_builder_defaults() {
1627 let source = PostgresSourceBuilder::new("test").build().unwrap();
1628 assert_eq!(source.config.host, "localhost");
1629 assert_eq!(source.config.port, 5432);
1630 assert_eq!(source.config.slot_name, "drasi_slot");
1631 assert_eq!(source.config.publication_name, "drasi_publication");
1632 }
1633
1634 #[test]
1635 fn test_postgres_builder_custom_values() {
1636 let source = PostgresSourceBuilder::new("test")
1637 .with_host("db.example.com")
1638 .with_port(5433)
1639 .with_database("production")
1640 .with_user("app_user")
1641 .with_password("secret")
1642 .with_tables(vec!["users".to_string(), "orders".to_string()])
1643 .build()
1644 .unwrap();
1645
1646 assert_eq!(source.config.host, "db.example.com");
1647 assert_eq!(source.config.port, 5433);
1648 assert_eq!(source.config.database, "production");
1649 assert_eq!(source.config.user, "app_user");
1650 assert_eq!(source.config.password, "secret");
1651 assert_eq!(source.config.tables.len(), 2);
1652 assert_eq!(source.config.tables[0], "users");
1653 assert_eq!(source.config.tables[1], "orders");
1654 }
1655
1656 #[test]
1657 fn test_builder_add_table() {
1658 let source = PostgresSourceBuilder::new("test")
1659 .add_table("table1")
1660 .add_table("table2")
1661 .add_table("table3")
1662 .build()
1663 .unwrap();
1664
1665 assert_eq!(source.config.tables.len(), 3);
1666 assert_eq!(source.config.tables[0], "table1");
1667 assert_eq!(source.config.tables[1], "table2");
1668 assert_eq!(source.config.tables[2], "table3");
1669 }
1670
1671 #[test]
1672 fn test_builder_slot_and_publication() {
1673 let source = PostgresSourceBuilder::new("test")
1674 .with_slot_name("custom_slot")
1675 .with_publication_name("custom_pub")
1676 .build()
1677 .unwrap();
1678
1679 assert_eq!(source.config.slot_name, "custom_slot");
1680 assert_eq!(source.config.publication_name, "custom_pub");
1681 }
1682
1683 #[test]
1684 fn test_builder_id() {
1685 let source = PostgresReplicationSource::builder("my-pg-source")
1686 .with_database("db")
1687 .with_user("user")
1688 .build()
1689 .unwrap();
1690
1691 assert_eq!(source.base.id, "my-pg-source");
1692 }
1693 }
1694
1695 mod config {
1696 use super::*;
1697
1698 #[test]
1699 fn test_config_serialization() {
1700 let config = PostgresSourceConfig {
1701 host: "localhost".to_string(),
1702 port: 5432,
1703 database: "testdb".to_string(),
1704 user: "testuser".to_string(),
1705 password: String::new(),
1706 tables: Vec::new(),
1707 slot_name: "drasi_slot".to_string(),
1708 publication_name: "drasi_publication".to_string(),
1709 ssl_mode: SslMode::default(),
1710 table_keys: Vec::new(),
1711 };
1712
1713 let json = serde_json::to_string(&config).unwrap();
1714 let deserialized: PostgresSourceConfig = serde_json::from_str(&json).unwrap();
1715
1716 assert_eq!(config, deserialized);
1717 }
1718
1719 #[test]
1720 fn test_config_deserialization_with_required_fields() {
1721 let json = r#"{
1722 "database": "mydb",
1723 "user": "myuser"
1724 }"#;
1725 let config: PostgresSourceConfig = serde_json::from_str(json).unwrap();
1726
1727 assert_eq!(config.database, "mydb");
1728 assert_eq!(config.user, "myuser");
1729 assert_eq!(config.host, "localhost"); assert_eq!(config.port, 5432); assert_eq!(config.slot_name, "drasi_slot"); }
1733
1734 #[test]
1735 fn test_config_deserialization_full() {
1736 let json = r#"{
1737 "host": "db.prod.internal",
1738 "port": 5433,
1739 "database": "production",
1740 "user": "replication_user",
1741 "password": "secret",
1742 "tables": ["accounts", "transactions"],
1743 "slot_name": "prod_slot",
1744 "publication_name": "prod_publication"
1745 }"#;
1746 let config: PostgresSourceConfig = serde_json::from_str(json).unwrap();
1747
1748 assert_eq!(config.host, "db.prod.internal");
1749 assert_eq!(config.port, 5433);
1750 assert_eq!(config.database, "production");
1751 assert_eq!(config.user, "replication_user");
1752 assert_eq!(config.password, "secret");
1753 assert_eq!(config.tables, vec!["accounts", "transactions"]);
1754 assert_eq!(config.slot_name, "prod_slot");
1755 assert_eq!(config.publication_name, "prod_publication");
1756 }
1757 }
1758}
1759
1760#[cfg(feature = "dynamic-plugin")]
1764drasi_plugin_sdk::export_plugin!(
1765 plugin_id = "postgres-source",
1766 core_version = env!("CARGO_PKG_VERSION"),
1767 lib_version = env!("CARGO_PKG_VERSION"),
1768 plugin_version = env!("CARGO_PKG_VERSION"),
1769 source_descriptors = [descriptor::PostgresSourceDescriptor],
1770 reaction_descriptors = [],
1771 bootstrap_descriptors = [],
1772);