use std::collections::HashMap;
use std::io::Write;
use anyhow::{Context, Result};
use colored::Colorize;
use tsafe_cli::cli::{SshAction, SshKeyType};
use tsafe_core::{audit::AuditEntry, errors::SafeError};
use crate::helpers::*;
pub(crate) fn cmd_ssh_add(profile: &str, key: &str) -> Result<()> {
let vault = open_vault(profile)?;
let value = vault.get(key).map_err(|e| match &e {
SafeError::SecretNotFound { .. } => anyhow::anyhow!("secret '{key}' not found"),
other => anyhow::anyhow!("{other}"),
})?;
let mut child = std::process::Command::new("ssh-add")
.arg("-")
.stdin(std::process::Stdio::piped())
.spawn()
.context("could not start ssh-add — is OpenSSH installed?")?;
child
.stdin
.as_mut()
.unwrap()
.write_all(value.as_bytes())
.context("failed to write key to ssh-add stdin")?;
drop(child.stdin.take());
let status = child.wait()?;
if !status.success() {
anyhow::bail!("ssh-add exited with status {status}");
}
audit(profile)
.append(&AuditEntry::success(profile, "ssh-add", Some(key)))
.ok();
println!("{} Key '{}' added to ssh-agent", "✓".green(), key);
Ok(())
}
pub(crate) fn cmd_ssh_import(
profile: &str,
path: &str,
name: Option<&str>,
tags: Vec<String>,
) -> Result<()> {
let key_content = std::fs::read_to_string(path)
.with_context(|| format!("could not read SSH key file: {path}"))?;
if !key_content.contains("PRIVATE KEY") {
anyhow::bail!("file does not appear to be an SSH private key: {path}");
}
let key_name = name.unwrap_or_else(|| {
std::path::Path::new(path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("ssh_key")
});
let mut tag_map = parse_tags_map(&tags);
tag_map.insert("type".into(), "ssh".into());
let mut vault = open_vault(profile)?;
vault.set(key_name, &key_content, tag_map)?;
audit(profile)
.append(&AuditEntry::success(profile, "ssh-import", Some(key_name)))
.ok();
println!("{} SSH key imported as '{}'", "✓".green(), key_name);
Ok(())
}
pub(crate) fn cmd_ssh_list(profile: &str) -> Result<()> {
let vault = open_vault(profile)?;
let keys: Vec<&str> = vault
.list()
.iter()
.filter(|&&key| {
if let Some(entry) = vault.file().secrets.get(key) {
if entry.tags.get("type").map(String::as_str) == Some("ssh") {
return true;
}
}
if let Ok(value) = vault.get(key) {
return value.contains("PRIVATE KEY");
}
false
})
.copied()
.collect();
if keys.is_empty() {
println!("No SSH keys found in profile '{profile}'.");
println!(" Import one with: tsafe ssh-import ~/.ssh/id_ed25519");
println!(" Or generate: tsafe ssh generate <key-name>");
} else {
println!("SSH keys in profile '{profile}':");
for key in &keys {
let key_type = vault
.file()
.secrets
.get(*key)
.and_then(|e| e.tags.get("ssh_key_type"))
.map(String::as_str)
.unwrap_or("unknown");
println!(" {key} ({key_type})");
}
println!(
"\n {} key(s). Use `tsafe ssh public-key <name>` to extract a public key.",
keys.len()
);
}
Ok(())
}
pub(crate) fn cmd_ssh_public_key(profile: &str, key: &str) -> Result<()> {
use ssh_key::PrivateKey;
let vault = open_vault(profile)?;
let private_key_pem = vault.get(key).map_err(|e| match &e {
SafeError::SecretNotFound { .. } => anyhow::anyhow!(
"secret '{key}' not found in profile '{profile}'\n\
List SSH keys with: tsafe ssh list"
),
other => anyhow::anyhow!("{other}"),
})?;
if !private_key_pem.contains("PRIVATE KEY") {
anyhow::bail!(
"secret '{key}' does not appear to be an SSH private key \
(missing 'PRIVATE KEY' header)"
);
}
let private_key = PrivateKey::from_openssh(private_key_pem.as_bytes())
.with_context(|| format!("failed to parse SSH private key '{key}'"))?;
let pubkey = private_key
.public_key()
.to_openssh()
.context("failed to encode public key")?;
println!("{pubkey}");
audit(profile)
.append(&AuditEntry::success(profile, "ssh-public-key", Some(key)))
.ok();
Ok(())
}
pub(crate) fn cmd_ssh_generate(
profile: &str,
key: &str,
key_type: SshKeyType,
bits: u32,
comment: Option<&str>,
print_pubkey: bool,
) -> Result<()> {
use ssh_key::{Algorithm, LineEnding, PrivateKey};
let mut rng = rand::thread_rng();
let mut private_key = match key_type {
SshKeyType::Ed25519 => PrivateKey::random(&mut rng, Algorithm::Ed25519)
.context("failed to generate Ed25519 key")?,
SshKeyType::Rsa => {
use ssh_key::private::{KeypairData, RsaKeypair};
let bit_size = bits as usize;
if bit_size < 2048 {
anyhow::bail!("RSA key size must be at least 2048 bits");
}
let rsa_keypair =
RsaKeypair::random(&mut rng, bit_size).context("failed to generate RSA key")?;
let key_data = KeypairData::from(rsa_keypair);
PrivateKey::new(key_data, "").context("failed to construct RSA private key")?
}
};
let effective_comment = comment.unwrap_or(key);
private_key.set_comment(effective_comment);
let pem = private_key
.to_openssh(LineEnding::LF)
.context("failed to serialize private key to OpenSSH PEM")?;
let type_str = match key_type {
SshKeyType::Ed25519 => "ed25519",
SshKeyType::Rsa => "rsa",
};
let mut tag_map = HashMap::new();
tag_map.insert("type".to_string(), "ssh".to_string());
tag_map.insert("ssh_key_type".to_string(), type_str.to_string());
if let Some(c) = comment {
tag_map.insert("ssh_comment".to_string(), c.to_string());
}
let mut vault = open_vault(profile)?;
vault.set(key, &pem, tag_map)?;
audit(profile)
.append(&AuditEntry::success(profile, "ssh-generate", Some(key)))
.ok();
let pubkey = private_key
.public_key()
.to_openssh()
.context("failed to encode public key")?;
println!(
"{} Generated {type_str} SSH key pair, private key stored as '{key}'",
"✓".green()
);
println!();
println!("Public key:");
println!("{pubkey}");
println!();
println!(" Copy the public key above to your authorized_keys or Git hosting service.");
if !print_pubkey {
println!(" Extract it again at any time with: tsafe ssh public-key {key}");
}
Ok(())
}
pub(crate) fn cmd_ssh_config(host: Option<&str>) -> Result<()> {
let pattern = host.unwrap_or("*");
println!("# ~/.ssh/config snippet — generated by tsafe ssh config");
println!("Host {pattern}");
println!(" IdentityAgent $TSAFE_AUTH_SOCK");
Ok(())
}
pub(crate) fn cmd_ssh_agent(profile: &str, ttl: Option<&str>, sock: Option<&str>) -> Result<()> {
#[cfg(unix)]
{
cmd_ssh_agent_unix(profile, ttl, sock)
}
#[cfg(not(unix))]
{
let _ = (profile, ttl, sock);
eprintln!(
"tsafe ssh-agent is not supported on Windows (Unix socket required).\n\
Use OpenSSH's built-in ssh-agent on Windows."
);
std::process::exit(1);
}
}
#[cfg(unix)]
fn parse_ttl_duration(s: &str) -> Result<std::time::Duration> {
let s = s.trim();
let mut total_secs: u64 = 0;
let mut num_buf = String::new();
for ch in s.chars() {
if ch.is_ascii_digit() {
num_buf.push(ch);
} else {
let n: u64 = num_buf
.parse()
.with_context(|| format!("invalid duration component in '{s}'"))?;
num_buf.clear();
match ch {
'h' => total_secs += n * 3600,
'm' => total_secs += n * 60,
's' => total_secs += n,
other => anyhow::bail!("unknown duration unit '{other}' in '{s}' (use h/m/s)"),
}
}
}
if !num_buf.is_empty() {
anyhow::bail!("trailing number without unit in duration '{s}' (use h/m/s)");
}
if total_secs == 0 {
anyhow::bail!("duration must be positive");
}
Ok(std::time::Duration::from_secs(total_secs))
}
#[cfg(unix)]
fn resolve_socket_path(sock_override: Option<&str>) -> std::path::PathBuf {
if let Some(p) = sock_override {
return std::path::PathBuf::from(p);
}
if let Ok(v) = std::env::var("TSAFE_AUTH_SOCK") {
if !v.is_empty() {
return std::path::PathBuf::from(v);
}
}
if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
return std::path::PathBuf::from(xdg).join("tsafe-ssh-agent.sock");
}
let uid = {
unsafe { libc::getuid() }
};
std::path::PathBuf::from(format!("/tmp/tsafe-ssh-agent-{uid}.sock"))
}
#[cfg(unix)]
fn cmd_ssh_agent_unix(profile: &str, ttl: Option<&str>, sock: Option<&str>) -> Result<()> {
use ssh_key::PrivateKey;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
let ttl_duration: Duration = ttl
.map(parse_ttl_duration)
.transpose()?
.unwrap_or(Duration::from_secs(8 * 3600));
let expires_at = Instant::now() + ttl_duration;
let sock_path = resolve_socket_path(sock);
let loaded_keys: Vec<LoadedKey> = {
let vault = open_vault(profile)?;
let mut keys = Vec::new();
for key_name in vault.list() {
let is_ssh = vault
.file()
.secrets
.get(key_name)
.map(|e| e.tags.get("type").map(String::as_str) == Some("ssh"))
.unwrap_or(false);
if !is_ssh {
continue;
}
let pem = match vault.get(key_name) {
Ok(v) => v,
Err(_) => continue,
};
let comment = vault
.file()
.secrets
.get(key_name)
.and_then(|e| e.tags.get("ssh_comment").cloned())
.unwrap_or_else(|| key_name.to_owned());
match PrivateKey::from_openssh(pem.as_bytes()) {
Ok(mut pk) => {
if pk.comment().is_empty() {
pk.set_comment(&comment);
}
let public_key = pk.public_key().key_data().clone();
keys.push(LoadedKey {
public_key,
private_key: pk,
comment,
expires_at,
});
}
Err(e) => {
tracing::warn!(key = key_name, error = %e, "skipping unparseable SSH key");
}
}
}
keys
};
println!(
"TSAFE_AUTH_SOCK={}; export TSAFE_AUTH_SOCK",
sock_path.display()
);
if sock_path.exists() {
std::fs::remove_file(&sock_path)
.with_context(|| format!("could not remove stale socket: {}", sock_path.display()))?;
}
let _sock_guard = SockGuard(sock_path.clone());
let keys_arc = Arc::new(Mutex::new(loaded_keys));
let agent = TsafeAgent { keys: keys_arc };
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("failed to start tokio runtime")?;
rt.block_on(async move {
use ssh_agent_lib::agent::listen;
use tokio::net::UnixListener;
let listener = UnixListener::bind(&sock_path)
.with_context(|| format!("could not bind Unix socket: {}", sock_path.display()))?;
tracing::info!(
sock = %sock_path.display(),
keys = agent.keys.lock().unwrap().len(),
"tsafe SSH agent listening"
);
listen(listener, agent).await.context("agent error")
})
}
#[cfg(unix)]
struct LoadedKey {
public_key: ssh_key::public::KeyData,
private_key: ssh_key::PrivateKey,
comment: String,
expires_at: std::time::Instant,
}
#[cfg(unix)]
struct SockGuard(std::path::PathBuf);
#[cfg(unix)]
impl Drop for SockGuard {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.0);
}
}
#[cfg(unix)]
#[derive(Clone)]
struct TsafeAgent {
keys: std::sync::Arc<std::sync::Mutex<Vec<LoadedKey>>>,
}
#[cfg(unix)]
#[ssh_agent_lib::async_trait]
impl ssh_agent_lib::agent::Session for TsafeAgent {
async fn request_identities(
&mut self,
) -> std::result::Result<Vec<ssh_agent_lib::proto::Identity>, ssh_agent_lib::error::AgentError>
{
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::Identity;
use std::time::Instant;
let now = Instant::now();
let keys = self.keys.lock().map_err(AgentError::other)?;
let identities = keys
.iter()
.filter(|k| k.expires_at > now)
.map(|k| Identity {
pubkey: k.public_key.clone(),
comment: k.comment.clone(),
})
.collect();
Ok(identities)
}
async fn sign(
&mut self,
request: ssh_agent_lib::proto::SignRequest,
) -> std::result::Result<ssh_key::Signature, ssh_agent_lib::error::AgentError> {
use signature::Signer as _;
use ssh_agent_lib::error::AgentError;
use std::time::Instant;
let now = Instant::now();
let keys = self.keys.lock().map_err(AgentError::other)?;
let loaded = keys
.iter()
.find(|k| k.public_key == request.pubkey)
.ok_or_else(|| AgentError::other("key not found"))?;
if loaded.expires_at <= now {
return Err(AgentError::other("key has expired"));
}
loaded
.private_key
.try_sign(&request.data)
.map_err(AgentError::other)
}
}
pub(crate) fn cmd_ssh(profile: &str, action: SshAction) -> Result<()> {
match action {
SshAction::List => cmd_ssh_list(profile),
SshAction::PublicKey { key } => cmd_ssh_public_key(profile, &key),
SshAction::Generate {
key,
r#type,
bits,
comment,
print,
} => cmd_ssh_generate(profile, &key, r#type, bits, comment.as_deref(), print),
SshAction::Config { host } => cmd_ssh_config(host.as_deref()),
SshAction::Agent { ttl, sock } => cmd_ssh_agent(profile, ttl.as_deref(), sock.as_deref()),
}
}