1use anyhow::{Context, Result};
2use log::{debug, error, info};
3use std::path::{Path, PathBuf};
4use std::process::Command;
5
6#[derive(Debug)]
8pub struct SignResult {
9 pub cert_path: PathBuf,
10}
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum CertStatus {
15 Valid {
16 expires_at: i64,
17 remaining_secs: i64,
18 total_secs: i64,
21 },
22 Expired,
23 Missing,
24 Invalid(String),
25}
26
27pub const RENEWAL_THRESHOLD_SECS: i64 = 300;
29
30pub const CERT_STATUS_CACHE_TTL_SECS: u64 = 300;
36
37pub const CERT_ERROR_BACKOFF_SECS: u64 = 30;
42
43pub fn is_valid_role(s: &str) -> bool {
46 !s.is_empty()
47 && s.len() <= 128
48 && s.chars()
49 .all(|c| c.is_ascii_alphanumeric() || c == '/' || c == '_' || c == '-')
50}
51
52pub fn is_valid_vault_addr(s: &str) -> bool {
59 let trimmed = s.trim();
60 !trimmed.is_empty()
61 && trimmed.len() <= 512
62 && !trimmed.chars().any(|c| c.is_control() || c.is_whitespace())
63}
64
65pub fn normalize_vault_addr(s: &str) -> String {
71 let trimmed = s.trim();
72 let lower = trimmed.to_ascii_lowercase();
74 let (with_scheme, scheme_len) = if lower.starts_with("http://") || lower.starts_with("https://")
75 {
76 let len = if lower.starts_with("https://") { 8 } else { 7 };
77 (trimmed.to_string(), len)
78 } else if trimmed.contains("://") {
79 return trimmed.to_string();
81 } else {
82 (format!("https://{}", trimmed), 8)
83 };
84 let after_scheme = &with_scheme[scheme_len..];
86 let authority = after_scheme.split('/').next().unwrap_or(after_scheme);
87 let has_port = if let Some(bracket_end) = authority.rfind(']') {
90 authority[bracket_end..].contains(':')
91 } else {
92 authority.contains(':')
93 };
94 if has_port {
95 with_scheme
96 } else {
97 let path_start = scheme_len + authority.len();
99 format!(
100 "{}:8200{}",
101 &with_scheme[..path_start],
102 &with_scheme[path_start..]
103 )
104 }
105}
106
107pub fn scrub_vault_stderr(raw: &str) -> String {
111 let filtered: String = raw
112 .lines()
113 .filter(|line| {
114 let lower = line.to_ascii_lowercase();
115 !(lower.contains("token")
116 || lower.contains("secret")
117 || lower.contains("x-vault-")
118 || lower.contains("cookie")
119 || lower.contains("authorization"))
120 })
121 .collect::<Vec<_>>()
122 .join(" ");
123 let trimmed = filtered.trim();
124 if trimmed.is_empty() {
125 return "Vault SSH signing failed. Check vault status and policy".to_string();
126 }
127 if trimmed.chars().count() > 200 {
128 trimmed.chars().take(200).collect::<String>() + "..."
129 } else {
130 trimmed.to_string()
131 }
132}
133
134pub fn cert_path_for(alias: &str) -> Result<PathBuf> {
136 anyhow::ensure!(
137 !alias.is_empty()
138 && !alias.contains('/')
139 && !alias.contains('\\')
140 && !alias.contains(':')
141 && !alias.contains('\0')
142 && !alias.contains(".."),
143 "Invalid alias for cert path: '{}'",
144 alias
145 );
146 let dir = dirs::home_dir()
147 .context("Could not determine home directory")?
148 .join(".purple/certs");
149 Ok(dir.join(format!("{}-cert.pub", alias)))
150}
151
152pub fn resolve_cert_path(alias: &str, certificate_file: &str) -> Result<PathBuf> {
155 if !certificate_file.is_empty() {
156 let expanded = if let Some(rest) = certificate_file.strip_prefix("~/") {
157 if let Some(home) = dirs::home_dir() {
158 home.join(rest)
159 } else {
160 PathBuf::from(certificate_file)
161 }
162 } else {
163 PathBuf::from(certificate_file)
164 };
165 Ok(expanded)
166 } else {
167 cert_path_for(alias)
168 }
169}
170
171pub fn sign_certificate(
181 role: &str,
182 pubkey_path: &Path,
183 alias: &str,
184 vault_addr: Option<&str>,
185) -> Result<SignResult> {
186 if !pubkey_path.exists() {
187 anyhow::bail!(
188 "Public key not found: {}. Set IdentityFile on the host or ensure ~/.ssh/id_ed25519.pub exists.",
189 pubkey_path.display()
190 );
191 }
192
193 if !is_valid_role(role) {
194 anyhow::bail!("Invalid Vault SSH role: '{}'", role);
195 }
196
197 let cert_dest = cert_path_for(alias)?;
198
199 if let Some(parent) = cert_dest.parent() {
200 std::fs::create_dir_all(parent)
201 .with_context(|| format!("Failed to create {}", parent.display()))?;
202 }
203
204 let pubkey_str = pubkey_path.to_str().context(
208 "public key path contains non-UTF8 bytes; vault CLI requires a valid UTF-8 path",
209 )?;
210 if pubkey_str.contains('=') {
217 anyhow::bail!(
218 "Public key path '{}' contains '=' which is not supported by the Vault CLI argument format. Rename the key file or directory.",
219 pubkey_str
220 );
221 }
222 let pubkey_arg = format!("public_key=@{}", pubkey_str);
223 debug!(
224 "[external] Vault sign request: addr={} role={}",
225 vault_addr.unwrap_or("<env>"),
226 role
227 );
228 let mut cmd = Command::new("vault");
229 cmd.args(["write", "-field=signed_key", role, &pubkey_arg]);
230 if let Some(addr) = vault_addr {
237 anyhow::ensure!(
238 is_valid_vault_addr(addr),
239 "Invalid VAULT_ADDR '{}' for role '{}'. Check the Vault SSH Address field.",
240 addr,
241 role
242 );
243 cmd.env("VAULT_ADDR", addr);
244 }
245 let mut child = cmd
246 .stdout(std::process::Stdio::piped())
247 .stderr(std::process::Stdio::piped())
248 .spawn()
249 .context("Failed to run vault CLI. Is vault installed and in PATH?")?;
250
251 let stdout_handle = child.stdout.take();
255 let stderr_handle = child.stderr.take();
256 let stdout_thread = std::thread::spawn(move || -> Vec<u8> {
257 let mut buf = Vec::new();
258 if let Some(mut h) = stdout_handle {
259 if let Err(e) = std::io::Read::read_to_end(&mut h, &mut buf) {
260 log::warn!("[external] Failed to read vault stdout pipe: {e}");
261 }
262 }
263 buf
264 });
265 let stderr_thread = std::thread::spawn(move || -> Vec<u8> {
266 let mut buf = Vec::new();
267 if let Some(mut h) = stderr_handle {
268 if let Err(e) = std::io::Read::read_to_end(&mut h, &mut buf) {
269 log::warn!("[external] Failed to read vault stderr pipe: {e}");
270 }
271 }
272 buf
273 });
274
275 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
279 let status = loop {
280 match child.try_wait() {
281 Ok(Some(s)) => break s,
282 Ok(None) => {
283 if std::time::Instant::now() >= deadline {
284 let _ = child.kill();
285 let _ = child.wait();
286 error!(
291 "[external] Vault unreachable: {}: timed out after 30s",
292 vault_addr.unwrap_or("<env>")
293 );
294 anyhow::bail!("Vault SSH timed out. Server unreachable.");
295 }
296 std::thread::sleep(std::time::Duration::from_millis(100));
297 }
298 Err(e) => {
299 let _ = child.kill();
300 let _ = child.wait();
301 anyhow::bail!("Failed to wait for vault CLI: {}", e);
302 }
303 }
304 };
305
306 let stdout_bytes = stdout_thread.join().unwrap_or_default();
307 let stderr_bytes = stderr_thread.join().unwrap_or_default();
308 let output = std::process::Output {
309 status,
310 stdout: stdout_bytes,
311 stderr: stderr_bytes,
312 };
313
314 if !output.status.success() {
315 let stderr = String::from_utf8_lossy(&output.stderr);
316 if stderr.contains("permission denied") || stderr.contains("403") {
317 error!(
318 "[external] Vault auth failed: permission denied (role={} addr={})",
319 role,
320 vault_addr.unwrap_or("<env>")
321 );
322 anyhow::bail!("Vault SSH permission denied. Check token and policy.");
323 }
324 if stderr.contains("missing client token") || stderr.contains("token expired") {
325 error!(
326 "[external] Vault auth failed: token missing or expired (role={} addr={})",
327 role,
328 vault_addr.unwrap_or("<env>")
329 );
330 anyhow::bail!("Vault SSH token missing or expired. Run `vault login`.");
331 }
332 if stderr.contains("connection refused") {
335 error!(
336 "[external] Vault unreachable: {}: connection refused",
337 vault_addr.unwrap_or("<env>")
338 );
339 anyhow::bail!("Vault SSH connection refused.");
340 }
341 if stderr.contains("i/o timeout") || stderr.contains("dial tcp") {
342 error!(
343 "[external] Vault unreachable: {}: connection timed out",
344 vault_addr.unwrap_or("<env>")
345 );
346 anyhow::bail!("Vault SSH connection timed out.");
347 }
348 if stderr.contains("no such host") {
349 error!(
350 "[external] Vault unreachable: {}: no such host",
351 vault_addr.unwrap_or("<env>")
352 );
353 anyhow::bail!("Vault SSH host not found.");
354 }
355 if stderr.contains("server gave HTTP response to HTTPS client") {
356 error!(
357 "[external] Vault unreachable: {}: server returned HTTP on HTTPS connection",
358 vault_addr.unwrap_or("<env>")
359 );
360 anyhow::bail!("Vault SSH server uses HTTP, not HTTPS. Set address to http://.");
361 }
362 if stderr.contains("certificate signed by unknown authority")
363 || stderr.contains("tls:")
364 || stderr.contains("x509:")
365 {
366 error!(
367 "[external] Vault unreachable: {}: TLS error",
368 vault_addr.unwrap_or("<env>")
369 );
370 anyhow::bail!("Vault SSH TLS error. Check certificate or use http://.");
371 }
372 error!(
373 "[external] Vault SSH signing failed: {}",
374 scrub_vault_stderr(&stderr)
375 );
376 anyhow::bail!("Vault SSH failed: {}", scrub_vault_stderr(&stderr));
377 }
378
379 let signed_key = String::from_utf8_lossy(&output.stdout).trim().to_string();
380 if signed_key.is_empty() {
381 anyhow::bail!("Vault returned empty certificate for role '{}'", role);
382 }
383
384 crate::fs_util::atomic_write(&cert_dest, signed_key.as_bytes())
385 .with_context(|| format!("Failed to write certificate to {}", cert_dest.display()))?;
386
387 info!("Vault SSH certificate signed for {}", alias);
388 Ok(SignResult {
389 cert_path: cert_dest,
390 })
391}
392
393pub fn check_cert_validity(cert_path: &Path) -> CertStatus {
401 if !cert_path.exists() {
402 return CertStatus::Missing;
403 }
404
405 let output = match Command::new("ssh-keygen")
406 .args(["-L", "-f"])
407 .arg(cert_path)
408 .output()
409 {
410 Ok(o) => o,
411 Err(e) => return CertStatus::Invalid(format!("Failed to run ssh-keygen: {}", e)),
412 };
413
414 if !output.status.success() {
415 return CertStatus::Invalid("ssh-keygen could not read certificate".to_string());
416 }
417
418 let stdout = String::from_utf8_lossy(&output.stdout);
419
420 for line in stdout.lines() {
422 let t = line.trim();
423 if t == "Valid: forever" || t.starts_with("Valid: from ") && t.ends_with(" to forever") {
424 return CertStatus::Valid {
425 expires_at: i64::MAX,
426 remaining_secs: i64::MAX,
427 total_secs: i64::MAX,
428 };
429 }
430 }
431
432 for line in stdout.lines() {
433 if let Some((from, to)) = parse_valid_line(line) {
434 let ttl = to - from; if ttl <= 0 {
439 return CertStatus::Invalid(
440 "certificate has non-positive validity window".to_string(),
441 );
442 }
443
444 let signed_at = match std::fs::metadata(cert_path)
446 .and_then(|m| m.modified())
447 .ok()
448 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
449 {
450 Some(d) => d.as_secs() as i64,
451 None => {
452 return CertStatus::Expired;
454 }
455 };
456
457 let now = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
458 Ok(d) => d.as_secs() as i64,
459 Err(_) => {
460 return CertStatus::Invalid("system clock before unix epoch".to_string());
461 }
462 };
463
464 let elapsed = now - signed_at;
465 let remaining = ttl - elapsed;
466
467 if remaining <= 0 {
468 return CertStatus::Expired;
469 }
470 let expires_at = now + remaining;
471 return CertStatus::Valid {
472 expires_at,
473 remaining_secs: remaining,
474 total_secs: ttl,
475 };
476 }
477 }
478
479 CertStatus::Invalid("No Valid: line found in certificate".to_string())
480}
481
482fn parse_valid_line(line: &str) -> Option<(i64, i64)> {
484 let trimmed = line.trim();
485 let rest = trimmed.strip_prefix("Valid:")?;
486 let rest = rest.trim();
487 let rest = rest.strip_prefix("from ")?;
488 let (from_str, rest) = rest.split_once(" to ")?;
489 let to_str = rest.trim();
490
491 let from = parse_ssh_datetime(from_str)?;
492 let to = parse_ssh_datetime(to_str)?;
493 Some((from, to))
494}
495
496fn parse_ssh_datetime(s: &str) -> Option<i64> {
501 let s = s.trim();
502 if s.len() < 19 {
503 return None;
504 }
505 let year: i64 = s.get(0..4)?.parse().ok()?;
506 let month: i64 = s.get(5..7)?.parse().ok()?;
507 let day: i64 = s.get(8..10)?.parse().ok()?;
508 let hour: i64 = s.get(11..13)?.parse().ok()?;
509 let min: i64 = s.get(14..16)?.parse().ok()?;
510 let sec: i64 = s.get(17..19)?.parse().ok()?;
511
512 if s.as_bytes().get(4) != Some(&b'-')
513 || s.as_bytes().get(7) != Some(&b'-')
514 || s.as_bytes().get(10) != Some(&b'T')
515 || s.as_bytes().get(13) != Some(&b':')
516 || s.as_bytes().get(16) != Some(&b':')
517 {
518 return None;
519 }
520
521 if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
522 return None;
523 }
524 if !(0..=23).contains(&hour) || !(0..=59).contains(&min) || !(0..=59).contains(&sec) {
525 return None;
526 }
527
528 let mut y = year;
530 let m = if month <= 2 {
531 y -= 1;
532 month + 9
533 } else {
534 month - 3
535 };
536 let era = if y >= 0 { y } else { y - 399 } / 400;
537 let yoe = y - era * 400;
538 let doy = (153 * m + 2) / 5 + day - 1;
539 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
540 let days = era * 146097 + doe - 719468;
541
542 Some(days * 86400 + hour * 3600 + min * 60 + sec)
543}
544
545pub fn needs_renewal(status: &CertStatus) -> bool {
552 match status {
553 CertStatus::Missing | CertStatus::Expired | CertStatus::Invalid(_) => true,
554 CertStatus::Valid {
555 remaining_secs,
556 total_secs,
557 ..
558 } => {
559 let threshold = if *total_secs > 0 && *total_secs <= RENEWAL_THRESHOLD_SECS {
560 *total_secs / 2
561 } else {
562 RENEWAL_THRESHOLD_SECS
563 };
564 *remaining_secs < threshold
565 }
566 }
567}
568
569pub fn ensure_cert(
572 role: &str,
573 pubkey_path: &Path,
574 alias: &str,
575 certificate_file: &str,
576 vault_addr: Option<&str>,
577) -> Result<PathBuf> {
578 let check_path = resolve_cert_path(alias, certificate_file)?;
579 let status = check_cert_validity(&check_path);
580
581 if !needs_renewal(&status) {
582 info!("Vault SSH certificate cache hit for {}", alias);
583 return Ok(check_path);
584 }
585
586 let result = sign_certificate(role, pubkey_path, alias, vault_addr)?;
587 Ok(result.cert_path)
588}
589
590pub fn resolve_pubkey_path(identity_file: &str) -> Result<PathBuf> {
597 let home = dirs::home_dir().context("Could not determine home directory")?;
598 let fallback = home.join(".ssh/id_ed25519.pub");
599
600 if identity_file.is_empty() {
601 return Ok(fallback);
602 }
603
604 let expanded = if let Some(rest) = identity_file.strip_prefix("~/") {
605 home.join(rest)
606 } else {
607 PathBuf::from(identity_file)
608 };
609
610 let canonical_home = match std::fs::canonicalize(&home) {
616 Ok(p) => p,
617 Err(_) => return Ok(fallback),
618 };
619 if expanded.exists() {
620 match std::fs::canonicalize(&expanded) {
621 Ok(canonical) if canonical.starts_with(&canonical_home) => {}
622 _ => return Ok(fallback),
623 }
624 } else if !expanded.starts_with(&home) {
625 return Ok(fallback);
626 }
627
628 if expanded.extension().is_some_and(|ext| ext == "pub") {
629 Ok(expanded)
630 } else {
631 let mut s = expanded.into_os_string();
632 s.push(".pub");
633 Ok(PathBuf::from(s))
634 }
635}
636
637pub fn resolve_vault_role(
640 host_vault_ssh: Option<&str>,
641 provider_name: Option<&str>,
642 provider_config: &crate::providers::config::ProviderConfig,
643) -> Option<String> {
644 if let Some(role) = host_vault_ssh {
645 if !role.is_empty() {
646 return Some(role.to_string());
647 }
648 }
649
650 if let Some(name) = provider_name {
651 if let Some(section) = provider_config.section(name) {
652 if !section.vault_role.is_empty() {
653 return Some(section.vault_role.clone());
654 }
655 }
656 }
657
658 None
659}
660
661pub fn resolve_vault_addr(
674 host_vault_addr: Option<&str>,
675 provider_name: Option<&str>,
676 provider_config: &crate::providers::config::ProviderConfig,
677) -> Option<String> {
678 if let Some(addr) = host_vault_addr {
679 let trimmed = addr.trim();
680 if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
681 return Some(normalize_vault_addr(trimmed));
682 }
683 }
684
685 if let Some(name) = provider_name {
686 if let Some(section) = provider_config.section(name) {
687 let trimmed = section.vault_addr.trim();
688 if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
689 return Some(normalize_vault_addr(trimmed));
690 }
691 }
692 }
693
694 None
695}
696
697pub fn format_remaining(remaining_secs: i64) -> String {
699 if remaining_secs <= 0 {
700 return "expired".to_string();
701 }
702 let hours = remaining_secs / 3600;
703 let mins = (remaining_secs % 3600) / 60;
704 if hours > 0 {
705 format!("{}h {}m", hours, mins)
706 } else {
707 format!("{}m", mins)
708 }
709}
710
711#[cfg(test)]
712#[path = "vault_ssh_tests.rs"]
713mod tests;