Skip to main content

rivet_cli/preflight/
doctor.rs

1use crate::config::{Config, DestinationType, SourceType};
2use crate::error::Result;
3
4pub fn doctor(config_path: &str) -> Result<()> {
5    println!("rivet doctor: verifying auth for config '{}'", config_path);
6    println!();
7
8    let config = match Config::load(config_path) {
9        Ok(c) => {
10            println!("[OK]  Config parsed successfully");
11            c
12        }
13        Err(e) => {
14            println!("[FAIL] Config error: {}", e);
15            return Err(e);
16        }
17    };
18
19    let mut all_ok = true;
20
21    match check_source_auth(&config) {
22        Ok(()) => println!("[OK]  Source auth ({:?})", config.source.source_type),
23        Err(e) => {
24            all_ok = false;
25            let category = categorize_source_error(&e);
26            println!("[FAIL] Source {}: {}", category, e);
27        }
28    }
29
30    let mut seen_destinations: Vec<String> = Vec::new();
31    for export in &config.exports {
32        let dest_key = format!(
33            "{:?}:{}:{}",
34            export.destination.destination_type,
35            export.destination.bucket.as_deref().unwrap_or("-"),
36            export.destination.endpoint.as_deref().unwrap_or("-"),
37        );
38        if seen_destinations.contains(&dest_key) {
39            continue;
40        }
41        seen_destinations.push(dest_key);
42
43        let label = match export.destination.destination_type {
44            DestinationType::Local => format!(
45                "Local({})",
46                export.destination.path.as_deref().unwrap_or(".")
47            ),
48            DestinationType::S3 => format!(
49                "S3({})",
50                export.destination.bucket.as_deref().unwrap_or("?")
51            ),
52            DestinationType::Gcs => format!(
53                "GCS({})",
54                export.destination.bucket.as_deref().unwrap_or("?")
55            ),
56            DestinationType::Stdout => {
57                log::info!("  Stdout: no auth check needed");
58                continue;
59            }
60        };
61
62        match check_destination_auth(&export.destination) {
63            Ok(()) => println!("[OK]  Destination {}", label),
64            Err(e) => {
65                all_ok = false;
66                let category = categorize_dest_error(&e, &export.destination);
67                println!("[FAIL] Destination {} -- {}: {}", label, category, e);
68            }
69        }
70    }
71
72    println!();
73    if all_ok {
74        println!("All checks passed.");
75    } else {
76        println!("Some checks failed. Fix the issues above before running exports.");
77    }
78
79    Ok(())
80}
81
82fn check_source_auth(config: &Config) -> Result<()> {
83    let url = config.source.resolve_url()?;
84    match config.source.source_type {
85        SourceType::Postgres => {
86            let mut client = postgres::Client::connect(&url, postgres::NoTls)?;
87            client.simple_query("SELECT 1")?;
88            Ok(())
89        }
90        SourceType::Mysql => {
91            let opts = mysql::Opts::from_url(&url)?;
92            let pool = mysql::Pool::new(opts)?;
93            let mut conn = pool.get_conn()?;
94            use mysql::prelude::Queryable;
95            conn.query_drop("SELECT 1")?;
96            Ok(())
97        }
98    }
99}
100
101fn check_destination_auth(dest: &crate::config::DestinationConfig) -> Result<()> {
102    use crate::destination::create_destination;
103    let d = create_destination(dest)?;
104    let probe_key = ".rivet_doctor_probe";
105    let tmp = std::env::temp_dir().join(probe_key);
106    std::fs::write(&tmp, b"ok")?;
107    match d.write(&tmp, probe_key) {
108        Ok(()) => {
109            log::debug!("doctor: probe write succeeded, cleaning up");
110        }
111        Err(e) => {
112            let _ = std::fs::remove_file(&tmp);
113            return Err(e);
114        }
115    }
116    let _ = std::fs::remove_file(&tmp);
117    Ok(())
118}
119
120pub(super) fn categorize_source_error(err: &anyhow::Error) -> &'static str {
121    let msg = err.to_string().to_lowercase();
122    if msg.contains("password") || msg.contains("authentication") || msg.contains("access denied") {
123        "auth error"
124    } else if msg.contains("connect")
125        || msg.contains("refused")
126        || msg.contains("timed out")
127        || msg.contains("could not translate host")
128        || msg.contains("name or service not known")
129    {
130        "connectivity error"
131    } else {
132        "error"
133    }
134}
135
136pub(super) fn categorize_dest_error(
137    err: &anyhow::Error,
138    dest: &crate::config::DestinationConfig,
139) -> &'static str {
140    let msg = err.to_string().to_lowercase();
141    if msg.contains("credential")
142        || msg.contains("permission denied")
143        || msg.contains("access denied")
144        || msg.contains("unauthorized")
145        || msg.contains("forbidden")
146        || msg.contains("invalid_grant")
147        || msg.contains("token")
148    {
149        "auth error"
150    } else if msg.contains("not found") || msg.contains("nosuchbucket") || msg.contains("404") {
151        match dest.destination_type {
152            DestinationType::S3 => "bucket not found",
153            DestinationType::Gcs => "bucket not found",
154            DestinationType::Local | DestinationType::Stdout => "path not found",
155        }
156    } else if msg.contains("connect")
157        || msg.contains("refused")
158        || msg.contains("timed out")
159        || msg.contains("dns")
160        || msg.contains("endpoint")
161    {
162        "connectivity error"
163    } else {
164        "error"
165    }
166}