1use anyhow::Result;
5use tokio_postgres::Client;
6
7#[derive(Debug, Clone)]
9pub struct CheckResult {
10 pub name: String,
11 pub passed: bool,
12 pub message: String,
13 pub details: Option<String>,
14}
15
16impl CheckResult {
17 pub fn pass(name: impl Into<String>, message: impl Into<String>) -> Self {
18 Self {
19 name: name.into(),
20 passed: true,
21 message: message.into(),
22 details: None,
23 }
24 }
25
26 pub fn fail(name: impl Into<String>, message: impl Into<String>) -> Self {
27 Self {
28 name: name.into(),
29 passed: false,
30 message: message.into(),
31 details: None,
32 }
33 }
34
35 pub fn with_details(mut self, details: impl Into<String>) -> Self {
36 self.details = Some(details.into());
37 self
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct PreflightIssue {
44 pub title: String,
45 pub explanation: String,
46 pub fixes: Vec<String>,
47}
48
49#[derive(Debug, Default)]
51pub struct PreflightResult {
52 pub local_env: Vec<CheckResult>,
53 pub network: Vec<CheckResult>,
54 pub source_permissions: Vec<CheckResult>,
55 pub target_permissions: Vec<CheckResult>,
56 pub issues: Vec<PreflightIssue>,
57 pub tool_version_incompatible: bool,
59 pub local_pg_version: Option<u32>,
60 pub source_pg_version: Option<u32>,
61}
62
63impl PreflightResult {
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn all_passed(&self) -> bool {
69 self.issues.is_empty()
70 }
71
72 pub fn failed_count(&self) -> usize {
73 self.issues.len()
74 }
75
76 pub fn print(&self) {
78 println!();
79 println!("Pre-flight Checks");
80 println!("{}", "═".repeat(61));
81 println!();
82
83 if !self.local_env.is_empty() {
84 println!("Local Environment:");
85 for check in &self.local_env {
86 let icon = if check.passed { "✓" } else { "✗" };
87 println!(" {} {}", icon, check.message);
88 if let Some(ref details) = check.details {
89 println!(" {}", details);
90 }
91 }
92 println!();
93 }
94
95 if !self.network.is_empty() {
96 println!("Network Connectivity:");
97 for check in &self.network {
98 let icon = if check.passed { "✓" } else { "✗" };
99 println!(" {} {}", icon, check.message);
100 if let Some(ref details) = check.details {
101 println!(" {}", details);
102 }
103 }
104 println!();
105 }
106
107 if !self.source_permissions.is_empty() {
108 println!("Source Permissions:");
109 for check in &self.source_permissions {
110 let icon = if check.passed { "✓" } else { "✗" };
111 println!(" {} {}", icon, check.message);
112 if let Some(ref details) = check.details {
113 println!(" {}", details);
114 }
115 }
116 println!();
117 }
118
119 if !self.target_permissions.is_empty() {
120 println!("Target Permissions:");
121 for check in &self.target_permissions {
122 let icon = if check.passed { "✓" } else { "✗" };
123 println!(" {} {}", icon, check.message);
124 if let Some(ref details) = check.details {
125 println!(" {}", details);
126 }
127 }
128 println!();
129 }
130
131 println!("{}", "═".repeat(61));
132 if self.all_passed() {
133 println!("PASSED: All pre-flight checks successful");
134 } else {
135 println!("FAILED: {} issue(s) must be resolved", self.failed_count());
136 println!();
137 for (i, issue) in self.issues.iter().enumerate() {
138 println!("Issue {}: {}", i + 1, issue.title);
139 println!(" {}", issue.explanation);
140 println!();
141 println!(" Fix options:");
142 for fix in &issue.fixes {
143 println!(" • {}", fix);
144 }
145 println!();
146 }
147 }
148 }
149}
150
151pub async fn run_preflight_checks(
163 source_url: &str,
164 target_url: &str,
165 _databases: Option<&[String]>,
166) -> Result<PreflightResult> {
167 let mut result = PreflightResult::new();
168
169 check_local_environment(&mut result);
171
172 let source_client_url = check_network_connectivity(&mut result, source_url, "source").await?;
174 let target_client_url = check_network_connectivity(&mut result, target_url, "target").await?;
175
176 if result.local_pg_version.is_some() && result.source_pg_version.is_some() {
178 check_version_compatibility(&mut result);
179 }
180
181 if let Some(url) = source_client_url {
183 match crate::postgres::connect_with_retry(&url).await {
184 Ok(client) => {
185 check_source_permissions(&mut result, &client).await;
186 }
188 Err(e) => {
189 result.source_permissions.push(CheckResult::fail(
190 "connection",
191 format!("Failed to re-establish connection to source for permission checks: {}", e),
192 ));
193 result.issues.push(PreflightIssue {
194 title: "Source connection for permissions failed".to_string(),
195 explanation: e.to_string(),
196 fixes: vec!["Ensure source database is accessible".to_string()],
197 });
198 }
199 }
200 }
201
202 if let Some(url) = target_client_url {
204 match crate::postgres::connect_with_retry(&url).await {
205 Ok(client) => {
206 check_target_permissions(&mut result, &client).await;
207 }
209 Err(e) => {
210 result.target_permissions.push(CheckResult::fail(
211 "connection",
212 format!("Failed to re-establish connection to target for permission checks: {}", e),
213 ));
214 result.issues.push(PreflightIssue {
215 title: "Target connection for permissions failed".to_string(),
216 explanation: e.to_string(),
217 fixes: vec!["Ensure target database is accessible".to_string()],
218 });
219 }
220 }
221 }
222
223 Ok(result)
224}
225
226fn check_local_environment(result: &mut PreflightResult) {
227 let tools = ["pg_dump", "pg_dumpall", "pg_restore", "psql"];
228 let mut missing = Vec::new();
229
230 for tool in tools {
231 match which::which(tool) {
232 Ok(path) => {
233 let path_str = path.display().to_string();
234 match crate::utils::get_pg_tool_version(tool) {
235 Ok(version) => {
236 if tool == "pg_dump" {
237 result.local_pg_version = Some(version);
238 }
239 result.local_env.push(
240 CheckResult::pass(tool, format!("{} found", tool))
241 .with_details(format!("{} ({})", path_str, version)),
242 );
243 }
244 Err(_) => {
245 result.local_env.push(
246 CheckResult::pass(tool, format!("{} found", tool))
247 .with_details(path_str),
248 );
249 }
250 }
251 }
252 Err(_) => {
253 missing.push(tool);
254 result.local_env.push(CheckResult::fail(
255 tool,
256 format!("{} not found in PATH", tool),
257 ));
258 }
259 }
260 }
261
262 if !missing.is_empty() {
263 result.issues.push(PreflightIssue {
264 title: "Missing PostgreSQL client tools".to_string(),
265 explanation: format!("Required tools not found: {}", missing.join(", ")),
266 fixes: vec![
267 "Ubuntu: sudo apt install postgresql-client-17".to_string(),
268 "macOS: brew install postgresql@17".to_string(),
269 "RHEL: sudo dnf install postgresql17".to_string(),
270 ],
271 });
272 }
273}
274
275async fn check_network_connectivity(
276 result: &mut PreflightResult,
277 db_url: &str,
278 db_type: &str, ) -> Result<Option<String>> {
280 match crate::postgres::connect_with_retry(db_url).await {
281 Ok(client) => {
282 if db_type == "source" {
284 if let Ok(row) = client.query_one("SHOW server_version", &[]).await {
285 let version_str: String = row.get(0);
286 if let Ok(version) = crate::utils::parse_pg_version_string(&version_str) {
287 result.source_pg_version = Some(version);
288 }
289 }
290 }
291 result
292 .network
293 .push(CheckResult::pass(db_type, format!("{} database reachable", db_type)));
294 Ok(Some(db_url.to_string())) }
296 Err(e) => {
297 result.network.push(CheckResult::fail(
298 db_type,
299 format!("Cannot connect to {}: {}", db_type, e),
300 ));
301 result.issues.push(PreflightIssue {
302 title: format!("{} database unreachable", db_type),
303 explanation: e.to_string(),
304 fixes: vec![
305 "Verify connection string is correct".to_string(),
306 "Check network connectivity to database host".to_string(),
307 "Ensure firewall allows PostgreSQL port (5432)".to_string(),
308 ],
309 });
310 Ok(None) }
312 }
313}
314
315fn check_version_compatibility(result: &mut PreflightResult) {
316 let local = result.local_pg_version.unwrap();
317 let server = result.source_pg_version.unwrap();
318
319 if local < server {
320 result.tool_version_incompatible = true;
321 result.local_env.push(CheckResult::fail(
322 "version",
323 format!("pg_dump version {} < source server {}", local, server),
324 ));
325 result.issues.push(PreflightIssue {
326 title: "PostgreSQL version mismatch".to_string(),
327 explanation: format!(
328 "Local pg_dump ({}) cannot dump from server ({})",
329 local, server
330 ),
331 fixes: vec![
332 format!("Install PostgreSQL {} client tools:", server),
333 format!(" Ubuntu: sudo apt install postgresql-client-{}", server),
334 format!(" macOS: brew install postgresql@{}", server),
335 "Or use SerenAI cloud execution (recommended for SerenDB targets)".to_string(),
336 ],
337 });
338 } else {
339 result.local_env.push(CheckResult::pass(
340 "version",
341 format!("pg_dump version {} >= source server {}", local, server),
342 ));
343 }
344}
345
346async fn check_source_permissions(result: &mut PreflightResult, client: &Client) {
347 match crate::postgres::check_source_privileges(client).await {
349 Ok(privs) => {
350 if privs.can_replicate() {
351 let method = if privs.has_rds_replication {
352 "Has rds_replication role (AWS RDS)"
353 } else if privs.is_superuser {
354 "Has superuser privilege"
355 } else {
356 "Has REPLICATION privilege"
357 };
358 result
359 .source_permissions
360 .push(CheckResult::pass("replication", method));
361 } else {
362 result.source_permissions.push(CheckResult::fail(
363 "replication",
364 "Missing REPLICATION privilege",
365 ));
366 result.issues.push(PreflightIssue {
367 title: "Missing REPLICATION privilege".to_string(),
368 explanation: "Required for continuous sync".to_string(),
369 fixes: vec![
370 "Standard PostgreSQL: ALTER USER <username> WITH REPLICATION;".to_string(),
371 "AWS RDS: GRANT rds_replication TO <username>;".to_string(),
372 ],
373 });
374 }
375 }
376 Err(e) => {
377 result.source_permissions.push(CheckResult::fail(
378 "privileges",
379 format!("Failed to check: {}", e),
380 ));
381 }
382 }
383
384 match crate::postgres::check_table_select_permissions(client).await {
386 Ok(perms) => {
387 if perms.all_accessible() {
388 result.source_permissions.push(CheckResult::pass(
389 "select",
390 format!("Has SELECT on all {} tables", perms.accessible_tables.len()),
391 ));
392 } else {
393 let inaccessible = &perms.inaccessible_tables;
394 let count = inaccessible.len();
395 let preview: Vec<&str> = inaccessible.iter().take(5).map(|s| s.as_str()).collect();
396 let details = if count > 5 {
397 format!("{}, ... ({} more)", preview.join(", "), count - 5)
398 } else {
399 preview.join(", ")
400 };
401
402 result.source_permissions.push(
403 CheckResult::fail("select", format!("Missing SELECT on {} tables", count))
404 .with_details(details),
405 );
406 result.issues.push(PreflightIssue {
407 title: "Missing table permissions".to_string(),
408 explanation: format!("User needs SELECT on {} tables", count),
409 fixes: vec![
410 "Run: GRANT SELECT ON ALL TABLES IN SCHEMA public TO <username>;"
411 .to_string(),
412 ],
413 });
414 }
415 }
416 Err(e) => {
417 result.source_permissions.push(CheckResult::fail(
418 "select",
419 format!("Failed to check table permissions: {}", e),
420 ));
421 }
422 }
423}
424
425async fn check_target_permissions(result: &mut PreflightResult, client: &Client) {
426 match crate::postgres::check_target_privileges(client).await {
427 Ok(privs) => {
428 if privs.has_create_db || privs.is_superuser {
429 result
430 .target_permissions
431 .push(CheckResult::pass("createdb", "Can create databases"));
432 } else {
433 result
434 .target_permissions
435 .push(CheckResult::fail("createdb", "Cannot create databases"));
436 result.issues.push(PreflightIssue {
437 title: "Missing CREATEDB privilege".to_string(),
438 explanation: "Cannot create databases on target".to_string(),
439 fixes: vec!["Run: ALTER USER <username> CREATEDB;".to_string()],
440 });
441 }
442
443 if privs.can_replicate() {
444 result.target_permissions.push(CheckResult::pass(
445 "subscription",
446 "Can create subscriptions",
447 ));
448 } else {
449 result.target_permissions.push(CheckResult::fail(
450 "subscription",
451 "Cannot create subscriptions",
452 ));
453 }
454 }
455 Err(e) => {
456 result.target_permissions.push(CheckResult::fail(
457 "privileges",
458 format!("Failed to check: {}", e),
459 ));
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn test_check_result_pass() {
470 let check = CheckResult::pass("test", "Test passed");
471 assert!(check.passed);
472 assert_eq!(check.name, "test");
473 }
474
475 #[test]
476 fn test_check_result_fail() {
477 let check = CheckResult::fail("test", "Test failed");
478 assert!(!check.passed);
479 }
480
481 #[test]
482 fn test_check_result_with_details() {
483 let check = CheckResult::pass("test", "Test passed").with_details("Some details");
484 assert_eq!(check.details, Some("Some details".to_string()));
485 }
486
487 #[test]
488 fn test_preflight_result_empty_passes() {
489 let result = PreflightResult::new();
490 assert!(result.all_passed());
491 assert_eq!(result.failed_count(), 0);
492 }
493
494 #[test]
495 fn test_preflight_result_with_issues() {
496 let mut result = PreflightResult::new();
497 result.issues.push(PreflightIssue {
498 title: "Test issue".to_string(),
499 explanation: "Test".to_string(),
500 fixes: vec![],
501 });
502 assert!(!result.all_passed());
503 assert_eq!(result.failed_count(), 1);
504 }
505
506 #[test]
507 fn test_preflight_issue_multiple_fixes() {
508 let issue = PreflightIssue {
509 title: "Test".to_string(),
510 explanation: "Details".to_string(),
511 fixes: vec!["Fix 1".to_string(), "Fix 2".to_string()],
512 };
513 assert_eq!(issue.fixes.len(), 2);
514 }
515}