use anyhow::{Context, Result};
use log::{debug, error, info};
use std::path::{Path, PathBuf};
use std::process::Command;
#[derive(Debug)]
pub struct SignResult {
pub cert_path: PathBuf,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CertStatus {
Valid {
expires_at: i64,
remaining_secs: i64,
total_secs: i64,
},
Expired,
Missing,
Invalid(String),
}
pub const RENEWAL_THRESHOLD_SECS: i64 = 300;
pub const CERT_STATUS_CACHE_TTL_SECS: u64 = 300;
pub const CERT_ERROR_BACKOFF_SECS: u64 = 30;
pub fn is_valid_role(s: &str) -> bool {
!s.is_empty()
&& s.len() <= 128
&& s.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '/' || c == '_' || c == '-')
}
pub fn is_valid_vault_addr(s: &str) -> bool {
let trimmed = s.trim();
!trimmed.is_empty()
&& trimmed.len() <= 512
&& !trimmed.chars().any(|c| c.is_control() || c.is_whitespace())
}
pub fn normalize_vault_addr(s: &str) -> String {
let trimmed = s.trim();
let lower = trimmed.to_ascii_lowercase();
let (with_scheme, scheme_len) = if lower.starts_with("http://") || lower.starts_with("https://")
{
let len = if lower.starts_with("https://") { 8 } else { 7 };
(trimmed.to_string(), len)
} else if trimmed.contains("://") {
return trimmed.to_string();
} else {
(format!("https://{}", trimmed), 8)
};
let after_scheme = &with_scheme[scheme_len..];
let authority = after_scheme.split('/').next().unwrap_or(after_scheme);
let has_port = if let Some(bracket_end) = authority.rfind(']') {
authority[bracket_end..].contains(':')
} else {
authority.contains(':')
};
if has_port {
with_scheme
} else {
let default_port = if lower.starts_with("http://") {
80
} else if lower.starts_with("https://") {
443
} else {
8200
};
let path_start = scheme_len + authority.len();
format!(
"{}:{}{}",
&with_scheme[..path_start],
default_port,
&with_scheme[path_start..]
)
}
}
pub fn scrub_vault_stderr(raw: &str) -> String {
let filtered: String = raw
.lines()
.filter(|line| {
let lower = line.to_ascii_lowercase();
!(lower.contains("token")
|| lower.contains("secret")
|| lower.contains("x-vault-")
|| lower.contains("cookie")
|| lower.contains("authorization"))
})
.collect::<Vec<_>>()
.join(" ");
let trimmed = filtered.trim();
if trimmed.is_empty() {
return "Vault SSH signing failed. Check vault status and policy".to_string();
}
if trimmed.chars().count() > 200 {
trimmed.chars().take(200).collect::<String>() + "..."
} else {
trimmed.to_string()
}
}
pub fn cert_path_for(alias: &str) -> Result<PathBuf> {
anyhow::ensure!(
!alias.is_empty()
&& !alias.contains('/')
&& !alias.contains('\\')
&& !alias.contains(':')
&& !alias.contains('\0')
&& !alias.contains(".."),
"Invalid alias for cert path: '{}'",
alias
);
let dir = dirs::home_dir()
.context("Could not determine home directory")?
.join(".purple/certs");
Ok(dir.join(format!("{}-cert.pub", alias)))
}
pub fn resolve_cert_path(alias: &str, certificate_file: &str) -> Result<PathBuf> {
if !certificate_file.is_empty() {
let expanded = if let Some(rest) = certificate_file.strip_prefix("~/") {
if let Some(home) = dirs::home_dir() {
home.join(rest)
} else {
PathBuf::from(certificate_file)
}
} else {
PathBuf::from(certificate_file)
};
Ok(expanded)
} else {
cert_path_for(alias)
}
}
pub fn sign_certificate(
role: &str,
pubkey_path: &Path,
alias: &str,
vault_addr: Option<&str>,
) -> Result<SignResult> {
if !pubkey_path.exists() {
anyhow::bail!(
"Public key not found: {}. Set IdentityFile on the host or ensure ~/.ssh/id_ed25519.pub exists.",
pubkey_path.display()
);
}
if !is_valid_role(role) {
anyhow::bail!("Invalid Vault SSH role: '{}'", role);
}
let cert_dest = cert_path_for(alias)?;
if let Some(parent) = cert_dest.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create {}", parent.display()))?;
}
let pubkey_str = pubkey_path.to_str().context(
"public key path contains non-UTF8 bytes; vault CLI requires a valid UTF-8 path",
)?;
if pubkey_str.contains('=') {
anyhow::bail!(
"Public key path '{}' contains '=' which is not supported by the Vault CLI argument format. Rename the key file or directory.",
pubkey_str
);
}
let pubkey_arg = format!("public_key=@{}", pubkey_str);
debug!(
"[external] Vault sign request: addr={} role={}",
vault_addr.unwrap_or("<env>"),
role
);
let mut cmd = Command::new("vault");
cmd.args(["write", "-field=signed_key", role, &pubkey_arg]);
if let Some(addr) = vault_addr {
anyhow::ensure!(
is_valid_vault_addr(addr),
"Invalid VAULT_ADDR '{}' for role '{}'. Check the Vault SSH Address field.",
addr,
role
);
cmd.env("VAULT_ADDR", addr);
}
let mut child = cmd
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.context("Failed to run vault CLI. Is vault installed and in PATH?")?;
let stdout_handle = child.stdout.take();
let stderr_handle = child.stderr.take();
let stdout_thread = std::thread::spawn(move || -> Vec<u8> {
let mut buf = Vec::new();
if let Some(mut h) = stdout_handle {
if let Err(e) = std::io::Read::read_to_end(&mut h, &mut buf) {
log::warn!("[external] Failed to read vault stdout pipe: {e}");
}
}
buf
});
let stderr_thread = std::thread::spawn(move || -> Vec<u8> {
let mut buf = Vec::new();
if let Some(mut h) = stderr_handle {
if let Err(e) = std::io::Read::read_to_end(&mut h, &mut buf) {
log::warn!("[external] Failed to read vault stderr pipe: {e}");
}
}
buf
});
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
let status = loop {
match child.try_wait() {
Ok(Some(s)) => break s,
Ok(None) => {
if std::time::Instant::now() >= deadline {
let _ = child.kill();
let _ = child.wait();
error!(
"[external] Vault unreachable: {}: timed out after 30s",
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH timed out. Server unreachable.");
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
Err(e) => {
let _ = child.kill();
let _ = child.wait();
anyhow::bail!("Failed to wait for vault CLI: {}", e);
}
}
};
let stdout_bytes = stdout_thread.join().unwrap_or_default();
let stderr_bytes = stderr_thread.join().unwrap_or_default();
let output = std::process::Output {
status,
stdout: stdout_bytes,
stderr: stderr_bytes,
};
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
if stderr.contains("permission denied") || stderr.contains("403") {
error!(
"[external] Vault auth failed: permission denied (role={} addr={})",
role,
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH permission denied. Check token and policy.");
}
if stderr.contains("missing client token") || stderr.contains("token expired") {
error!(
"[external] Vault auth failed: token missing or expired (role={} addr={})",
role,
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH token missing or expired. Run `vault login`.");
}
if stderr.contains("connection refused") {
error!(
"[external] Vault unreachable: {}: connection refused",
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH connection refused.");
}
if stderr.contains("i/o timeout") || stderr.contains("dial tcp") {
error!(
"[external] Vault unreachable: {}: connection timed out",
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH connection timed out.");
}
if stderr.contains("no such host") {
error!(
"[external] Vault unreachable: {}: no such host",
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH host not found.");
}
if stderr.contains("server gave HTTP response to HTTPS client") {
error!(
"[external] Vault unreachable: {}: server returned HTTP on HTTPS connection",
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH server uses HTTP, not HTTPS. Set address to http://.");
}
if stderr.contains("certificate signed by unknown authority")
|| stderr.contains("tls:")
|| stderr.contains("x509:")
{
error!(
"[external] Vault unreachable: {}: TLS error",
vault_addr.unwrap_or("<env>")
);
anyhow::bail!("Vault SSH TLS error. Check certificate or use http://.");
}
error!(
"[external] Vault SSH signing failed: {}",
scrub_vault_stderr(&stderr)
);
anyhow::bail!("Vault SSH failed: {}", scrub_vault_stderr(&stderr));
}
let signed_key = String::from_utf8_lossy(&output.stdout).trim().to_string();
if signed_key.is_empty() {
anyhow::bail!("Vault returned empty certificate for role '{}'", role);
}
crate::fs_util::atomic_write(&cert_dest, signed_key.as_bytes())
.with_context(|| format!("Failed to write certificate to {}", cert_dest.display()))?;
info!("Vault SSH certificate signed for {}", alias);
Ok(SignResult {
cert_path: cert_dest,
})
}
pub fn check_cert_validity(cert_path: &Path) -> CertStatus {
if !cert_path.exists() {
return CertStatus::Missing;
}
let output = match Command::new("ssh-keygen")
.args(["-L", "-f"])
.arg(cert_path)
.output()
{
Ok(o) => o,
Err(e) => return CertStatus::Invalid(format!("Failed to run ssh-keygen: {}", e)),
};
if !output.status.success() {
return CertStatus::Invalid("ssh-keygen could not read certificate".to_string());
}
let stdout = String::from_utf8_lossy(&output.stdout);
for line in stdout.lines() {
let t = line.trim();
if t == "Valid: forever" || t.starts_with("Valid: from ") && t.ends_with(" to forever") {
return CertStatus::Valid {
expires_at: i64::MAX,
remaining_secs: i64::MAX,
total_secs: i64::MAX,
};
}
}
for line in stdout.lines() {
if let Some((from, to)) = parse_valid_line(line) {
let ttl = to - from; if ttl <= 0 {
return CertStatus::Invalid(
"certificate has non-positive validity window".to_string(),
);
}
let signed_at = match std::fs::metadata(cert_path)
.and_then(|m| m.modified())
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
{
Some(d) => d.as_secs() as i64,
None => {
return CertStatus::Expired;
}
};
let now = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs() as i64,
Err(_) => {
return CertStatus::Invalid("system clock before unix epoch".to_string());
}
};
let elapsed = now - signed_at;
let remaining = ttl - elapsed;
if remaining <= 0 {
return CertStatus::Expired;
}
let expires_at = now + remaining;
return CertStatus::Valid {
expires_at,
remaining_secs: remaining,
total_secs: ttl,
};
}
}
CertStatus::Invalid("No Valid: line found in certificate".to_string())
}
fn parse_valid_line(line: &str) -> Option<(i64, i64)> {
let trimmed = line.trim();
let rest = trimmed.strip_prefix("Valid:")?;
let rest = rest.trim();
let rest = rest.strip_prefix("from ")?;
let (from_str, rest) = rest.split_once(" to ")?;
let to_str = rest.trim();
let from = parse_ssh_datetime(from_str)?;
let to = parse_ssh_datetime(to_str)?;
Some((from, to))
}
fn parse_ssh_datetime(s: &str) -> Option<i64> {
let s = s.trim();
if s.len() < 19 {
return None;
}
let year: i64 = s.get(0..4)?.parse().ok()?;
let month: i64 = s.get(5..7)?.parse().ok()?;
let day: i64 = s.get(8..10)?.parse().ok()?;
let hour: i64 = s.get(11..13)?.parse().ok()?;
let min: i64 = s.get(14..16)?.parse().ok()?;
let sec: i64 = s.get(17..19)?.parse().ok()?;
if s.as_bytes().get(4) != Some(&b'-')
|| s.as_bytes().get(7) != Some(&b'-')
|| s.as_bytes().get(10) != Some(&b'T')
|| s.as_bytes().get(13) != Some(&b':')
|| s.as_bytes().get(16) != Some(&b':')
{
return None;
}
if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
return None;
}
if !(0..=23).contains(&hour) || !(0..=59).contains(&min) || !(0..=59).contains(&sec) {
return None;
}
let mut y = year;
let m = if month <= 2 {
y -= 1;
month + 9
} else {
month - 3
};
let era = if y >= 0 { y } else { y - 399 } / 400;
let yoe = y - era * 400;
let doy = (153 * m + 2) / 5 + day - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
let days = era * 146097 + doe - 719468;
Some(days * 86400 + hour * 3600 + min * 60 + sec)
}
pub fn needs_renewal(status: &CertStatus) -> bool {
match status {
CertStatus::Missing | CertStatus::Expired | CertStatus::Invalid(_) => true,
CertStatus::Valid {
remaining_secs,
total_secs,
..
} => {
let threshold = if *total_secs > 0 && *total_secs <= RENEWAL_THRESHOLD_SECS {
*total_secs / 2
} else {
RENEWAL_THRESHOLD_SECS
};
*remaining_secs < threshold
}
}
}
pub fn ensure_cert(
role: &str,
pubkey_path: &Path,
alias: &str,
certificate_file: &str,
vault_addr: Option<&str>,
) -> Result<PathBuf> {
let check_path = resolve_cert_path(alias, certificate_file)?;
let status = check_cert_validity(&check_path);
if !needs_renewal(&status) {
info!("Vault SSH certificate cache hit for {}", alias);
return Ok(check_path);
}
let result = sign_certificate(role, pubkey_path, alias, vault_addr)?;
Ok(result.cert_path)
}
pub fn resolve_pubkey_path(identity_file: &str) -> Result<PathBuf> {
let home = dirs::home_dir().context("Could not determine home directory")?;
let fallback = home.join(".ssh/id_ed25519.pub");
if identity_file.is_empty() {
return Ok(fallback);
}
let expanded = if let Some(rest) = identity_file.strip_prefix("~/") {
home.join(rest)
} else {
PathBuf::from(identity_file)
};
let canonical_home = match std::fs::canonicalize(&home) {
Ok(p) => p,
Err(_) => return Ok(fallback),
};
if expanded.exists() {
match std::fs::canonicalize(&expanded) {
Ok(canonical) if canonical.starts_with(&canonical_home) => {}
_ => return Ok(fallback),
}
} else if !expanded.starts_with(&home) {
return Ok(fallback);
}
if expanded.extension().is_some_and(|ext| ext == "pub") {
Ok(expanded)
} else {
let mut s = expanded.into_os_string();
s.push(".pub");
Ok(PathBuf::from(s))
}
}
pub fn resolve_vault_role(
host_vault_ssh: Option<&str>,
provider_name: Option<&str>,
provider_config: &crate::providers::config::ProviderConfig,
) -> Option<String> {
if let Some(role) = host_vault_ssh {
if !role.is_empty() {
return Some(role.to_string());
}
}
if let Some(name) = provider_name {
if let Some(section) = provider_config.section(name) {
if !section.vault_role.is_empty() {
return Some(section.vault_role.clone());
}
}
}
None
}
pub fn resolve_vault_addr(
host_vault_addr: Option<&str>,
provider_name: Option<&str>,
provider_config: &crate::providers::config::ProviderConfig,
) -> Option<String> {
if let Some(addr) = host_vault_addr {
let trimmed = addr.trim();
if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
return Some(normalize_vault_addr(trimmed));
}
}
if let Some(name) = provider_name {
if let Some(section) = provider_config.section(name) {
let trimmed = section.vault_addr.trim();
if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
return Some(normalize_vault_addr(trimmed));
}
}
}
None
}
pub fn format_remaining(remaining_secs: i64) -> String {
if remaining_secs <= 0 {
return "expired".to_string();
}
let hours = remaining_secs / 3600;
let mins = (remaining_secs % 3600) / 60;
if hours > 0 {
format!("{}h {}m", hours, mins)
} else {
format!("{}m", mins)
}
}
#[cfg(test)]
#[path = "vault_ssh_tests.rs"]
mod tests;