1use anyhow::Result;
5use tokio_postgres::Client;
6
7#[derive(Debug, Clone)]
9pub struct CheckResult {
10 pub name: String,
11 pub passed: bool,
12 pub message: String,
13 pub details: Option<String>,
14}
15
16impl CheckResult {
17 pub fn pass(name: impl Into<String>, message: impl Into<String>) -> Self {
18 Self {
19 name: name.into(),
20 passed: true,
21 message: message.into(),
22 details: None,
23 }
24 }
25
26 pub fn fail(name: impl Into<String>, message: impl Into<String>) -> Self {
27 Self {
28 name: name.into(),
29 passed: false,
30 message: message.into(),
31 details: None,
32 }
33 }
34
35 pub fn with_details(mut self, details: impl Into<String>) -> Self {
36 self.details = Some(details.into());
37 self
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct PreflightIssue {
44 pub title: String,
45 pub explanation: String,
46 pub fixes: Vec<String>,
47}
48
49#[derive(Debug, Default)]
51pub struct PreflightResult {
52 pub local_env: Vec<CheckResult>,
53 pub network: Vec<CheckResult>,
54 pub source_permissions: Vec<CheckResult>,
55 pub target_permissions: Vec<CheckResult>,
56 pub issues: Vec<PreflightIssue>,
57 pub tool_version_incompatible: bool,
59 pub local_pg_version: Option<u32>,
60 pub source_pg_version: Option<u32>,
61}
62
63impl PreflightResult {
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn all_passed(&self) -> bool {
69 self.issues.is_empty()
70 }
71
72 pub fn failed_count(&self) -> usize {
73 self.issues.len()
74 }
75
76 pub fn print(&self) {
78 println!();
79 println!("Pre-flight Checks");
80 println!("{}", "═".repeat(61));
81 println!();
82
83 if !self.local_env.is_empty() {
84 println!("Local Environment:");
85 for check in &self.local_env {
86 let icon = if check.passed { "✓" } else { "✗" };
87 println!(" {} {}", icon, check.message);
88 if let Some(ref details) = check.details {
89 println!(" {}", details);
90 }
91 }
92 println!();
93 }
94
95 if !self.network.is_empty() {
96 println!("Network Connectivity:");
97 for check in &self.network {
98 let icon = if check.passed { "✓" } else { "✗" };
99 println!(" {} {}", icon, check.message);
100 if let Some(ref details) = check.details {
101 println!(" {}", details);
102 }
103 }
104 println!();
105 }
106
107 if !self.source_permissions.is_empty() {
108 println!("Source Permissions:");
109 for check in &self.source_permissions {
110 let icon = if check.passed { "✓" } else { "✗" };
111 println!(" {} {}", icon, check.message);
112 if let Some(ref details) = check.details {
113 println!(" {}", details);
114 }
115 }
116 println!();
117 }
118
119 if !self.target_permissions.is_empty() {
120 println!("Target Permissions:");
121 for check in &self.target_permissions {
122 let icon = if check.passed { "✓" } else { "✗" };
123 println!(" {} {}", icon, check.message);
124 if let Some(ref details) = check.details {
125 println!(" {}", details);
126 }
127 }
128 println!();
129 }
130
131 println!("{}", "═".repeat(61));
132 if self.all_passed() {
133 println!("PASSED: All pre-flight checks successful");
134 } else {
135 println!("FAILED: {} issue(s) must be resolved", self.failed_count());
136 println!();
137 for (i, issue) in self.issues.iter().enumerate() {
138 println!("Issue {}: {}", i + 1, issue.title);
139 println!(" {}", issue.explanation);
140 println!();
141 println!(" Fix options:");
142 for fix in &issue.fixes {
143 println!(" • {}", fix);
144 }
145 println!();
146 }
147 }
148 }
149}
150
151pub async fn run_preflight_checks(
163 source_url: &str,
164 target_url: &str,
165 _databases: Option<&[String]>,
166) -> Result<PreflightResult> {
167 let mut result = PreflightResult::new();
168
169 check_local_environment(&mut result);
171
172 let source_client_url = check_network_connectivity(&mut result, source_url, "source").await?;
174 let target_client_url = check_network_connectivity(&mut result, target_url, "target").await?;
175
176 if result.local_pg_version.is_some() && result.source_pg_version.is_some() {
178 check_version_compatibility(&mut result);
179 }
180
181 if let Some(url) = source_client_url {
183 match crate::postgres::connect_with_retry(&url).await {
184 Ok(client) => {
185 check_source_permissions(&mut result, &client).await;
186 }
188 Err(e) => {
189 result.source_permissions.push(CheckResult::fail(
190 "connection",
191 format!(
192 "Failed to re-establish connection to source for permission checks: {}",
193 e
194 ),
195 ));
196 result.issues.push(PreflightIssue {
197 title: "Source connection for permissions failed".to_string(),
198 explanation: e.to_string(),
199 fixes: vec!["Ensure source database is accessible".to_string()],
200 });
201 }
202 }
203 }
204
205 if let Some(url) = target_client_url {
207 match crate::postgres::connect_with_retry(&url).await {
208 Ok(client) => {
209 check_target_permissions(&mut result, &client).await;
210 }
212 Err(e) => {
213 result.target_permissions.push(CheckResult::fail(
214 "connection",
215 format!(
216 "Failed to re-establish connection to target for permission checks: {}",
217 e
218 ),
219 ));
220 result.issues.push(PreflightIssue {
221 title: "Target connection for permissions failed".to_string(),
222 explanation: e.to_string(),
223 fixes: vec!["Ensure target database is accessible".to_string()],
224 });
225 }
226 }
227 }
228
229 Ok(result)
230}
231
232fn check_local_environment(result: &mut PreflightResult) {
233 let tools = ["pg_dump", "pg_dumpall", "pg_restore", "psql"];
234 let mut missing = Vec::new();
235
236 for tool in tools {
237 match which::which(tool) {
238 Ok(path) => {
239 let path_str = path.display().to_string();
240 match crate::utils::get_pg_tool_version(tool) {
241 Ok(version) => {
242 if tool == "pg_dump" {
243 result.local_pg_version = Some(version);
244 }
245 result.local_env.push(
246 CheckResult::pass(tool, format!("{} found", tool))
247 .with_details(format!("{} ({})", path_str, version)),
248 );
249 }
250 Err(_) => {
251 result.local_env.push(
252 CheckResult::pass(tool, format!("{} found", tool))
253 .with_details(path_str),
254 );
255 }
256 }
257 }
258 Err(_) => {
259 missing.push(tool);
260 result.local_env.push(CheckResult::fail(
261 tool,
262 format!("{} not found in PATH", tool),
263 ));
264 }
265 }
266 }
267
268 if !missing.is_empty() {
269 result.issues.push(PreflightIssue {
270 title: "Missing PostgreSQL client tools".to_string(),
271 explanation: format!("Required tools not found: {}", missing.join(", ")),
272 fixes: vec![
273 "Ubuntu: sudo apt install postgresql-client-17".to_string(),
274 "macOS: brew install postgresql@17".to_string(),
275 "RHEL: sudo dnf install postgresql17".to_string(),
276 ],
277 });
278 }
279}
280
281async fn check_network_connectivity(
282 result: &mut PreflightResult,
283 db_url: &str,
284 db_type: &str, ) -> Result<Option<String>> {
286 match crate::postgres::connect_with_retry(db_url).await {
287 Ok(client) => {
288 if db_type == "source" {
290 if let Ok(row) = client.query_one("SHOW server_version", &[]).await {
291 let version_str: String = row.get(0);
292 if let Ok(version) = crate::utils::parse_pg_version_string(&version_str) {
293 result.source_pg_version = Some(version);
294 }
295 }
296 }
297 result.network.push(CheckResult::pass(
298 db_type,
299 format!("{} database reachable", db_type),
300 ));
301 Ok(Some(db_url.to_string())) }
303 Err(e) => {
304 result.network.push(CheckResult::fail(
305 db_type,
306 format!("Cannot connect to {}: {}", db_type, e),
307 ));
308 result.issues.push(PreflightIssue {
309 title: format!("{} database unreachable", db_type),
310 explanation: e.to_string(),
311 fixes: vec![
312 "Verify connection string is correct".to_string(),
313 "Check network connectivity to database host".to_string(),
314 "Ensure firewall allows PostgreSQL port (5432)".to_string(),
315 ],
316 });
317 Ok(None) }
319 }
320}
321
322fn check_version_compatibility(result: &mut PreflightResult) {
323 let local = result.local_pg_version.unwrap();
324 let server = result.source_pg_version.unwrap();
325
326 if local < server {
327 result.tool_version_incompatible = true;
328 result.local_env.push(CheckResult::fail(
329 "version",
330 format!("pg_dump version {} < source server {}", local, server),
331 ));
332 result.issues.push(PreflightIssue {
333 title: "PostgreSQL version mismatch".to_string(),
334 explanation: format!(
335 "Local pg_dump ({}) cannot dump from server ({})",
336 local, server
337 ),
338 fixes: vec![
339 format!("Install PostgreSQL {} client tools:", server),
340 format!(" Ubuntu: sudo apt install postgresql-client-{}", server),
341 format!(" macOS: brew install postgresql@{}", server),
342 "Or use SerenAI cloud execution (recommended for SerenDB targets)".to_string(),
343 ],
344 });
345 } else {
346 result.local_env.push(CheckResult::pass(
347 "version",
348 format!("pg_dump version {} >= source server {}", local, server),
349 ));
350 }
351}
352
353async fn check_source_permissions(result: &mut PreflightResult, client: &Client) {
354 match crate::postgres::check_source_privileges(client).await {
356 Ok(privs) => {
357 if privs.can_replicate() {
358 let method = if privs.has_rds_replication {
359 "Has rds_replication role (AWS RDS)"
360 } else if privs.is_superuser {
361 "Has superuser privilege"
362 } else {
363 "Has REPLICATION privilege"
364 };
365 result
366 .source_permissions
367 .push(CheckResult::pass("replication", method));
368 } else {
369 result.source_permissions.push(CheckResult::fail(
370 "replication",
371 "Missing REPLICATION privilege",
372 ));
373 result.issues.push(PreflightIssue {
374 title: "Missing REPLICATION privilege".to_string(),
375 explanation: "Required for continuous sync".to_string(),
376 fixes: vec![
377 "Standard PostgreSQL: ALTER USER <username> WITH REPLICATION;".to_string(),
378 "AWS RDS: GRANT rds_replication TO <username>;".to_string(),
379 ],
380 });
381 }
382 }
383 Err(e) => {
384 result.source_permissions.push(CheckResult::fail(
385 "privileges",
386 format!("Failed to check: {}", e),
387 ));
388 }
389 }
390
391 match crate::postgres::check_table_select_permissions(client).await {
393 Ok(perms) => {
394 if perms.all_accessible() {
395 result.source_permissions.push(CheckResult::pass(
396 "select",
397 format!("Has SELECT on all {} tables", perms.accessible_tables.len()),
398 ));
399 } else {
400 let inaccessible = &perms.inaccessible_tables;
401 let count = inaccessible.len();
402 let preview: Vec<&str> = inaccessible.iter().take(5).map(|s| s.as_str()).collect();
403 let details = if count > 5 {
404 format!("{}, ... ({} more)", preview.join(", "), count - 5)
405 } else {
406 preview.join(", ")
407 };
408
409 result.source_permissions.push(
410 CheckResult::fail("select", format!("Missing SELECT on {} tables", count))
411 .with_details(details),
412 );
413 result.issues.push(PreflightIssue {
414 title: "Missing table permissions".to_string(),
415 explanation: format!("User needs SELECT on {} tables", count),
416 fixes: vec![
417 "Run: GRANT SELECT ON ALL TABLES IN SCHEMA public TO <username>;"
418 .to_string(),
419 ],
420 });
421 }
422 }
423 Err(e) => {
424 result.source_permissions.push(CheckResult::fail(
425 "select",
426 format!("Failed to check table permissions: {}", e),
427 ));
428 }
429 }
430}
431
432async fn check_target_permissions(result: &mut PreflightResult, client: &Client) {
433 match crate::postgres::check_target_privileges(client).await {
434 Ok(privs) => {
435 if privs.has_create_db || privs.is_superuser {
436 result
437 .target_permissions
438 .push(CheckResult::pass("createdb", "Can create databases"));
439 } else {
440 result
441 .target_permissions
442 .push(CheckResult::fail("createdb", "Cannot create databases"));
443 result.issues.push(PreflightIssue {
444 title: "Missing CREATEDB privilege".to_string(),
445 explanation: "Cannot create databases on target".to_string(),
446 fixes: vec!["Run: ALTER USER <username> CREATEDB;".to_string()],
447 });
448 }
449
450 if privs.can_replicate() {
451 result.target_permissions.push(CheckResult::pass(
452 "subscription",
453 "Can create subscriptions",
454 ));
455 } else {
456 result.target_permissions.push(CheckResult::fail(
457 "subscription",
458 "Cannot create subscriptions",
459 ));
460 }
461 }
462 Err(e) => {
463 result.target_permissions.push(CheckResult::fail(
464 "privileges",
465 format!("Failed to check: {}", e),
466 ));
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_check_result_pass() {
477 let check = CheckResult::pass("test", "Test passed");
478 assert!(check.passed);
479 assert_eq!(check.name, "test");
480 }
481
482 #[test]
483 fn test_check_result_fail() {
484 let check = CheckResult::fail("test", "Test failed");
485 assert!(!check.passed);
486 }
487
488 #[test]
489 fn test_check_result_with_details() {
490 let check = CheckResult::pass("test", "Test passed").with_details("Some details");
491 assert_eq!(check.details, Some("Some details".to_string()));
492 }
493
494 #[test]
495 fn test_preflight_result_empty_passes() {
496 let result = PreflightResult::new();
497 assert!(result.all_passed());
498 assert_eq!(result.failed_count(), 0);
499 }
500
501 #[test]
502 fn test_preflight_result_with_issues() {
503 let mut result = PreflightResult::new();
504 result.issues.push(PreflightIssue {
505 title: "Test issue".to_string(),
506 explanation: "Test".to_string(),
507 fixes: vec![],
508 });
509 assert!(!result.all_passed());
510 assert_eq!(result.failed_count(), 1);
511 }
512
513 #[test]
514 fn test_preflight_issue_multiple_fixes() {
515 let issue = PreflightIssue {
516 title: "Test".to_string(),
517 explanation: "Details".to_string(),
518 fixes: vec!["Fix 1".to_string(), "Fix 2".to_string()],
519 };
520 assert_eq!(issue.fixes.len(), 2);
521 }
522}