use std::path::PathBuf;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::error::{MigrationError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationConfig {
pub mode: MigrationMode,
pub source: EndpointConfig,
pub target: EndpointConfig,
pub dump_scope: DumpScope,
pub drop_target_first: bool,
pub jobs: usize,
pub schemas: Vec<String>,
pub tables: Vec<String>,
#[serde(default)]
pub exclude_schemas: Vec<String>,
#[serde(default)]
pub exclude_tables: Vec<String>,
pub online: OnlineOptions,
#[serde(default)]
pub allow_restore_errors: bool,
#[serde(default = "default_true")]
pub no_publications: bool,
#[serde(default = "default_true")]
pub no_subscriptions: bool,
#[serde(default = "default_true")]
pub split_sections: bool,
#[serde(default = "default_compress")]
pub dump_compress: Option<String>,
#[serde(default = "default_true")]
pub no_sync: bool,
#[serde(default = "default_true")]
pub no_comments: bool,
#[serde(default = "default_true")]
pub no_security_labels: bool,
#[serde(default)]
pub no_table_access_method: bool,
#[serde(default)]
pub resume: bool,
#[serde(default)]
pub resume_file: Option<PathBuf>,
#[serde(default)]
pub dump_path: Option<PathBuf>,
pub verbose: bool,
}
fn default_true() -> bool {
true
}
fn default_compress() -> Option<String> {
Some("lz4:1".into())
}
pub fn default_jobs() -> usize {
std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(4)
.clamp(1, 8)
}
impl Default for MigrationConfig {
fn default() -> Self {
Self {
mode: MigrationMode::Offline,
source: EndpointConfig::default(),
target: EndpointConfig::default(),
dump_scope: DumpScope::All,
drop_target_first: false,
jobs: default_jobs(),
schemas: Vec::new(),
tables: Vec::new(),
exclude_schemas: Vec::new(),
exclude_tables: Vec::new(),
online: OnlineOptions::default(),
allow_restore_errors: false,
no_publications: true,
no_subscriptions: true,
split_sections: true,
dump_compress: default_compress(),
no_sync: true,
no_comments: true,
no_security_labels: true,
no_table_access_method: false,
resume: false,
resume_file: None,
dump_path: None,
verbose: false,
}
}
}
impl MigrationConfig {
pub fn validate(&self) -> Result<()> {
if self.source.connection_string.is_empty() {
return Err(MigrationError::config("source connection string is empty"));
}
if self.target.connection_string.is_empty() {
return Err(MigrationError::config("target connection string is empty"));
}
if self.jobs == 0 {
return Err(MigrationError::config("jobs must be >= 1"));
}
if self.resume && self.dump_path.is_none() {
return Err(MigrationError::config(
"--resume requires --dump-path so subsequent runs target the same archive",
));
}
if self.mode == MigrationMode::Online {
self.online.validate()?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MigrationMode {
Offline,
Online,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DumpScope {
All,
SchemaOnly,
DataOnly,
}
impl DumpScope {
pub fn pg_dump_flag(self) -> Option<&'static str> {
match self {
Self::All => None,
Self::SchemaOnly => Some("--schema-only"),
Self::DataOnly => Some("--data-only"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EndpointConfig {
pub connection_string: String,
pub host: String,
pub port: u16,
pub database: String,
pub user: String,
pub password: String,
}
impl EndpointConfig {
pub fn parse(conn: &str) -> Result<Self> {
let url = Url::parse(conn)
.map_err(|e| MigrationError::InvalidConnectionString(format!("{conn}: {e}")))?;
if !matches!(url.scheme(), "postgres" | "postgresql") {
return Err(MigrationError::InvalidConnectionString(format!(
"unsupported scheme `{}`",
url.scheme()
)));
}
let host = url
.host_str()
.ok_or_else(|| {
MigrationError::InvalidConnectionString(format!("{conn}: missing host"))
})?
.to_string();
let port = url.port().unwrap_or(5432);
let database = url
.path()
.trim_start_matches('/')
.split('?')
.next()
.unwrap_or("")
.to_string();
let user = url.username().to_string();
let password = url.password().unwrap_or("").to_string();
Ok(Self {
connection_string: conn.to_string(),
host,
port,
database,
user,
password,
})
}
pub fn redacted(&self) -> String {
if self.password.is_empty() {
self.connection_string.clone()
} else {
self.connection_string
.replacen(&format!(":{}@", self.password), ":****@", 1)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnlineOptions {
pub slot_name: String,
pub publication: String,
pub protocol_version: u32,
#[serde(default = "default_subscription_name")]
pub subscription_name: String,
#[serde(default)]
pub subscription_source_conn: Option<String>,
#[serde(default = "default_true")]
pub drop_subscription_on_cutover: bool,
#[serde(default)]
pub force_clean: bool,
#[serde(default = "default_true")]
pub sync_sequences_on_cutover: bool,
pub apply: ReplicationApplyConfig,
pub cutover: CutoverConfig,
}
fn default_subscription_name() -> String {
"pg_dbmigrator_sub".to_string()
}
impl Default for OnlineOptions {
fn default() -> Self {
Self {
slot_name: "pg_dbmigrator_slot".to_string(),
publication: "pg_dbmigrator_pub".to_string(),
protocol_version: 2,
subscription_name: default_subscription_name(),
subscription_source_conn: None,
drop_subscription_on_cutover: true,
force_clean: false,
sync_sequences_on_cutover: true,
apply: ReplicationApplyConfig::default(),
cutover: CutoverConfig::default(),
}
}
}
impl OnlineOptions {
pub fn validate(&self) -> Result<()> {
if self.slot_name.is_empty() {
return Err(MigrationError::config("slot_name must not be empty"));
}
if self.publication.is_empty() {
return Err(MigrationError::config("publication must not be empty"));
}
if self.protocol_version == 0 || self.protocol_version > 4 {
return Err(MigrationError::config("protocol_version must be in 1..=4"));
}
if self.subscription_name.is_empty() {
return Err(MigrationError::config(
"subscription_name must not be empty",
));
}
self.cutover.validate()?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplicationApplyConfig {
#[serde(with = "humantime_serde_workaround")]
pub feedback_interval: Duration,
#[serde(with = "humantime_serde_workaround")]
pub connection_timeout: Duration,
#[serde(with = "humantime_serde_workaround")]
pub health_check_interval: Duration,
pub max_runtime_seconds: Option<u64>,
}
impl Default for ReplicationApplyConfig {
fn default() -> Self {
Self {
feedback_interval: Duration::from_secs(10),
connection_timeout: Duration::from_secs(30),
health_check_interval: Duration::from_secs(60),
max_runtime_seconds: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CutoverConfig {
#[serde(with = "humantime_serde_workaround")]
pub poll_interval: Duration,
#[serde(with = "humantime_serde_workaround", default = "default_fast_poll")]
pub fast_poll_interval: Duration,
pub lag_threshold_bytes: u64,
}
fn default_fast_poll() -> Duration {
Duration::from_secs(1)
}
impl Default for CutoverConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(5),
fast_poll_interval: default_fast_poll(),
lag_threshold_bytes: 8 * 1024,
}
}
}
impl CutoverConfig {
pub fn validate(&self) -> Result<()> {
if self.poll_interval.is_zero() {
return Err(MigrationError::config("cutover.poll_interval must be > 0"));
}
if self.fast_poll_interval.is_zero() {
return Err(MigrationError::config(
"cutover.fast_poll_interval must be > 0",
));
}
Ok(())
}
}
mod humantime_serde_workaround {
use std::time::Duration;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
d.as_secs().serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
let secs = u64::deserialize(d)?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_basic_uri() {
let ep =
EndpointConfig::parse("postgresql://alice:s3cret@db.example:5433/app").expect("parse");
assert_eq!(ep.host, "db.example");
assert_eq!(ep.port, 5433);
assert_eq!(ep.database, "app");
assert_eq!(ep.user, "alice");
assert_eq!(ep.password, "s3cret");
}
#[test]
fn parses_default_port() {
let ep = EndpointConfig::parse("postgres://u@h/db").unwrap();
assert_eq!(ep.port, 5432);
assert_eq!(ep.user, "u");
assert!(ep.password.is_empty());
}
#[test]
fn rejects_bad_scheme() {
let err = EndpointConfig::parse("mysql://u@h/db").unwrap_err();
assert!(matches!(err, MigrationError::InvalidConnectionString(_)));
}
#[test]
fn rejects_missing_host() {
let err = EndpointConfig::parse("postgresql:///db").unwrap_err();
assert!(matches!(err, MigrationError::InvalidConnectionString(_)));
}
#[test]
fn redacted_masks_password() {
let ep = EndpointConfig::parse("postgresql://u:topsecret@h/db").unwrap();
let redacted = ep.redacted();
assert!(!redacted.contains("topsecret"));
assert!(redacted.contains(":****@"));
}
#[test]
fn redacted_passthrough_when_no_password() {
let ep = EndpointConfig::parse("postgresql://u@h/db").unwrap();
assert_eq!(ep.redacted(), "postgresql://u@h/db");
}
#[test]
fn validate_rejects_zero_jobs() {
let cfg = MigrationConfig {
source: EndpointConfig::parse("postgres://u@s/db").unwrap(),
target: EndpointConfig::parse("postgres://u@t/db").unwrap(),
jobs: 0,
..MigrationConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(matches!(err, MigrationError::Config(_)));
}
#[test]
fn online_options_validate() {
let mut opts = OnlineOptions::default();
assert!(opts.validate().is_ok());
opts.slot_name.clear();
assert!(opts.validate().is_err());
opts.slot_name = "s".into();
opts.publication.clear();
assert!(opts.validate().is_err());
opts.publication = "p".into();
opts.protocol_version = 99;
assert!(opts.validate().is_err());
}
#[test]
fn dump_scope_flag_mapping() {
assert_eq!(DumpScope::All.pg_dump_flag(), None);
assert_eq!(DumpScope::SchemaOnly.pg_dump_flag(), Some("--schema-only"));
assert_eq!(DumpScope::DataOnly.pg_dump_flag(), Some("--data-only"));
}
#[test]
fn cutover_config_default_is_valid() {
let c = CutoverConfig::default();
assert!(c.validate().is_ok());
assert_eq!(c.lag_threshold_bytes, 8 * 1024);
}
#[test]
fn cutover_config_rejects_zero_poll_interval() {
let c = CutoverConfig {
poll_interval: Duration::from_secs(0),
..CutoverConfig::default()
};
let err = c.validate().unwrap_err();
assert!(matches!(err, MigrationError::Config(_)));
}
#[test]
fn online_options_validate_propagates_cutover_error() {
let opts = OnlineOptions {
cutover: CutoverConfig {
poll_interval: Duration::from_secs(0),
..CutoverConfig::default()
},
..OnlineOptions::default()
};
assert!(opts.validate().is_err());
}
#[test]
fn online_options_default_subscription_name() {
let opts = OnlineOptions::default();
assert_eq!(opts.subscription_name, "pg_dbmigrator_sub");
assert!(opts.drop_subscription_on_cutover);
}
#[test]
fn online_options_reject_empty_subscription_name() {
let opts = OnlineOptions {
subscription_name: String::new(),
..OnlineOptions::default()
};
assert!(opts.validate().is_err());
}
#[test]
fn online_options_default_syncs_sequences_on_cutover() {
let opts = OnlineOptions::default();
assert!(opts.sync_sequences_on_cutover);
}
#[test]
fn migration_config_default_has_empty_exclude_lists() {
let cfg = MigrationConfig::default();
assert!(cfg.exclude_schemas.is_empty());
assert!(cfg.exclude_tables.is_empty());
}
}