1use anyhow::{Context, Result};
2use log::{debug, error, info};
3use std::path::{Path, PathBuf};
4use std::process::Command;
5
6#[derive(Debug)]
8pub struct SignResult {
9 pub cert_path: PathBuf,
10}
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum CertStatus {
15 Valid {
16 expires_at: i64,
17 remaining_secs: i64,
18 total_secs: i64,
21 },
22 Expired,
23 Missing,
24 Invalid(String),
25}
26
27pub const RENEWAL_THRESHOLD_SECS: i64 = 300;
29
30pub const CERT_STATUS_CACHE_TTL_SECS: u64 = 300;
36
37pub const CERT_ERROR_BACKOFF_SECS: u64 = 30;
42
43pub fn is_valid_role(s: &str) -> bool {
46 !s.is_empty()
47 && s.len() <= 128
48 && s.chars()
49 .all(|c| c.is_ascii_alphanumeric() || c == '/' || c == '_' || c == '-')
50}
51
52pub fn is_valid_vault_addr(s: &str) -> bool {
59 let trimmed = s.trim();
60 !trimmed.is_empty()
61 && trimmed.len() <= 512
62 && !trimmed.chars().any(|c| c.is_control() || c.is_whitespace())
63}
64
65pub fn normalize_vault_addr(s: &str) -> String {
72 let trimmed = s.trim();
73 let lower = trimmed.to_ascii_lowercase();
75 let (with_scheme, scheme_len) = if lower.starts_with("http://") || lower.starts_with("https://")
76 {
77 let len = if lower.starts_with("https://") { 8 } else { 7 };
78 (trimmed.to_string(), len)
79 } else if trimmed.contains("://") {
80 return trimmed.to_string();
82 } else {
83 (format!("https://{}", trimmed), 8)
84 };
85 let after_scheme = &with_scheme[scheme_len..];
87 let authority = after_scheme.split('/').next().unwrap_or(after_scheme);
88 let has_port = if let Some(bracket_end) = authority.rfind(']') {
91 authority[bracket_end..].contains(':')
92 } else {
93 authority.contains(':')
94 };
95 if has_port {
96 with_scheme
97 } else {
98 let default_port = if lower.starts_with("http://") {
101 80
102 } else if lower.starts_with("https://") {
103 443
104 } else {
105 8200
106 };
107 let path_start = scheme_len + authority.len();
108 format!(
109 "{}:{}{}",
110 &with_scheme[..path_start],
111 default_port,
112 &with_scheme[path_start..]
113 )
114 }
115}
116
117pub fn scrub_vault_stderr(raw: &str) -> String {
121 let filtered: String = raw
122 .lines()
123 .filter(|line| {
124 let lower = line.to_ascii_lowercase();
125 !(lower.contains("token")
126 || lower.contains("secret")
127 || lower.contains("x-vault-")
128 || lower.contains("cookie")
129 || lower.contains("authorization"))
130 })
131 .collect::<Vec<_>>()
132 .join(" ");
133 let trimmed = filtered.trim();
134 if trimmed.is_empty() {
135 return "Vault SSH signing failed. Check vault status and policy".to_string();
136 }
137 if trimmed.chars().count() > 200 {
138 trimmed.chars().take(200).collect::<String>() + "..."
139 } else {
140 trimmed.to_string()
141 }
142}
143
144pub fn cert_path_for(alias: &str) -> Result<PathBuf> {
146 anyhow::ensure!(
147 !alias.is_empty()
148 && !alias.contains('/')
149 && !alias.contains('\\')
150 && !alias.contains(':')
151 && !alias.contains('\0')
152 && !alias.contains(".."),
153 "Invalid alias for cert path: '{}'",
154 alias
155 );
156 let dir = dirs::home_dir()
157 .context("Could not determine home directory")?
158 .join(".purple/certs");
159 Ok(dir.join(format!("{}-cert.pub", alias)))
160}
161
162pub fn resolve_cert_path(alias: &str, certificate_file: &str) -> Result<PathBuf> {
165 if !certificate_file.is_empty() {
166 let expanded = if let Some(rest) = certificate_file.strip_prefix("~/") {
167 if let Some(home) = dirs::home_dir() {
168 home.join(rest)
169 } else {
170 PathBuf::from(certificate_file)
171 }
172 } else {
173 PathBuf::from(certificate_file)
174 };
175 Ok(expanded)
176 } else {
177 cert_path_for(alias)
178 }
179}
180
181pub fn sign_certificate(
191 role: &str,
192 pubkey_path: &Path,
193 alias: &str,
194 vault_addr: Option<&str>,
195) -> Result<SignResult> {
196 if !pubkey_path.exists() {
197 anyhow::bail!(
198 "Public key not found: {}. Set IdentityFile on the host or ensure ~/.ssh/id_ed25519.pub exists.",
199 pubkey_path.display()
200 );
201 }
202
203 if !is_valid_role(role) {
204 anyhow::bail!("Invalid Vault SSH role: '{}'", role);
205 }
206
207 let cert_dest = cert_path_for(alias)?;
208
209 if let Some(parent) = cert_dest.parent() {
210 std::fs::create_dir_all(parent)
211 .with_context(|| format!("Failed to create {}", parent.display()))?;
212 }
213
214 let pubkey_str = pubkey_path.to_str().context(
218 "public key path contains non-UTF8 bytes; vault CLI requires a valid UTF-8 path",
219 )?;
220 if pubkey_str.contains('=') {
227 anyhow::bail!(
228 "Public key path '{}' contains '=' which is not supported by the Vault CLI argument format. Rename the key file or directory.",
229 pubkey_str
230 );
231 }
232 let pubkey_arg = format!("public_key=@{}", pubkey_str);
233 debug!(
234 "[external] Vault sign request: addr={} role={}",
235 vault_addr.unwrap_or("<env>"),
236 role
237 );
238 let mut cmd = Command::new("vault");
239 cmd.args(["write", "-field=signed_key", role, &pubkey_arg]);
240 if let Some(addr) = vault_addr {
247 anyhow::ensure!(
248 is_valid_vault_addr(addr),
249 "Invalid VAULT_ADDR '{}' for role '{}'. Check the Vault SSH Address field.",
250 addr,
251 role
252 );
253 cmd.env("VAULT_ADDR", addr);
254 }
255 let mut child = cmd
256 .stdout(std::process::Stdio::piped())
257 .stderr(std::process::Stdio::piped())
258 .spawn()
259 .context("Failed to run vault CLI. Is vault installed and in PATH?")?;
260
261 let stdout_handle = child.stdout.take();
265 let stderr_handle = child.stderr.take();
266 let stdout_thread = std::thread::spawn(move || -> Vec<u8> {
267 let mut buf = Vec::new();
268 if let Some(mut h) = stdout_handle {
269 if let Err(e) = std::io::Read::read_to_end(&mut h, &mut buf) {
270 log::warn!("[external] Failed to read vault stdout pipe: {e}");
271 }
272 }
273 buf
274 });
275 let stderr_thread = std::thread::spawn(move || -> Vec<u8> {
276 let mut buf = Vec::new();
277 if let Some(mut h) = stderr_handle {
278 if let Err(e) = std::io::Read::read_to_end(&mut h, &mut buf) {
279 log::warn!("[external] Failed to read vault stderr pipe: {e}");
280 }
281 }
282 buf
283 });
284
285 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
289 let status = loop {
290 match child.try_wait() {
291 Ok(Some(s)) => break s,
292 Ok(None) => {
293 if std::time::Instant::now() >= deadline {
294 let _ = child.kill();
295 let _ = child.wait();
296 error!(
301 "[external] Vault unreachable: {}: timed out after 30s",
302 vault_addr.unwrap_or("<env>")
303 );
304 anyhow::bail!("Vault SSH timed out. Server unreachable.");
305 }
306 std::thread::sleep(std::time::Duration::from_millis(100));
307 }
308 Err(e) => {
309 let _ = child.kill();
310 let _ = child.wait();
311 anyhow::bail!("Failed to wait for vault CLI: {}", e);
312 }
313 }
314 };
315
316 let stdout_bytes = stdout_thread.join().unwrap_or_default();
317 let stderr_bytes = stderr_thread.join().unwrap_or_default();
318 let output = std::process::Output {
319 status,
320 stdout: stdout_bytes,
321 stderr: stderr_bytes,
322 };
323
324 if !output.status.success() {
325 let stderr = String::from_utf8_lossy(&output.stderr);
326 if stderr.contains("permission denied") || stderr.contains("403") {
327 error!(
328 "[external] Vault auth failed: permission denied (role={} addr={})",
329 role,
330 vault_addr.unwrap_or("<env>")
331 );
332 anyhow::bail!("Vault SSH permission denied. Check token and policy.");
333 }
334 if stderr.contains("missing client token") || stderr.contains("token expired") {
335 error!(
336 "[external] Vault auth failed: token missing or expired (role={} addr={})",
337 role,
338 vault_addr.unwrap_or("<env>")
339 );
340 anyhow::bail!("Vault SSH token missing or expired. Run `vault login`.");
341 }
342 if stderr.contains("connection refused") {
345 error!(
346 "[external] Vault unreachable: {}: connection refused",
347 vault_addr.unwrap_or("<env>")
348 );
349 anyhow::bail!("Vault SSH connection refused.");
350 }
351 if stderr.contains("i/o timeout") || stderr.contains("dial tcp") {
352 error!(
353 "[external] Vault unreachable: {}: connection timed out",
354 vault_addr.unwrap_or("<env>")
355 );
356 anyhow::bail!("Vault SSH connection timed out.");
357 }
358 if stderr.contains("no such host") {
359 error!(
360 "[external] Vault unreachable: {}: no such host",
361 vault_addr.unwrap_or("<env>")
362 );
363 anyhow::bail!("Vault SSH host not found.");
364 }
365 if stderr.contains("server gave HTTP response to HTTPS client") {
366 error!(
367 "[external] Vault unreachable: {}: server returned HTTP on HTTPS connection",
368 vault_addr.unwrap_or("<env>")
369 );
370 anyhow::bail!("Vault SSH server uses HTTP, not HTTPS. Set address to http://.");
371 }
372 if stderr.contains("certificate signed by unknown authority")
373 || stderr.contains("tls:")
374 || stderr.contains("x509:")
375 {
376 error!(
377 "[external] Vault unreachable: {}: TLS error",
378 vault_addr.unwrap_or("<env>")
379 );
380 anyhow::bail!("Vault SSH TLS error. Check certificate or use http://.");
381 }
382 error!(
383 "[external] Vault SSH signing failed: {}",
384 scrub_vault_stderr(&stderr)
385 );
386 anyhow::bail!("Vault SSH failed: {}", scrub_vault_stderr(&stderr));
387 }
388
389 let signed_key = String::from_utf8_lossy(&output.stdout).trim().to_string();
390 if signed_key.is_empty() {
391 anyhow::bail!("Vault returned empty certificate for role '{}'", role);
392 }
393
394 crate::fs_util::atomic_write(&cert_dest, signed_key.as_bytes())
395 .with_context(|| format!("Failed to write certificate to {}", cert_dest.display()))?;
396
397 info!("Vault SSH certificate signed for {}", alias);
398 Ok(SignResult {
399 cert_path: cert_dest,
400 })
401}
402
403pub fn check_cert_validity(cert_path: &Path) -> CertStatus {
411 if !cert_path.exists() {
412 return CertStatus::Missing;
413 }
414
415 let output = match Command::new("ssh-keygen")
416 .args(["-L", "-f"])
417 .arg(cert_path)
418 .output()
419 {
420 Ok(o) => o,
421 Err(e) => return CertStatus::Invalid(format!("Failed to run ssh-keygen: {}", e)),
422 };
423
424 if !output.status.success() {
425 return CertStatus::Invalid("ssh-keygen could not read certificate".to_string());
426 }
427
428 let stdout = String::from_utf8_lossy(&output.stdout);
429
430 for line in stdout.lines() {
432 let t = line.trim();
433 if t == "Valid: forever" || t.starts_with("Valid: from ") && t.ends_with(" to forever") {
434 return CertStatus::Valid {
435 expires_at: i64::MAX,
436 remaining_secs: i64::MAX,
437 total_secs: i64::MAX,
438 };
439 }
440 }
441
442 for line in stdout.lines() {
443 if let Some((from, to)) = parse_valid_line(line) {
444 let ttl = to - from; if ttl <= 0 {
449 return CertStatus::Invalid(
450 "certificate has non-positive validity window".to_string(),
451 );
452 }
453
454 let signed_at = match std::fs::metadata(cert_path)
456 .and_then(|m| m.modified())
457 .ok()
458 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
459 {
460 Some(d) => d.as_secs() as i64,
461 None => {
462 return CertStatus::Expired;
464 }
465 };
466
467 let now = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
468 Ok(d) => d.as_secs() as i64,
469 Err(_) => {
470 return CertStatus::Invalid("system clock before unix epoch".to_string());
471 }
472 };
473
474 let elapsed = now - signed_at;
475 let remaining = ttl - elapsed;
476
477 if remaining <= 0 {
478 return CertStatus::Expired;
479 }
480 let expires_at = now + remaining;
481 return CertStatus::Valid {
482 expires_at,
483 remaining_secs: remaining,
484 total_secs: ttl,
485 };
486 }
487 }
488
489 CertStatus::Invalid("No Valid: line found in certificate".to_string())
490}
491
492fn parse_valid_line(line: &str) -> Option<(i64, i64)> {
494 let trimmed = line.trim();
495 let rest = trimmed.strip_prefix("Valid:")?;
496 let rest = rest.trim();
497 let rest = rest.strip_prefix("from ")?;
498 let (from_str, rest) = rest.split_once(" to ")?;
499 let to_str = rest.trim();
500
501 let from = parse_ssh_datetime(from_str)?;
502 let to = parse_ssh_datetime(to_str)?;
503 Some((from, to))
504}
505
506fn parse_ssh_datetime(s: &str) -> Option<i64> {
511 let s = s.trim();
512 if s.len() < 19 {
513 return None;
514 }
515 let year: i64 = s.get(0..4)?.parse().ok()?;
516 let month: i64 = s.get(5..7)?.parse().ok()?;
517 let day: i64 = s.get(8..10)?.parse().ok()?;
518 let hour: i64 = s.get(11..13)?.parse().ok()?;
519 let min: i64 = s.get(14..16)?.parse().ok()?;
520 let sec: i64 = s.get(17..19)?.parse().ok()?;
521
522 if s.as_bytes().get(4) != Some(&b'-')
523 || s.as_bytes().get(7) != Some(&b'-')
524 || s.as_bytes().get(10) != Some(&b'T')
525 || s.as_bytes().get(13) != Some(&b':')
526 || s.as_bytes().get(16) != Some(&b':')
527 {
528 return None;
529 }
530
531 if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
532 return None;
533 }
534 if !(0..=23).contains(&hour) || !(0..=59).contains(&min) || !(0..=59).contains(&sec) {
535 return None;
536 }
537
538 let mut y = year;
540 let m = if month <= 2 {
541 y -= 1;
542 month + 9
543 } else {
544 month - 3
545 };
546 let era = if y >= 0 { y } else { y - 399 } / 400;
547 let yoe = y - era * 400;
548 let doy = (153 * m + 2) / 5 + day - 1;
549 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
550 let days = era * 146097 + doe - 719468;
551
552 Some(days * 86400 + hour * 3600 + min * 60 + sec)
553}
554
555pub fn needs_renewal(status: &CertStatus) -> bool {
562 match status {
563 CertStatus::Missing | CertStatus::Expired | CertStatus::Invalid(_) => true,
564 CertStatus::Valid {
565 remaining_secs,
566 total_secs,
567 ..
568 } => {
569 let threshold = if *total_secs > 0 && *total_secs <= RENEWAL_THRESHOLD_SECS {
570 *total_secs / 2
571 } else {
572 RENEWAL_THRESHOLD_SECS
573 };
574 *remaining_secs < threshold
575 }
576 }
577}
578
579pub fn ensure_cert(
582 role: &str,
583 pubkey_path: &Path,
584 alias: &str,
585 certificate_file: &str,
586 vault_addr: Option<&str>,
587) -> Result<PathBuf> {
588 let check_path = resolve_cert_path(alias, certificate_file)?;
589 let status = check_cert_validity(&check_path);
590
591 if !needs_renewal(&status) {
592 info!("Vault SSH certificate cache hit for {}", alias);
593 return Ok(check_path);
594 }
595
596 let result = sign_certificate(role, pubkey_path, alias, vault_addr)?;
597 Ok(result.cert_path)
598}
599
600pub fn resolve_pubkey_path(identity_file: &str) -> Result<PathBuf> {
607 let home = dirs::home_dir().context("Could not determine home directory")?;
608 let fallback = home.join(".ssh/id_ed25519.pub");
609
610 if identity_file.is_empty() {
611 return Ok(fallback);
612 }
613
614 let expanded = if let Some(rest) = identity_file.strip_prefix("~/") {
615 home.join(rest)
616 } else {
617 PathBuf::from(identity_file)
618 };
619
620 let canonical_home = match std::fs::canonicalize(&home) {
626 Ok(p) => p,
627 Err(_) => return Ok(fallback),
628 };
629 if expanded.exists() {
630 match std::fs::canonicalize(&expanded) {
631 Ok(canonical) if canonical.starts_with(&canonical_home) => {}
632 _ => return Ok(fallback),
633 }
634 } else if !expanded.starts_with(&home) {
635 return Ok(fallback);
636 }
637
638 if expanded.extension().is_some_and(|ext| ext == "pub") {
639 Ok(expanded)
640 } else {
641 let mut s = expanded.into_os_string();
642 s.push(".pub");
643 Ok(PathBuf::from(s))
644 }
645}
646
647pub fn resolve_vault_role(
650 host_vault_ssh: Option<&str>,
651 provider_name: Option<&str>,
652 provider_config: &crate::providers::config::ProviderConfig,
653) -> Option<String> {
654 if let Some(role) = host_vault_ssh {
655 if !role.is_empty() {
656 return Some(role.to_string());
657 }
658 }
659
660 if let Some(name) = provider_name {
661 if let Some(section) = provider_config.section(name) {
662 if !section.vault_role.is_empty() {
663 return Some(section.vault_role.clone());
664 }
665 }
666 }
667
668 None
669}
670
671pub fn resolve_vault_addr(
684 host_vault_addr: Option<&str>,
685 provider_name: Option<&str>,
686 provider_config: &crate::providers::config::ProviderConfig,
687) -> Option<String> {
688 if let Some(addr) = host_vault_addr {
689 let trimmed = addr.trim();
690 if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
691 return Some(normalize_vault_addr(trimmed));
692 }
693 }
694
695 if let Some(name) = provider_name {
696 if let Some(section) = provider_config.section(name) {
697 let trimmed = section.vault_addr.trim();
698 if !trimmed.is_empty() && is_valid_vault_addr(trimmed) {
699 return Some(normalize_vault_addr(trimmed));
700 }
701 }
702 }
703
704 None
705}
706
707pub fn format_remaining(remaining_secs: i64) -> String {
709 if remaining_secs <= 0 {
710 return "expired".to_string();
711 }
712 let hours = remaining_secs / 3600;
713 let mins = (remaining_secs % 3600) / 60;
714 if hours > 0 {
715 format!("{}h {}m", hours, mins)
716 } else {
717 format!("{}m", mins)
718 }
719}
720
721#[cfg(test)]
722#[path = "vault_ssh_tests.rs"]
723mod tests;