use crate::config::{Config, DestinationType, SourceType};
use crate::error::Result;
pub fn doctor(config_path: &str) -> Result<()> {
println!("rivet doctor: verifying auth for config '{}'", config_path);
println!();
let config = match Config::load(config_path) {
Ok(c) => {
println!("[OK] Config parsed successfully");
c
}
Err(e) => {
println!("[FAIL] Config error: {}", trim_probe_error(&e));
anyhow::bail!("doctor: config check failed (see [FAIL] above)")
}
};
let mut all_ok = true;
match check_source_auth(&config) {
Ok(()) => {
println!("[OK] Source auth ({:?})", config.source.source_type);
note_mssql_harm_permission(&config);
}
Err(e) => {
all_ok = false;
let category = categorize_source_error(&e);
println!("[FAIL] Source {}: {}", category, trim_probe_error(&e));
if let Some(hint) = source_error_hint(category, &e, &config.source.source_type) {
println!(" Hint: {}", hint);
}
}
}
let mut seen_destinations: Vec<String> = Vec::new();
for export in &config.exports {
let dest_key = super::destination_identity(&export.destination);
if seen_destinations.contains(&dest_key) {
continue;
}
seen_destinations.push(dest_key);
let label = match export.destination.destination_type {
DestinationType::Local => format!(
"Local({})",
export.destination.path.as_deref().unwrap_or(".")
),
DestinationType::S3 => format!(
"S3({})",
export.destination.bucket.as_deref().unwrap_or("?")
),
DestinationType::Gcs => format!(
"GCS({})",
export.destination.bucket.as_deref().unwrap_or("?")
),
DestinationType::Azure => format!(
"Azure({})",
export.destination.bucket.as_deref().unwrap_or("?")
),
DestinationType::Stdout => {
println!("[OK] Destination Stdout (streaming; no preflight needed)");
continue;
}
};
let expanded_dest = crate::plan::build::expand_destination_templates(
export.destination.clone(),
&export.name,
);
match check_destination_auth(&expanded_dest) {
Ok(()) => println!("[OK] Destination {}", label),
Err(e) => {
all_ok = false;
let category = categorize_dest_error(&e, &expanded_dest);
println!(
"[FAIL] Destination {} -- {}: {}",
label,
category,
trim_probe_error(&e)
);
if let Some(hint) = destination_error_hint(category, &expanded_dest) {
println!(" Hint: {}", hint);
}
}
}
}
println!();
if all_ok {
println!("All checks passed.");
println!("Next: rivet check -c {config_path} # column-type & schema report");
Ok(())
} else {
println!("Some checks failed. Fix the issues above before running exports.");
anyhow::bail!("doctor: one or more preflight checks failed (see output above)")
}
}
fn check_source_auth(config: &Config) -> Result<()> {
let url = config.source.resolve_url()?;
let tls = config.source.tls.as_ref();
crate::source::warn_if_tls_disabled(&config.source);
match config.source.source_type {
SourceType::Postgres => {
let mut client = crate::source::postgres::connect_client(&url, tls)?;
client.simple_query("SELECT 1")?;
Ok(())
}
SourceType::Mysql => {
let pool = crate::source::mysql::connect_pool(&url, tls)?;
let mut conn = pool.get_conn()?;
use mysql::prelude::Queryable;
conn.query_drop("SELECT 1")?;
Ok(())
}
SourceType::Mssql => {
crate::source::mssql::MssqlSource::connect_with_tls(&url, tls)?;
Ok(())
}
}
}
fn note_mssql_harm_permission(config: &Config) {
if config.source.source_type != SourceType::Mssql {
return;
}
let Ok(url) = config.source.resolve_url() else {
return;
};
if let Some(false) =
crate::source::mssql::sample_view_server_state(&url, config.source.tls.as_ref())
{
println!(
"[note] Source-harm metrics need VIEW SERVER STATE — this SQL Server login lacks it, \
so lock-wait metrics will be skipped. Data extraction is unaffected. \
Grant with: GRANT VIEW SERVER STATE TO [your_login];"
);
}
}
fn check_destination_auth(dest: &crate::config::DestinationConfig) -> Result<()> {
use crate::destination::create_destination_for_probe;
let d = create_destination_for_probe(dest)?;
let probe_key = crate::manifest::DOCTOR_PROBE_FILENAME;
let tmp = std::env::temp_dir().join(probe_key);
std::fs::write(&tmp, b"ok")?;
match d.write(&tmp, probe_key) {
Ok(_) => {
log::debug!("doctor: probe write succeeded, cleaning up");
}
Err(e) => {
let _ = std::fs::remove_file(&tmp);
return Err(e);
}
}
let _ = std::fs::remove_file(&tmp);
remove_destination_probe(dest, probe_key);
Ok(())
}
fn local_base_path(dest: &crate::config::DestinationConfig) -> String {
dest.path
.clone()
.or_else(|| dest.prefix.clone())
.unwrap_or_else(|| ".".to_string())
}
fn remove_destination_probe(dest: &crate::config::DestinationConfig, probe_key: &str) {
match dest.destination_type {
DestinationType::Local => {
let probe_path = std::path::Path::new(&local_base_path(dest)).join(probe_key);
match std::fs::remove_file(&probe_path) {
Ok(()) => log::debug!("doctor: removed destination probe {}", probe_path.display()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => log::warn!(
"doctor: could not remove destination probe {} (left at prefix): {e}",
probe_path.display()
),
}
}
DestinationType::Stdout => {} DestinationType::S3 | DestinationType::Gcs | DestinationType::Azure => {
log::debug!(
"doctor: destination probe '{probe_key}' left at the {:?} prefix \
(no object-delete on this backend); manifest reconcile filters it from listings",
dest.destination_type
);
}
}
}
fn trim_probe_error(err: &anyhow::Error) -> String {
let raw = format!("{err:#}");
let flat = raw.replace(['\n', '\r'], " ");
let lower = flat.to_ascii_lowercase();
let cut = [
", context: {",
" context: {",
" parts {",
", headers: {",
" headers: {",
", response:",
]
.iter()
.filter_map(|m| lower.find(m))
.min();
let mut out = match cut {
Some(i) => flat[..i].trim_end_matches([' ', ',']).to_string(),
None => raw.trim().to_string(),
};
const MAX: usize = 1200;
if out.chars().count() > MAX {
out = out.chars().take(MAX).collect::<String>();
out.push('…');
}
out
}
pub(crate) fn categorize_source_error(err: &anyhow::Error) -> &'static str {
let msg = format!("{err:#}").to_lowercase();
if msg.contains("password")
|| msg.contains("authentication")
|| msg.contains("access denied")
|| msg.contains("login failed")
|| msg.contains("db error")
{
"auth error"
} else if msg.contains("connect")
|| msg.contains("refused")
|| msg.contains("timed out")
|| msg.contains("could not translate host")
|| msg.contains("name or service not known")
{
"connectivity error"
} else {
"error"
}
}
pub(super) fn categorize_dest_error(
err: &anyhow::Error,
dest: &crate::config::DestinationConfig,
) -> &'static str {
let msg = format!("{err:#}").to_lowercase();
if msg.contains("already expired") && msg.contains("sas") {
return "sas expired";
}
if matches!(
dest.destination_type,
DestinationType::Local | DestinationType::Stdout
) && (msg.contains("permission denied")
|| msg.contains("permissiondenied")
|| msg.contains("os error 13"))
{
return "permission error";
}
if msg.contains("credential")
|| msg.contains("permission denied")
|| msg.contains("permissiondenied")
|| msg.contains("access denied")
|| msg.contains("unauthorized")
|| msg.contains("forbidden")
|| msg.contains("invalid_grant")
|| msg.contains("token")
|| msg.contains("invalidaccesskeyid")
|| msg.contains("403")
{
"auth error"
} else if msg.contains("not found") || msg.contains("nosuchbucket") || msg.contains("404") {
match dest.destination_type {
DestinationType::S3 => "bucket not found",
DestinationType::Gcs => "bucket not found",
DestinationType::Azure => "container not found",
DestinationType::Local | DestinationType::Stdout => "path not found",
}
} else if msg.contains("connect")
|| msg.contains("refused")
|| msg.contains("timed out")
|| msg.contains("dns")
|| msg.contains("endpoint")
|| msg.contains("error sending request")
|| msg.contains("send http request")
{
"connectivity error"
} else {
"error"
}
}
pub(crate) fn source_error_hint(
category: &'static str,
err: &anyhow::Error,
source_type: &crate::config::SourceType,
) -> Option<&'static str> {
use crate::config::SourceType;
let msg = err.to_string().to_lowercase();
if msg.contains("tls")
|| msg.contains("ssl")
|| msg.contains("certificate")
|| msg.contains("handshake")
{
return Some(match source_type {
SourceType::Postgres => {
"TLS handshake failed. Try `tls.mode: prefer` (downgrade gracefully) or set `tls.ca_file: /path/to/ca-bundle.pem` if your DB uses a private CA."
}
SourceType::Mysql => {
"TLS handshake failed. Try `tls.mode: prefer` or set `tls.ca_file: /path/to/ca-bundle.pem` to trust the DB's certificate authority."
}
SourceType::Mssql => {
"TLS handshake failed. SQL Server forces TLS on the login handshake; set `tls.ca_file: /path/to/ca-bundle.pem` to trust a private CA, or `tls.accept_invalid_certs: true` for a self-signed dev cert."
}
});
}
match category {
"auth error" => Some(match source_type {
SourceType::Postgres => {
"Verify the user/password and that pg_hba.conf permits your client IP. The user also needs SELECT on the target tables and USAGE on the schema."
}
SourceType::Mysql => {
"Verify the user/password and that the user has SELECT grants on the target tables. MySQL `GRANT SELECT ON db.* TO 'user'@'host'` plus `FLUSH PRIVILEGES`."
}
SourceType::Mssql => {
"Verify the SQL login/password and that the login maps to a database user with SELECT on the target tables (`GRANT SELECT ON dbo.tbl TO [user]`). Check you are pointed at the right database — contained-DB users and server logins are resolved differently."
}
}),
"connectivity error" => Some(
"Verify host/port reachability from this machine. If the DB is behind a bastion or VPN, ensure the tunnel is up before running rivet. `rivet doctor` must run from the same network as `rivet run` will.",
),
_ => None,
}
}
pub(super) fn destination_error_hint(
category: &'static str,
dest: &crate::config::DestinationConfig,
) -> Option<&'static str> {
match category {
"sas expired" => Some(
"Azure SAS token is expired or near-expiry. Generate a new SAS via `az storage container generate-sas --permissions rwdlc --expiry <future-date>` and re-export AZURE_STORAGE_SAS_TOKEN.",
),
"permission error" => Some("Verify filesystem permissions on the destination directory."),
"auth error" => Some(match dest.destination_type {
DestinationType::S3 => {
"Verify AWS credentials resolve (env / profile / instance role) and that the role has s3:PutObject + s3:GetObject + s3:ListBucket on the prefix. See docs/cloud-permissions.md."
}
DestinationType::Gcs => {
"Verify the service account credentials resolve (ADC / env / explicit credentials_file) and that the principal has storage.objects.{create,get,list} on the bucket. See docs/cloud-permissions.md."
}
DestinationType::Azure => {
"Verify Azure credentials. Account-key auth: check account_key_env. SAS auth: regenerate the SAS with rwdlc permissions and a future expiry. See docs/cloud-permissions.md."
}
DestinationType::Local | DestinationType::Stdout => {
"Verify filesystem permissions on the destination directory."
}
}),
"bucket not found" | "container not found" => Some(match dest.destination_type {
DestinationType::S3 => {
"Bucket must already exist; rivet does NOT auto-create. `aws s3 mb s3://<bucket>` (with the right region) before running."
}
DestinationType::Gcs => {
"Bucket must already exist; rivet does NOT auto-create. `gcloud storage buckets create gs://<bucket>` before running."
}
DestinationType::Azure => {
"Container must already exist; rivet does NOT auto-create. `az storage container create --account-name <acct> --name <container>` before running."
}
_ => "Path / bucket / container must already exist.",
}),
"connectivity error" => Some(match dest.destination_type {
DestinationType::S3 => {
"Verify endpoint and region. For non-AWS endpoints (MinIO / R2 / Wasabi) set `endpoint:` explicitly. For AWS, ensure `region:` matches the bucket's region — cross-region writes fail with a confusing redirect error."
}
DestinationType::Gcs => {
"Verify network reachability to storage.googleapis.com. If using a custom endpoint, set `endpoint:` explicitly."
}
DestinationType::Azure => {
"Verify network reachability to <account>.blob.core.windows.net. For Azurite or sovereign clouds, set `endpoint:` explicitly."
}
_ => "Verify network reachability to the destination.",
}),
"path not found" => Some(
"Parent directory must exist. Create it with `mkdir -p` before running, or use a different `path:` in your config.",
),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roast_doctor_write_probes_each_distinct_local_destination_path() {
let dir_a = tempfile::tempdir().unwrap();
let dir_b = tempfile::tempdir().unwrap();
let config_dir = tempfile::tempdir().unwrap();
let leaf_a = dir_a.path().join("probe_here");
let leaf_b = dir_b.path().join("probe_here");
let yaml = format!(
r#"
source:
type: postgres
url_env: RIVET_ROAST_DOCTOR_DEDUP_UNSET_URL_ENV
exports:
- name: roast_dest_a
query: "SELECT 1"
format: csv
destination:
type: local
path: "{a}"
- name: roast_dest_b
query: "SELECT 1"
format: csv
destination:
type: local
path: "{b}"
"#,
a = leaf_a.display(),
b = leaf_b.display(),
);
let config_path = config_dir.path().join("rivet.yaml");
std::fs::write(&config_path, yaml).unwrap();
let _ = doctor(config_path.to_str().unwrap());
let probe = crate::manifest::DOCTOR_PROBE_FILENAME;
for (label, leaf) in [("first", &leaf_a), ("second", &leaf_b)] {
assert!(
leaf.exists(),
"doctor never write-probed the {label} local destination {} — its dedup key \
must include `path`; a key that omits it collapses both local destinations to \
one entry and only probes the first, so an unwritable second directory would \
pass doctor and fail at run time",
leaf.display()
);
assert!(
!leaf.join(probe).exists(),
"doctor left its write-probe `{probe}` at the {label} destination {} \
(FINDING #26: it must remove the destination-side probe, not only the local temp)",
leaf.display()
);
assert!(
std::fs::read_dir(leaf).unwrap().next().is_none(),
"doctor must leave the {label} destination {} exactly as it created it (empty)",
leaf.display()
);
}
}
fn dest_of(t: DestinationType) -> crate::config::DestinationConfig {
crate::config::DestinationConfig {
destination_type: t,
..Default::default()
}
}
#[test]
fn audit_pg_db_error_is_auth_with_hint() {
let err = anyhow::anyhow!("db error");
let cat = categorize_source_error(&err);
assert_eq!(
cat, "auth error",
"Postgres wrong-password surfaces as 'db error'; categorizer returned {:?} instead of 'auth error'",
cat
);
let hint = source_error_hint(cat, &err, &SourceType::Postgres);
assert!(
hint.is_some(),
"no actionable hint produced for Postgres 'db error' (category {:?}); operator gets no next step",
cat
);
}
#[test]
fn audit_mssql_login_failed_is_auth_with_hint() {
let err = anyhow::anyhow!("login failed for user 'sa'");
let cat = categorize_source_error(&err);
assert_eq!(
cat, "auth error",
"MSSQL bad login surfaces as 'Login failed for user ...'; categorizer returned {:?} instead of 'auth error'",
cat
);
let hint = source_error_hint(cat, &err, &SourceType::Mssql);
assert!(
hint.is_some(),
"no actionable hint produced for MSSQL 'login failed for user' (category {:?})",
cat
);
}
#[test]
fn audit_mysql_access_denied_is_auth_with_hint() {
let err = anyhow::anyhow!("access denied for user");
let cat = categorize_source_error(&err);
assert_eq!(
cat, "auth error",
"MySQL 'access denied for user' must stay auth; categorizer returned {:?}",
cat
);
let hint = source_error_hint(cat, &err, &SourceType::Mysql);
assert!(
hint.is_some(),
"no actionable hint produced for MySQL 'access denied for user' (category {:?})",
cat
);
}
#[test]
fn audit_s3_permission_denied_403_is_auth_with_hint() {
let dest = dest_of(DestinationType::S3);
let err = anyhow::anyhow!(
"PermissionDenied at write => InvalidAccessKeyId, status: 403, https://bucket.s3.amazonaws.com/probe"
);
let cat = categorize_dest_error(&err, &dest);
assert_eq!(
cat, "auth error",
"S3 'PermissionDenied/InvalidAccessKeyId/403' must categorize as auth; categorizer returned {:?}",
cat
);
let hint = destination_error_hint(cat, &dest);
assert!(
hint.is_some(),
"no actionable hint produced for S3 auth failure (category {:?}); operator gets no next step",
cat
);
}
#[test]
fn audit_azure_send_request_error_is_connectivity_with_hint() {
let dest = dest_of(DestinationType::Azure);
let err = anyhow::anyhow!(
"error sending request for url (https://x.blob.core.windows.net/probe)"
);
let cat = categorize_dest_error(&err, &dest);
assert_eq!(
cat, "connectivity error",
"Azure 'error sending request for url' must categorize as connectivity; categorizer returned {:?}",
cat
);
let hint = destination_error_hint(cat, &dest);
assert!(
hint.is_some(),
"no actionable hint produced for Azure connectivity failure (category {:?})",
cat
);
}
#[test]
fn audit_dest_connection_refused_is_connectivity_with_hint() {
let dest = dest_of(DestinationType::S3);
let err = anyhow::anyhow!("connection refused");
let cat = categorize_dest_error(&err, &dest);
assert_eq!(
cat, "connectivity error",
"'connection refused' must stay connectivity; categorizer returned {:?}",
cat
);
let hint = destination_error_hint(cat, &dest);
assert!(
hint.is_some(),
"no actionable hint produced for 'connection refused' (category {:?})",
cat
);
}
#[test]
fn source_pg_nested_password_cause_via_alternate_is_auth() {
let root = anyhow::anyhow!("password authentication failed for user \"rivet\"");
let wrapped = root.context("db error");
assert_eq!(categorize_source_error(&wrapped), "auth error");
assert_eq!(
categorize_source_error(&anyhow::anyhow!("db error")),
"auth error"
);
}
#[test]
fn source_connection_refused_stays_connectivity_not_auth() {
let err = anyhow::anyhow!("error connecting to server: Connection refused (os error 61)");
assert_eq!(categorize_source_error(&err), "connectivity error");
}
#[test]
fn dest_no_space_permissiondenied_is_auth() {
let dest = dest_of(DestinationType::Gcs);
let err = anyhow::anyhow!("PermissionDenied (persistent) at write");
assert_eq!(categorize_dest_error(&err, &dest), "auth error");
assert!(destination_error_hint("auth error", &dest).is_some());
}
#[test]
fn dest_send_http_request_is_connectivity() {
let dest = dest_of(DestinationType::S3);
let err = anyhow::anyhow!("failed to send http request to the store");
assert_eq!(categorize_dest_error(&err, &dest), "connectivity error");
}
#[test]
fn dest_404_stays_bucket_not_found_after_403_needle_added() {
let dest = dest_of(DestinationType::S3);
let err = anyhow::anyhow!("NoSuchBucket, status: 404");
assert_eq!(categorize_dest_error(&err, &dest), "bucket not found");
}
#[test]
fn remove_destination_probe_local_deletes_the_probe_object() {
let dir = tempfile::tempdir().unwrap();
let probe_key = crate::manifest::DOCTOR_PROBE_FILENAME;
std::fs::write(dir.path().join(probe_key), b"ok").unwrap();
let dest = crate::config::DestinationConfig {
destination_type: DestinationType::Local,
path: Some(dir.path().to_string_lossy().into_owned()),
..Default::default()
};
remove_destination_probe(&dest, probe_key);
assert!(
std::fs::read_dir(dir.path()).unwrap().next().is_none(),
"destination prefix must be left exactly as doctor found it (empty)"
);
}
#[test]
fn remove_destination_probe_local_uses_prefix_when_path_unset() {
let dir = tempfile::tempdir().unwrap();
let probe_key = crate::manifest::DOCTOR_PROBE_FILENAME;
std::fs::write(dir.path().join(probe_key), b"ok").unwrap();
let dest = crate::config::DestinationConfig {
destination_type: DestinationType::Local,
prefix: Some(dir.path().to_string_lossy().into_owned()),
..Default::default()
};
remove_destination_probe(&dest, probe_key);
assert!(
!dir.path().join(probe_key).exists(),
"cleanup must follow the same base-path resolution as the writer"
);
}
#[test]
fn remove_destination_probe_missing_is_noop() {
let dir = tempfile::tempdir().unwrap();
let dest = crate::config::DestinationConfig {
destination_type: DestinationType::Local,
path: Some(dir.path().to_string_lossy().into_owned()),
..Default::default()
};
remove_destination_probe(&dest, crate::manifest::DOCTOR_PROBE_FILENAME);
assert!(std::fs::read_dir(dir.path()).unwrap().next().is_none());
}
#[test]
fn local_permission_denied_is_permission_error_not_auth() {
let dest = dest_of(DestinationType::Local);
let err = anyhow::anyhow!("Permission denied (os error 13)");
let cat = categorize_dest_error(&err, &dest);
assert_eq!(
cat, "permission error",
"a local FS `os error 13` is a directory-permission problem, not auth; got {cat:?}"
);
assert!(
destination_error_hint(cat, &dest).is_some(),
"permission error must still surface the filesystem-permissions hint"
);
}
#[test]
fn cloud_permission_denied_stays_auth_error() {
for t in [
DestinationType::S3,
DestinationType::Gcs,
DestinationType::Azure,
] {
let dest = dest_of(t);
let err = anyhow::anyhow!("PermissionDenied at write");
assert_eq!(
categorize_dest_error(&err, &dest),
"auth error",
"cloud permission denial must remain auth for {t:?}"
);
}
}
#[test]
fn trim_probe_error_strips_http_response_parts_and_headers() {
let raw = "PermissionDenied (persistent) at write, context: { uri: https://b.s3.amazonaws.com/probe, response: Parts { status: 403, version: HTTP/1.1, headers: {\"x-amz-request-id\": \"ABC123\", \"content-type\": \"application/xml\"} }, service: s3 } => InvalidAccessKeyId";
let err = anyhow::anyhow!(raw);
let out = trim_probe_error(&err);
assert!(
!out.contains("Parts {") && !out.to_lowercase().contains("headers: {"),
"trimmed error still leaks the HTTP response dump: {out:?}"
);
assert!(
!out.contains('\n'),
"trimmed error must be a single line: {out:?}"
);
assert!(
out.starts_with("PermissionDenied (persistent) at write"),
"trimmed error must keep the meaningful root-cause prefix: {out:?}"
);
}
#[test]
fn trim_probe_error_leaves_clean_message_intact() {
let err = anyhow::anyhow!("error connecting to server: Connection refused (os error 61)");
assert_eq!(
trim_probe_error(&err),
"error connecting to server: Connection refused (os error 61)"
);
}
#[test]
fn trim_probe_error_caps_unbounded_line() {
let err = anyhow::anyhow!("x".repeat(5000));
let out = trim_probe_error(&err);
assert!(
out.chars().count() <= 1201,
"line not capped: {} chars",
out.chars().count()
);
assert!(
out.ends_with('…'),
"capped line must signal truncation: {out:?}"
);
}
#[test]
fn trim_probe_error_preserves_multiline_hint_verbatim() {
let err = anyhow::anyhow!(
"chunked mode needs one of:\n - chunk_column: <int col>\n - chunk_by_key: <col>\n - chunk_count: <n>"
);
let out = trim_probe_error(&err);
assert!(out.contains("chunk_column"), "got: {out:?}");
assert!(out.contains("chunk_by_key"), "got: {out:?}");
assert!(
out.contains("chunk_count"),
"all options preserved: {out:?}"
);
assert!(out.contains('\n'), "newlines preserved: {out:?}");
}
#[test]
fn config_load_failure_returns_pointer_not_duplicate_message() {
let dir = tempfile::tempdir().unwrap();
let cfg = dir.path().join("rivet.yaml");
std::fs::write(&cfg, "source: not-a-mapping\n").unwrap();
let err = doctor(cfg.to_str().unwrap())
.expect_err("doctor must return Err when the config fails to load");
let msg = err.to_string();
assert!(
msg.contains("doctor: config check failed") && msg.contains("[FAIL]"),
"returned error must be the one-line pointer (so `main` does not double-print the \
config error); got {msg:?}"
);
}
}