use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use narwhal_core::PreConnectStep;
use thiserror::Error;
use tokio::process::Command;
use tokio::time::timeout;
const DEFAULT_TIMEOUT_SECS: u32 = 30;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum PreConnectError {
#[error("pre-connect step `{0}`: spawn failed: {1}")]
Spawn(String, std::io::Error),
#[error("pre-connect step `{command}`: exited with status {status} — {stderr}")]
NonZero {
command: String,
status: i32,
stderr: String,
},
#[error("pre-connect step `{0}`: timed out after {1}s")]
Timeout(String, u32),
#[error("pre-connect step `{0}`: stdout was not valid UTF-8")]
NotUtf8(String),
}
pub async fn run_pre_connect(
steps: &[PreConnectStep],
) -> Result<HashMap<String, String>, PreConnectError> {
let mut vars = HashMap::new();
for step in steps {
match run_one(step).await {
Ok(stdout) => {
if let Some(key) = step.save_output_to.as_deref() {
vars.insert(key.to_owned(), stdout);
}
}
Err(error) => {
if step.required {
return Err(error);
}
tracing::warn!(
target: "narwhal::pre_connect",
error = %error,
"non-required pre-connect step failed; continuing"
);
}
}
}
Ok(vars)
}
async fn run_one(step: &PreConnectStep) -> Result<String, PreConnectError> {
let secs = step.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let mut cmd = shell_command(&step.command);
let spawn = cmd.output();
let output = match timeout(Duration::from_secs(u64::from(secs)), spawn).await {
Ok(Ok(output)) => output,
Ok(Err(io)) => return Err(PreConnectError::Spawn(step.command.clone(), io)),
Err(_) => return Err(PreConnectError::Timeout(step.command.clone(), secs)),
};
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned();
return Err(PreConnectError::NonZero {
command: step.command.clone(),
status: output.status.code().unwrap_or(-1),
stderr,
});
}
let stdout = String::from_utf8(output.stdout)
.map_err(|_| PreConnectError::NotUtf8(step.command.clone()))?;
Ok(stdout.trim().to_owned())
}
#[cfg(unix)]
fn shell_command(line: &str) -> Command {
let mut cmd = Command::new("sh");
cmd.arg("-c")
.arg(line)
.stdin(Stdio::null())
.kill_on_drop(true);
cmd
}
#[cfg(windows)]
fn shell_command(line: &str) -> Command {
let mut cmd = Command::new("cmd");
cmd.arg("/C")
.arg(line)
.stdin(Stdio::null())
.kill_on_drop(true);
cmd
}
pub fn substitute_pre_connect(
params: &mut narwhal_core::ConnectionParams,
vars: &HashMap<String, String>,
) -> Result<(), SubstitutionError> {
substitute_opt(&mut params.host, vars)?;
substitute_opt(&mut params.database, vars)?;
substitute_opt(&mut params.username, vars)?;
substitute_opt(&mut params.path, vars)?;
for value in params.options.values_mut() {
let replaced = substitute_str(value, vars)?;
*value = replaced;
}
Ok(())
}
pub fn substitute_password(
password: Option<String>,
vars: &HashMap<String, String>,
) -> Result<Option<String>, SubstitutionError> {
let Some(pw) = password else {
return Ok(None);
};
if !pw.contains("${preconnect:") {
return Ok(Some(pw));
}
Ok(Some(substitute_str(&pw, vars)?))
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum SubstitutionError {
#[error(
"connection params reference `${{preconnect:{0}}}` but no pre-connect step saved that key"
)]
MissingVar(String),
#[error("`${{preconnect:…}}` placeholder is missing a closing brace in: {0}")]
Unterminated(String),
}
fn substitute_opt(
slot: &mut Option<String>,
vars: &HashMap<String, String>,
) -> Result<(), SubstitutionError> {
if let Some(s) = slot.take() {
*slot = Some(substitute_str(&s, vars)?);
}
Ok(())
}
fn substitute_str(
input: &str,
vars: &HashMap<String, String>,
) -> Result<String, SubstitutionError> {
let mut out = String::with_capacity(input.len());
let mut rest = input;
while let Some(start) = rest.find("${preconnect:") {
out.push_str(&rest[..start]);
let after = &rest[start + "${preconnect:".len()..];
let Some(end) = after.find('}') else {
return Err(SubstitutionError::Unterminated(input.into()));
};
let name = &after[..end];
let value = vars
.get(name)
.ok_or_else(|| SubstitutionError::MissingVar(name.into()))?;
out.push_str(value);
rest = &after[end + 1..];
}
out.push_str(rest);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn empty_steps_yield_empty_map() {
let vars = run_pre_connect(&[]).await.unwrap();
assert!(vars.is_empty());
}
fn params_with_host(host: &str) -> narwhal_core::ConnectionParams {
narwhal_core::ConnectionParams::with(|p| {
p.host = Some(host.into());
})
}
#[tokio::test]
async fn captures_stdout_into_named_var() {
let steps = vec![
PreConnectStep::new("echo hello")
.with_save_output_to("GREETING")
.with_timeout_secs(5),
];
let vars = run_pre_connect(&steps).await.unwrap();
assert_eq!(vars.get("GREETING").map(String::as_str), Some("hello"));
}
#[tokio::test]
async fn step_without_save_output_runs_but_does_not_populate_map() {
let steps = vec![PreConnectStep::new("true").with_timeout_secs(5)];
let vars = run_pre_connect(&steps).await.unwrap();
assert!(vars.is_empty());
}
#[tokio::test]
async fn required_failure_aborts_sequence() {
let steps = vec![
PreConnectStep::new("false").with_timeout_secs(5),
PreConnectStep::new("echo should-not-run")
.with_save_output_to("X")
.with_timeout_secs(5),
];
let err = run_pre_connect(&steps).await.unwrap_err();
assert!(matches!(err, PreConnectError::NonZero { .. }));
}
#[tokio::test]
async fn non_required_failure_continues() {
let steps = vec![
PreConnectStep::new("false")
.with_save_output_to("UNSET")
.with_timeout_secs(5)
.with_required(false),
PreConnectStep::new("echo ok")
.with_save_output_to("OK")
.with_timeout_secs(5),
];
let vars = run_pre_connect(&steps).await.unwrap();
assert!(!vars.contains_key("UNSET"));
assert_eq!(vars.get("OK").map(String::as_str), Some("ok"));
}
#[tokio::test]
async fn timeout_returns_timeout_error() {
let steps = vec![PreConnectStep::new("sleep 5").with_timeout_secs(1)];
let err = run_pre_connect(&steps).await.unwrap_err();
assert!(matches!(err, PreConnectError::Timeout(_, 1)));
}
#[test]
fn substitute_replaces_placeholder() {
let mut params = params_with_host("${preconnect:HOST}.example.com");
params.database = Some("plain".into());
let mut vars = HashMap::new();
vars.insert("HOST".into(), "db01".into());
substitute_pre_connect(&mut params, &vars).unwrap();
assert_eq!(params.host.as_deref(), Some("db01.example.com"));
assert_eq!(params.database.as_deref(), Some("plain"));
}
#[test]
fn substitute_errors_on_missing_var() {
let mut params = params_with_host("${preconnect:NOPE}");
let vars = HashMap::new();
let err = substitute_pre_connect(&mut params, &vars).unwrap_err();
assert!(matches!(err, SubstitutionError::MissingVar(ref n) if n == "NOPE"));
}
#[test]
fn substitute_password_none_passes_through() {
let vars = HashMap::new();
assert_eq!(substitute_password(None, &vars).unwrap(), None);
}
#[test]
fn substitute_password_without_placeholder_passes_through() {
let vars = HashMap::new();
let out = substitute_password(Some("plain-secret".into()), &vars).unwrap();
assert_eq!(out.as_deref(), Some("plain-secret"));
}
#[test]
fn substitute_password_expands_placeholder() {
let mut vars = HashMap::new();
vars.insert("VAULT_PASS".into(), "hunter2".into());
let out = substitute_password(Some("${preconnect:VAULT_PASS}".into()), &vars).unwrap();
assert_eq!(out.as_deref(), Some("hunter2"));
}
#[test]
fn substitute_password_errors_on_missing_var() {
let vars = HashMap::new();
let err = substitute_password(Some("${preconnect:GONE}".into()), &vars).unwrap_err();
assert!(matches!(err, SubstitutionError::MissingVar(ref n) if n == "GONE"));
}
#[test]
fn substitute_passthrough_when_no_placeholder() {
let mut params = params_with_host("plain.example.com");
let vars = HashMap::new();
substitute_pre_connect(&mut params, &vars).unwrap();
assert_eq!(params.host.as_deref(), Some("plain.example.com"));
}
}