rivet_cli/preflight/
doctor.rs1use crate::config::{Config, DestinationType, SourceType};
2use crate::error::Result;
3
4pub fn doctor(config_path: &str) -> Result<()> {
5 println!("rivet doctor: verifying auth for config '{}'", config_path);
6 println!();
7
8 let config = match Config::load(config_path) {
9 Ok(c) => {
10 println!("[OK] Config parsed successfully");
11 c
12 }
13 Err(e) => {
14 println!("[FAIL] Config error: {}", e);
15 return Err(e);
16 }
17 };
18
19 let mut all_ok = true;
20
21 match check_source_auth(&config) {
22 Ok(()) => println!("[OK] Source auth ({:?})", config.source.source_type),
23 Err(e) => {
24 all_ok = false;
25 let category = categorize_source_error(&e);
26 println!("[FAIL] Source {}: {}", category, e);
27 }
28 }
29
30 let mut seen_destinations: Vec<String> = Vec::new();
31 for export in &config.exports {
32 let dest_key = format!(
33 "{:?}:{}:{}",
34 export.destination.destination_type,
35 export.destination.bucket.as_deref().unwrap_or("-"),
36 export.destination.endpoint.as_deref().unwrap_or("-"),
37 );
38 if seen_destinations.contains(&dest_key) {
39 continue;
40 }
41 seen_destinations.push(dest_key);
42
43 let label = match export.destination.destination_type {
44 DestinationType::Local => format!(
45 "Local({})",
46 export.destination.path.as_deref().unwrap_or(".")
47 ),
48 DestinationType::S3 => format!(
49 "S3({})",
50 export.destination.bucket.as_deref().unwrap_or("?")
51 ),
52 DestinationType::Gcs => format!(
53 "GCS({})",
54 export.destination.bucket.as_deref().unwrap_or("?")
55 ),
56 DestinationType::Stdout => {
57 log::info!(" Stdout: no auth check needed");
58 continue;
59 }
60 };
61
62 match check_destination_auth(&export.destination) {
63 Ok(()) => println!("[OK] Destination {}", label),
64 Err(e) => {
65 all_ok = false;
66 let category = categorize_dest_error(&e, &export.destination);
67 println!("[FAIL] Destination {} -- {}: {}", label, category, e);
68 }
69 }
70 }
71
72 println!();
73 if all_ok {
74 println!("All checks passed.");
75 } else {
76 println!("Some checks failed. Fix the issues above before running exports.");
77 }
78
79 Ok(())
80}
81
82fn check_source_auth(config: &Config) -> Result<()> {
83 let url = config.source.resolve_url()?;
84 match config.source.source_type {
85 SourceType::Postgres => {
86 let mut client = postgres::Client::connect(&url, postgres::NoTls)?;
87 client.simple_query("SELECT 1")?;
88 Ok(())
89 }
90 SourceType::Mysql => {
91 let opts = mysql::Opts::from_url(&url)?;
92 let pool = mysql::Pool::new(opts)?;
93 let mut conn = pool.get_conn()?;
94 use mysql::prelude::Queryable;
95 conn.query_drop("SELECT 1")?;
96 Ok(())
97 }
98 }
99}
100
101fn check_destination_auth(dest: &crate::config::DestinationConfig) -> Result<()> {
102 use crate::destination::create_destination;
103 let d = create_destination(dest)?;
104 let probe_key = ".rivet_doctor_probe";
105 let tmp = std::env::temp_dir().join(probe_key);
106 std::fs::write(&tmp, b"ok")?;
107 match d.write(&tmp, probe_key) {
108 Ok(()) => {
109 log::debug!("doctor: probe write succeeded, cleaning up");
110 }
111 Err(e) => {
112 let _ = std::fs::remove_file(&tmp);
113 return Err(e);
114 }
115 }
116 let _ = std::fs::remove_file(&tmp);
117 Ok(())
118}
119
120pub(super) fn categorize_source_error(err: &anyhow::Error) -> &'static str {
121 let msg = err.to_string().to_lowercase();
122 if msg.contains("password") || msg.contains("authentication") || msg.contains("access denied") {
123 "auth error"
124 } else if msg.contains("connect")
125 || msg.contains("refused")
126 || msg.contains("timed out")
127 || msg.contains("could not translate host")
128 || msg.contains("name or service not known")
129 {
130 "connectivity error"
131 } else {
132 "error"
133 }
134}
135
136pub(super) fn categorize_dest_error(
137 err: &anyhow::Error,
138 dest: &crate::config::DestinationConfig,
139) -> &'static str {
140 let msg = err.to_string().to_lowercase();
141 if msg.contains("credential")
142 || msg.contains("permission denied")
143 || msg.contains("access denied")
144 || msg.contains("unauthorized")
145 || msg.contains("forbidden")
146 || msg.contains("invalid_grant")
147 || msg.contains("token")
148 {
149 "auth error"
150 } else if msg.contains("not found") || msg.contains("nosuchbucket") || msg.contains("404") {
151 match dest.destination_type {
152 DestinationType::S3 => "bucket not found",
153 DestinationType::Gcs => "bucket not found",
154 DestinationType::Local | DestinationType::Stdout => "path not found",
155 }
156 } else if msg.contains("connect")
157 || msg.contains("refused")
158 || msg.contains("timed out")
159 || msg.contains("dns")
160 || msg.contains("endpoint")
161 {
162 "connectivity error"
163 } else {
164 "error"
165 }
166}