use std::io;
use std::process::ExitStatus;
use tracing::info;
use crate::error::{MigrationError, Result};
use crate::tls::connect_with_sslmode;
pub const REQUIRED_TOOLS: &[&str] = &["pg_dump", "pg_restore"];
pub async fn verify_pg_tools_installed() -> Result<()> {
for tool in REQUIRED_TOOLS {
let outcome = spawn_version_check(tool).await;
classify_version_check(tool, outcome)?;
}
Ok(())
}
async fn spawn_version_check(tool: &str) -> std::result::Result<ExitStatus, io::Error> {
use std::process::Stdio;
use tokio::process::Command;
Command::new(tool)
.arg("--version")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.await
}
pub(crate) fn classify_version_check(
tool: &str,
outcome: std::result::Result<ExitStatus, io::Error>,
) -> Result<()> {
match outcome {
Ok(s) if s.success() => Ok(()),
Ok(s) => Err(MigrationError::missing_tool(
tool,
format!("`{tool} --version` exited with status {s}"),
)),
Err(e) if e.kind() == io::ErrorKind::NotFound => {
let path = std::env::var("PATH").unwrap_or_default();
Err(MigrationError::missing_tool(
tool,
format!("not found in $PATH (PATH={path})"),
))
}
Err(e) => Err(MigrationError::missing_tool(
tool,
format!("failed to spawn `{tool} --version`: {e}"),
)),
}
}
pub async fn verify_publication_exists(source_conn: &str, publication: &str) -> Result<()> {
let client = connect_with_sslmode(source_conn).await?;
let row = client
.query_one(
"SELECT EXISTS(SELECT 1 FROM pg_publication WHERE pubname = $1)",
&[&publication],
)
.await?;
let exists: bool = row.get(0);
if !exists {
return Err(MigrationError::config(format!(
"publication `{publication}` does not exist on the source. \
Run `CREATE PUBLICATION {publication} FOR ALL TABLES;` \
(or a more targeted `FOR TABLE ...`) before retrying."
)));
}
Ok(())
}
pub async fn verify_source_logical_replication_ready(source_conn: &str) -> Result<()> {
let client = connect_with_sslmode(source_conn).await?;
let row = client
.query_one("SELECT current_setting('wal_level')", &[])
.await?;
let wal_level: String = row.get(0);
if wal_level != "logical" {
return Err(MigrationError::config(format!(
"the source server has `wal_level = '{wal_level}'`. \
Online migrations require `wal_level = 'logical'`. \
Set it via `ALTER SYSTEM SET wal_level = 'logical';` \
and restart the source server (this GUC is not reloadable)."
)));
}
for guc in ["max_replication_slots", "max_wal_senders"] {
let row = client
.query_one("SELECT current_setting($1)::text", &[&guc])
.await?;
let raw: String = row.get(0);
let parsed: i64 = raw.trim().parse().map_err(|_| {
MigrationError::config(format!(
"could not parse `{guc}` value `{raw}` as an integer"
))
})?;
if parsed <= 0 {
return Err(MigrationError::config(format!(
"the source server has `{guc} = {parsed}`. \
Online migrations require `{guc} > 0`. \
Raise it (PostgreSQL recommends >= 4) and restart \
the source server."
)));
}
}
info!("source is configured for logical replication (wal_level=logical)");
Ok(())
}
pub fn quote_qualified_name(name: &str) -> Result<String> {
let parts: Vec<&str> = name.splitn(2, '.').collect();
if parts.iter().any(|p| p.is_empty()) {
return Err(MigrationError::config(format!(
"invalid qualified name: `{name}` (empty component)"
)));
}
let quoted: std::result::Result<Vec<_>, _> =
parts.iter().map(|p| pg_walstream::quote_ident(p)).collect();
Ok(quoted?.join("."))
}
pub fn build_create_publication_sql(
publication: &str,
tables: &[String],
schemas: &[String],
) -> Result<String> {
let pub_ident = pg_walstream::quote_ident(publication)?;
let scope = if !tables.is_empty() || !schemas.is_empty() {
let mut scope_parts = Vec::new();
if !tables.is_empty() {
let quoted: std::result::Result<Vec<_>, _> =
tables.iter().map(|t| quote_qualified_name(t)).collect();
scope_parts.push(format!("TABLE {}", quoted?.join(", ")));
}
if !schemas.is_empty() {
let quoted: std::result::Result<Vec<_>, _> = schemas
.iter()
.map(|s| pg_walstream::quote_ident(s))
.collect();
scope_parts.push(format!("TABLES IN SCHEMA {}", quoted?.join(", ")));
}
format!("FOR {}", scope_parts.join(", "))
} else {
"FOR ALL TABLES".to_string()
};
Ok(format!("CREATE PUBLICATION {pub_ident} {scope}"))
}
pub fn filter_tables_by_exclusions(
tables: &[String],
exclude_tables: &[String],
exclude_schemas: &[String],
) -> Vec<String> {
tables
.iter()
.filter(|t| {
if exclude_tables.iter().any(|ex| ex == *t) {
return false;
}
if let Some(schema) = t.split('.').next() {
if exclude_schemas.iter().any(|ex| ex == schema) {
return false;
}
}
true
})
.cloned()
.collect()
}
async fn fetch_published_tables(
client: &tokio_postgres::Client,
exclude_tables: &[String],
exclude_schemas: &[String],
) -> Result<Vec<String>> {
let rows = client
.query(
"SELECT n.nspname::text, c.relname::text \
FROM pg_class c \
JOIN pg_namespace n ON n.oid = c.relnamespace \
WHERE c.relkind IN ('r', 'p') \
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \
AND n.nspname NOT LIKE 'pg_temp_%' \
AND n.nspname NOT LIKE 'pg_toast_temp_%'",
&[],
)
.await?;
let all_tables: Vec<String> = rows
.iter()
.map(|r| {
let schema: &str = r.get(0);
let table: &str = r.get(1);
format!("{schema}.{table}")
})
.collect();
Ok(filter_tables_by_exclusions(
&all_tables,
exclude_tables,
exclude_schemas,
))
}
pub async fn ensure_publication_exists(
source_conn: &str,
publication: &str,
tables: &[String],
schemas: &[String],
exclude_tables: &[String],
exclude_schemas: &[String],
) -> Result<bool> {
let client = connect_with_sslmode(source_conn).await?;
let row = client
.query_one(
"SELECT EXISTS(SELECT 1 FROM pg_publication WHERE pubname = $1)",
&[&publication],
)
.await?;
let exists: bool = row.get(0);
if exists {
info!(publication, "publication already exists on source");
return Ok(false);
}
let has_exclusions = !exclude_tables.is_empty() || !exclude_schemas.is_empty();
let has_includes = !tables.is_empty() || !schemas.is_empty();
let (effective_tables, effective_schemas): (Vec<String>, Vec<String>) = if has_exclusions
&& !has_includes
{
let resolved = fetch_published_tables(&client, exclude_tables, exclude_schemas).await?;
(resolved, Vec::new())
} else if has_exclusions && has_includes {
let filtered_tables = filter_tables_by_exclusions(tables, exclude_tables, exclude_schemas);
let filtered_schemas: Vec<String> = schemas
.iter()
.filter(|s| !exclude_schemas.iter().any(|ex| ex == *s))
.cloned()
.collect();
(filtered_tables, filtered_schemas)
} else {
(tables.to_vec(), schemas.to_vec())
};
let sql = build_create_publication_sql(publication, &effective_tables, &effective_schemas)?;
info!(publication, sql = %sql, "auto-creating publication on source");
client.batch_execute(&sql).await?;
info!(publication, "publication created successfully");
Ok(true)
}
pub fn maintenance_connection_string(conn: &str) -> String {
match conn.find('?') {
Some(q) => {
let scheme_end = conn.find("://").map(|i| i + 3).unwrap_or(0);
let at = conn[scheme_end..q].rfind('@').map(|i| i + scheme_end);
let host_start = at.map(|i| i + 1).unwrap_or(scheme_end);
match conn[host_start..q].find('/') {
Some(slash) => {
let abs = host_start + slash;
format!("{}/postgres{}", &conn[..abs], &conn[q..])
}
None => conn.to_string(),
}
}
None => {
let scheme_end = conn.find("://").map(|i| i + 3).unwrap_or(0);
let at = conn[scheme_end..].rfind('@').map(|i| i + scheme_end);
let host_start = at.map(|i| i + 1).unwrap_or(scheme_end);
match conn[host_start..].find('/') {
Some(slash) => {
let abs = host_start + slash;
format!("{}/postgres", &conn[..abs])
}
None => conn.to_string(),
}
}
}
}
pub async fn ensure_target_database_exists(target_conn: &str, db_name: &str) -> Result<()> {
let maint_conn = maintenance_connection_string(target_conn);
let client = connect_with_sslmode(&maint_conn).await?;
let row = client
.query_one(
"SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)",
&[&db_name],
)
.await?;
let exists: bool = row.get(0);
if exists {
info!(database = db_name, "target database already exists");
} else {
info!(database = db_name, "creating target database");
let create_sql = format!("CREATE DATABASE {}", pg_walstream::quote_ident(db_name)?);
client.batch_execute(&create_sql).await?;
info!(database = db_name, "target database created");
}
Ok(())
}
pub async fn ensure_pglogical_not_interfering(target_conn: &str) -> Result<()> {
let client = connect_with_sslmode(target_conn).await?;
let row = client
.query_one("SELECT current_setting('shared_preload_libraries')", &[])
.await?;
let libs: &str = row.get(0);
if libs.split(',').any(|lib| lib.trim() == "pglogical") {
return Err(MigrationError::config(
"the target server has `pglogical` in `shared_preload_libraries`. \
This is known to prevent native PostgreSQL logical-replication apply \
workers from starting (the workers crash silently on launch). \
Remove `pglogical` from `shared_preload_libraries` and restart the \
server before retrying."
.to_string(),
));
}
info!("pglogical is not in shared_preload_libraries — native logical replication will work");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::os::unix::process::ExitStatusExt;
fn ok_status() -> ExitStatus {
ExitStatus::from_raw(0)
}
fn fail_status() -> ExitStatus {
ExitStatus::from_raw(1 << 8) }
#[test]
fn classify_ok_when_version_succeeds() {
assert!(classify_version_check("pg_dump", Ok(ok_status())).is_ok());
}
#[test]
fn classify_missing_tool_when_not_found() {
let err = classify_version_check("pg_dump", Err(io::Error::from(io::ErrorKind::NotFound)))
.unwrap_err();
match err {
MigrationError::MissingTool { tool, reason } => {
assert_eq!(tool, "pg_dump");
assert!(reason.contains("not found in $PATH"));
}
other => panic!("expected MissingTool, got {other:?}"),
}
}
#[test]
fn classify_missing_tool_when_version_exits_nonzero() {
let err = classify_version_check("pg_restore", Ok(fail_status())).unwrap_err();
match err {
MigrationError::MissingTool { tool, reason } => {
assert_eq!(tool, "pg_restore");
assert!(reason.contains("--version"));
}
other => panic!("expected MissingTool, got {other:?}"),
}
}
#[test]
fn classify_missing_tool_for_other_io_errors() {
let err = classify_version_check(
"pg_dump",
Err(io::Error::from(io::ErrorKind::PermissionDenied)),
)
.unwrap_err();
match err {
MigrationError::MissingTool { tool, reason } => {
assert_eq!(tool, "pg_dump");
assert!(reason.contains("failed to spawn"));
}
other => panic!("expected MissingTool, got {other:?}"),
}
}
#[test]
fn missing_tool_error_message_includes_install_hint() {
let err = MigrationError::missing_tool("pg_dump", "not found in $PATH");
let msg = err.to_string();
assert!(msg.contains("pg_dump"));
assert!(msg.contains("not installed or not on $PATH"));
assert!(msg.contains("postgresql-client"));
}
#[test]
fn required_tools_includes_pg_dump_and_pg_restore() {
assert!(REQUIRED_TOOLS.contains(&"pg_dump"));
assert!(REQUIRED_TOOLS.contains(&"pg_restore"));
}
#[tokio::test]
async fn verify_pg_tools_passes_in_test_env() {
let _ = verify_pg_tools_installed().await;
}
#[test]
fn maintenance_conn_swaps_database_name() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432/mydb?sslmode=require"),
"postgresql://u:p@host:5432/postgres?sslmode=require"
);
}
#[test]
fn maintenance_conn_no_query_params() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432/mydb"),
"postgresql://u:p@host:5432/postgres"
);
}
#[test]
fn maintenance_conn_preserves_multiple_query_params() {
assert_eq!(
maintenance_connection_string(
"postgresql://u:p@host/db1?sslmode=require&connect_timeout=10"
),
"postgresql://u:p@host/postgres?sslmode=require&connect_timeout=10"
);
}
#[test]
fn maintenance_conn_handles_no_password() {
assert_eq!(
maintenance_connection_string("postgresql://u@host/db1?sslmode=require"),
"postgresql://u@host/postgres?sslmode=require"
);
}
#[test]
fn maintenance_conn_no_slash_after_host_returns_unchanged() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432"),
"postgresql://u:p@host:5432"
);
}
#[test]
fn maintenance_conn_no_slash_after_host_with_query_returns_unchanged() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432?sslmode=require"),
"postgresql://u:p@host:5432?sslmode=require"
);
}
#[test]
fn maintenance_conn_no_auth() {
assert_eq!(
maintenance_connection_string("postgresql://host:5432/mydb"),
"postgresql://host:5432/postgres"
);
}
#[test]
fn maintenance_conn_password_with_at_sign() {
assert_eq!(
maintenance_connection_string("postgresql://u:p%40ss@host/db?sslmode=require"),
"postgresql://u:p%40ss@host/postgres?sslmode=require"
);
}
#[test]
fn maintenance_conn_port_only_no_database() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432"),
"postgresql://u:p@host:5432"
);
}
#[test]
fn maintenance_conn_empty_database() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432/"),
"postgresql://u:p@host:5432/postgres"
);
}
#[test]
fn maintenance_conn_empty_database_with_query() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432/?sslmode=require"),
"postgresql://u:p@host:5432/postgres?sslmode=require"
);
}
#[test]
fn classify_version_check_ok_success_returns_ok() {
let result = classify_version_check("tool_x", Ok(ok_status()));
assert!(result.is_ok());
}
#[test]
fn required_tools_length() {
assert_eq!(REQUIRED_TOOLS.len(), 2);
}
#[test]
fn build_publication_sql_all_tables() {
let sql = build_create_publication_sql("my_pub", &[], &[]).unwrap();
assert_eq!(sql, "CREATE PUBLICATION \"my_pub\" FOR ALL TABLES");
}
#[test]
fn build_publication_sql_specific_tables() {
let tables = vec!["public.users".to_string(), "public.orders".to_string()];
let sql = build_create_publication_sql("my_pub", &tables, &[]).unwrap();
assert_eq!(
sql,
"CREATE PUBLICATION \"my_pub\" FOR TABLE \"public\".\"users\", \"public\".\"orders\""
);
}
#[test]
fn build_publication_sql_specific_schemas() {
let schemas = vec!["public".to_string(), "app".to_string()];
let sql = build_create_publication_sql("my_pub", &[], &schemas).unwrap();
assert_eq!(
sql,
"CREATE PUBLICATION \"my_pub\" FOR TABLES IN SCHEMA \"public\", \"app\""
);
}
#[test]
fn build_publication_sql_combines_tables_and_schemas() {
let tables = vec!["public.users".to_string()];
let schemas = vec!["app".to_string()];
let sql = build_create_publication_sql("my_pub", &tables, &schemas).unwrap();
assert_eq!(
sql,
"CREATE PUBLICATION \"my_pub\" FOR TABLE \"public\".\"users\", TABLES IN SCHEMA \"app\""
);
}
#[test]
fn build_publication_sql_quotes_special_chars() {
let sql = build_create_publication_sql("pub\"name", &[], &[]).unwrap();
assert!(sql.contains("\"pub\"\"name\""));
}
#[test]
fn quote_qualified_name_unqualified() {
let result = quote_qualified_name("users").unwrap();
assert_eq!(result, "\"users\"");
}
#[test]
fn quote_qualified_name_schema_qualified() {
let result = quote_qualified_name("public.users").unwrap();
assert_eq!(result, "\"public\".\"users\"");
}
#[test]
fn quote_qualified_name_special_chars() {
let result = quote_qualified_name("my schema.my table").unwrap();
assert_eq!(result, "\"my schema\".\"my table\"");
}
#[test]
fn quote_qualified_name_dot_in_table_part() {
let result = quote_qualified_name("public.my.table").unwrap();
assert_eq!(result, "\"public\".\"my.table\"");
}
#[test]
fn quote_qualified_name_rejects_trailing_dot() {
let result = quote_qualified_name("public.");
assert!(result.is_err());
}
#[test]
fn quote_qualified_name_rejects_leading_dot() {
let result = quote_qualified_name(".table");
assert!(result.is_err());
}
#[test]
fn filter_tables_excludes_by_table_name() {
let tables = vec![
"public.users".into(),
"public.orders".into(),
"public.large_logs".into(),
];
let result = filter_tables_by_exclusions(&tables, &["public.large_logs".into()], &[]);
assert_eq!(result, vec!["public.users", "public.orders"]);
}
#[test]
fn filter_tables_excludes_by_schema() {
let tables = vec![
"public.users".into(),
"audit.events".into(),
"audit.actions".into(),
"app.config".into(),
];
let result = filter_tables_by_exclusions(&tables, &[], &["audit".into()]);
assert_eq!(result, vec!["public.users", "app.config"]);
}
#[test]
fn filter_tables_excludes_both_table_and_schema() {
let tables = vec![
"public.users".into(),
"public.large_logs".into(),
"audit.events".into(),
"app.config".into(),
];
let result =
filter_tables_by_exclusions(&tables, &["public.large_logs".into()], &["audit".into()]);
assert_eq!(result, vec!["public.users", "app.config"]);
}
#[test]
fn filter_tables_no_exclusions_returns_all() {
let tables = vec!["public.users".into(), "public.orders".into()];
let result = filter_tables_by_exclusions(&tables, &[], &[]);
assert_eq!(result, tables);
}
#[test]
fn filter_tables_empty_input() {
let result: Vec<String> =
filter_tables_by_exclusions(&[], &["public.x".into()], &["audit".into()]);
assert!(result.is_empty());
}
#[test]
fn filter_tables_exclude_all_matches_returns_empty() {
let tables = vec!["audit.x".into(), "audit.y".into()];
let result = filter_tables_by_exclusions(&tables, &[], &["audit".into()]);
assert!(result.is_empty());
}
#[test]
fn filter_tables_exclude_nonexistent_is_noop() {
let tables = vec!["public.users".into()];
let result = filter_tables_by_exclusions(
&tables,
&["public.nonexistent".into()],
&["no_such_schema".into()],
);
assert_eq!(result, vec!["public.users"]);
}
#[test]
fn filter_then_build_sql_excludes_correctly() {
let all_tables: Vec<String> = vec![
"public.users".into(),
"public.orders".into(),
"audit.logs".into(),
"temp.scratch".into(),
];
let filtered =
filter_tables_by_exclusions(&all_tables, &["public.orders".into()], &["audit".into()]);
let sql = build_create_publication_sql("my_pub", &filtered, &[]).unwrap();
assert_eq!(
sql,
"CREATE PUBLICATION \"my_pub\" FOR TABLE \"public\".\"users\", \"temp\".\"scratch\""
);
assert!(!sql.contains("orders"));
assert!(!sql.contains("audit"));
}
#[test]
fn filter_schemas_from_include_list() {
let schemas: Vec<String> = ["public", "audit", "app"]
.iter()
.map(|s| (*s).into())
.collect();
let exclude_schemas: Vec<String> = ["audit"].iter().map(|s| (*s).into()).collect();
let filtered: Vec<String> = schemas
.iter()
.filter(|s| !exclude_schemas.iter().any(|ex| ex == *s))
.cloned()
.collect();
assert_eq!(filtered, vec!["public", "app"]);
let sql = build_create_publication_sql("p", &[], &filtered).unwrap();
assert!(sql.contains("\"public\""));
assert!(sql.contains("\"app\""));
assert!(!sql.contains("\"audit\""));
}
}