use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::resolve::resolve_env_vars;
use crate::tuning::{TuningConfig, TuningProfile};
#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[serde(deny_unknown_fields)]
pub struct SourceConfig {
#[serde(rename = "type")]
pub source_type: SourceType,
pub url: Option<String>,
pub url_env: Option<String>,
pub url_file: Option<String>,
pub host: Option<String>,
pub port: Option<u16>,
pub user: Option<String>,
pub password: Option<String>,
pub password_env: Option<String>,
pub database: Option<String>,
#[serde(default)]
pub environment: Option<SourceEnvironment>,
#[serde(default)]
pub tuning: Option<TuningConfig>,
#[serde(default)]
pub tls: Option<TlsConfig>,
}
#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SourceEnvironment {
Local,
Replica,
Production,
}
impl SourceEnvironment {
pub fn default_profile(self) -> TuningProfile {
match self {
SourceEnvironment::Local => TuningProfile::Fast,
SourceEnvironment::Replica | SourceEnvironment::Production => TuningProfile::Balanced,
}
}
}
#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, Default)]
#[serde(deny_unknown_fields)]
pub struct TlsConfig {
#[serde(default)]
pub mode: TlsMode,
pub ca_file: Option<String>,
#[serde(default)]
pub accept_invalid_certs: bool,
#[serde(default)]
pub accept_invalid_hostnames: bool,
}
#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum TlsMode {
Disable,
Require,
VerifyCa,
#[default]
VerifyFull,
}
impl TlsMode {
pub fn is_enforced(self) -> bool {
!matches!(self, TlsMode::Disable)
}
}
impl SourceConfig {
pub fn redact_for_artifact(&self) -> (Self, bool) {
let mut out = self.clone();
let mut redacted = false;
if out.password.is_some() {
out.password = None;
redacted = true;
}
if let Some(ref raw) = out.url
&& let Some((userinfo_end, scheme_end)) = find_userinfo(raw)
{
let mut s = String::with_capacity(raw.len());
s.push_str(&raw[..scheme_end]); s.push_str("REDACTED");
s.push_str(&raw[userinfo_end..]); out.url = Some(s);
redacted = true;
}
(out, redacted)
}
pub(crate) fn has_structured_fields(&self) -> bool {
self.host.is_some()
|| self.user.is_some()
|| self.database.is_some()
|| self.password.is_some()
|| self.password_env.is_some()
}
pub(crate) fn has_url_fields(&self) -> bool {
self.url.is_some() || self.url_env.is_some() || self.url_file.is_some()
}
fn build_url_from_fields(&self) -> crate::error::Result<String> {
let host = self.host.as_deref().ok_or_else(|| {
anyhow::anyhow!(
"source: structured config is missing 'host'.\n Hint: add `host: localhost` (or your DB host) under `source:` in rivet.yaml.\n Or switch to URL-based config: `url_env: DATABASE_URL`."
)
})?;
let user = self.user.as_deref().ok_or_else(|| {
anyhow::anyhow!(
"source: structured config is missing 'user'.\n Hint: add `user: <username>` under `source:` in rivet.yaml."
)
})?;
let database = self.database.as_deref().ok_or_else(|| {
anyhow::anyhow!(
"source: structured config is missing 'database'.\n Hint: add `database: <dbname>` under `source:` in rivet.yaml."
)
})?;
let password: zeroize::Zeroizing<String> =
zeroize::Zeroizing::new(match (&self.password, &self.password_env) {
(Some(_), Some(_)) => {
anyhow::bail!("source: specify 'password' or 'password_env', not both");
}
(Some(p), None) => {
static WARNED: std::sync::Once = std::sync::Once::new();
WARNED.call_once(|| {
log::warn!(
"source config contains plaintext password -- consider using password_env"
);
});
resolve_env_vars(p)?
}
(None, Some(env)) => std::env::var(env).map_err(|_| {
anyhow::anyhow!(
"source: env var '{0}' is not set (referenced by password_env).\n Hint: export the value before running, e.g.\n export {0}='your-database-password'",
env
)
})?,
(None, None) => String::new(),
});
let default_port = match self.source_type {
SourceType::Postgres => 5432,
SourceType::Mysql => 3306,
SourceType::Mssql => 1433,
};
let port = self.port.unwrap_or(default_port);
let scheme = match self.source_type {
SourceType::Postgres => "postgresql",
SourceType::Mysql => "mysql",
SourceType::Mssql => "sqlserver",
};
if password.is_empty() {
Ok(format!(
"{}://{}@{}:{}/{}",
scheme, user, host, port, database
))
} else {
Ok(format!(
"{}://{}:{}@{}:{}/{}",
scheme,
user,
password.as_str(),
host,
port,
database
))
}
}
pub fn resolve_url(&self) -> crate::error::Result<String> {
if self.has_url_fields() && self.has_structured_fields() {
anyhow::bail!(
"source: pick either URL-based config (url/url_env/url_file) OR structured fields (host/user/database/port/password_env), not both.\n Hint: remove whichever block you don't want; mixing the two is ambiguous."
);
}
if self.has_structured_fields() {
return self.build_url_from_fields();
}
#[allow(dead_code)]
enum UrlSource<'a> {
InlineYaml,
EnvVar(&'a str),
File(&'a str),
}
let (raw, source) = match (&self.url, &self.url_env, &self.url_file) {
(Some(u), None, None) => (u.clone(), UrlSource::InlineYaml),
(None, Some(env), None) => (
std::env::var(env).map_err(|_| {
anyhow::anyhow!(
"source: env var '{0}' is not set (referenced by url_env).\n Hint: export the value before running, e.g.\n export {0}='postgresql://user:pass@host:5432/dbname'\n Or change `url_env: {0}` in your config to a different env var name.",
env
)
})?,
UrlSource::EnvVar(env),
),
(None, None, Some(file)) => (
std::fs::read_to_string(file)
.map_err(|e| {
anyhow::anyhow!(
"source: cannot read url_file '{}': {}.\n Hint: ensure the file exists and is readable; the file should contain only the URL on a single line.",
file,
e
)
})?
.trim()
.to_string(),
UrlSource::File(file),
),
_ => anyhow::bail!(
"source: configure exactly one connection method:\n url_env: DATABASE_URL (URL from env var — recommended)\n url: 'postgresql://user:pass@host:5432/db' (inline — not recommended for committed configs)\n url_file: /etc/rivet/source.url (URL from file — rotation-friendly)\n host/user/database/... (structured fields under `source:`)"
),
};
let resolved = resolve_env_vars(&raw)?;
if resolved.contains('@')
&& resolved.contains(':')
&& let Some(userinfo) = resolved.split('@').next()
&& userinfo.contains(':')
&& !userinfo.ends_with(':')
{
match source {
UrlSource::InlineYaml => {
static WARNED: std::sync::Once = std::sync::Once::new();
WARNED.call_once(|| {
log::warn!(
"source: inline `url:` in YAML contains a plaintext password — \
move it to `url_env: DATABASE_URL` (or `url_file:`) to keep \
credentials out of committed configs"
);
});
}
UrlSource::EnvVar(_) | UrlSource::File(_) => {
}
}
}
Ok(resolved)
}
}
#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SourceType {
Postgres,
Mysql,
Mssql,
}
fn find_userinfo(raw: &str) -> Option<(usize, usize)> {
let scheme = raw.find("://")? + 3;
let rest = &raw[scheme..];
let at = rest.find('@')?;
if let Some(path) = rest.find('/')
&& path < at
{
return None;
}
Some((scheme + at, scheme))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tls_mode_disable_not_enforced() {
assert!(!TlsMode::Disable.is_enforced());
}
#[test]
fn tls_mode_require_is_enforced() {
assert!(TlsMode::Require.is_enforced());
assert!(TlsMode::VerifyCa.is_enforced());
assert!(TlsMode::VerifyFull.is_enforced());
}
fn make_source(source_type: SourceType) -> SourceConfig {
SourceConfig {
source_type,
url: None,
url_env: None,
url_file: None,
host: None,
port: None,
user: None,
password: None,
password_env: None,
database: None,
environment: None,
tuning: None,
tls: None,
}
}
#[test]
fn redact_plaintext_password() {
let mut src = make_source(SourceType::Postgres);
src.password = Some("s3cr3t".into());
let (redacted, flag) = src.redact_for_artifact();
assert!(flag, "redaction should be flagged");
assert!(
redacted.password.is_none(),
"plaintext password must be stripped"
);
}
#[test]
fn redact_url_with_password() {
let mut src = make_source(SourceType::Postgres);
src.url = Some("postgresql://user:hunter2@db.example.com:5432/app".into());
let (redacted, flag) = src.redact_for_artifact();
assert!(flag, "URL redaction flagged");
let url = redacted.url.unwrap();
assert!(!url.contains("hunter2"), "password must not appear: {url}");
assert!(url.contains("REDACTED"), "placeholder must appear: {url}");
assert!(url.contains("@db.example.com"), "host retained: {url}");
}
#[test]
fn redact_url_without_at_sign_not_flagged() {
let mut src = make_source(SourceType::Postgres);
src.url = Some("postgresql://db.example.com:5432/app".into());
let (_, flag) = src.redact_for_artifact();
assert!(!flag, "URL with no userinfo must not be flagged");
}
#[test]
fn redact_url_with_user_but_no_password_is_flagged() {
let mut src = make_source(SourceType::Postgres);
src.url = Some("postgresql://user@db.example.com:5432/app".into());
let (redacted, flag) = src.redact_for_artifact();
assert!(flag, "bare user@ is still userinfo and gets redacted");
let url = redacted.url.unwrap();
assert!(url.contains("REDACTED"), "userinfo replaced: {url}");
assert!(!url.contains("user@"), "bare username removed: {url}");
}
#[test]
fn redact_env_var_reference_kept_intact() {
let mut src = make_source(SourceType::Mysql);
src.url_env = Some("DB_URL".into());
src.password_env = Some("DB_PASS".into());
let (redacted, flag) = src.redact_for_artifact();
assert!(!flag, "env var references are not secrets");
assert_eq!(redacted.url_env.as_deref(), Some("DB_URL"));
assert_eq!(redacted.password_env.as_deref(), Some("DB_PASS"));
}
#[test]
fn redact_mysql_url_with_password() {
let mut src = make_source(SourceType::Mysql);
src.url = Some("mysql://root:pass@127.0.0.1:3306/mydb".into());
let (redacted, flag) = src.redact_for_artifact();
assert!(flag);
let url = redacted.url.unwrap();
assert!(url.contains("REDACTED"), "{url}");
assert!(!url.contains("pass"), "{url}");
}
#[test]
fn resolve_url_from_structured_fields_postgres() {
let mut src = make_source(SourceType::Postgres);
src.host = Some("pg.internal".into());
src.user = Some("alice".into());
src.database = Some("warehouse".into());
src.port = Some(5433);
let url = src.resolve_url().unwrap();
assert_eq!(url, "postgresql://alice@pg.internal:5433/warehouse");
}
#[test]
fn resolve_url_from_structured_fields_defaults_port() {
let mut src = make_source(SourceType::Mysql);
src.host = Some("my.internal".into());
src.user = Some("bob".into());
src.database = Some("orders".into());
let url = src.resolve_url().unwrap();
assert_eq!(url, "mysql://bob@my.internal:3306/orders");
}
#[test]
fn resolve_url_direct_url_passthrough() {
let mut src = make_source(SourceType::Postgres);
src.url = Some("postgresql://carol@pg.example.com:5432/db".into());
let url = src.resolve_url().unwrap();
assert_eq!(url, "postgresql://carol@pg.example.com:5432/db");
}
#[test]
fn resolve_url_rejects_mixed_url_and_structured() {
let mut src = make_source(SourceType::Postgres);
src.url = Some("postgresql://carol@pg.example.com/db".into());
src.host = Some("other".into());
let err = src.resolve_url().unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("URL-based") || msg.contains("structured"),
"{msg}"
);
}
#[test]
fn resolve_url_rejects_missing_host() {
let mut src = make_source(SourceType::Postgres);
src.user = Some("alice".into());
src.database = Some("warehouse".into());
let err = src.resolve_url().unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("host"), "{msg}");
}
#[test]
fn find_userinfo_detects_password_in_url() {
let url = "postgresql://user:pass@host/db";
let result = find_userinfo(url);
assert!(result.is_some(), "should detect user:pass@");
}
#[test]
fn find_userinfo_no_password_no_at_returns_none() {
assert!(find_userinfo("postgresql://host/db").is_none());
}
#[test]
fn find_userinfo_user_only_at_sign_matches() {
let url = "postgresql://user@host/db";
assert!(find_userinfo(url).is_some(), "bare user@ should match");
}
#[test]
fn find_userinfo_no_at_sign_returns_none() {
assert!(find_userinfo("postgresql://db.example.com:5432/app").is_none());
}
}