1use anyhow::Result;
5use tokio_postgres::Client;
6
7#[derive(Debug, Clone)]
9pub struct CheckResult {
10 pub name: String,
11 pub passed: bool,
12 pub message: String,
13 pub details: Option<String>,
14}
15
16impl CheckResult {
17 pub fn pass(name: impl Into<String>, message: impl Into<String>) -> Self {
18 Self {
19 name: name.into(),
20 passed: true,
21 message: message.into(),
22 details: None,
23 }
24 }
25
26 pub fn fail(name: impl Into<String>, message: impl Into<String>) -> Self {
27 Self {
28 name: name.into(),
29 passed: false,
30 message: message.into(),
31 details: None,
32 }
33 }
34
35 pub fn with_details(mut self, details: impl Into<String>) -> Self {
36 self.details = Some(details.into());
37 self
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct PreflightIssue {
44 pub title: String,
45 pub explanation: String,
46 pub fixes: Vec<String>,
47}
48
49#[derive(Debug, Default)]
51pub struct PreflightResult {
52 pub local_env: Vec<CheckResult>,
53 pub network: Vec<CheckResult>,
54 pub source_permissions: Vec<CheckResult>,
55 pub target_permissions: Vec<CheckResult>,
56 pub issues: Vec<PreflightIssue>,
57 pub tool_version_incompatible: bool,
59 pub local_pg_version: Option<u32>,
60 pub source_pg_version: Option<u32>,
61}
62
63impl PreflightResult {
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn all_passed(&self) -> bool {
69 self.issues.is_empty()
70 }
71
72 pub fn failed_count(&self) -> usize {
73 self.issues.len()
74 }
75
76 pub fn print(&self) {
78 println!();
79 println!("Pre-flight Checks");
80 println!("{}", "═".repeat(61));
81 println!();
82
83 if !self.local_env.is_empty() {
84 println!("Local Environment:");
85 for check in &self.local_env {
86 let icon = if check.passed { "✓" } else { "✗" };
87 println!(" {} {}", icon, check.message);
88 if let Some(ref details) = check.details {
89 println!(" {}", details);
90 }
91 }
92 println!();
93 }
94
95 if !self.network.is_empty() {
96 println!("Network Connectivity:");
97 for check in &self.network {
98 let icon = if check.passed { "✓" } else { "✗" };
99 println!(" {} {}", icon, check.message);
100 if let Some(ref details) = check.details {
101 println!(" {}", details);
102 }
103 }
104 println!();
105 }
106
107 if !self.source_permissions.is_empty() {
108 println!("Source Permissions:");
109 for check in &self.source_permissions {
110 let icon = if check.passed { "✓" } else { "✗" };
111 println!(" {} {}", icon, check.message);
112 if let Some(ref details) = check.details {
113 println!(" {}", details);
114 }
115 }
116 println!();
117 }
118
119 if !self.target_permissions.is_empty() {
120 println!("Target Permissions:");
121 for check in &self.target_permissions {
122 let icon = if check.passed { "✓" } else { "✗" };
123 println!(" {} {}", icon, check.message);
124 if let Some(ref details) = check.details {
125 println!(" {}", details);
126 }
127 }
128 println!();
129 }
130
131 println!("{}", "═".repeat(61));
132 if self.all_passed() {
133 println!("PASSED: All pre-flight checks successful");
134 } else {
135 println!("FAILED: {} issue(s) must be resolved", self.failed_count());
136 println!();
137 for (i, issue) in self.issues.iter().enumerate() {
138 println!("Issue {}: {}", i + 1, issue.title);
139 println!(" {}", issue.explanation);
140 println!();
141 println!(" Fix options:");
142 for fix in &issue.fixes {
143 println!(" • {}", fix);
144 }
145 println!();
146 }
147 }
148 }
149}
150
151pub async fn run_preflight_checks(
163 source_url: &str,
164 target_url: &str,
165 _databases: Option<&[String]>,
166) -> Result<PreflightResult> {
167 let mut result = PreflightResult::new();
168
169 check_local_environment(&mut result);
171
172 let clients = check_network_connectivity(&mut result, source_url, target_url).await;
174
175 if result.local_pg_version.is_some() && result.source_pg_version.is_some() {
177 check_version_compatibility(&mut result);
178 }
179
180 if let Some(ref client) = clients.source {
182 check_source_permissions(&mut result, client).await;
183 }
184
185 if let Some(ref client) = clients.target {
187 check_target_permissions(&mut result, client).await;
188 }
189
190 Ok(result)
191}
192
193fn check_local_environment(result: &mut PreflightResult) {
194 let tools = ["pg_dump", "pg_dumpall", "pg_restore", "psql"];
195 let mut missing = Vec::new();
196
197 for tool in tools {
198 match which::which(tool) {
199 Ok(path) => {
200 let path_str = path.display().to_string();
201 match crate::utils::get_pg_tool_version(tool) {
202 Ok(version) => {
203 if tool == "pg_dump" {
204 result.local_pg_version = Some(version);
205 }
206 result.local_env.push(
207 CheckResult::pass(tool, format!("{} found", tool))
208 .with_details(format!("{} ({})", path_str, version)),
209 );
210 }
211 Err(_) => {
212 result.local_env.push(
213 CheckResult::pass(tool, format!("{} found", tool))
214 .with_details(path_str),
215 );
216 }
217 }
218 }
219 Err(_) => {
220 missing.push(tool);
221 result.local_env.push(CheckResult::fail(
222 tool,
223 format!("{} not found in PATH", tool),
224 ));
225 }
226 }
227 }
228
229 if !missing.is_empty() {
230 result.issues.push(PreflightIssue {
231 title: "Missing PostgreSQL client tools".to_string(),
232 explanation: format!("Required tools not found: {}", missing.join(", ")),
233 fixes: vec![
234 "Ubuntu: sudo apt install postgresql-client-17".to_string(),
235 "macOS: brew install postgresql@17".to_string(),
236 "RHEL: sudo dnf install postgresql17".to_string(),
237 ],
238 });
239 }
240}
241
242#[derive(Default)]
243struct ConnectivityClients {
244 source: Option<Client>,
245 target: Option<Client>,
246}
247
248async fn check_network_connectivity(
249 result: &mut PreflightResult,
250 source_url: &str,
251 target_url: &str,
252) -> ConnectivityClients {
253 let mut clients = ConnectivityClients::default();
254
255 match crate::postgres::connect_with_retry(source_url).await {
257 Ok(client) => {
258 if let Ok(row) = client.query_one("SHOW server_version", &[]).await {
260 let version_str: String = row.get(0);
261 if let Ok(version) = crate::utils::parse_pg_version_string(&version_str) {
262 result.source_pg_version = Some(version);
263 }
264 }
265 result
266 .network
267 .push(CheckResult::pass("source", "Source database reachable"));
268 clients.source = Some(client);
269 }
270 Err(e) => {
271 result.network.push(CheckResult::fail(
272 "source",
273 format!("Cannot connect to source: {}", e),
274 ));
275 result.issues.push(PreflightIssue {
276 title: "Source database unreachable".to_string(),
277 explanation: e.to_string(),
278 fixes: vec![
279 "Verify connection string is correct".to_string(),
280 "Check network connectivity to database host".to_string(),
281 "Ensure firewall allows PostgreSQL port (5432)".to_string(),
282 ],
283 });
284 }
285 }
286
287 match crate::postgres::connect_with_retry(target_url).await {
289 Ok(client) => {
290 result
291 .network
292 .push(CheckResult::pass("target", "Target database reachable"));
293 clients.target = Some(client);
294 }
295 Err(e) => {
296 result.network.push(CheckResult::fail(
297 "target",
298 format!("Cannot connect to target: {}", e),
299 ));
300 result.issues.push(PreflightIssue {
301 title: "Target database unreachable".to_string(),
302 explanation: e.to_string(),
303 fixes: vec![
304 "Verify connection string is correct".to_string(),
305 "Check network connectivity to database host".to_string(),
306 ],
307 });
308 }
309 }
310
311 clients
312}
313
314fn check_version_compatibility(result: &mut PreflightResult) {
315 let local = result.local_pg_version.unwrap();
316 let server = result.source_pg_version.unwrap();
317
318 if local < server {
319 result.tool_version_incompatible = true;
320 result.local_env.push(CheckResult::fail(
321 "version",
322 format!("pg_dump version {} < source server {}", local, server),
323 ));
324 result.issues.push(PreflightIssue {
325 title: "PostgreSQL version mismatch".to_string(),
326 explanation: format!(
327 "Local pg_dump ({}) cannot dump from server ({})",
328 local, server
329 ),
330 fixes: vec![
331 format!("Install PostgreSQL {} client tools:", server),
332 format!(" Ubuntu: sudo apt install postgresql-client-{}", server),
333 format!(" macOS: brew install postgresql@{}", server),
334 "Or use SerenAI cloud execution (recommended for SerenDB targets)".to_string(),
335 ],
336 });
337 } else {
338 result.local_env.push(CheckResult::pass(
339 "version",
340 format!("pg_dump version {} >= source server {}", local, server),
341 ));
342 }
343}
344
345async fn check_source_permissions(result: &mut PreflightResult, client: &Client) {
346 match crate::postgres::check_source_privileges(client).await {
348 Ok(privs) => {
349 if privs.can_replicate() {
350 let method = if privs.has_rds_replication {
351 "Has rds_replication role (AWS RDS)"
352 } else if privs.is_superuser {
353 "Has superuser privilege"
354 } else {
355 "Has REPLICATION privilege"
356 };
357 result
358 .source_permissions
359 .push(CheckResult::pass("replication", method));
360 } else {
361 result.source_permissions.push(CheckResult::fail(
362 "replication",
363 "Missing REPLICATION privilege",
364 ));
365 result.issues.push(PreflightIssue {
366 title: "Missing REPLICATION privilege".to_string(),
367 explanation: "Required for continuous sync".to_string(),
368 fixes: vec![
369 "Standard PostgreSQL: ALTER USER <username> WITH REPLICATION;".to_string(),
370 "AWS RDS: GRANT rds_replication TO <username>;".to_string(),
371 ],
372 });
373 }
374 }
375 Err(e) => {
376 result.source_permissions.push(CheckResult::fail(
377 "privileges",
378 format!("Failed to check: {}", e),
379 ));
380 }
381 }
382
383 match crate::postgres::check_table_select_permissions(client).await {
385 Ok(perms) => {
386 if perms.all_accessible() {
387 result.source_permissions.push(CheckResult::pass(
388 "select",
389 format!("Has SELECT on all {} tables", perms.accessible_tables.len()),
390 ));
391 } else {
392 let inaccessible = &perms.inaccessible_tables;
393 let count = inaccessible.len();
394 let preview: Vec<&str> = inaccessible.iter().take(5).map(|s| s.as_str()).collect();
395 let details = if count > 5 {
396 format!("{}, ... ({} more)", preview.join(", "), count - 5)
397 } else {
398 preview.join(", ")
399 };
400
401 result.source_permissions.push(
402 CheckResult::fail("select", format!("Missing SELECT on {} tables", count))
403 .with_details(details),
404 );
405 result.issues.push(PreflightIssue {
406 title: "Missing table permissions".to_string(),
407 explanation: format!("User needs SELECT on {} tables", count),
408 fixes: vec![
409 "Run: GRANT SELECT ON ALL TABLES IN SCHEMA public TO <username>;"
410 .to_string(),
411 ],
412 });
413 }
414 }
415 Err(e) => {
416 result.source_permissions.push(CheckResult::fail(
417 "select",
418 format!("Failed to check table permissions: {}", e),
419 ));
420 }
421 }
422}
423
424async fn check_target_permissions(result: &mut PreflightResult, client: &Client) {
425 match crate::postgres::check_target_privileges(client).await {
426 Ok(privs) => {
427 if privs.has_create_db || privs.is_superuser {
428 result
429 .target_permissions
430 .push(CheckResult::pass("createdb", "Can create databases"));
431 } else {
432 result
433 .target_permissions
434 .push(CheckResult::fail("createdb", "Cannot create databases"));
435 result.issues.push(PreflightIssue {
436 title: "Missing CREATEDB privilege".to_string(),
437 explanation: "Cannot create databases on target".to_string(),
438 fixes: vec!["Run: ALTER USER <username> CREATEDB;".to_string()],
439 });
440 }
441
442 if privs.can_replicate() {
443 result.target_permissions.push(CheckResult::pass(
444 "subscription",
445 "Can create subscriptions",
446 ));
447 } else {
448 result.target_permissions.push(CheckResult::fail(
449 "subscription",
450 "Cannot create subscriptions",
451 ));
452 }
453 }
454 Err(e) => {
455 result.target_permissions.push(CheckResult::fail(
456 "privileges",
457 format!("Failed to check: {}", e),
458 ));
459 }
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_check_result_pass() {
469 let check = CheckResult::pass("test", "Test passed");
470 assert!(check.passed);
471 assert_eq!(check.name, "test");
472 }
473
474 #[test]
475 fn test_check_result_fail() {
476 let check = CheckResult::fail("test", "Test failed");
477 assert!(!check.passed);
478 }
479
480 #[test]
481 fn test_check_result_with_details() {
482 let check = CheckResult::pass("test", "Test passed").with_details("Some details");
483 assert_eq!(check.details, Some("Some details".to_string()));
484 }
485
486 #[test]
487 fn test_preflight_result_empty_passes() {
488 let result = PreflightResult::new();
489 assert!(result.all_passed());
490 assert_eq!(result.failed_count(), 0);
491 }
492
493 #[test]
494 fn test_preflight_result_with_issues() {
495 let mut result = PreflightResult::new();
496 result.issues.push(PreflightIssue {
497 title: "Test issue".to_string(),
498 explanation: "Test".to_string(),
499 fixes: vec![],
500 });
501 assert!(!result.all_passed());
502 assert_eq!(result.failed_count(), 1);
503 }
504
505 #[test]
506 fn test_preflight_issue_multiple_fixes() {
507 let issue = PreflightIssue {
508 title: "Test".to_string(),
509 explanation: "Details".to_string(),
510 fixes: vec!["Fix 1".to_string(), "Fix 2".to_string()],
511 };
512 assert_eq!(issue.fixes.len(), 2);
513 }
514}