database_replicator/
preflight.rs

1// ABOUTME: Pre-flight validation checks for replication prerequisites
2// ABOUTME: Validates local environment, network connectivity, and database permissions
3
4use anyhow::Result;
5
6/// Individual check result
7#[derive(Debug, Clone)]
8pub struct CheckResult {
9    pub name: String,
10    pub passed: bool,
11    pub message: String,
12    pub details: Option<String>,
13}
14
15impl CheckResult {
16    pub fn pass(name: impl Into<String>, message: impl Into<String>) -> Self {
17        Self {
18            name: name.into(),
19            passed: true,
20            message: message.into(),
21            details: None,
22        }
23    }
24
25    pub fn fail(name: impl Into<String>, message: impl Into<String>) -> Self {
26        Self {
27            name: name.into(),
28            passed: false,
29            message: message.into(),
30            details: None,
31        }
32    }
33
34    pub fn with_details(mut self, details: impl Into<String>) -> Self {
35        self.details = Some(details.into());
36        self
37    }
38}
39
40/// Issue with suggested fixes
41#[derive(Debug, Clone)]
42pub struct PreflightIssue {
43    pub title: String,
44    pub explanation: String,
45    pub fixes: Vec<String>,
46}
47
48/// Complete pre-flight results
49#[derive(Debug, Default)]
50pub struct PreflightResult {
51    pub local_env: Vec<CheckResult>,
52    pub network: Vec<CheckResult>,
53    pub source_permissions: Vec<CheckResult>,
54    pub target_permissions: Vec<CheckResult>,
55    pub issues: Vec<PreflightIssue>,
56    /// True if pg_dump version < source server version
57    pub tool_version_incompatible: bool,
58    pub local_pg_version: Option<u32>,
59    pub source_pg_version: Option<u32>,
60}
61
62impl PreflightResult {
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    pub fn all_passed(&self) -> bool {
68        self.issues.is_empty()
69    }
70
71    pub fn failed_count(&self) -> usize {
72        self.issues.len()
73    }
74
75    /// Print formatted output
76    pub fn print(&self) {
77        println!();
78        println!("Pre-flight Checks");
79        println!("{}", "═".repeat(61));
80        println!();
81
82        if !self.local_env.is_empty() {
83            println!("Local Environment:");
84            for check in &self.local_env {
85                let icon = if check.passed { "✓" } else { "✗" };
86                println!("  {} {}", icon, check.message);
87                if let Some(ref details) = check.details {
88                    println!("      {}", details);
89                }
90            }
91            println!();
92        }
93
94        if !self.network.is_empty() {
95            println!("Network Connectivity:");
96            for check in &self.network {
97                let icon = if check.passed { "✓" } else { "✗" };
98                println!("  {} {}", icon, check.message);
99                if let Some(ref details) = check.details {
100                    println!("      {}", details);
101                }
102            }
103            println!();
104        }
105
106        if !self.source_permissions.is_empty() {
107            println!("Source Permissions:");
108            for check in &self.source_permissions {
109                let icon = if check.passed { "✓" } else { "✗" };
110                println!("  {} {}", icon, check.message);
111                if let Some(ref details) = check.details {
112                    println!("      {}", details);
113                }
114            }
115            println!();
116        }
117
118        if !self.target_permissions.is_empty() {
119            println!("Target Permissions:");
120            for check in &self.target_permissions {
121                let icon = if check.passed { "✓" } else { "✗" };
122                println!("  {} {}", icon, check.message);
123                if let Some(ref details) = check.details {
124                    println!("      {}", details);
125                }
126            }
127            println!();
128        }
129
130        println!("{}", "═".repeat(61));
131        if self.all_passed() {
132            println!("PASSED: All pre-flight checks successful");
133        } else {
134            println!("FAILED: {} issue(s) must be resolved", self.failed_count());
135            println!();
136            for (i, issue) in self.issues.iter().enumerate() {
137                println!("Issue {}: {}", i + 1, issue.title);
138                println!("  {}", issue.explanation);
139                println!();
140                println!("  Fix options:");
141                for fix in &issue.fixes {
142                    println!("    • {}", fix);
143                }
144                println!();
145            }
146        }
147    }
148}
149
150/// Run all pre-flight checks
151///
152/// # Arguments
153///
154/// * `source_url` - PostgreSQL connection string for source
155/// * `target_url` - PostgreSQL connection string for target
156/// * `databases` - Optional list of databases to check permissions for
157///
158/// # Returns
159///
160/// PreflightResult containing all check results
161pub async fn run_preflight_checks(
162    source_url: &str,
163    target_url: &str,
164    _databases: Option<&[String]>,
165) -> Result<PreflightResult> {
166    let mut result = PreflightResult::new();
167
168    // 1. Check local environment (pg_dump, pg_restore, etc.)
169    check_local_environment(&mut result);
170
171    // 2. Check network connectivity
172    check_network_connectivity(&mut result, source_url, target_url).await;
173
174    // 3. Check version compatibility (only if we could connect and have local version)
175    if result.local_pg_version.is_some() && result.source_pg_version.is_some() {
176        check_version_compatibility(&mut result);
177    }
178
179    // 4. Check source permissions
180    if result
181        .network
182        .iter()
183        .any(|c| c.name == "source" && c.passed)
184    {
185        check_source_permissions(&mut result, source_url).await;
186    }
187
188    // 5. Check target permissions
189    if result
190        .network
191        .iter()
192        .any(|c| c.name == "target" && c.passed)
193    {
194        check_target_permissions(&mut result, target_url).await;
195    }
196
197    Ok(result)
198}
199
200fn check_local_environment(result: &mut PreflightResult) {
201    let tools = ["pg_dump", "pg_dumpall", "pg_restore", "psql"];
202    let mut missing = Vec::new();
203
204    for tool in tools {
205        match which::which(tool) {
206            Ok(path) => {
207                let path_str = path.display().to_string();
208                match crate::utils::get_pg_tool_version(tool) {
209                    Ok(version) => {
210                        if tool == "pg_dump" {
211                            result.local_pg_version = Some(version);
212                        }
213                        result.local_env.push(
214                            CheckResult::pass(tool, format!("{} found", tool))
215                                .with_details(format!("{} ({})", path_str, version)),
216                        );
217                    }
218                    Err(_) => {
219                        result.local_env.push(
220                            CheckResult::pass(tool, format!("{} found", tool))
221                                .with_details(path_str),
222                        );
223                    }
224                }
225            }
226            Err(_) => {
227                missing.push(tool);
228                result.local_env.push(CheckResult::fail(
229                    tool,
230                    format!("{} not found in PATH", tool),
231                ));
232            }
233        }
234    }
235
236    if !missing.is_empty() {
237        result.issues.push(PreflightIssue {
238            title: "Missing PostgreSQL client tools".to_string(),
239            explanation: format!("Required tools not found: {}", missing.join(", ")),
240            fixes: vec![
241                "Ubuntu: sudo apt install postgresql-client-17".to_string(),
242                "macOS: brew install postgresql@17".to_string(),
243                "RHEL: sudo dnf install postgresql17".to_string(),
244            ],
245        });
246    }
247}
248
249async fn check_network_connectivity(
250    result: &mut PreflightResult,
251    source_url: &str,
252    target_url: &str,
253) {
254    // Check source
255    match crate::postgres::connect_with_retry(source_url).await {
256        Ok(client) => {
257            // Also get server version while connected
258            if let Ok(row) = client.query_one("SHOW server_version", &[]).await {
259                let version_str: String = row.get(0);
260                if let Ok(version) = crate::utils::parse_pg_version_string(&version_str) {
261                    result.source_pg_version = Some(version);
262                }
263            }
264            result
265                .network
266                .push(CheckResult::pass("source", "Source database reachable"));
267        }
268        Err(e) => {
269            result.network.push(CheckResult::fail(
270                "source",
271                format!("Cannot connect to source: {}", e),
272            ));
273            result.issues.push(PreflightIssue {
274                title: "Source database unreachable".to_string(),
275                explanation: e.to_string(),
276                fixes: vec![
277                    "Verify connection string is correct".to_string(),
278                    "Check network connectivity to database host".to_string(),
279                    "Ensure firewall allows PostgreSQL port (5432)".to_string(),
280                ],
281            });
282        }
283    }
284
285    // Check target
286    match crate::postgres::connect_with_retry(target_url).await {
287        Ok(_) => {
288            result
289                .network
290                .push(CheckResult::pass("target", "Target database reachable"));
291        }
292        Err(e) => {
293            result.network.push(CheckResult::fail(
294                "target",
295                format!("Cannot connect to target: {}", e),
296            ));
297            result.issues.push(PreflightIssue {
298                title: "Target database unreachable".to_string(),
299                explanation: e.to_string(),
300                fixes: vec![
301                    "Verify connection string is correct".to_string(),
302                    "Check network connectivity to database host".to_string(),
303                ],
304            });
305        }
306    }
307}
308
309fn check_version_compatibility(result: &mut PreflightResult) {
310    let local = result.local_pg_version.unwrap();
311    let server = result.source_pg_version.unwrap();
312
313    if local < server {
314        result.tool_version_incompatible = true;
315        result.local_env.push(CheckResult::fail(
316            "version",
317            format!("pg_dump version {} < source server {}", local, server),
318        ));
319        result.issues.push(PreflightIssue {
320            title: "PostgreSQL version mismatch".to_string(),
321            explanation: format!(
322                "Local pg_dump ({}) cannot dump from server ({})",
323                local, server
324            ),
325            fixes: vec![
326                format!("Install PostgreSQL {} client tools:", server),
327                format!("  Ubuntu: sudo apt install postgresql-client-{}", server),
328                format!("  macOS: brew install postgresql@{}", server),
329                "Or use SerenAI cloud execution (recommended for SerenDB targets)".to_string(),
330            ],
331        });
332    } else {
333        result.local_env.push(CheckResult::pass(
334            "version",
335            format!("pg_dump version {} >= source server {}", local, server),
336        ));
337    }
338}
339
340async fn check_source_permissions(result: &mut PreflightResult, source_url: &str) {
341    if let Ok(client) = crate::postgres::connect_with_retry(source_url).await {
342        // Check REPLICATION privilege (or AWS RDS rds_replication role)
343        match crate::postgres::check_source_privileges(&client).await {
344            Ok(privs) => {
345                if privs.can_replicate() {
346                    let method = if privs.has_rds_replication {
347                        "Has rds_replication role (AWS RDS)"
348                    } else if privs.is_superuser {
349                        "Has superuser privilege"
350                    } else {
351                        "Has REPLICATION privilege"
352                    };
353                    result
354                        .source_permissions
355                        .push(CheckResult::pass("replication", method));
356                } else {
357                    result.source_permissions.push(CheckResult::fail(
358                        "replication",
359                        "Missing REPLICATION privilege",
360                    ));
361                    result.issues.push(PreflightIssue {
362                        title: "Missing REPLICATION privilege".to_string(),
363                        explanation: "Required for continuous sync".to_string(),
364                        fixes: vec![
365                            "Standard PostgreSQL: ALTER USER <username> WITH REPLICATION;"
366                                .to_string(),
367                            "AWS RDS: GRANT rds_replication TO <username>;".to_string(),
368                        ],
369                    });
370                }
371            }
372            Err(e) => {
373                result.source_permissions.push(CheckResult::fail(
374                    "privileges",
375                    format!("Failed to check: {}", e),
376                ));
377            }
378        }
379
380        // Check table SELECT permissions
381        match crate::postgres::check_table_select_permissions(&client).await {
382            Ok(perms) => {
383                if perms.all_accessible() {
384                    result.source_permissions.push(CheckResult::pass(
385                        "select",
386                        format!("Has SELECT on all {} tables", perms.accessible_tables.len()),
387                    ));
388                } else {
389                    let inaccessible = &perms.inaccessible_tables;
390                    let count = inaccessible.len();
391                    let preview: Vec<&str> =
392                        inaccessible.iter().take(5).map(|s| s.as_str()).collect();
393                    let details = if count > 5 {
394                        format!("{}, ... ({} more)", preview.join(", "), count - 5)
395                    } else {
396                        preview.join(", ")
397                    };
398
399                    result.source_permissions.push(
400                        CheckResult::fail("select", format!("Missing SELECT on {} tables", count))
401                            .with_details(details),
402                    );
403                    result.issues.push(PreflightIssue {
404                        title: "Missing table permissions".to_string(),
405                        explanation: format!("User needs SELECT on {} tables", count),
406                        fixes: vec![
407                            "Run: GRANT SELECT ON ALL TABLES IN SCHEMA public TO <username>;"
408                                .to_string(),
409                        ],
410                    });
411                }
412            }
413            Err(e) => {
414                result.source_permissions.push(CheckResult::fail(
415                    "select",
416                    format!("Failed to check table permissions: {}", e),
417                ));
418            }
419        }
420    }
421}
422
423async fn check_target_permissions(result: &mut PreflightResult, target_url: &str) {
424    if let Ok(client) = crate::postgres::connect_with_retry(target_url).await {
425        match crate::postgres::check_target_privileges(&client).await {
426            Ok(privs) => {
427                if privs.has_create_db || privs.is_superuser {
428                    result
429                        .target_permissions
430                        .push(CheckResult::pass("createdb", "Can create databases"));
431                } else {
432                    result
433                        .target_permissions
434                        .push(CheckResult::fail("createdb", "Cannot create databases"));
435                    result.issues.push(PreflightIssue {
436                        title: "Missing CREATEDB privilege".to_string(),
437                        explanation: "Cannot create databases on target".to_string(),
438                        fixes: vec!["Run: ALTER USER <username> CREATEDB;".to_string()],
439                    });
440                }
441
442                if privs.can_replicate() {
443                    result.target_permissions.push(CheckResult::pass(
444                        "subscription",
445                        "Can create subscriptions",
446                    ));
447                } else {
448                    result.target_permissions.push(CheckResult::fail(
449                        "subscription",
450                        "Cannot create subscriptions",
451                    ));
452                }
453            }
454            Err(e) => {
455                result.target_permissions.push(CheckResult::fail(
456                    "privileges",
457                    format!("Failed to check: {}", e),
458                ));
459            }
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn test_check_result_pass() {
470        let check = CheckResult::pass("test", "Test passed");
471        assert!(check.passed);
472        assert_eq!(check.name, "test");
473    }
474
475    #[test]
476    fn test_check_result_fail() {
477        let check = CheckResult::fail("test", "Test failed");
478        assert!(!check.passed);
479    }
480
481    #[test]
482    fn test_check_result_with_details() {
483        let check = CheckResult::pass("test", "Test passed").with_details("Some details");
484        assert_eq!(check.details, Some("Some details".to_string()));
485    }
486
487    #[test]
488    fn test_preflight_result_empty_passes() {
489        let result = PreflightResult::new();
490        assert!(result.all_passed());
491        assert_eq!(result.failed_count(), 0);
492    }
493
494    #[test]
495    fn test_preflight_result_with_issues() {
496        let mut result = PreflightResult::new();
497        result.issues.push(PreflightIssue {
498            title: "Test issue".to_string(),
499            explanation: "Test".to_string(),
500            fixes: vec![],
501        });
502        assert!(!result.all_passed());
503        assert_eq!(result.failed_count(), 1);
504    }
505
506    #[test]
507    fn test_preflight_issue_multiple_fixes() {
508        let issue = PreflightIssue {
509            title: "Test".to_string(),
510            explanation: "Details".to_string(),
511            fixes: vec!["Fix 1".to_string(), "Fix 2".to_string()],
512        };
513        assert_eq!(issue.fixes.len(), 2);
514    }
515}