1use anyhow::{Context, Result};
2use log::{debug, error, info};
3use std::path::{Path, PathBuf};
4use std::process::Command;
5
6#[derive(Debug)]
8pub struct SignResult {
9 pub cert_path: PathBuf,
10}
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum CertStatus {
15 Valid {
16 expires_at: i64,
17 remaining_secs: i64,
18 total_secs: i64,
21 },
22 Expired,
23 Missing,
24 Invalid(String),
25}
26
27pub const RENEWAL_THRESHOLD_SECS: i64 = 300;
29
30pub const CERT_STATUS_CACHE_TTL_SECS: u64 = 300;
36
37pub const CERT_ERROR_BACKOFF_SECS: u64 = 30;
42
43pub fn is_valid_role(s: &str) -> bool {
46 !s.is_empty()
47 && s.len() <= 128
48 && s.chars()
49 .all(|c| c.is_ascii_alphanumeric() || c == '/' || c == '_' || c == '-')
50}
51
52pub fn is_valid_vault_addr(s: &str) -> bool {
59 let trimmed = s.trim();
60 !trimmed.is_empty()
61 && trimmed.len() <= 512
62 && !trimmed.chars().any(|c| c.is_control() || c.is_whitespace())
63}
64
65pub fn normalize_vault_addr(s: &str) -> String {
71 let trimmed = s.trim();
72 let lower = trimmed.to_ascii_lowercase();
74 let (with_scheme, scheme_len) = if lower.starts_with("http://") || lower.starts_with("https://")
75 {
76 let len = if lower.starts_with("https://") { 8 } else { 7 };
77 (trimmed.to_string(), len)
78 } else if trimmed.contains("://") {
79 return trimmed.to_string();
81 } else {
82 (format!("https://{}", trimmed), 8)
83 };
84 let after_scheme = &with_scheme[scheme_len..];
86 let authority = after_scheme.split('/').next().unwrap_or(after_scheme);
87 let has_port = if let Some(bracket_end) = authority.rfind(']') {
90 authority[bracket_end..].contains(':')
91 } else {
92 authority.contains(':')
93 };
94 if has_port {
95 with_scheme
96 } else {
97 let path_start = scheme_len + authority.len();
99 format!(
100 "{}:8200{}",
101 &with_scheme[..path_start],
102 &with_scheme[path_start..]
103 )
104 }
105}
106
107pub fn scrub_vault_stderr(raw: &str) -> String {
111 let filtered: String = raw
112 .lines()
113 .filter(|line| {
114 let lower = line.to_ascii_lowercase();
115 !(lower.contains("token")
116 || lower.contains("secret")
117 || lower.contains("x-vault-")
118 || lower.contains("cookie")
119 || lower.contains("authorization"))
120 })
121 .collect::<Vec<_>>()
122 .join(" ");
123 let trimmed = filtered.trim();
124 if trimmed.is_empty() {
125 return "Vault SSH signing failed. Check vault status and policy".to_string();
126 }
127 if trimmed.chars().count() > 200 {
128 trimmed.chars().take(200).collect::<String>() + "..."
129 } else {
130 trimmed.to_string()
131 }
132}
133
134pub fn cert_path_for(alias: &str) -> Result<PathBuf> {
136 anyhow::ensure!(
137 !alias.is_empty()
138 && !alias.contains('/')
139 && !alias.contains('\\')
140 && !alias.contains(':')
141 && !alias.contains('\0')
142 && !alias.contains(".."),
143 "Invalid alias for cert path: '{}'",
144 alias
145 );
146 let dir = dirs::home_dir()
147 .context("Could not determine home directory")?
148 .join(".purple/certs");
149 Ok(dir.join(format!("{}-cert.pub", alias)))
150}
151
152pub fn resolve_cert_path(alias: &str, certificate_file: &str) -> Result<PathBuf> {
155 if !certificate_file.is_empty() {
156 let expanded = if let Some(rest) = certificate_file.strip_prefix("~/") {
157 if let Some(home) = dirs::home_dir() {
158 home.join(rest)
159 } else {
160 PathBuf::from(certificate_file)
161 }
162 } else {
163 PathBuf::from(certificate_file)
164 };
165 Ok(expanded)
166 } else {
167 cert_path_for(alias)
168 }
169}
170
171pub fn sign_certificate(
181 role: &str,
182 pubkey_path: &Path,
183 alias: &str,
184 vault_addr: Option<&str>,
185) -> Result<SignResult> {
186 if !pubkey_path.exists() {
187 anyhow::bail!(
188 "Public key not found: {}. Set IdentityFile on the host or ensure ~/.ssh/id_ed25519.pub exists.",
189 pubkey_path.display()
190 );
191 }
192
193 if !is_valid_role(role) {
194 anyhow::bail!("Invalid Vault SSH role: '{}'", role);
195 }
196
197 let cert_dest = cert_path_for(alias)?;
198
199 if let Some(parent) = cert_dest.parent() {
200 std::fs::create_dir_all(parent)
201 .with_context(|| format!("Failed to create {}", parent.display()))?;
202 }
203
204 let pubkey_str = pubkey_path.to_str().context(
208 "public key path contains non-UTF8 bytes; vault CLI requires a valid UTF-8 path",
209 )?;
210 if pubkey_str.contains('=') {
217 anyhow::bail!(
218 "Public key path '{}' contains '=' which is not supported by the Vault CLI argument format. Rename the key file or directory.",
219 pubkey_str
220 );
221 }
222 let pubkey_arg = format!("public_key=@{}", pubkey_str);
223 debug!(
224 "[external] Vault sign request: addr={} role={}",
225 vault_addr.unwrap_or("<env>"),
226 role
227 );
228 let mut cmd = Command::new("vault");
229 cmd.args(["write", "-field=signed_key", role, &pubkey_arg]);
230 if let Some(addr) = vault_addr {
237 anyhow::ensure!(
238 is_valid_vault_addr(addr),
239 "Invalid VAULT_ADDR '{}' for role '{}'. Check the Vault SSH Address field.",
240 addr,
241 role
242 );
243 cmd.env("VAULT_ADDR", addr);
244 }
245 let mut child = cmd
246 .stdout(std::process::Stdio::piped())
247 .stderr(std::process::Stdio::piped())
248 .spawn()
249 .context("Failed to run vault CLI. Is vault installed and in PATH?")?;
250
251 let stdout_handle = child.stdout.take();
255 let stderr_handle = child.stderr.take();
256 let stdout_thread = std::thread::spawn(move || -> Vec<u8> {
257 let mut buf = Vec::new();
258 if let Some(mut h) = stdout_handle {
259 let _ = std::io::Read::read_to_end(&mut h, &mut buf);
260 }
261 buf
262 });
263 let stderr_thread = std::thread::spawn(move || -> Vec<u8> {
264 let mut buf = Vec::new();
265 if let Some(mut h) = stderr_handle {
266 let _ = std::io::Read::read_to_end(&mut h, &mut buf);
267 }
268 buf
269 });
270
271 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
275 let status = loop {
276 match child.try_wait() {
277 Ok(Some(s)) => break s,
278 Ok(None) => {
279 if std::time::Instant::now() >= deadline {
280 let _ = child.kill();
281 let _ = child.wait();
282 error!(
287 "[external] Vault unreachable: {}: timed out after 30s",
288 vault_addr.unwrap_or("<env>")
289 );
290 anyhow::bail!("Vault SSH timed out. Server unreachable.");
291 }
292 std::thread::sleep(std::time::Duration::from_millis(100));
293 }
294 Err(e) => {
295 let _ = child.kill();
296 let _ = child.wait();
297 anyhow::bail!("Failed to wait for vault CLI: {}", e);
298 }
299 }
300 };
301
302 let stdout_bytes = stdout_thread.join().unwrap_or_default();
303 let stderr_bytes = stderr_thread.join().unwrap_or_default();
304 let output = std::process::Output {
305 status,
306 stdout: stdout_bytes,
307 stderr: stderr_bytes,
308 };
309
310 if !output.status.success() {
311 let stderr = String::from_utf8_lossy(&output.stderr);
312 if stderr.contains("permission denied") || stderr.contains("403") {
313 error!(
314 "[external] Vault auth failed: permission denied (role={} addr={})",
315 role,
316 vault_addr.unwrap_or("<env>")
317 );
318 anyhow::bail!("Vault SSH permission denied. Check token and policy.");
319 }
320 if stderr.contains("missing client token") || stderr.contains("token expired") {
321 error!(
322 "[external] Vault auth failed: token missing or expired (role={} addr={})",
323 role,
324 vault_addr.unwrap_or("<env>")
325 );
326 anyhow::bail!("Vault SSH token missing or expired. Run `vault login`.");
327 }
328 if stderr.contains("connection refused") {
331 error!(
332 "[external] Vault unreachable: {}: connection refused",
333 vault_addr.unwrap_or("<env>")
334 );
335 anyhow::bail!("Vault SSH connection refused.");
336 }
337 if stderr.contains("i/o timeout") || stderr.contains("dial tcp") {
338 error!(
339 "[external] Vault unreachable: {}: connection timed out",
340 vault_addr.unwrap_or("<env>")
341 );
342 anyhow::bail!("Vault SSH connection timed out.");
343 }
344 if stderr.contains("no such host") {
345 error!(
346 "[external] Vault unreachable: {}: no such host",
347 vault_addr.unwrap_or("<env>")
348 );
349 anyhow::bail!("Vault SSH host not found.");
350 }
351 if stderr.contains("server gave HTTP response to HTTPS client") {
352 error!(
353 "[external] Vault unreachable: {}: server returned HTTP on HTTPS connection",
354 vault_addr.unwrap_or("<env>")
355 );
356 anyhow::bail!("Vault SSH server uses HTTP, not HTTPS. Set address to http://.");
357 }
358 if stderr.contains("certificate signed by unknown authority")
359 || stderr.contains("tls:")
360 || stderr.contains("x509:")
361 {
362 error!(
363 "[external] Vault unreachable: {}: TLS error",
364 vault_addr.unwrap_or("<env>")
365 );
366 anyhow::bail!("Vault SSH TLS error. Check certificate or use http://.");
367 }
368 error!(
369 "[external] Vault SSH signing failed: {}",
370 scrub_vault_stderr(&stderr)
371 );
372 anyhow::bail!("Vault SSH failed: {}", scrub_vault_stderr(&stderr));
373 }
374
375 let signed_key = String::from_utf8_lossy(&output.stdout).trim().to_string();
376 if signed_key.is_empty() {
377 anyhow::bail!("Vault returned empty certificate for role '{}'", role);
378 }
379
380 crate::fs_util::atomic_write(&cert_dest, signed_key.as_bytes())
381 .with_context(|| format!("Failed to write certificate to {}", cert_dest.display()))?;
382
383 info!("Vault SSH certificate signed for {}", alias);
384 Ok(SignResult {
385 cert_path: cert_dest,
386 })
387}
388
389pub fn check_cert_validity(cert_path: &Path) -> CertStatus {
397 if !cert_path.exists() {
398 return CertStatus::Missing;
399 }
400
401 let output = match Command::new("ssh-keygen")
402 .args(["-L", "-f"])
403 .arg(cert_path)
404 .output()
405 {
406 Ok(o) => o,
407 Err(e) => return CertStatus::Invalid(format!("Failed to run ssh-keygen: {}", e)),
408 };
409
410 if !output.status.success() {
411 return CertStatus::Invalid("ssh-keygen could not read certificate".to_string());
412 }
413
414 let stdout = String::from_utf8_lossy(&output.stdout);
415
416 for line in stdout.lines() {
418 let t = line.trim();
419 if t == "Valid: forever" || t.starts_with("Valid: from ") && t.ends_with(" to forever") {
420 return CertStatus::Valid {
421 expires_at: i64::MAX,
422 remaining_secs: i64::MAX,
423 total_secs: i64::MAX,
424 };
425 }
426 }
427
428 for line in stdout.lines() {
429 if let Some((from, to)) = parse_valid_line(line) {
430 let ttl = to - from; if ttl <= 0 {
435 return CertStatus::Invalid(
436 "certificate has non-positive validity window".to_string(),
437 );
438 }
439
440 let signed_at = match std::fs::metadata(cert_path)
442 .and_then(|m| m.modified())
443 .ok()
444 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
445 {
446 Some(d) => d.as_secs() as i64,
447 None => {
448 return CertStatus::Expired;
450 }
451 };
452
453 let now = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
454 Ok(d) => d.as_secs() as i64,
455 Err(_) => {
456 return CertStatus::Invalid("system clock before unix epoch".to_string());
457 }
458 };
459
460 let elapsed = now - signed_at;
461 let remaining = ttl - elapsed;
462
463 if remaining <= 0 {
464 return CertStatus::Expired;
465 }
466 let expires_at = now + remaining;
467 return CertStatus::Valid {
468 expires_at,
469 remaining_secs: remaining,
470 total_secs: ttl,
471 };
472 }
473 }
474
475 CertStatus::Invalid("No Valid: line found in certificate".to_string())
476}
477
478fn parse_valid_line(line: &str) -> Option<(i64, i64)> {
480 let trimmed = line.trim();
481 let rest = trimmed.strip_prefix("Valid:")?;
482 let rest = rest.trim();
483 let rest = rest.strip_prefix("from ")?;
484 let (from_str, rest) = rest.split_once(" to ")?;
485 let to_str = rest.trim();
486
487 let from = parse_ssh_datetime(from_str)?;
488 let to = parse_ssh_datetime(to_str)?;
489 Some((from, to))
490}
491
492fn parse_ssh_datetime(s: &str) -> Option<i64> {
497 let s = s.trim();
498 if s.len() < 19 {
499 return None;
500 }
501 let year: i64 = s.get(0..4)?.parse().ok()?;
502 let month: i64 = s.get(5..7)?.parse().ok()?;
503 let day: i64 = s.get(8..10)?.parse().ok()?;
504 let hour: i64 = s.get(11..13)?.parse().ok()?;
505 let min: i64 = s.get(14..16)?.parse().ok()?;
506 let sec: i64 = s.get(17..19)?.parse().ok()?;
507
508 if s.as_bytes().get(4) != Some(&b'-')
509 || s.as_bytes().get(7) != Some(&b'-')
510 || s.as_bytes().get(10) != Some(&b'T')
511 || s.as_bytes().get(13) != Some(&b':')
512 || s.as_bytes().get(16) != Some(&b':')
513 {
514 return None;
515 }
516
517 if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
518 return None;
519 }
520 if !(0..=23).contains(&hour) || !(0..=59).contains(&min) || !(0..=59).contains(&sec) {
521 return None;
522 }
523
524 let mut y = year;
526 let m = if month <= 2 {
527 y -= 1;
528 month + 9
529 } else {
530 month - 3
531 };
532 let era = if y >= 0 { y } else { y - 399 } / 400;
533 let yoe = y - era * 400;
534 let doy = (153 * m + 2) / 5 + day - 1;
535 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
536 let days = era * 146097 + doe - 719468;
537
538 Some(days * 86400 + hour * 3600 + min * 60 + sec)
539}
540
541pub fn needs_renewal(status: &CertStatus) -> bool {
548 match status {
549 CertStatus::Missing | CertStatus::Expired | CertStatus::Invalid(_) => true,
550 CertStatus::Valid {
551 remaining_secs,
552 total_secs,
553 ..
554 } => {
555 let threshold = if *total_secs > 0 && *total_secs <= RENEWAL_THRESHOLD_SECS {
556 *total_secs / 2
557 } else {
558 RENEWAL_THRESHOLD_SECS
559 };
560 *remaining_secs < threshold
561 }
562 }
563}
564
565pub fn ensure_cert(
568 role: &str,
569 pubkey_path: &Path,
570 alias: &str,
571 certificate_file: &str,
572 vault_addr: Option<&str>,
573) -> Result<PathBuf> {
574 let check_path = resolve_cert_path(alias, certificate_file)?;
575 let status = check_cert_validity(&check_path);
576
577 if !needs_renewal(&status) {
578 info!("Vault SSH certificate cache hit for {}", alias);
579 return Ok(check_path);
580 }
581
582 let result = sign_certificate(role, pubkey_path, alias, vault_addr)?;
583 Ok(result.cert_path)
584}
585
586pub fn resolve_pubkey_path(identity_file: &str) -> Result<PathBuf> {
593 let home = dirs::home_dir().context("Could not determine home directory")?;
594 let fallback = home.join(".ssh/id_ed25519.pub");
595
596 if identity_file.is_empty() {
597 return Ok(fallback);
598 }
599
600 let expanded = if let Some(rest) = identity_file.strip_prefix("~/") {
601 home.join(rest)
602 } else {
603 PathBuf::from(identity_file)
604 };
605
606 let canonical_home = match std::fs::canonicalize(&home) {
612 Ok(p) => p,
613 Err(_) => return Ok(fallback),
614 };
615 if expanded.exists() {
616 match std::fs::canonicalize(&expanded) {
617 Ok(canonical) if canonical.starts_with(&canonical_home) => {}
618 _ => return Ok(fallback),
619 }
620 } else if !expanded.starts_with(&home) {
621 return Ok(fallback);
622 }
623
624 if expanded.extension().is_some_and(|ext| ext == "pub") {
625 Ok(expanded)
626 } else {
627 let mut s = expanded.into_os_string();
628 s.push(".pub");
629 Ok(PathBuf::from(s))
630 }
631}
632
633pub fn resolve_vault_role(
636 host_vault_ssh: Option<&str>,
637 provider_name: Option<&str>,
638 provider_config: &crate::providers::config::ProviderConfig,
639) -> Option<String> {
640 if let Some(role) = host_vault_ssh {
641 if !role.is_empty() {
642 return Some(role.to_string());
643 }
644 }
645
646 if let Some(name) = provider_name {
647 if let Some(section) = provider_config.section(name) {
648 if !section.vault_role.is_empty() {
649 return Some(section.vault_role.clone());
650 }
651 }
652 }
653
654 None
655}
656
657pub fn resolve_vault_addr(
670 host_vault_addr: Option<&str>,
671 provider_name: Option<&str>,
672 provider_config: &crate::providers::config::ProviderConfig,
673) -> Option<String> {
674 if let Some(addr) = host_vault_addr {
675 let trimmed = addr.trim();
676 if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
677 return Some(normalize_vault_addr(trimmed));
678 }
679 }
680
681 if let Some(name) = provider_name {
682 if let Some(section) = provider_config.section(name) {
683 let trimmed = section.vault_addr.trim();
684 if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
685 return Some(normalize_vault_addr(trimmed));
686 }
687 }
688 }
689
690 None
691}
692
693pub fn format_remaining(remaining_secs: i64) -> String {
695 if remaining_secs <= 0 {
696 return "expired".to_string();
697 }
698 let hours = remaining_secs / 3600;
699 let mins = (remaining_secs % 3600) / 60;
700 if hours > 0 {
701 format!("{}h {}m", hours, mins)
702 } else {
703 format!("{}m", mins)
704 }
705}
706
707#[cfg(test)]
708#[path = "vault_ssh_tests.rs"]
709mod tests;