database_replicator/
preflight.rs

1// ABOUTME: Pre-flight validation checks for replication prerequisites
2// ABOUTME: Validates local environment, network connectivity, and database permissions
3
4use anyhow::Result;
5use tokio_postgres::Client;
6
7/// Individual check result
8#[derive(Debug, Clone)]
9pub struct CheckResult {
10    pub name: String,
11    pub passed: bool,
12    pub message: String,
13    pub details: Option<String>,
14}
15
16impl CheckResult {
17    pub fn pass(name: impl Into<String>, message: impl Into<String>) -> Self {
18        Self {
19            name: name.into(),
20            passed: true,
21            message: message.into(),
22            details: None,
23        }
24    }
25
26    pub fn fail(name: impl Into<String>, message: impl Into<String>) -> Self {
27        Self {
28            name: name.into(),
29            passed: false,
30            message: message.into(),
31            details: None,
32        }
33    }
34
35    pub fn with_details(mut self, details: impl Into<String>) -> Self {
36        self.details = Some(details.into());
37        self
38    }
39}
40
41/// Issue with suggested fixes
42#[derive(Debug, Clone)]
43pub struct PreflightIssue {
44    pub title: String,
45    pub explanation: String,
46    pub fixes: Vec<String>,
47}
48
49/// Complete pre-flight results
50#[derive(Debug, Default)]
51pub struct PreflightResult {
52    pub local_env: Vec<CheckResult>,
53    pub network: Vec<CheckResult>,
54    pub source_permissions: Vec<CheckResult>,
55    pub target_permissions: Vec<CheckResult>,
56    pub issues: Vec<PreflightIssue>,
57    /// True if pg_dump version < source server version
58    pub tool_version_incompatible: bool,
59    pub local_pg_version: Option<u32>,
60    pub source_pg_version: Option<u32>,
61}
62
63impl PreflightResult {
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    pub fn all_passed(&self) -> bool {
69        self.issues.is_empty()
70    }
71
72    pub fn failed_count(&self) -> usize {
73        self.issues.len()
74    }
75
76    /// Print formatted output
77    pub fn print(&self) {
78        println!();
79        println!("Pre-flight Checks");
80        println!("{}", "═".repeat(61));
81        println!();
82
83        if !self.local_env.is_empty() {
84            println!("Local Environment:");
85            for check in &self.local_env {
86                let icon = if check.passed { "✓" } else { "✗" };
87                println!("  {} {}", icon, check.message);
88                if let Some(ref details) = check.details {
89                    println!("      {}", details);
90                }
91            }
92            println!();
93        }
94
95        if !self.network.is_empty() {
96            println!("Network Connectivity:");
97            for check in &self.network {
98                let icon = if check.passed { "✓" } else { "✗" };
99                println!("  {} {}", icon, check.message);
100                if let Some(ref details) = check.details {
101                    println!("      {}", details);
102                }
103            }
104            println!();
105        }
106
107        if !self.source_permissions.is_empty() {
108            println!("Source Permissions:");
109            for check in &self.source_permissions {
110                let icon = if check.passed { "✓" } else { "✗" };
111                println!("  {} {}", icon, check.message);
112                if let Some(ref details) = check.details {
113                    println!("      {}", details);
114                }
115            }
116            println!();
117        }
118
119        if !self.target_permissions.is_empty() {
120            println!("Target Permissions:");
121            for check in &self.target_permissions {
122                let icon = if check.passed { "✓" } else { "✗" };
123                println!("  {} {}", icon, check.message);
124                if let Some(ref details) = check.details {
125                    println!("      {}", details);
126                }
127            }
128            println!();
129        }
130
131        println!("{}", "═".repeat(61));
132        if self.all_passed() {
133            println!("PASSED: All pre-flight checks successful");
134        } else {
135            println!("FAILED: {} issue(s) must be resolved", self.failed_count());
136            println!();
137            for (i, issue) in self.issues.iter().enumerate() {
138                println!("Issue {}: {}", i + 1, issue.title);
139                println!("  {}", issue.explanation);
140                println!();
141                println!("  Fix options:");
142                for fix in &issue.fixes {
143                    println!("    • {}", fix);
144                }
145                println!();
146            }
147        }
148    }
149}
150
151/// Run all pre-flight checks
152///
153/// # Arguments
154///
155/// * `source_url` - PostgreSQL connection string for source
156/// * `target_url` - PostgreSQL connection string for target
157/// * `databases` - Optional list of databases to check permissions for
158///
159/// # Returns
160///
161/// PreflightResult containing all check results
162pub async fn run_preflight_checks(
163    source_url: &str,
164    target_url: &str,
165    _databases: Option<&[String]>,
166) -> Result<PreflightResult> {
167    let mut result = PreflightResult::new();
168
169    // 1. Check local environment (pg_dump, pg_restore, etc.)
170    check_local_environment(&mut result);
171
172    // 2. Check network connectivity and get server versions. Connections are short-lived.
173    let source_client_url = check_network_connectivity(&mut result, source_url, "source").await?;
174    let target_client_url = check_network_connectivity(&mut result, target_url, "target").await?;
175
176    // 3. Check version compatibility (only if we could connect and have local version)
177    if result.local_pg_version.is_some() && result.source_pg_version.is_some() {
178        check_version_compatibility(&mut result);
179    }
180
181    // 4. Check source permissions using a new, short-lived connection
182    if let Some(url) = source_client_url {
183        match crate::postgres::connect_with_retry(&url).await {
184            Ok(client) => {
185                check_source_permissions(&mut result, &client).await;
186                // client is dropped here, closing the connection
187            }
188            Err(e) => {
189                result.source_permissions.push(CheckResult::fail(
190                    "connection",
191                    format!(
192                        "Failed to re-establish connection to source for permission checks: {}",
193                        e
194                    ),
195                ));
196                result.issues.push(PreflightIssue {
197                    title: "Source connection for permissions failed".to_string(),
198                    explanation: e.to_string(),
199                    fixes: vec!["Ensure source database is accessible".to_string()],
200                });
201            }
202        }
203    }
204
205    // 5. Check target permissions using a new, short-lived connection
206    if let Some(url) = target_client_url {
207        match crate::postgres::connect_with_retry(&url).await {
208            Ok(client) => {
209                check_target_permissions(&mut result, &client).await;
210                // client is dropped here, closing the connection
211            }
212            Err(e) => {
213                result.target_permissions.push(CheckResult::fail(
214                    "connection",
215                    format!(
216                        "Failed to re-establish connection to target for permission checks: {}",
217                        e
218                    ),
219                ));
220                result.issues.push(PreflightIssue {
221                    title: "Target connection for permissions failed".to_string(),
222                    explanation: e.to_string(),
223                    fixes: vec!["Ensure target database is accessible".to_string()],
224                });
225            }
226        }
227    }
228
229    Ok(result)
230}
231
232fn check_local_environment(result: &mut PreflightResult) {
233    let tools = ["pg_dump", "pg_dumpall", "pg_restore", "psql"];
234    let mut missing = Vec::new();
235
236    for tool in tools {
237        match which::which(tool) {
238            Ok(path) => {
239                let path_str = path.display().to_string();
240                match crate::utils::get_pg_tool_version(tool) {
241                    Ok(version) => {
242                        if tool == "pg_dump" {
243                            result.local_pg_version = Some(version);
244                        }
245                        result.local_env.push(
246                            CheckResult::pass(tool, format!("{} found", tool))
247                                .with_details(format!("{} ({})", path_str, version)),
248                        );
249                    }
250                    Err(_) => {
251                        result.local_env.push(
252                            CheckResult::pass(tool, format!("{} found", tool))
253                                .with_details(path_str),
254                        );
255                    }
256                }
257            }
258            Err(_) => {
259                missing.push(tool);
260                result.local_env.push(CheckResult::fail(
261                    tool,
262                    format!("{} not found in PATH", tool),
263                ));
264            }
265        }
266    }
267
268    if !missing.is_empty() {
269        result.issues.push(PreflightIssue {
270            title: "Missing PostgreSQL client tools".to_string(),
271            explanation: format!("Required tools not found: {}", missing.join(", ")),
272            fixes: vec![
273                "Ubuntu: sudo apt install postgresql-client-17".to_string(),
274                "macOS: brew install postgresql@17".to_string(),
275                "RHEL: sudo dnf install postgresql17".to_string(),
276            ],
277        });
278    }
279}
280
281async fn check_network_connectivity(
282    result: &mut PreflightResult,
283    db_url: &str,
284    db_type: &str, // "source" or "target"
285) -> Result<Option<String>> {
286    match crate::postgres::connect_with_retry(db_url).await {
287        Ok(client) => {
288            // Also get server version while connected (only for source)
289            if db_type == "source" {
290                if let Ok(row) = client.query_one("SHOW server_version", &[]).await {
291                    let version_str: String = row.get(0);
292                    if let Ok(version) = crate::utils::parse_pg_version_string(&version_str) {
293                        result.source_pg_version = Some(version);
294                    }
295                }
296            }
297            result.network.push(CheckResult::pass(
298                db_type,
299                format!("{} database reachable", db_type),
300            ));
301            Ok(Some(db_url.to_string())) // Return the URL if connection was successful
302        }
303        Err(e) => {
304            result.network.push(CheckResult::fail(
305                db_type,
306                format!("Cannot connect to {}: {}", db_type, e),
307            ));
308            result.issues.push(PreflightIssue {
309                title: format!("{} database unreachable", db_type),
310                explanation: e.to_string(),
311                fixes: vec![
312                    "Verify connection string is correct".to_string(),
313                    "Check network connectivity to database host".to_string(),
314                    "Ensure firewall allows PostgreSQL port (5432)".to_string(),
315                ],
316            });
317            Ok(None) // Return None if connection failed
318        }
319    }
320}
321
322fn check_version_compatibility(result: &mut PreflightResult) {
323    let local = result.local_pg_version.unwrap();
324    let server = result.source_pg_version.unwrap();
325
326    if local < server {
327        result.tool_version_incompatible = true;
328        result.local_env.push(CheckResult::fail(
329            "version",
330            format!("pg_dump version {} < source server {}", local, server),
331        ));
332        result.issues.push(PreflightIssue {
333            title: "PostgreSQL version mismatch".to_string(),
334            explanation: format!(
335                "Local pg_dump ({}) cannot dump from server ({})",
336                local, server
337            ),
338            fixes: vec![
339                format!("Install PostgreSQL {} client tools:", server),
340                format!("  Ubuntu: sudo apt install postgresql-client-{}", server),
341                format!("  macOS: brew install postgresql@{}", server),
342                "Or use SerenAI cloud execution (recommended for SerenDB targets)".to_string(),
343            ],
344        });
345    } else {
346        result.local_env.push(CheckResult::pass(
347            "version",
348            format!("pg_dump version {} >= source server {}", local, server),
349        ));
350    }
351}
352
353async fn check_source_permissions(result: &mut PreflightResult, client: &Client) {
354    // Check REPLICATION privilege (or AWS RDS rds_replication role)
355    match crate::postgres::check_source_privileges(client).await {
356        Ok(privs) => {
357            if privs.can_replicate() {
358                let method = if privs.has_rds_replication {
359                    "Has rds_replication role (AWS RDS)"
360                } else if privs.is_superuser {
361                    "Has superuser privilege"
362                } else {
363                    "Has REPLICATION privilege"
364                };
365                result
366                    .source_permissions
367                    .push(CheckResult::pass("replication", method));
368            } else {
369                result.source_permissions.push(CheckResult::fail(
370                    "replication",
371                    "Missing REPLICATION privilege",
372                ));
373                result.issues.push(PreflightIssue {
374                    title: "Missing REPLICATION privilege".to_string(),
375                    explanation: "Required for continuous sync".to_string(),
376                    fixes: vec![
377                        "Standard PostgreSQL: ALTER USER <username> WITH REPLICATION;".to_string(),
378                        "AWS RDS: GRANT rds_replication TO <username>;".to_string(),
379                    ],
380                });
381            }
382        }
383        Err(e) => {
384            result.source_permissions.push(CheckResult::fail(
385                "privileges",
386                format!("Failed to check: {}", e),
387            ));
388        }
389    }
390
391    // Check table SELECT permissions
392    match crate::postgres::check_table_select_permissions(client).await {
393        Ok(perms) => {
394            if perms.all_accessible() {
395                result.source_permissions.push(CheckResult::pass(
396                    "select",
397                    format!("Has SELECT on all {} tables", perms.accessible_tables.len()),
398                ));
399            } else {
400                let inaccessible = &perms.inaccessible_tables;
401                let count = inaccessible.len();
402                let preview: Vec<&str> = inaccessible.iter().take(5).map(|s| s.as_str()).collect();
403                let details = if count > 5 {
404                    format!("{}, ... ({} more)", preview.join(", "), count - 5)
405                } else {
406                    preview.join(", ")
407                };
408
409                result.source_permissions.push(
410                    CheckResult::fail("select", format!("Missing SELECT on {} tables", count))
411                        .with_details(details),
412                );
413                result.issues.push(PreflightIssue {
414                    title: "Missing table permissions".to_string(),
415                    explanation: format!("User needs SELECT on {} tables", count),
416                    fixes: vec![
417                        "Run: GRANT SELECT ON ALL TABLES IN SCHEMA public TO <username>;"
418                            .to_string(),
419                    ],
420                });
421            }
422        }
423        Err(e) => {
424            result.source_permissions.push(CheckResult::fail(
425                "select",
426                format!("Failed to check table permissions: {}", e),
427            ));
428        }
429    }
430}
431
432async fn check_target_permissions(result: &mut PreflightResult, client: &Client) {
433    match crate::postgres::check_target_privileges(client).await {
434        Ok(privs) => {
435            if privs.has_create_db || privs.is_superuser {
436                result
437                    .target_permissions
438                    .push(CheckResult::pass("createdb", "Can create databases"));
439            } else {
440                result
441                    .target_permissions
442                    .push(CheckResult::fail("createdb", "Cannot create databases"));
443                result.issues.push(PreflightIssue {
444                    title: "Missing CREATEDB privilege".to_string(),
445                    explanation: "Cannot create databases on target".to_string(),
446                    fixes: vec!["Run: ALTER USER <username> CREATEDB;".to_string()],
447                });
448            }
449
450            if privs.can_replicate() {
451                result.target_permissions.push(CheckResult::pass(
452                    "subscription",
453                    "Can create subscriptions",
454                ));
455            } else {
456                result.target_permissions.push(CheckResult::fail(
457                    "subscription",
458                    "Cannot create subscriptions",
459                ));
460            }
461        }
462        Err(e) => {
463            result.target_permissions.push(CheckResult::fail(
464                "privileges",
465                format!("Failed to check: {}", e),
466            ));
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_check_result_pass() {
477        let check = CheckResult::pass("test", "Test passed");
478        assert!(check.passed);
479        assert_eq!(check.name, "test");
480    }
481
482    #[test]
483    fn test_check_result_fail() {
484        let check = CheckResult::fail("test", "Test failed");
485        assert!(!check.passed);
486    }
487
488    #[test]
489    fn test_check_result_with_details() {
490        let check = CheckResult::pass("test", "Test passed").with_details("Some details");
491        assert_eq!(check.details, Some("Some details".to_string()));
492    }
493
494    #[test]
495    fn test_preflight_result_empty_passes() {
496        let result = PreflightResult::new();
497        assert!(result.all_passed());
498        assert_eq!(result.failed_count(), 0);
499    }
500
501    #[test]
502    fn test_preflight_result_with_issues() {
503        let mut result = PreflightResult::new();
504        result.issues.push(PreflightIssue {
505            title: "Test issue".to_string(),
506            explanation: "Test".to_string(),
507            fixes: vec![],
508        });
509        assert!(!result.all_passed());
510        assert_eq!(result.failed_count(), 1);
511    }
512
513    #[test]
514    fn test_preflight_issue_multiple_fixes() {
515        let issue = PreflightIssue {
516            title: "Test".to_string(),
517            explanation: "Details".to_string(),
518            fixes: vec!["Fix 1".to_string(), "Fix 2".to_string()],
519        };
520        assert_eq!(issue.fixes.len(), 2);
521    }
522}