database_replicator/commands/
validate.rs

1// ABOUTME: Pre-flight validation command for migration readiness
2// ABOUTME: Checks connectivity, privileges, and version compatibility
3
4use crate::{migration, postgres, utils};
5use anyhow::{bail, Context, Result};
6
7/// Pre-flight validation command for migration readiness
8///
9/// Performs comprehensive validation before migration:
10/// - Checks for required PostgreSQL client tools (pg_dump, pg_dumpall, psql)
11/// - Validates connection string format
12/// - Tests connectivity to both source and target databases
13/// - Discovers and filters databases based on criteria
14/// - Shows which databases will be replicated
15/// - Verifies source user has REPLICATION privilege
16/// - Verifies target user has CREATEDB privilege
17/// - Confirms PostgreSQL major versions match
18/// - Validates extension compatibility and preload requirements
19///
20/// # Arguments
21///
22/// * `source_url` - PostgreSQL connection string for source database
23/// * `target_url` - PostgreSQL connection string for target (Seren) database
24/// * `filter` - Replication filter for database and table selection
25///
26/// # Returns
27///
28/// Returns `Ok(())` if all validation checks pass.
29///
30/// # Errors
31///
32/// This function will return an error if:
33/// - Required PostgreSQL tools are not installed
34/// - Connection strings are invalid
35/// - Cannot connect to source or target database
36/// - No databases match filter criteria
37/// - Source user lacks REPLICATION privilege
38/// - Target user lacks CREATEDB privilege
39/// - PostgreSQL major versions don't match
40///
41/// # Examples
42///
43/// ```no_run
44/// # use anyhow::Result;
45/// # use database_replicator::commands::validate;
46/// # use database_replicator::filters::ReplicationFilter;
47/// # async fn example() -> Result<()> {
48/// // Validate all databases
49/// validate(
50///     "postgresql://user:pass@source.example.com/postgres",
51///     "postgresql://user:pass@target.example.com/postgres",
52///     ReplicationFilter::empty()
53/// ).await?;
54///
55/// // Validate only specific databases
56/// let filter = ReplicationFilter::new(
57///     Some(vec!["mydb".to_string(), "analytics".to_string()]),
58///     None,
59///     None,
60///     None,
61/// )?;
62/// validate(
63///     "postgresql://user:pass@source.example.com/postgres",
64///     "postgresql://user:pass@target.example.com/postgres",
65///     filter
66/// ).await?;
67/// # Ok(())
68/// # }
69/// ```
70pub async fn validate(
71    source_url: &str,
72    target_url: &str,
73    filter: crate::filters::ReplicationFilter,
74) -> Result<()> {
75    tracing::info!("Starting validation...");
76
77    // Step 0a: Check for required tools
78    tracing::info!("Checking for required PostgreSQL client tools...");
79    utils::check_required_tools().context("Required tools check failed")?;
80    tracing::info!("✓ Required tools found (pg_dump, pg_dumpall, psql)");
81
82    // Step 0b: Validate connection strings
83    tracing::info!("Validating connection strings...");
84    utils::validate_connection_string(source_url).context("Invalid source connection string")?;
85    utils::validate_connection_string(target_url).context("Invalid target connection string")?;
86    tracing::info!("✓ Connection strings are valid");
87
88    // Step 0c: Ensure source and target are different
89    tracing::info!("Verifying source and target are different databases...");
90    utils::validate_source_target_different(source_url, target_url)
91        .context("Source and target validation failed")?;
92    tracing::info!("✓ Source and target are different databases");
93
94    // Step 1: Connect to source
95    tracing::info!("Connecting to source database...");
96    let source_client = postgres::connect(source_url)
97        .await
98        .context("Failed to connect to source database")?;
99    tracing::info!("✓ Connected to source");
100
101    // Step 2: Discover and filter databases
102    tracing::info!("Discovering databases on source...");
103    let all_databases = migration::list_databases(&source_client)
104        .await
105        .context("Failed to list databases on source")?;
106
107    // Apply filtering rules
108    let databases: Vec<_> = all_databases
109        .into_iter()
110        .filter(|db| filter.should_replicate_database(&db.name))
111        .collect();
112
113    if databases.is_empty() {
114        if filter.is_empty() {
115            bail!(
116                "No user databases found on source. Only template databases exist.\n\
117                 Cannot proceed with migration - source appears empty."
118            );
119        } else {
120            bail!(
121                "No databases matched the filter criteria.\n\
122                 Check your --include-databases or --exclude-databases settings.\n\
123                 Available databases: {}",
124                migration::list_databases(&source_client)
125                    .await?
126                    .iter()
127                    .map(|db| &db.name)
128                    .cloned()
129                    .collect::<Vec<_>>()
130                    .join(", ")
131            );
132        }
133    }
134
135    tracing::info!("✓ Found {} database(s) to replicate:", databases.len());
136    for db in &databases {
137        tracing::info!("  - {}", db.name);
138    }
139
140    // Show table filtering info if applicable
141    if filter.include_tables().is_some() || filter.exclude_tables().is_some() {
142        tracing::info!("  Table filtering is active - only filtered tables will be replicated");
143    }
144
145    // Step 3: Connect to target
146    tracing::info!("Connecting to target database...");
147    let target_client = postgres::connect(target_url)
148        .await
149        .context("Failed to connect to target database")?;
150    tracing::info!("✓ Connected to target");
151
152    // Step 4: Check source privileges
153    tracing::info!("Checking source privileges...");
154    let source_privs = postgres::check_source_privileges(&source_client).await?;
155    if !source_privs.has_replication && !source_privs.is_superuser {
156        bail!("Source user lacks REPLICATION privilege. Grant with: ALTER USER <user> WITH REPLICATION;");
157    }
158    tracing::info!("✓ Source has replication privileges");
159
160    // Step 5: Check target privileges
161    tracing::info!("Checking target privileges...");
162    let target_privs = postgres::check_target_privileges(&target_client).await?;
163    if !target_privs.has_create_db && !target_privs.is_superuser {
164        bail!(
165            "Target user lacks CREATE DATABASE privilege. Grant with: ALTER USER <user> CREATEDB;"
166        );
167    }
168    if !target_privs.has_create_role && !target_privs.is_superuser {
169        tracing::warn!("⚠ Target user lacks CREATE ROLE privilege. Role migration may fail.");
170    }
171    tracing::info!("✓ Target has sufficient privileges");
172
173    // Step 5a: Check target wal_level for logical replication
174    tracing::info!("Checking target wal_level setting...");
175    let target_wal_level = postgres::check_wal_level(&target_client).await?;
176    if target_wal_level == "logical" {
177        tracing::info!("✓ Target wal_level is set to 'logical' (logical replication supported)");
178    } else {
179        tracing::warn!(
180            "⚠ Target wal_level is set to '{}', but 'logical' is required for continuous sync",
181            target_wal_level
182        );
183        tracing::warn!("  Continuous replication (subscriptions) will not be possible");
184        tracing::warn!("  You can still perform initial snapshot replication");
185    }
186
187    // Step 6: Check PostgreSQL versions
188    tracing::info!("Checking PostgreSQL versions...");
189    let source_version = get_pg_version(&source_client).await?;
190    let target_version = get_pg_version(&target_client).await?;
191
192    if source_version.major != target_version.major {
193        bail!(
194            "PostgreSQL major version mismatch: source={}.{}, target={}.{}. Logical replication requires same major version.",
195            source_version.major, source_version.minor,
196            target_version.major, target_version.minor
197        );
198    }
199    tracing::info!(
200        "✓ Version compatibility confirmed (both {}.{})",
201        source_version.major,
202        source_version.minor
203    );
204
205    // Step 7: Check extension compatibility
206    tracing::info!("Checking extension compatibility...");
207    check_extension_compatibility(&source_client, &target_client).await?;
208    tracing::info!("✓ Extension compatibility confirmed");
209
210    tracing::info!("");
211    tracing::info!("✅ Validation complete - ready for migration");
212    tracing::info!("");
213    tracing::info!(
214        "The following {} database(s) will be replicated:",
215        databases.len()
216    );
217    for db in &databases {
218        tracing::info!("  ✓ {}", db.name);
219    }
220    Ok(())
221}
222
223struct PgVersion {
224    major: u32,
225    minor: u32,
226}
227
228async fn get_pg_version(client: &tokio_postgres::Client) -> Result<PgVersion> {
229    let row = client
230        .query_one("SHOW server_version", &[])
231        .await
232        .context("Failed to get PostgreSQL version")?;
233
234    let version_str: String = row.get(0);
235
236    // Parse version string like "16.2 (Debian 16.2-1.pgdg120+1)"
237    let parts: Vec<&str> = version_str
238        .split_whitespace()
239        .next()
240        .unwrap_or("0.0")
241        .split('.')
242        .collect();
243
244    let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0);
245    let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
246
247    Ok(PgVersion { major, minor })
248}
249
250async fn check_extension_compatibility(
251    source_client: &tokio_postgres::Client,
252    target_client: &tokio_postgres::Client,
253) -> Result<()> {
254    // Get installed extensions from source
255    let source_extensions = postgres::get_installed_extensions(source_client)
256        .await
257        .context("Failed to get source extensions")?;
258
259    // If no extensions on source (besides plpgsql), skip checks
260    if source_extensions.is_empty() {
261        tracing::info!("  No extensions found on source database");
262        return Ok(());
263    }
264
265    tracing::info!(
266        "  Found {} extension(s) on source: {}",
267        source_extensions.len(),
268        source_extensions
269            .iter()
270            .map(|e| &e.name)
271            .cloned()
272            .collect::<Vec<_>>()
273            .join(", ")
274    );
275
276    // Get available extensions on target
277    let target_available = postgres::get_available_extensions(target_client)
278        .await
279        .context("Failed to get target available extensions")?;
280
281    // Get preloaded libraries on target
282    let target_preloaded = postgres::get_preloaded_libraries(target_client)
283        .await
284        .context("Failed to get target preloaded libraries")?;
285
286    let mut errors = Vec::new();
287    let mut warnings = Vec::new();
288
289    // Check each source extension
290    for source_ext in &source_extensions {
291        // Check if extension is available on target
292        let target_ext = target_available.iter().find(|e| e.name == source_ext.name);
293
294        match target_ext {
295            None => {
296                errors.push(format!(
297                    "Extension '{}' (version {}) is required but not available on target",
298                    source_ext.name, source_ext.version
299                ));
300            }
301            Some(target) => {
302                // Check if extension requires preloading
303                if postgres::requires_preload(&source_ext.name) {
304                    let is_preloaded = target_preloaded.iter().any(|lib| lib == &source_ext.name);
305
306                    if !is_preloaded {
307                        errors.push(format!(
308                            "Extension '{}' requires preloading but is not in shared_preload_libraries on target. \
309                             Add to postgresql.conf: shared_preload_libraries = '{}' and restart PostgreSQL.",
310                            source_ext.name, source_ext.name
311                        ));
312                    }
313                }
314
315                // Warn on version mismatch
316                if let Some(target_version) = &target.default_version {
317                    let source_major = source_ext.version.split('.').next().unwrap_or("0");
318                    let target_major = target_version.split('.').next().unwrap_or("0");
319
320                    if source_major != target_major {
321                        warnings.push(format!(
322                            "Extension '{}' version mismatch: source={}, target default={}",
323                            source_ext.name, source_ext.version, target_version
324                        ));
325                    }
326                }
327            }
328        }
329    }
330
331    // Report warnings
332    for warning in &warnings {
333        tracing::warn!("  ⚠ {}", warning);
334    }
335
336    // Report errors and fail if any
337    if !errors.is_empty() {
338        tracing::error!("Extension compatibility check failed:");
339        for error in &errors {
340            tracing::error!("  ✗ {}", error);
341        }
342        bail!("Target database is missing required extensions or configuration. See errors above.");
343    }
344
345    Ok(())
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[tokio::test]
353    #[ignore]
354    async fn test_validate_with_valid_databases_succeeds() {
355        let source = std::env::var("TEST_SOURCE_URL").unwrap();
356        let target = std::env::var("TEST_TARGET_URL").unwrap();
357
358        let filter = crate::filters::ReplicationFilter::empty();
359        let result = validate(&source, &target, filter).await;
360        assert!(result.is_ok());
361    }
362
363    #[tokio::test]
364    async fn test_validate_with_invalid_source_fails() {
365        let filter = crate::filters::ReplicationFilter::empty();
366        let result = validate("invalid-url", "postgresql://localhost/db", filter).await;
367        assert!(result.is_err());
368    }
369
370    #[tokio::test]
371    #[ignore]
372    async fn test_validate_with_database_filter() {
373        let source = std::env::var("TEST_SOURCE_URL").unwrap();
374        let target = std::env::var("TEST_TARGET_URL").unwrap();
375
376        // Create filter that includes only postgres database
377        let filter = crate::filters::ReplicationFilter::new(
378            Some(vec!["postgres".to_string()]),
379            None,
380            None,
381            None,
382        )
383        .expect("Failed to create filter");
384
385        let result = validate(&source, &target, filter).await;
386        assert!(result.is_ok(), "Validate with database filter failed");
387    }
388
389    #[tokio::test]
390    #[ignore]
391    async fn test_validate_with_no_matching_databases_fails() {
392        let source = std::env::var("TEST_SOURCE_URL").unwrap();
393        let target = std::env::var("TEST_TARGET_URL").unwrap();
394
395        // Create filter that matches no databases
396        let filter = crate::filters::ReplicationFilter::new(
397            Some(vec!["nonexistent_database".to_string()]),
398            None,
399            None,
400            None,
401        )
402        .expect("Failed to create filter");
403
404        let result = validate(&source, &target, filter).await;
405        assert!(
406            result.is_err(),
407            "Validate should fail when no databases match filter"
408        );
409        assert!(
410            result
411                .unwrap_err()
412                .to_string()
413                .contains("No databases matched"),
414            "Error message should indicate no databases matched"
415        );
416    }
417}