#![allow(unexpected_cfgs)]
pub mod config;
pub mod connection;
pub mod decoder;
pub mod descriptor;
pub mod protocol;
pub mod scram;
pub mod stream;
pub mod types;
pub use config::{PostgresSourceConfig, SslMode, TableKeyConfig};
use anyhow::Result;
use async_trait::async_trait;
use log::{error, info};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use drasi_lib::channels::{DispatchMode, *};
use drasi_lib::component_graph::ComponentStatusHandle;
use drasi_lib::sources::base::{SourceBase, SourceBaseParams};
use drasi_lib::Source;
use tracing::Instrument;
pub struct PostgresReplicationSource {
base: SourceBase,
config: PostgresSourceConfig,
}
impl PostgresReplicationSource {
pub fn builder(id: impl Into<String>) -> PostgresSourceBuilder {
PostgresSourceBuilder::new(id)
}
pub fn new(id: impl Into<String>, config: PostgresSourceConfig) -> Result<Self> {
let id = id.into();
let params = SourceBaseParams::new(id);
Ok(Self {
base: SourceBase::new(params)?,
config,
})
}
pub fn with_dispatch(
id: impl Into<String>,
config: PostgresSourceConfig,
dispatch_mode: Option<DispatchMode>,
dispatch_buffer_capacity: Option<usize>,
) -> Result<Self> {
let id = id.into();
let mut params = SourceBaseParams::new(id);
if let Some(mode) = dispatch_mode {
params = params.with_dispatch_mode(mode);
}
if let Some(capacity) = dispatch_buffer_capacity {
params = params.with_dispatch_buffer_capacity(capacity);
}
Ok(Self {
base: SourceBase::new(params)?,
config,
})
}
}
#[async_trait]
impl Source for PostgresReplicationSource {
fn id(&self) -> &str {
&self.base.id
}
fn type_name(&self) -> &str {
"postgres"
}
fn properties(&self) -> HashMap<String, serde_json::Value> {
match serde_json::to_value(&self.config) {
Ok(serde_json::Value::Object(mut map)) => {
map.remove("password");
map.into_iter().collect()
}
_ => HashMap::new(),
}
}
fn auto_start(&self) -> bool {
self.base.get_auto_start()
}
async fn start(&self) -> Result<()> {
if self.base.get_status().await == ComponentStatus::Running {
return Ok(());
}
self.base.set_status(ComponentStatus::Starting, None).await;
info!("Starting PostgreSQL replication source: {}", self.base.id);
let config = self.config.clone();
let source_id = self.base.id.clone();
let dispatchers = self.base.dispatchers.clone();
let reporter = self.base.status_handle();
let instance_id = self
.base
.context()
.await
.map(|c| c.instance_id)
.unwrap_or_default();
let source_id_for_span = source_id.clone();
let span = tracing::info_span!(
"postgres_replication_task",
instance_id = %instance_id,
component_id = %source_id_for_span,
component_type = "source"
);
let task = tokio::spawn(
async move {
if let Err(e) =
run_replication(source_id.clone(), config, dispatchers, reporter.clone()).await
{
error!("Replication task failed for {source_id}: {e}");
reporter
.set_status(
ComponentStatus::Error,
Some(format!("Replication failed: {e}")),
)
.await;
}
}
.instrument(span),
);
*self.base.task_handle.write().await = Some(task);
self.base
.set_status(
ComponentStatus::Running,
Some("PostgreSQL replication started".to_string()),
)
.await;
Ok(())
}
async fn stop(&self) -> Result<()> {
if self.base.get_status().await != ComponentStatus::Running {
return Ok(());
}
info!("Stopping PostgreSQL replication source: {}", self.base.id);
self.base.set_status(ComponentStatus::Stopping, None).await;
if let Some(task) = self.base.task_handle.write().await.take() {
task.abort();
}
self.base
.set_status(
ComponentStatus::Stopped,
Some("PostgreSQL replication stopped".to_string()),
)
.await;
Ok(())
}
async fn status(&self) -> ComponentStatus {
self.base.get_status().await
}
async fn subscribe(
&self,
settings: drasi_lib::config::SourceSubscriptionSettings,
) -> Result<SubscriptionResponse> {
self.base
.subscribe_with_bootstrap(&settings, "PostgreSQL")
.await
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn initialize(&self, context: drasi_lib::context::SourceRuntimeContext) {
self.base.initialize(context).await;
}
async fn set_bootstrap_provider(
&self,
provider: Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>,
) {
self.base.set_bootstrap_provider(provider).await;
}
}
async fn run_replication(
source_id: String,
config: PostgresSourceConfig,
dispatchers: Arc<
RwLock<
Vec<Box<dyn drasi_lib::channels::ChangeDispatcher<SourceEventWrapper> + Send + Sync>>,
>,
>,
status_handle: ComponentStatusHandle,
) -> Result<()> {
info!("Starting replication for source {source_id}");
let mut stream = stream::ReplicationStream::new(config, source_id, dispatchers, status_handle);
stream.run().await
}
pub struct PostgresSourceBuilder {
id: String,
host: String,
port: u16,
database: String,
user: String,
password: String,
tables: Vec<String>,
slot_name: String,
publication_name: String,
ssl_mode: SslMode,
table_keys: Vec<TableKeyConfig>,
dispatch_mode: Option<DispatchMode>,
dispatch_buffer_capacity: Option<usize>,
bootstrap_provider: Option<Box<dyn drasi_lib::bootstrap::BootstrapProvider + 'static>>,
auto_start: bool,
}
impl PostgresSourceBuilder {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
host: "localhost".to_string(),
port: 5432,
database: String::new(),
user: String::new(),
password: String::new(),
tables: Vec::new(),
slot_name: "drasi_slot".to_string(),
publication_name: "drasi_publication".to_string(),
ssl_mode: SslMode::default(),
table_keys: Vec::new(),
dispatch_mode: None,
dispatch_buffer_capacity: None,
bootstrap_provider: None,
auto_start: true,
}
}
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.database = database.into();
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = user.into();
self
}
pub fn with_password(mut self, password: impl Into<String>) -> Self {
self.password = password.into();
self
}
pub fn with_tables(mut self, tables: Vec<String>) -> Self {
self.tables = tables;
self
}
pub fn add_table(mut self, table: impl Into<String>) -> Self {
self.tables.push(table.into());
self
}
pub fn with_slot_name(mut self, slot_name: impl Into<String>) -> Self {
self.slot_name = slot_name.into();
self
}
pub fn with_publication_name(mut self, publication_name: impl Into<String>) -> Self {
self.publication_name = publication_name.into();
self
}
pub fn with_ssl_mode(mut self, ssl_mode: SslMode) -> Self {
self.ssl_mode = ssl_mode;
self
}
pub fn with_table_keys(mut self, table_keys: Vec<TableKeyConfig>) -> Self {
self.table_keys = table_keys;
self
}
pub fn add_table_key(mut self, table_key: TableKeyConfig) -> Self {
self.table_keys.push(table_key);
self
}
pub fn with_dispatch_mode(mut self, mode: DispatchMode) -> Self {
self.dispatch_mode = Some(mode);
self
}
pub fn with_dispatch_buffer_capacity(mut self, capacity: usize) -> Self {
self.dispatch_buffer_capacity = Some(capacity);
self
}
pub fn with_bootstrap_provider(
mut self,
provider: impl drasi_lib::bootstrap::BootstrapProvider + 'static,
) -> Self {
self.bootstrap_provider = Some(Box::new(provider));
self
}
pub fn with_auto_start(mut self, auto_start: bool) -> Self {
self.auto_start = auto_start;
self
}
pub fn with_config(mut self, config: PostgresSourceConfig) -> Self {
self.host = config.host;
self.port = config.port;
self.database = config.database;
self.user = config.user;
self.password = config.password;
self.tables = config.tables;
self.slot_name = config.slot_name;
self.publication_name = config.publication_name;
self.ssl_mode = config.ssl_mode;
self.table_keys = config.table_keys;
self
}
pub fn build(self) -> Result<PostgresReplicationSource> {
let config = PostgresSourceConfig {
host: self.host,
port: self.port,
database: self.database,
user: self.user,
password: self.password,
tables: self.tables,
slot_name: self.slot_name,
publication_name: self.publication_name,
ssl_mode: self.ssl_mode,
table_keys: self.table_keys,
};
let mut params = SourceBaseParams::new(&self.id).with_auto_start(self.auto_start);
if let Some(mode) = self.dispatch_mode {
params = params.with_dispatch_mode(mode);
}
if let Some(capacity) = self.dispatch_buffer_capacity {
params = params.with_dispatch_buffer_capacity(capacity);
}
if let Some(provider) = self.bootstrap_provider {
params = params.with_bootstrap_provider(provider);
}
Ok(PostgresReplicationSource {
base: SourceBase::new(params)?,
config,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
mod construction {
use super::*;
#[test]
fn test_builder_with_valid_config() {
let source = PostgresSourceBuilder::new("test-source")
.with_database("testdb")
.with_user("testuser")
.build();
assert!(source.is_ok());
}
#[test]
fn test_builder_with_custom_config() {
let source = PostgresSourceBuilder::new("pg-source")
.with_host("192.168.1.100")
.with_port(5433)
.with_database("production")
.with_user("admin")
.with_password("secret")
.build()
.unwrap();
assert_eq!(source.id(), "pg-source");
}
#[test]
fn test_with_dispatch_creates_source() {
let config = PostgresSourceConfig {
host: "localhost".to_string(),
port: 5432,
database: "testdb".to_string(),
user: "testuser".to_string(),
password: String::new(),
tables: Vec::new(),
slot_name: "drasi_slot".to_string(),
publication_name: "drasi_publication".to_string(),
ssl_mode: SslMode::default(),
table_keys: Vec::new(),
};
let source = PostgresReplicationSource::with_dispatch(
"dispatch-source",
config,
Some(DispatchMode::Channel),
Some(2000),
);
assert!(source.is_ok());
assert_eq!(source.unwrap().id(), "dispatch-source");
}
}
mod properties {
use super::*;
#[test]
fn test_id_returns_correct_value() {
let source = PostgresSourceBuilder::new("my-pg-source")
.with_database("db")
.with_user("user")
.build()
.unwrap();
assert_eq!(source.id(), "my-pg-source");
}
#[test]
fn test_type_name_returns_postgres() {
let source = PostgresSourceBuilder::new("test")
.with_database("db")
.with_user("user")
.build()
.unwrap();
assert_eq!(source.type_name(), "postgres");
}
#[test]
fn test_properties_contains_connection_info() {
let source = PostgresSourceBuilder::new("test")
.with_host("db.example.com")
.with_port(5433)
.with_database("mydb")
.with_user("app_user")
.with_password("secret")
.with_tables(vec!["users".to_string()])
.build()
.unwrap();
let props = source.properties();
assert_eq!(
props.get("host"),
Some(&serde_json::Value::String("db.example.com".to_string()))
);
assert_eq!(
props.get("port"),
Some(&serde_json::Value::Number(5433.into()))
);
assert_eq!(
props.get("database"),
Some(&serde_json::Value::String("mydb".to_string()))
);
assert_eq!(
props.get("user"),
Some(&serde_json::Value::String("app_user".to_string()))
);
}
#[test]
fn test_properties_does_not_expose_password() {
let source = PostgresSourceBuilder::new("test")
.with_database("db")
.with_user("user")
.with_password("super_secret_password")
.build()
.unwrap();
let props = source.properties();
assert!(!props.contains_key("password"));
}
#[test]
fn test_properties_includes_tables() {
let source = PostgresSourceBuilder::new("test")
.with_database("db")
.with_user("user")
.with_tables(vec!["users".to_string(), "orders".to_string()])
.build()
.unwrap();
let props = source.properties();
let tables = props.get("tables").unwrap().as_array().unwrap();
assert_eq!(tables.len(), 2);
assert_eq!(tables[0], "users");
assert_eq!(tables[1], "orders");
}
}
mod lifecycle {
use super::*;
#[tokio::test]
async fn test_initial_status_is_stopped() {
let source = PostgresSourceBuilder::new("test")
.with_database("db")
.with_user("user")
.build()
.unwrap();
assert_eq!(source.status().await, ComponentStatus::Stopped);
}
}
mod builder {
use super::*;
#[test]
fn test_postgres_builder_defaults() {
let source = PostgresSourceBuilder::new("test").build().unwrap();
assert_eq!(source.config.host, "localhost");
assert_eq!(source.config.port, 5432);
assert_eq!(source.config.slot_name, "drasi_slot");
assert_eq!(source.config.publication_name, "drasi_publication");
}
#[test]
fn test_postgres_builder_custom_values() {
let source = PostgresSourceBuilder::new("test")
.with_host("db.example.com")
.with_port(5433)
.with_database("production")
.with_user("app_user")
.with_password("secret")
.with_tables(vec!["users".to_string(), "orders".to_string()])
.build()
.unwrap();
assert_eq!(source.config.host, "db.example.com");
assert_eq!(source.config.port, 5433);
assert_eq!(source.config.database, "production");
assert_eq!(source.config.user, "app_user");
assert_eq!(source.config.password, "secret");
assert_eq!(source.config.tables.len(), 2);
assert_eq!(source.config.tables[0], "users");
assert_eq!(source.config.tables[1], "orders");
}
#[test]
fn test_builder_add_table() {
let source = PostgresSourceBuilder::new("test")
.add_table("table1")
.add_table("table2")
.add_table("table3")
.build()
.unwrap();
assert_eq!(source.config.tables.len(), 3);
assert_eq!(source.config.tables[0], "table1");
assert_eq!(source.config.tables[1], "table2");
assert_eq!(source.config.tables[2], "table3");
}
#[test]
fn test_builder_slot_and_publication() {
let source = PostgresSourceBuilder::new("test")
.with_slot_name("custom_slot")
.with_publication_name("custom_pub")
.build()
.unwrap();
assert_eq!(source.config.slot_name, "custom_slot");
assert_eq!(source.config.publication_name, "custom_pub");
}
#[test]
fn test_builder_id() {
let source = PostgresReplicationSource::builder("my-pg-source")
.with_database("db")
.with_user("user")
.build()
.unwrap();
assert_eq!(source.base.id, "my-pg-source");
}
}
mod config {
use super::*;
#[test]
fn test_config_serialization() {
let config = PostgresSourceConfig {
host: "localhost".to_string(),
port: 5432,
database: "testdb".to_string(),
user: "testuser".to_string(),
password: String::new(),
tables: Vec::new(),
slot_name: "drasi_slot".to_string(),
publication_name: "drasi_publication".to_string(),
ssl_mode: SslMode::default(),
table_keys: Vec::new(),
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: PostgresSourceConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, deserialized);
}
#[test]
fn test_config_deserialization_with_required_fields() {
let json = r#"{
"database": "mydb",
"user": "myuser"
}"#;
let config: PostgresSourceConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.database, "mydb");
assert_eq!(config.user, "myuser");
assert_eq!(config.host, "localhost"); assert_eq!(config.port, 5432); assert_eq!(config.slot_name, "drasi_slot"); }
#[test]
fn test_config_deserialization_full() {
let json = r#"{
"host": "db.prod.internal",
"port": 5433,
"database": "production",
"user": "replication_user",
"password": "secret",
"tables": ["accounts", "transactions"],
"slot_name": "prod_slot",
"publication_name": "prod_publication"
}"#;
let config: PostgresSourceConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.host, "db.prod.internal");
assert_eq!(config.port, 5433);
assert_eq!(config.database, "production");
assert_eq!(config.user, "replication_user");
assert_eq!(config.password, "secret");
assert_eq!(config.tables, vec!["accounts", "transactions"]);
assert_eq!(config.slot_name, "prod_slot");
assert_eq!(config.publication_name, "prod_publication");
}
}
}
#[cfg(feature = "dynamic-plugin")]
drasi_plugin_sdk::export_plugin!(
plugin_id = "postgres-source",
core_version = env!("CARGO_PKG_VERSION"),
lib_version = env!("CARGO_PKG_VERSION"),
plugin_version = env!("CARGO_PKG_VERSION"),
source_descriptors = [descriptor::PostgresSourceDescriptor],
reaction_descriptors = [],
bootstrap_descriptors = [],
);