use anyhow::{Context, Result, bail};
use ssh_key::{Algorithm, LineEnding, PrivateKey, certificate};
use crate::backend::FileInjection;
#[derive(Debug, Clone)]
pub struct SshConfig {
pub enabled: bool,
pub port: u16,
pub host_port: Option<u16>,
pub vault_addr: Option<String>,
pub vault_ssh_mount: String,
pub vault_ssh_role: String,
pub cert_ttl: String,
pub user: String,
}
impl Default for SshConfig {
fn default() -> Self {
Self {
enabled: false,
port: 22,
host_port: None,
vault_addr: None,
vault_ssh_mount: "ssh".to_string(),
vault_ssh_role: "agentkernel-client".to_string(),
cert_ttl: "30m".to_string(),
user: "sandbox".to_string(),
}
}
}
pub fn generate_sshd_config(config: &SshConfig) -> String {
format!(
r#"# agentkernel sshd configuration — certificate-only auth
Port {port}
ListenAddress 0.0.0.0
# Host key
HostKey /etc/ssh/ssh_host_ed25519_key
# Certificate-based authentication
TrustedUserCAKeys /etc/ssh/ca.pub
AuthorizedPrincipalsFile /etc/ssh/principals
PubkeyAuthentication yes
# Disable all other auth methods
PasswordAuthentication no
KbdInteractiveAuthentication no
# Disable root login
PermitRootLogin no
# Logging
LogLevel INFO
# Misc hardening
X11Forwarding no
PrintMotd no
AcceptEnv LANG LC_*
"#,
port = config.port
)
}
pub fn generate_ca_keypair() -> Result<(String, String)> {
let mut rng = rand::thread_rng();
let private_key = PrivateKey::random(&mut rng, Algorithm::Ed25519)
.context("Failed to generate CA ed25519 keypair")?;
let private_pem = private_key
.to_openssh(LineEnding::LF)
.context("Failed to encode CA private key")?;
let public_openssh = private_key
.public_key()
.to_openssh()
.context("Failed to encode CA public key")?;
Ok((private_pem.to_string(), public_openssh))
}
fn generate_host_keypair() -> Result<(String, String)> {
let mut rng = rand::thread_rng();
let private_key = PrivateKey::random(&mut rng, Algorithm::Ed25519)
.context("Failed to generate host ed25519 keypair")?;
let private_pem = private_key
.to_openssh(LineEnding::LF)
.context("Failed to encode host private key")?;
let public_openssh = private_key
.public_key()
.to_openssh()
.context("Failed to encode host public key")?;
Ok((private_pem.to_string(), public_openssh))
}
fn generate_start_sshd_script(config: &SshConfig) -> String {
format!(
r#"#!/bin/sh
set -e
# Create the sandbox user if it doesn't exist
if ! id -u {user} >/dev/null 2>&1; then
adduser -D -h /home/{user} -s /bin/sh {user} 2>/dev/null || \
useradd -m -d /home/{user} -s /bin/sh {user} 2>/dev/null || true
fi
# Unlock the account for SSH cert auth.
# adduser -D creates a locked account (shadow password '!').
# OpenSSH without PAM rejects locked accounts even for cert/pubkey auth.
passwd -u {user} 2>/dev/null || \
sed -i 's/^{user}:!/{user}:/' /etc/shadow 2>/dev/null || true
# Set up .ssh directory and clean login profile
mkdir -p /home/{user}/.ssh
chmod 700 /home/{user}/.ssh
touch /home/{user}/.hushlogin
cat > /home/{user}/.profile << 'PROFILE'
export PS1="agentkernel:$(basename "$PWD")\$ "
PROFILE
chown -R {user}:{user} /home/{user} 2>/dev/null || \
chown -R {user} /home/{user}
# Fix ownership and permissions on sshd files
# (docker cp preserves host UID; sshd StrictModes requires root ownership)
chown root:root /etc/ssh/sshd_config /etc/ssh/ssh_host_ed25519_key \
/etc/ssh/ssh_host_ed25519_key.pub /etc/ssh/ca.pub /etc/ssh/principals
chmod 600 /etc/ssh/ssh_host_ed25519_key
chmod 644 /etc/ssh/ssh_host_ed25519_key.pub
chmod 644 /etc/ssh/ca.pub
chmod 644 /etc/ssh/principals
chmod 644 /etc/ssh/sshd_config
# Generate host keys if sshd expects them (some distros require all types)
ssh-keygen -A 2>/dev/null || true
# Create privilege separation directory
mkdir -p /run/sshd 2>/dev/null || mkdir -p /var/run/sshd 2>/dev/null || true
# Start sshd in the background
/usr/sbin/sshd -f /etc/ssh/sshd_config -D &
echo "sshd started on port {port}"
"#,
user = config.user,
port = config.port,
)
}
pub fn sshd_file_injections(
ca_public_key: &str,
ssh_config: &SshConfig,
) -> Result<Vec<FileInjection>> {
let sshd_config_content = generate_sshd_config(ssh_config);
let (host_private, host_public) = generate_host_keypair()?;
let start_script = generate_start_sshd_script(ssh_config);
let mut files = vec![
FileInjection {
content: sshd_config_content.into_bytes(),
dest: "/etc/ssh/sshd_config".to_string(),
},
FileInjection {
content: ca_public_key.as_bytes().to_vec(),
dest: "/etc/ssh/ca.pub".to_string(),
},
FileInjection {
content: format!("{}\n", ssh_config.user).into_bytes(),
dest: "/etc/ssh/principals".to_string(),
},
FileInjection {
content: host_private.into_bytes(),
dest: "/etc/ssh/ssh_host_ed25519_key".to_string(),
},
FileInjection {
content: host_public.into_bytes(),
dest: "/etc/ssh/ssh_host_ed25519_key.pub".to_string(),
},
FileInjection {
content: start_script.into_bytes(),
dest: "/tmp/start-sshd.sh".to_string(),
},
];
files.push(FileInjection {
content: Vec::new(),
dest: format!("/home/{}/.ssh/.keep", ssh_config.user),
});
Ok(files)
}
pub fn sign_client_key_local(
ca_private_key: &str,
client_public_key: &str,
principals: &[&str],
ttl_secs: u64,
) -> Result<String> {
let ca_key =
PrivateKey::from_openssh(ca_private_key).context("Failed to parse CA private key")?;
let client_pubkey = ssh_key::PublicKey::from_openssh(client_public_key)
.context("Failed to parse client public key")?;
let mut rng = rand::thread_rng();
let valid_after = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.context("System time before UNIX epoch")?
.as_secs();
let valid_before = valid_after + ttl_secs;
let mut builder = certificate::Builder::new_with_random_nonce(
&mut rng,
client_pubkey.key_data().clone(),
valid_after,
valid_before,
)
.context("Failed to create certificate builder")?;
builder
.cert_type(certificate::CertType::User)
.context("Failed to set cert type")?;
builder
.key_id("agentkernel-client")
.context("Failed to set key id")?;
for principal in principals {
builder
.valid_principal(principal.to_string())
.context("Failed to add principal")?;
}
for ext in &[
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc",
] {
builder
.extension(*ext, "")
.context("Failed to add extension")?;
}
let cert = builder
.sign(&ca_key)
.context("Failed to sign client certificate")?;
cert.to_openssh().context("Failed to encode certificate")
}
pub fn generate_client_keypair() -> Result<(String, String)> {
let mut rng = rand::thread_rng();
let private_key = PrivateKey::random(&mut rng, Algorithm::Ed25519)
.context("Failed to generate client ed25519 keypair")?;
let private_pem = private_key
.to_openssh(LineEnding::LF)
.context("Failed to encode client private key")?;
let public_openssh = private_key
.public_key()
.to_openssh()
.context("Failed to encode client public key")?;
Ok((private_pem.to_string(), public_openssh))
}
pub fn parse_ttl_to_secs(ttl: &str) -> Result<u64> {
let ttl = ttl.trim();
if ttl.is_empty() {
bail!("TTL string is empty");
}
if let Ok(secs) = ttl.parse::<u64>() {
return Ok(secs);
}
let mut total: u64 = 0;
let mut current_num = String::new();
for ch in ttl.chars() {
if ch.is_ascii_digit() {
current_num.push(ch);
} else {
if current_num.is_empty() {
bail!("Invalid TTL format: unexpected '{}' in \"{}\"", ch, ttl);
}
let n: u64 = current_num
.parse()
.context("Invalid number in TTL string")?;
current_num.clear();
match ch {
's' => total += n,
'm' => total += n * 60,
'h' => total += n * 3600,
_ => bail!("Unknown TTL unit '{}' in \"{}\"", ch, ttl),
}
}
}
if !current_num.is_empty() {
bail!(
"Invalid TTL format: trailing digits without unit in \"{}\"",
ttl
);
}
if total == 0 {
bail!("TTL resolves to 0 seconds: \"{}\"", ttl);
}
Ok(total)
}
pub fn sign_client_key(
ssh_config: &SshConfig,
ca_private_key: Option<&str>,
client_public_key: &str,
principals: &[&str],
ttl_secs: u64,
) -> Result<String> {
if ssh_config.vault_addr.is_some() {
bail!(
"Vault SSH signing requires calling sign_client_key_vault() \
in an async context. Set vault_addr=None to use local CA signing."
);
}
let ca_key = ca_private_key.ok_or_else(|| {
anyhow::anyhow!("Local CA signing requires a CA private key (ca_private_key is None)")
})?;
sign_client_key_local(ca_key, client_public_key, principals, ttl_secs)
}
#[cfg(any(feature = "enterprise", feature = "nomad"))]
#[allow(dead_code)]
pub async fn get_vault_ca_public_key(
vault_addr: &str,
vault_token: &str,
ssh_config: &SshConfig,
) -> Result<String> {
let url = format!(
"{}/v1/{}/config/ca",
vault_addr.trim_end_matches('/'),
ssh_config.vault_ssh_mount,
);
let client = reqwest::Client::new();
let resp = client
.get(&url)
.header("X-Vault-Token", vault_token)
.send()
.await
.context("Failed to contact Vault for CA public key")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("Vault CA public key fetch failed ({}): {}", status, text);
}
let result: serde_json::Value = resp
.json()
.await
.context("Failed to parse Vault CA response")?;
let public_key = result["data"]["public_key"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Vault response missing data.public_key"))?;
Ok(public_key.to_string())
}
#[allow(dead_code)]
pub async fn sign_client_key_vault(
vault_addr: &str,
vault_token: &str,
ssh_config: &SshConfig,
client_public_key: &str,
) -> Result<String> {
let url = format!(
"{}/v1/{}/sign/{}",
vault_addr.trim_end_matches('/'),
ssh_config.vault_ssh_mount,
ssh_config.vault_ssh_role
);
let body = serde_json::json!({
"public_key": client_public_key,
"valid_principals": ssh_config.user,
"ttl": ssh_config.cert_ttl,
"cert_type": "user",
});
#[cfg(any(feature = "enterprise", feature = "nomad"))]
{
let client = reqwest::Client::new();
let resp = client
.post(&url)
.header("X-Vault-Token", vault_token)
.json(&body)
.send()
.await
.context("Failed to contact Vault")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("Vault SSH sign failed ({}): {}", status, text);
}
let result: serde_json::Value = resp
.json()
.await
.context("Failed to parse Vault response")?;
let signed_key = result["data"]["signed_key"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Vault response missing signed_key"))?;
Ok(signed_key.to_string())
}
#[cfg(not(any(feature = "enterprise", feature = "nomad")))]
{
let _ = (url, body, vault_token);
bail!(
"Vault SSH signing requires the 'enterprise' or 'nomad' feature \
(for reqwest HTTP client). Rebuild with --features enterprise"
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ssh_config_defaults() {
let config = SshConfig::default();
assert!(!config.enabled);
assert_eq!(config.port, 22);
assert!(config.host_port.is_none());
assert!(config.vault_addr.is_none());
assert_eq!(config.vault_ssh_mount, "ssh");
assert_eq!(config.vault_ssh_role, "agentkernel-client");
assert_eq!(config.cert_ttl, "30m");
assert_eq!(config.user, "sandbox");
}
#[test]
fn test_generate_sshd_config_contains_directives() {
let config = SshConfig::default();
let sshd_config = generate_sshd_config(&config);
assert!(sshd_config.contains("Port 22"));
assert!(sshd_config.contains("TrustedUserCAKeys /etc/ssh/ca.pub"));
assert!(sshd_config.contains("PasswordAuthentication no"));
assert!(sshd_config.contains("PermitRootLogin no"));
assert!(sshd_config.contains("AuthorizedPrincipalsFile /etc/ssh/principals"));
assert!(sshd_config.contains("PubkeyAuthentication yes"));
assert!(sshd_config.contains("HostKey /etc/ssh/ssh_host_ed25519_key"));
}
#[test]
fn test_generate_sshd_config_custom_port() {
let config = SshConfig {
port: 2222,
..SshConfig::default()
};
let sshd_config = generate_sshd_config(&config);
assert!(sshd_config.contains("Port 2222"));
assert!(!sshd_config.contains("Port 22\n"));
}
#[test]
fn test_generate_ca_keypair() {
let (private_key, public_key) = generate_ca_keypair().unwrap();
assert!(private_key.contains("BEGIN OPENSSH PRIVATE KEY"));
assert!(private_key.contains("END OPENSSH PRIVATE KEY"));
assert!(public_key.starts_with("ssh-ed25519 "));
}
#[test]
fn test_sshd_file_injections_returns_correct_files() {
let config = SshConfig::default();
let (_, ca_pub) = generate_ca_keypair().unwrap();
let files = sshd_file_injections(&ca_pub, &config).unwrap();
let dests: Vec<&str> = files.iter().map(|f| f.dest.as_str()).collect();
assert!(dests.contains(&"/etc/ssh/sshd_config"));
assert!(dests.contains(&"/etc/ssh/ca.pub"));
assert!(dests.contains(&"/etc/ssh/principals"));
assert!(dests.contains(&"/etc/ssh/ssh_host_ed25519_key"));
assert!(dests.contains(&"/etc/ssh/ssh_host_ed25519_key.pub"));
assert!(dests.contains(&"/tmp/start-sshd.sh"));
assert!(dests.contains(&"/home/sandbox/.ssh/.keep"));
}
#[test]
fn test_sshd_file_injections_principals_content() {
let config = SshConfig {
user: "testuser".to_string(),
..SshConfig::default()
};
let (_, ca_pub) = generate_ca_keypair().unwrap();
let files = sshd_file_injections(&ca_pub, &config).unwrap();
let principals = files
.iter()
.find(|f| f.dest == "/etc/ssh/principals")
.unwrap();
assert_eq!(String::from_utf8_lossy(&principals.content), "testuser\n");
}
#[test]
fn test_sshd_file_injections_custom_user_path() {
let config = SshConfig {
user: "agent".to_string(),
..SshConfig::default()
};
let (_, ca_pub) = generate_ca_keypair().unwrap();
let files = sshd_file_injections(&ca_pub, &config).unwrap();
let dests: Vec<&str> = files.iter().map(|f| f.dest.as_str()).collect();
assert!(dests.contains(&"/home/agent/.ssh/.keep"));
}
#[test]
fn test_sign_client_key_local() {
let (ca_priv, _ca_pub) = generate_ca_keypair().unwrap();
let mut rng = rand::thread_rng();
let client_key = PrivateKey::random(&mut rng, Algorithm::Ed25519).unwrap();
let client_pub = client_key.public_key().to_openssh().unwrap();
let cert = sign_client_key_local(
&ca_priv,
&client_pub,
&["sandbox"],
1800, )
.unwrap();
assert!(cert.contains("ssh-ed25519-cert-v01@openssh.com"));
let parsed = ssh_key::Certificate::from_openssh(&cert).unwrap();
let extensions = parsed.extensions();
assert!(
extensions.get("permit-pty").is_some(),
"Certificate must include permit-pty extension for PTY allocation"
);
assert!(extensions.get("permit-port-forwarding").is_some());
assert!(extensions.get("permit-agent-forwarding").is_some());
}
#[test]
fn test_start_sshd_script_content() {
let config = SshConfig {
user: "myuser".to_string(),
port: 2222,
..SshConfig::default()
};
let script = generate_start_sshd_script(&config);
assert!(script.contains("#!/bin/sh"));
assert!(script.contains("myuser"));
assert!(script.contains("port 2222"));
assert!(script.contains("chmod 600 /etc/ssh/ssh_host_ed25519_key"));
assert!(script.contains("/usr/sbin/sshd"));
}
#[test]
fn test_generate_host_keypair() {
let (private_key, public_key) = generate_host_keypair().unwrap();
assert!(private_key.contains("BEGIN OPENSSH PRIVATE KEY"));
assert!(public_key.starts_with("ssh-ed25519 "));
}
#[test]
fn test_generate_client_keypair() {
let (private_key, public_key) = generate_client_keypair().unwrap();
assert!(private_key.contains("BEGIN OPENSSH PRIVATE KEY"));
assert!(private_key.contains("END OPENSSH PRIVATE KEY"));
assert!(public_key.starts_with("ssh-ed25519 "));
}
#[test]
fn test_generate_client_keypair_is_unique() {
let (priv1, pub1) = generate_client_keypair().unwrap();
let (priv2, pub2) = generate_client_keypair().unwrap();
assert_ne!(priv1, priv2);
assert_ne!(pub1, pub2);
}
#[test]
fn test_parse_ttl_to_secs_minutes() {
assert_eq!(parse_ttl_to_secs("30m").unwrap(), 1800);
assert_eq!(parse_ttl_to_secs("5m").unwrap(), 300);
assert_eq!(parse_ttl_to_secs("1m").unwrap(), 60);
}
#[test]
fn test_parse_ttl_to_secs_hours() {
assert_eq!(parse_ttl_to_secs("1h").unwrap(), 3600);
assert_eq!(parse_ttl_to_secs("2h").unwrap(), 7200);
}
#[test]
fn test_parse_ttl_to_secs_seconds() {
assert_eq!(parse_ttl_to_secs("90s").unwrap(), 90);
assert_eq!(parse_ttl_to_secs("1s").unwrap(), 1);
}
#[test]
fn test_parse_ttl_to_secs_combined() {
assert_eq!(parse_ttl_to_secs("1h30m").unwrap(), 5400);
assert_eq!(parse_ttl_to_secs("2h15m30s").unwrap(), 8130);
}
#[test]
fn test_parse_ttl_to_secs_plain_integer() {
assert_eq!(parse_ttl_to_secs("3600").unwrap(), 3600);
assert_eq!(parse_ttl_to_secs("0").unwrap(), 0);
}
#[test]
fn test_parse_ttl_to_secs_invalid() {
assert!(parse_ttl_to_secs("").is_err());
assert!(parse_ttl_to_secs("abc").is_err());
assert!(parse_ttl_to_secs("30x").is_err());
assert!(parse_ttl_to_secs("m30").is_err());
}
#[test]
fn test_parse_ttl_to_secs_trailing_digits_rejected() {
assert!(parse_ttl_to_secs("30m15").is_err());
}
#[test]
fn test_sign_client_key_local_routing() {
let (ca_priv, _) = generate_ca_keypair().unwrap();
let (_, client_pub) = generate_client_keypair().unwrap();
let config = SshConfig::default();
let cert =
sign_client_key(&config, Some(&ca_priv), &client_pub, &["sandbox"], 1800).unwrap();
assert!(cert.contains("ssh-ed25519-cert-v01@openssh.com"));
}
#[test]
fn test_sign_client_key_vault_routing_errors() {
let (_, client_pub) = generate_client_keypair().unwrap();
let config = SshConfig {
vault_addr: Some("https://vault.example.com".to_string()),
..SshConfig::default()
};
assert!(sign_client_key(&config, None, &client_pub, &["sandbox"], 1800).is_err());
}
#[test]
fn test_sign_client_key_no_ca_key_errors() {
let (_, client_pub) = generate_client_keypair().unwrap();
let config = SshConfig::default();
assert!(sign_client_key(&config, None, &client_pub, &["sandbox"], 1800).is_err());
}
#[test]
fn test_sign_client_key_with_generated_keypair() {
let (ca_priv, _ca_pub) = generate_ca_keypair().unwrap();
let (_, client_pub) = generate_client_keypair().unwrap();
let cert =
sign_client_key_local(&ca_priv, &client_pub, &["sandbox", "agent"], 600).unwrap();
assert!(cert.contains("ssh-ed25519-cert-v01@openssh.com"));
}
}