use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use anyhow::{Context, Result};
use regex::Regex;
use russh::ChannelId;
use russh::server::Handle;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt;
use tokio::process::Command;
use crate::shared::auth_types::UserInfo;
const EXIT_CODE_TIMEOUT: i32 = 124;
const EXIT_CODE_REJECTED: i32 = 126;
const DANGEROUS_ENV_VARS: &[&str] = &[
"LD_PRELOAD",
"LD_LIBRARY_PATH",
"BASH_ENV",
"ENV",
"PROMPT_COMMAND",
"PERL5LIB",
"PYTHONPATH",
"RUBYLIB",
];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecConfig {
#[serde(default = "default_shell")]
pub default_shell: PathBuf,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default = "default_timeout_secs")]
pub timeout_secs: u64,
#[serde(default)]
pub working_dir: Option<PathBuf>,
#[serde(default)]
pub allowed_commands: Option<Vec<String>>,
#[serde(default = "default_blocked_commands")]
pub blocked_commands: Vec<String>,
}
fn default_shell() -> PathBuf {
PathBuf::from("/bin/sh")
}
fn default_timeout_secs() -> u64 {
3600 }
fn default_blocked_commands() -> Vec<String> {
vec![
"rm".to_string(),
"mkfs".to_string(),
"dd".to_string(),
"shred".to_string(),
"reboot".to_string(),
"shutdown".to_string(),
"halt".to_string(),
"poweroff".to_string(),
"sudo".to_string(),
"su".to_string(),
"doas".to_string(),
"apt".to_string(),
"apt-get".to_string(),
"yum".to_string(),
"dnf".to_string(),
"pacman".to_string(),
"insmod".to_string(),
"rmmod".to_string(),
"modprobe".to_string(),
]
}
impl Default for ExecConfig {
fn default() -> Self {
Self {
default_shell: default_shell(),
env: HashMap::new(),
timeout_secs: default_timeout_secs(),
working_dir: None,
allowed_commands: None,
blocked_commands: default_blocked_commands(),
}
}
}
impl ExecConfig {
pub fn new() -> Self {
Self::default()
}
pub fn timeout(&self) -> Option<Duration> {
if self.timeout_secs == 0 {
None
} else {
Some(Duration::from_secs(self.timeout_secs))
}
}
pub fn with_shell(mut self, shell: impl Into<PathBuf>) -> Self {
self.default_shell = shell.into();
self
}
pub fn with_timeout_secs(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
pub fn with_working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_dir = Some(dir.into());
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn with_allowed_commands(mut self, commands: Vec<String>) -> Self {
self.allowed_commands = Some(commands);
self
}
pub fn with_blocked_command(mut self, pattern: impl Into<String>) -> Self {
self.blocked_commands.push(pattern.into());
self
}
}
pub struct CommandExecutor {
config: ExecConfig,
}
impl CommandExecutor {
pub fn new(config: ExecConfig) -> Self {
Self { config }
}
pub async fn execute(
&self,
command: &str,
user_info: &UserInfo,
channel_id: ChannelId,
handle: Handle,
) -> Result<i32> {
if let Err(e) = self.validate_command(command) {
tracing::warn!(
user = %user_info.username,
command = %command,
"Command validation failed: {}",
e
);
let error_msg = format!("Command rejected: {e}\n");
let _ = handle
.extended_data(
channel_id,
1,
bytes::Bytes::copy_from_slice(error_msg.as_bytes()),
)
.await;
return Ok(EXIT_CODE_REJECTED);
}
tracing::info!(
user = %user_info.username,
command = %command,
"Executing command"
);
let mut cmd = Command::new(&self.config.default_shell);
cmd.arg("-c").arg(command);
cmd.env_clear();
cmd.env("HOME", &user_info.home_dir);
cmd.env("USER", &user_info.username);
cmd.env("LOGNAME", &user_info.username);
cmd.env("SHELL", &user_info.shell);
cmd.env("PATH", "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin");
for (key, value) in &self.config.env {
if DANGEROUS_ENV_VARS.contains(&key.as_str()) {
tracing::warn!(
user = %user_info.username,
env_var = %key,
"Blocked dangerous environment variable"
);
continue;
}
cmd.env(key, value);
}
let work_dir = self
.config
.working_dir
.clone()
.unwrap_or_else(|| user_info.home_dir.clone());
cmd.current_dir(&work_dir);
cmd.stdin(std::process::Stdio::null());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
cmd.kill_on_drop(true);
#[cfg(unix)]
{
cmd.process_group(0);
}
let mut child = cmd.spawn().context("Failed to spawn command")?;
let stdout = child.stdout.take();
let stderr = child.stderr.take();
let stdout_handle = handle.clone();
let stderr_handle = handle.clone();
let stdout_task = tokio::spawn(async move {
if let Some(stdout) = stdout {
Self::stream_output(stdout, channel_id, stdout_handle, false).await
} else {
Ok(())
}
});
let stderr_task = tokio::spawn(async move {
if let Some(stderr) = stderr {
Self::stream_output(stderr, channel_id, stderr_handle, true).await
} else {
Ok(())
}
});
let exit_status = if let Some(timeout) = self.config.timeout() {
match tokio::time::timeout(timeout, child.wait()).await {
Ok(status) => status?,
Err(_) => {
tracing::warn!(
user = %user_info.username,
command = %command,
"Command timed out after {} seconds",
self.config.timeout_secs
);
#[cfg(unix)]
{
if let Some(pid) = child.id() {
unsafe {
libc::kill(-(pid as i32), libc::SIGKILL);
}
}
}
let _ = child.kill().await;
let timeout_msg = format!(
"Command timed out after {} seconds\n",
self.config.timeout_secs
);
let _ = handle
.extended_data(
channel_id,
1,
bytes::Bytes::copy_from_slice(timeout_msg.as_bytes()),
)
.await;
return Ok(EXIT_CODE_TIMEOUT);
}
}
} else {
child.wait().await?
};
let _ = tokio::join!(stdout_task, stderr_task);
let exit_code = exit_status.code().unwrap_or(1);
tracing::debug!(
user = %user_info.username,
command = %command,
exit_code = %exit_code,
"Command completed"
);
Ok(exit_code)
}
async fn stream_output(
mut output: impl AsyncReadExt + Unpin,
channel_id: ChannelId,
handle: Handle,
is_stderr: bool,
) -> Result<()> {
let mut buffer = [0u8; 8192];
loop {
let n = output.read(&mut buffer).await?;
if n == 0 {
break;
}
let data = bytes::Bytes::copy_from_slice(&buffer[..n]);
let result = if is_stderr {
handle.extended_data(channel_id, 1, data).await
} else {
handle.data(channel_id, data).await
};
if result.is_err() {
tracing::warn!(
channel = ?channel_id,
is_stderr = %is_stderr,
"Failed to send data to channel"
);
break;
}
}
Ok(())
}
pub fn validate_command(&self, command: &str) -> Result<()> {
let normalized = command.to_lowercase();
let normalized = normalized.split_whitespace().collect::<Vec<_>>().join(" ");
let chaining_patterns = [
";", "&&", "||", "|", "`", "$(", "$((", ">", ">>", "<", "<<<", "&", "\n", "\r", ];
for pattern in &chaining_patterns {
if command.contains(pattern) {
if *pattern == "|" && !command.contains("||") {
tracing::info!("Command contains pipe operator: {}", command);
continue;
}
if (*pattern == ">" || *pattern == ">>") && !command.contains("/dev/") {
tracing::info!("Command contains redirection: {}", command);
continue;
}
anyhow::bail!(
"Command contains shell metacharacter that could enable command chaining: '{pattern}'"
);
}
}
let dangerous_patterns = [
(r"(?i)\$\{[^}]*\}", "Variable expansion"),
(r"(?i)\$[A-Za-z_][A-Za-z0-9_]*", "Variable substitution"),
(r"(?i)<\([^)]*\)", "Process substitution"),
(r"(?i)>\([^)]*\)", "Process substitution"),
];
for (pattern, description) in &dangerous_patterns {
if let Ok(re) = Regex::new(pattern)
&& re.is_match(command)
{
anyhow::bail!("Command contains dangerous pattern ({})", description);
}
}
for blocked in &self.config.blocked_commands {
let blocked_normalized = blocked.to_lowercase();
if normalized.contains(&blocked_normalized) {
anyhow::bail!("Command contains blocked pattern: '{blocked}'");
}
if let Some(first_word) = normalized.split_whitespace().next()
&& first_word == blocked_normalized
{
anyhow::bail!("Command '{first_word}' is blocked");
}
}
if let Some(ref allowed) = self.config.allowed_commands {
if command.contains(';')
|| command.contains("&&")
|| command.contains("||")
|| command.contains("$(")
|| command.contains('`')
{
anyhow::bail!("Command chaining is not allowed when using command allowlist");
}
let cmd_name = command.split_whitespace().next().unwrap_or("");
let is_allowed = allowed.iter().any(|a| {
cmd_name == a
});
if !is_allowed {
anyhow::bail!("Command '{cmd_name}' is not in the allowed list");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exec_config_default() {
let config = ExecConfig::default();
assert_eq!(config.default_shell, PathBuf::from("/bin/sh"));
assert_eq!(config.timeout_secs, 3600);
assert!(config.working_dir.is_none());
assert!(config.allowed_commands.is_none());
assert!(!config.blocked_commands.is_empty());
}
#[test]
fn test_exec_config_builder() {
let config = ExecConfig::new()
.with_shell("/bin/bash")
.with_timeout_secs(600)
.with_working_dir("/tmp")
.with_env("LANG", "en_US.UTF-8")
.with_allowed_commands(vec!["ls".to_string(), "cat".to_string()])
.with_blocked_command("dangerous_cmd");
assert_eq!(config.default_shell, PathBuf::from("/bin/bash"));
assert_eq!(config.timeout_secs, 600);
assert_eq!(config.working_dir, Some(PathBuf::from("/tmp")));
assert_eq!(config.env.get("LANG"), Some(&"en_US.UTF-8".to_string()));
assert!(config.allowed_commands.is_some());
assert!(
config
.blocked_commands
.contains(&"dangerous_cmd".to_string())
);
}
#[test]
fn test_exec_config_timeout() {
let mut config = ExecConfig::default();
assert_eq!(config.timeout(), Some(Duration::from_secs(3600)));
config.timeout_secs = 0;
assert!(config.timeout().is_none());
config.timeout_secs = 300;
assert_eq!(config.timeout(), Some(Duration::from_secs(300)));
}
#[test]
fn test_validate_command_blocked() {
let config = ExecConfig::default();
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("rm -rf /").is_err());
assert!(executor.validate_command("rm -fr /home").is_err());
assert!(executor.validate_command("sudo mkfs /dev/sda").is_err());
assert!(
executor
.validate_command("dd if=/dev/zero of=/dev/sda")
.is_err()
);
assert!(executor.validate_command("ls; rm -rf /").is_err());
assert!(executor.validate_command("ls && rm -rf /").is_err());
assert!(executor.validate_command("ls || rm -rf /").is_err());
assert!(executor.validate_command("ls `rm -rf /`").is_err());
assert!(executor.validate_command("ls $(rm -rf /)").is_err());
assert!(executor.validate_command("ls -la").is_ok());
assert!(executor.validate_command("cat /etc/passwd").is_ok());
assert!(executor.validate_command("echo hello").is_ok());
}
#[test]
fn test_validate_command_whitelist() {
let config = ExecConfig::new().with_allowed_commands(vec![
"ls".to_string(),
"cat".to_string(),
"echo".to_string(),
]);
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("ls -la").is_ok());
assert!(executor.validate_command("cat /etc/passwd").is_ok());
assert!(executor.validate_command("echo hello world").is_ok());
assert!(executor.validate_command("rm -rf /").is_err());
assert!(
executor
.validate_command("wget http://example.com")
.is_err()
);
assert!(
executor
.validate_command("curl http://example.com")
.is_err()
);
assert!(executor.validate_command("ls; rm -rf /").is_err());
assert!(
executor
.validate_command("cat /etc/passwd && rm -rf /")
.is_err()
);
}
#[test]
fn test_validate_command_combined() {
let config = ExecConfig::new()
.with_allowed_commands(vec!["ls".to_string(), "echo".to_string()])
.with_blocked_command("dangerous");
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("ls -la").is_ok());
assert!(executor.validate_command("echo hello").is_ok());
assert!(executor.validate_command("cat file.txt").is_err());
assert!(executor.validate_command("ls; echo test").is_err());
}
#[test]
fn test_validate_command_empty() {
let config = ExecConfig::default();
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("").is_ok());
}
#[test]
fn test_validate_command_whitespace() {
let config = ExecConfig::new().with_allowed_commands(vec!["ls".to_string()]);
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("ls").is_ok());
assert!(executor.validate_command("ls -la").is_ok());
}
#[test]
fn test_default_blocked_commands() {
let blocked = default_blocked_commands();
assert!(blocked.contains(&"rm".to_string()));
assert!(blocked.contains(&"mkfs".to_string()));
assert!(blocked.contains(&"dd".to_string()));
assert!(blocked.contains(&"sudo".to_string()));
}
#[test]
fn test_command_executor_creation() {
let config = ExecConfig::default();
let _executor = CommandExecutor::new(config);
}
#[test]
fn test_exec_config_serialization() {
let config = ExecConfig::new()
.with_shell("/bin/bash")
.with_timeout_secs(1800)
.with_env("LANG", "C.UTF-8");
let yaml = serde_yaml::to_string(&config).unwrap();
assert!(yaml.contains("/bin/bash"));
assert!(yaml.contains("1800"));
let deserialized: ExecConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(deserialized.default_shell, PathBuf::from("/bin/bash"));
assert_eq!(deserialized.timeout_secs, 1800);
assert_eq!(deserialized.env.get("LANG"), Some(&"C.UTF-8".to_string()));
}
#[test]
fn test_command_injection_prevention() {
let config = ExecConfig::default();
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("ls; rm -rf /").is_err());
assert!(executor.validate_command("ls && rm -rf /").is_err());
assert!(executor.validate_command("ls || rm -rf /").is_err());
assert!(executor.validate_command("ls `whoami`").is_err());
assert!(executor.validate_command("ls $(whoami)").is_err());
assert!(executor.validate_command("cat file > /dev/sda").is_err());
assert!(executor.validate_command("cat file >> /dev/sda").is_err());
assert!(executor.validate_command("echo ${PATH}").is_err());
assert!(executor.validate_command("echo $HOME").is_err());
assert!(executor.validate_command("cat <(ls)").is_err());
assert!(executor.validate_command("cat >(cat)").is_err());
}
#[test]
fn test_blocklist_normalization() {
let config = ExecConfig::default();
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("RM -rf /").is_err());
assert!(executor.validate_command("Rm -rf /").is_err());
assert!(executor.validate_command("rM -rf /").is_err());
assert!(executor.validate_command("rm -rf /").is_err());
assert!(executor.validate_command("SUDO apt-get install").is_err());
assert!(executor.validate_command("SuDo apt-get install").is_err());
}
#[test]
fn test_allowlist_exact_match() {
let config =
ExecConfig::new().with_allowed_commands(vec!["ls".to_string(), "cat".to_string()]);
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("ls -la").is_ok());
assert!(executor.validate_command("cat file.txt").is_ok());
assert!(executor.validate_command("lsof").is_err());
assert!(executor.validate_command("catch").is_err());
}
#[test]
fn test_dangerous_env_vars() {
let dangerous_vars = DANGEROUS_ENV_VARS;
assert!(dangerous_vars.contains(&"LD_PRELOAD"));
assert!(dangerous_vars.contains(&"LD_LIBRARY_PATH"));
assert!(dangerous_vars.contains(&"BASH_ENV"));
assert!(dangerous_vars.contains(&"ENV"));
assert!(dangerous_vars.contains(&"PROMPT_COMMAND"));
}
#[test]
fn test_default_blocked_patterns() {
let blocked = default_blocked_commands();
assert!(blocked.contains(&"rm".to_string()));
assert!(blocked.contains(&"sudo".to_string()));
assert!(blocked.contains(&"mkfs".to_string()));
assert!(blocked.contains(&"dd".to_string()));
assert!(blocked.contains(&"reboot".to_string()));
assert!(blocked.contains(&"shutdown".to_string()));
assert!(blocked.contains(&"apt".to_string()));
assert!(blocked.contains(&"yum".to_string()));
}
#[test]
fn test_pipe_handling() {
let config = ExecConfig::default();
let executor = CommandExecutor::new(config);
assert!(executor.validate_command("ls || rm -rf /").is_err());
}
}