use std::io;
use std::process::ExitStatus;
use tracing::info;
use crate::error::{MigrationError, Result};
use crate::tls::connect_with_sslmode;
pub const REQUIRED_TOOLS: &[&str] = &["pg_dump", "pg_restore"];
pub async fn verify_pg_tools_installed() -> Result<()> {
for tool in REQUIRED_TOOLS {
let outcome = spawn_version_check(tool).await;
classify_version_check(tool, outcome)?;
}
Ok(())
}
async fn spawn_version_check(tool: &str) -> std::result::Result<ExitStatus, io::Error> {
use std::process::Stdio;
use tokio::process::Command;
Command::new(tool)
.arg("--version")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.await
}
pub(crate) fn classify_version_check(
tool: &str,
outcome: std::result::Result<ExitStatus, io::Error>,
) -> Result<()> {
match outcome {
Ok(s) if s.success() => Ok(()),
Ok(s) => Err(MigrationError::missing_tool(
tool,
format!("`{tool} --version` exited with status {s}"),
)),
Err(e) if e.kind() == io::ErrorKind::NotFound => {
let path = std::env::var("PATH").unwrap_or_default();
Err(MigrationError::missing_tool(
tool,
format!("not found in $PATH (PATH={path})"),
))
}
Err(e) => Err(MigrationError::missing_tool(
tool,
format!("failed to spawn `{tool} --version`: {e}"),
)),
}
}
pub async fn verify_publication_exists(source_conn: &str, publication: &str) -> Result<()> {
let client = connect_with_sslmode(source_conn).await?;
let row = client
.query_one(
"SELECT EXISTS(SELECT 1 FROM pg_publication WHERE pubname = $1)",
&[&publication],
)
.await?;
let exists: bool = row.get(0);
if !exists {
return Err(MigrationError::config(format!(
"publication `{publication}` does not exist on the source. \
Run `CREATE PUBLICATION {publication} FOR ALL TABLES;` \
(or a more targeted `FOR TABLE ...`) before retrying."
)));
}
Ok(())
}
pub async fn verify_source_logical_replication_ready(source_conn: &str) -> Result<()> {
let client = connect_with_sslmode(source_conn).await?;
let row = client
.query_one("SELECT current_setting('wal_level')", &[])
.await?;
let wal_level: String = row.get(0);
if wal_level != "logical" {
return Err(MigrationError::config(format!(
"the source server has `wal_level = '{wal_level}'`. \
Online migrations require `wal_level = 'logical'`. \
Set it via `ALTER SYSTEM SET wal_level = 'logical';` \
and restart the source server (this GUC is not reloadable)."
)));
}
for guc in ["max_replication_slots", "max_wal_senders"] {
let row = client
.query_one("SELECT current_setting($1)::text", &[&guc])
.await?;
let raw: String = row.get(0);
let parsed: i64 = raw.trim().parse().map_err(|_| {
MigrationError::config(format!(
"could not parse `{guc}` value `{raw}` as an integer"
))
})?;
if parsed <= 0 {
return Err(MigrationError::config(format!(
"the source server has `{guc} = {parsed}`. \
Online migrations require `{guc} > 0`. \
Raise it (PostgreSQL recommends >= 4) and restart \
the source server."
)));
}
}
info!("source is configured for logical replication (wal_level=logical)");
Ok(())
}
pub fn maintenance_connection_string(conn: &str) -> String {
match conn.find('?') {
Some(q) => {
let scheme_end = conn.find("://").map(|i| i + 3).unwrap_or(0);
let at = conn[scheme_end..q].rfind('@').map(|i| i + scheme_end);
let host_start = at.map(|i| i + 1).unwrap_or(scheme_end);
match conn[host_start..q].find('/') {
Some(slash) => {
let abs = host_start + slash;
format!("{}/postgres{}", &conn[..abs], &conn[q..])
}
None => conn.to_string(),
}
}
None => {
let scheme_end = conn.find("://").map(|i| i + 3).unwrap_or(0);
let at = conn[scheme_end..].rfind('@').map(|i| i + scheme_end);
let host_start = at.map(|i| i + 1).unwrap_or(scheme_end);
match conn[host_start..].find('/') {
Some(slash) => {
let abs = host_start + slash;
format!("{}/postgres", &conn[..abs])
}
None => conn.to_string(),
}
}
}
}
pub async fn ensure_target_database_exists(target_conn: &str, db_name: &str) -> Result<()> {
let maint_conn = maintenance_connection_string(target_conn);
let client = connect_with_sslmode(&maint_conn).await?;
let row = client
.query_one(
"SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)",
&[&db_name],
)
.await?;
let exists: bool = row.get(0);
if exists {
info!(database = db_name, "target database already exists");
} else {
info!(database = db_name, "creating target database");
let create_sql = format!("CREATE DATABASE {}", pg_walstream::quote_ident(db_name)?);
client.batch_execute(&create_sql).await?;
info!(database = db_name, "target database created");
}
Ok(())
}
pub async fn ensure_pglogical_not_interfering(target_conn: &str) -> Result<()> {
let client = connect_with_sslmode(target_conn).await?;
let row = client
.query_one("SELECT current_setting('shared_preload_libraries')", &[])
.await?;
let libs: &str = row.get(0);
if libs.split(',').any(|lib| lib.trim() == "pglogical") {
return Err(MigrationError::config(
"the target server has `pglogical` in `shared_preload_libraries`. \
This is known to prevent native PostgreSQL logical-replication apply \
workers from starting (the workers crash silently on launch). \
Remove `pglogical` from `shared_preload_libraries` and restart the \
server before retrying."
.to_string(),
));
}
info!("pglogical is not in shared_preload_libraries — native logical replication will work");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::os::unix::process::ExitStatusExt;
fn ok_status() -> ExitStatus {
ExitStatus::from_raw(0)
}
fn fail_status() -> ExitStatus {
ExitStatus::from_raw(1 << 8) }
#[test]
fn classify_ok_when_version_succeeds() {
assert!(classify_version_check("pg_dump", Ok(ok_status())).is_ok());
}
#[test]
fn classify_missing_tool_when_not_found() {
let err = classify_version_check("pg_dump", Err(io::Error::from(io::ErrorKind::NotFound)))
.unwrap_err();
match err {
MigrationError::MissingTool { tool, reason } => {
assert_eq!(tool, "pg_dump");
assert!(reason.contains("not found in $PATH"));
}
other => panic!("expected MissingTool, got {other:?}"),
}
}
#[test]
fn classify_missing_tool_when_version_exits_nonzero() {
let err = classify_version_check("pg_restore", Ok(fail_status())).unwrap_err();
match err {
MigrationError::MissingTool { tool, reason } => {
assert_eq!(tool, "pg_restore");
assert!(reason.contains("--version"));
}
other => panic!("expected MissingTool, got {other:?}"),
}
}
#[test]
fn classify_missing_tool_for_other_io_errors() {
let err = classify_version_check(
"pg_dump",
Err(io::Error::from(io::ErrorKind::PermissionDenied)),
)
.unwrap_err();
match err {
MigrationError::MissingTool { tool, reason } => {
assert_eq!(tool, "pg_dump");
assert!(reason.contains("failed to spawn"));
}
other => panic!("expected MissingTool, got {other:?}"),
}
}
#[test]
fn missing_tool_error_message_includes_install_hint() {
let err = MigrationError::missing_tool("pg_dump", "not found in $PATH");
let msg = err.to_string();
assert!(msg.contains("pg_dump"));
assert!(msg.contains("not installed or not on $PATH"));
assert!(msg.contains("postgresql-client"));
}
#[test]
fn required_tools_includes_pg_dump_and_pg_restore() {
assert!(REQUIRED_TOOLS.contains(&"pg_dump"));
assert!(REQUIRED_TOOLS.contains(&"pg_restore"));
}
#[tokio::test]
async fn verify_pg_tools_passes_in_test_env() {
let _ = verify_pg_tools_installed().await;
}
#[test]
fn maintenance_conn_swaps_database_name() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432/mydb?sslmode=require"),
"postgresql://u:p@host:5432/postgres?sslmode=require"
);
}
#[test]
fn maintenance_conn_no_query_params() {
assert_eq!(
maintenance_connection_string("postgresql://u:p@host:5432/mydb"),
"postgresql://u:p@host:5432/postgres"
);
}
#[test]
fn maintenance_conn_preserves_multiple_query_params() {
assert_eq!(
maintenance_connection_string(
"postgresql://u:p@host/db1?sslmode=require&connect_timeout=10"
),
"postgresql://u:p@host/postgres?sslmode=require&connect_timeout=10"
);
}
#[test]
fn maintenance_conn_handles_no_password() {
assert_eq!(
maintenance_connection_string("postgresql://u@host/db1?sslmode=require"),
"postgresql://u@host/postgres?sslmode=require"
);
}
}