1use anyhow::Result;
5
6#[derive(Debug, Clone)]
8pub struct CheckResult {
9 pub name: String,
10 pub passed: bool,
11 pub message: String,
12 pub details: Option<String>,
13}
14
15impl CheckResult {
16 pub fn pass(name: impl Into<String>, message: impl Into<String>) -> Self {
17 Self {
18 name: name.into(),
19 passed: true,
20 message: message.into(),
21 details: None,
22 }
23 }
24
25 pub fn fail(name: impl Into<String>, message: impl Into<String>) -> Self {
26 Self {
27 name: name.into(),
28 passed: false,
29 message: message.into(),
30 details: None,
31 }
32 }
33
34 pub fn with_details(mut self, details: impl Into<String>) -> Self {
35 self.details = Some(details.into());
36 self
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct PreflightIssue {
43 pub title: String,
44 pub explanation: String,
45 pub fixes: Vec<String>,
46}
47
48#[derive(Debug, Default)]
50pub struct PreflightResult {
51 pub local_env: Vec<CheckResult>,
52 pub network: Vec<CheckResult>,
53 pub source_permissions: Vec<CheckResult>,
54 pub target_permissions: Vec<CheckResult>,
55 pub issues: Vec<PreflightIssue>,
56 pub tool_version_incompatible: bool,
58 pub local_pg_version: Option<u32>,
59 pub source_pg_version: Option<u32>,
60}
61
62impl PreflightResult {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn all_passed(&self) -> bool {
68 self.issues.is_empty()
69 }
70
71 pub fn failed_count(&self) -> usize {
72 self.issues.len()
73 }
74
75 pub fn print(&self) {
77 println!();
78 println!("Pre-flight Checks");
79 println!("{}", "═".repeat(61));
80 println!();
81
82 if !self.local_env.is_empty() {
83 println!("Local Environment:");
84 for check in &self.local_env {
85 let icon = if check.passed { "✓" } else { "✗" };
86 println!(" {} {}", icon, check.message);
87 if let Some(ref details) = check.details {
88 println!(" {}", details);
89 }
90 }
91 println!();
92 }
93
94 if !self.network.is_empty() {
95 println!("Network Connectivity:");
96 for check in &self.network {
97 let icon = if check.passed { "✓" } else { "✗" };
98 println!(" {} {}", icon, check.message);
99 if let Some(ref details) = check.details {
100 println!(" {}", details);
101 }
102 }
103 println!();
104 }
105
106 if !self.source_permissions.is_empty() {
107 println!("Source Permissions:");
108 for check in &self.source_permissions {
109 let icon = if check.passed { "✓" } else { "✗" };
110 println!(" {} {}", icon, check.message);
111 if let Some(ref details) = check.details {
112 println!(" {}", details);
113 }
114 }
115 println!();
116 }
117
118 if !self.target_permissions.is_empty() {
119 println!("Target Permissions:");
120 for check in &self.target_permissions {
121 let icon = if check.passed { "✓" } else { "✗" };
122 println!(" {} {}", icon, check.message);
123 if let Some(ref details) = check.details {
124 println!(" {}", details);
125 }
126 }
127 println!();
128 }
129
130 println!("{}", "═".repeat(61));
131 if self.all_passed() {
132 println!("PASSED: All pre-flight checks successful");
133 } else {
134 println!("FAILED: {} issue(s) must be resolved", self.failed_count());
135 println!();
136 for (i, issue) in self.issues.iter().enumerate() {
137 println!("Issue {}: {}", i + 1, issue.title);
138 println!(" {}", issue.explanation);
139 println!();
140 println!(" Fix options:");
141 for fix in &issue.fixes {
142 println!(" • {}", fix);
143 }
144 println!();
145 }
146 }
147 }
148}
149
150pub async fn run_preflight_checks(
162 source_url: &str,
163 target_url: &str,
164 _databases: Option<&[String]>,
165) -> Result<PreflightResult> {
166 let mut result = PreflightResult::new();
167
168 check_local_environment(&mut result);
170
171 check_network_connectivity(&mut result, source_url, target_url).await;
173
174 if result.local_pg_version.is_some() && result.source_pg_version.is_some() {
176 check_version_compatibility(&mut result);
177 }
178
179 if result
181 .network
182 .iter()
183 .any(|c| c.name == "source" && c.passed)
184 {
185 check_source_permissions(&mut result, source_url).await;
186 }
187
188 if result
190 .network
191 .iter()
192 .any(|c| c.name == "target" && c.passed)
193 {
194 check_target_permissions(&mut result, target_url).await;
195 }
196
197 Ok(result)
198}
199
200fn check_local_environment(result: &mut PreflightResult) {
201 let tools = ["pg_dump", "pg_dumpall", "pg_restore", "psql"];
202 let mut missing = Vec::new();
203
204 for tool in tools {
205 match which::which(tool) {
206 Ok(path) => {
207 let path_str = path.display().to_string();
208 match crate::utils::get_pg_tool_version(tool) {
209 Ok(version) => {
210 if tool == "pg_dump" {
211 result.local_pg_version = Some(version);
212 }
213 result.local_env.push(
214 CheckResult::pass(tool, format!("{} found", tool))
215 .with_details(format!("{} ({})", path_str, version)),
216 );
217 }
218 Err(_) => {
219 result.local_env.push(
220 CheckResult::pass(tool, format!("{} found", tool))
221 .with_details(path_str),
222 );
223 }
224 }
225 }
226 Err(_) => {
227 missing.push(tool);
228 result.local_env.push(CheckResult::fail(
229 tool,
230 format!("{} not found in PATH", tool),
231 ));
232 }
233 }
234 }
235
236 if !missing.is_empty() {
237 result.issues.push(PreflightIssue {
238 title: "Missing PostgreSQL client tools".to_string(),
239 explanation: format!("Required tools not found: {}", missing.join(", ")),
240 fixes: vec![
241 "Ubuntu: sudo apt install postgresql-client-17".to_string(),
242 "macOS: brew install postgresql@17".to_string(),
243 "RHEL: sudo dnf install postgresql17".to_string(),
244 ],
245 });
246 }
247}
248
249async fn check_network_connectivity(
250 result: &mut PreflightResult,
251 source_url: &str,
252 target_url: &str,
253) {
254 match crate::postgres::connect_with_retry(source_url).await {
256 Ok(client) => {
257 if let Ok(row) = client.query_one("SHOW server_version", &[]).await {
259 let version_str: String = row.get(0);
260 if let Ok(version) = crate::utils::parse_pg_version_string(&version_str) {
261 result.source_pg_version = Some(version);
262 }
263 }
264 result
265 .network
266 .push(CheckResult::pass("source", "Source database reachable"));
267 }
268 Err(e) => {
269 result.network.push(CheckResult::fail(
270 "source",
271 format!("Cannot connect to source: {}", e),
272 ));
273 result.issues.push(PreflightIssue {
274 title: "Source database unreachable".to_string(),
275 explanation: e.to_string(),
276 fixes: vec![
277 "Verify connection string is correct".to_string(),
278 "Check network connectivity to database host".to_string(),
279 "Ensure firewall allows PostgreSQL port (5432)".to_string(),
280 ],
281 });
282 }
283 }
284
285 match crate::postgres::connect_with_retry(target_url).await {
287 Ok(_) => {
288 result
289 .network
290 .push(CheckResult::pass("target", "Target database reachable"));
291 }
292 Err(e) => {
293 result.network.push(CheckResult::fail(
294 "target",
295 format!("Cannot connect to target: {}", e),
296 ));
297 result.issues.push(PreflightIssue {
298 title: "Target database unreachable".to_string(),
299 explanation: e.to_string(),
300 fixes: vec![
301 "Verify connection string is correct".to_string(),
302 "Check network connectivity to database host".to_string(),
303 ],
304 });
305 }
306 }
307}
308
309fn check_version_compatibility(result: &mut PreflightResult) {
310 let local = result.local_pg_version.unwrap();
311 let server = result.source_pg_version.unwrap();
312
313 if local < server {
314 result.tool_version_incompatible = true;
315 result.local_env.push(CheckResult::fail(
316 "version",
317 format!("pg_dump version {} < source server {}", local, server),
318 ));
319 result.issues.push(PreflightIssue {
320 title: "PostgreSQL version mismatch".to_string(),
321 explanation: format!(
322 "Local pg_dump ({}) cannot dump from server ({})",
323 local, server
324 ),
325 fixes: vec![
326 format!("Install PostgreSQL {} client tools:", server),
327 format!(" Ubuntu: sudo apt install postgresql-client-{}", server),
328 format!(" macOS: brew install postgresql@{}", server),
329 "Or use SerenAI cloud execution (recommended for SerenDB targets)".to_string(),
330 ],
331 });
332 } else {
333 result.local_env.push(CheckResult::pass(
334 "version",
335 format!("pg_dump version {} >= source server {}", local, server),
336 ));
337 }
338}
339
340async fn check_source_permissions(result: &mut PreflightResult, source_url: &str) {
341 if let Ok(client) = crate::postgres::connect_with_retry(source_url).await {
342 match crate::postgres::check_source_privileges(&client).await {
344 Ok(privs) => {
345 if privs.can_replicate() {
346 let method = if privs.has_rds_replication {
347 "Has rds_replication role (AWS RDS)"
348 } else if privs.is_superuser {
349 "Has superuser privilege"
350 } else {
351 "Has REPLICATION privilege"
352 };
353 result
354 .source_permissions
355 .push(CheckResult::pass("replication", method));
356 } else {
357 result.source_permissions.push(CheckResult::fail(
358 "replication",
359 "Missing REPLICATION privilege",
360 ));
361 result.issues.push(PreflightIssue {
362 title: "Missing REPLICATION privilege".to_string(),
363 explanation: "Required for continuous sync".to_string(),
364 fixes: vec![
365 "Standard PostgreSQL: ALTER USER <username> WITH REPLICATION;"
366 .to_string(),
367 "AWS RDS: GRANT rds_replication TO <username>;".to_string(),
368 ],
369 });
370 }
371 }
372 Err(e) => {
373 result.source_permissions.push(CheckResult::fail(
374 "privileges",
375 format!("Failed to check: {}", e),
376 ));
377 }
378 }
379
380 match crate::postgres::check_table_select_permissions(&client).await {
382 Ok(perms) => {
383 if perms.all_accessible() {
384 result.source_permissions.push(CheckResult::pass(
385 "select",
386 format!("Has SELECT on all {} tables", perms.accessible_tables.len()),
387 ));
388 } else {
389 let inaccessible = &perms.inaccessible_tables;
390 let count = inaccessible.len();
391 let preview: Vec<&str> =
392 inaccessible.iter().take(5).map(|s| s.as_str()).collect();
393 let details = if count > 5 {
394 format!("{}, ... ({} more)", preview.join(", "), count - 5)
395 } else {
396 preview.join(", ")
397 };
398
399 result.source_permissions.push(
400 CheckResult::fail("select", format!("Missing SELECT on {} tables", count))
401 .with_details(details),
402 );
403 result.issues.push(PreflightIssue {
404 title: "Missing table permissions".to_string(),
405 explanation: format!("User needs SELECT on {} tables", count),
406 fixes: vec![
407 "Run: GRANT SELECT ON ALL TABLES IN SCHEMA public TO <username>;"
408 .to_string(),
409 ],
410 });
411 }
412 }
413 Err(e) => {
414 result.source_permissions.push(CheckResult::fail(
415 "select",
416 format!("Failed to check table permissions: {}", e),
417 ));
418 }
419 }
420 }
421}
422
423async fn check_target_permissions(result: &mut PreflightResult, target_url: &str) {
424 if let Ok(client) = crate::postgres::connect_with_retry(target_url).await {
425 match crate::postgres::check_target_privileges(&client).await {
426 Ok(privs) => {
427 if privs.has_create_db || privs.is_superuser {
428 result
429 .target_permissions
430 .push(CheckResult::pass("createdb", "Can create databases"));
431 } else {
432 result
433 .target_permissions
434 .push(CheckResult::fail("createdb", "Cannot create databases"));
435 result.issues.push(PreflightIssue {
436 title: "Missing CREATEDB privilege".to_string(),
437 explanation: "Cannot create databases on target".to_string(),
438 fixes: vec!["Run: ALTER USER <username> CREATEDB;".to_string()],
439 });
440 }
441
442 if privs.can_replicate() {
443 result.target_permissions.push(CheckResult::pass(
444 "subscription",
445 "Can create subscriptions",
446 ));
447 } else {
448 result.target_permissions.push(CheckResult::fail(
449 "subscription",
450 "Cannot create subscriptions",
451 ));
452 }
453 }
454 Err(e) => {
455 result.target_permissions.push(CheckResult::fail(
456 "privileges",
457 format!("Failed to check: {}", e),
458 ));
459 }
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn test_check_result_pass() {
470 let check = CheckResult::pass("test", "Test passed");
471 assert!(check.passed);
472 assert_eq!(check.name, "test");
473 }
474
475 #[test]
476 fn test_check_result_fail() {
477 let check = CheckResult::fail("test", "Test failed");
478 assert!(!check.passed);
479 }
480
481 #[test]
482 fn test_check_result_with_details() {
483 let check = CheckResult::pass("test", "Test passed").with_details("Some details");
484 assert_eq!(check.details, Some("Some details".to_string()));
485 }
486
487 #[test]
488 fn test_preflight_result_empty_passes() {
489 let result = PreflightResult::new();
490 assert!(result.all_passed());
491 assert_eq!(result.failed_count(), 0);
492 }
493
494 #[test]
495 fn test_preflight_result_with_issues() {
496 let mut result = PreflightResult::new();
497 result.issues.push(PreflightIssue {
498 title: "Test issue".to_string(),
499 explanation: "Test".to_string(),
500 fixes: vec![],
501 });
502 assert!(!result.all_passed());
503 assert_eq!(result.failed_count(), 1);
504 }
505
506 #[test]
507 fn test_preflight_issue_multiple_fixes() {
508 let issue = PreflightIssue {
509 title: "Test".to_string(),
510 explanation: "Details".to_string(),
511 fixes: vec!["Fix 1".to_string(), "Fix 2".to_string()],
512 };
513 assert_eq!(issue.fixes.len(), 2);
514 }
515}