use anyhow::{Context, Result};
use std::collections::HashMap;
use std::process::Command;
use std::time::Duration;
use super::MatchContext;
const EXEC_TIMEOUT_SECS: u64 = 5;
pub fn execute_match_command(command: &str, context: &MatchContext) -> Result<bool> {
validate_exec_command(command)?;
let expanded_command = expand_variables(command, &context.variables);
tracing::debug!("Executing Match exec command: {}", expanded_command);
let parts = shell_words::split(&expanded_command)
.with_context(|| format!("Failed to parse command: {expanded_command}"))?;
if parts.is_empty() {
anyhow::bail!("Empty command for Match exec");
}
let program = &parts[0];
let args = &parts[1..];
#[cfg(unix)]
{
use std::process::Stdio;
use std::time::Instant;
let start = Instant::now();
let timeout = Duration::from_secs(EXEC_TIMEOUT_SECS);
let mut cmd = Command::new(program);
cmd.args(args)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
for (key, value) in &context.variables {
cmd.env(format!("SSH_MATCH_{}", key.to_uppercase()), value);
}
let mut child = match cmd.spawn() {
Ok(child) => child,
Err(e) => {
tracing::debug!("Failed to spawn Match exec command '{}': {}", program, e);
return Ok(false); }
};
loop {
match child.try_wait() {
Ok(Some(status)) => {
let success = status.success();
let elapsed = start.elapsed();
tracing::debug!(
"Match exec command '{}' completed in {:.1}s with status: {} (exit code: {:?})",
program,
elapsed.as_secs_f64(),
success,
status.code()
);
return Ok(success);
}
Ok(None) => {
if start.elapsed() > timeout {
tracing::warn!(
"Match exec command '{}' exceeded timeout of {}s, killing process",
program,
EXEC_TIMEOUT_SECS
);
let _ = child.kill();
std::thread::sleep(Duration::from_millis(100));
let _ = child.wait();
return Ok(false);
}
std::thread::sleep(Duration::from_millis(50));
}
Err(e) => {
tracing::error!("Error waiting for Match exec command '{}': {}", program, e);
let _ = child.kill();
return Ok(false);
}
}
}
}
#[cfg(not(unix))]
{
use std::process::Stdio;
let mut cmd = Command::new(program);
cmd.args(args)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
for (key, value) in &context.variables {
cmd.env(format!("SSH_MATCH_{}", key.to_uppercase()), value);
}
match cmd.status() {
Ok(status) => {
let success = status.success();
tracing::debug!(
"Match exec command '{}' returned: {} (exit code: {:?})",
program,
success,
status.code()
);
Ok(success)
}
Err(e) => {
tracing::debug!("Match exec command '{}' failed: {}", program, e);
Ok(false)
}
}
}
}
pub fn validate_exec_command(command: &str) -> Result<()> {
const MAX_COMMAND_LENGTH: usize = 1024;
if command.len() > MAX_COMMAND_LENGTH {
anyhow::bail!(
"Match exec command is too long ({} bytes). Maximum allowed is {} bytes.",
command.len(),
MAX_COMMAND_LENGTH
);
}
if command
.chars()
.any(|c| c.is_control() && c != ' ' && c != '\t')
{
anyhow::bail!(
"Match exec command contains control characters. This is blocked for security."
);
}
const DANGEROUS_PATTERNS: &[&str] = &[
"rm ", "rm\t", "rm-", "rmdir", "dd ", "dd\t", "mkfs", "format", "fdisk", ">", ">>", "<",
"<<", "|", ";", "&&", "||", "&", "`", "$(", "${", "\\n", "\\r", "../", "..\\", "~/.", "~root", ];
for pattern in DANGEROUS_PATTERNS {
if command.contains(pattern) {
anyhow::bail!(
"Match exec command contains potentially dangerous pattern '{pattern}'. \
This is blocked for security reasons."
);
}
}
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut prev_char = '\0';
for ch in command.chars() {
match ch {
'\'' if prev_char != '\\' => in_single_quote = !in_single_quote,
'"' if prev_char != '\\' => in_double_quote = !in_double_quote,
'`' if !in_single_quote => {
anyhow::bail!(
"Match exec command contains backtick outside single quotes. \
This could allow command substitution."
);
}
'$' if !in_single_quote => {
if let Some(next) = command.chars().nth(command.find('$').unwrap() + 1)
&& (next == '(' || next == '{')
{
anyhow::bail!(
"Match exec command contains potential command or variable substitution. \
This is blocked for security."
);
}
}
_ => {}
}
prev_char = ch;
}
if in_single_quote || in_double_quote {
anyhow::bail!("Match exec command has unbalanced quotes.");
}
const BLOCKED_COMMANDS: &[&str] = &[
"sh", "bash", "zsh", "ksh", "csh", "fish", "python", "python2", "python3", "perl", "ruby", "php", "node", "nc", "netcat", "ncat", "socat", "wget", "curl", "fetch", "chmod", "chown", "chgrp", ];
let first_word = command
.split_whitespace()
.next()
.unwrap_or("")
.trim_start_matches('/');
for blocked in BLOCKED_COMMANDS {
if first_word == *blocked || first_word.ends_with(&format!("/{blocked}")) {
anyhow::bail!(
"Match exec command uses blocked executable '{blocked}'. \
Executing shells or interpreters is not allowed for security."
);
}
}
const SENSITIVE_COMMANDS: &[&str] = &["sudo", "su", "doas", "passwd", "ssh", "scp", "sftp"];
for cmd in SENSITIVE_COMMANDS {
if first_word == *cmd || first_word.ends_with(&format!("/{cmd}")) {
tracing::warn!(
"Match exec command uses potentially sensitive command '{}'. \
Please ensure this is intentional and secure.",
cmd
);
}
}
const SAFE_COMMANDS: &[&str] = &[
"test", "[", "ls", "cat", "grep", "head", "tail", "echo", "true", "false", "date",
"hostname",
];
if !SAFE_COMMANDS
.iter()
.any(|&safe| first_word == safe || first_word.ends_with(&format!("/{safe}")))
{
tracing::info!(
"Match exec command '{}' is not in the safe command allowlist. \
Consider using one of: {:?}",
first_word,
SAFE_COMMANDS
);
}
Ok(())
}
pub fn expand_variables(command: &str, variables: &HashMap<String, String>) -> String {
if !command.contains('%') {
return command.to_string();
}
let mut result = String::with_capacity(command.len() + 32);
let mut chars = command.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '%' {
if let Some(&next_ch) = chars.peek() {
let key = next_ch.to_string();
if let Some(value) = variables.get(&key) {
result.push_str(value);
chars.next(); } else {
result.push(ch); }
} else {
result.push(ch); }
} else {
result.push(ch);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ssh::ssh_config::match_directive::MatchContext;
#[test]
fn test_validate_exec_command() {
assert!(validate_exec_command("test -f /tmp/file").is_ok());
assert!(validate_exec_command("ls -la").is_ok());
assert!(validate_exec_command("echo hello").is_ok());
assert!(validate_exec_command("rm -rf /").is_err());
assert!(validate_exec_command("ls; rm file").is_err());
assert!(validate_exec_command("echo `whoami`").is_err());
assert!(validate_exec_command("cat file | grep pattern").is_err());
assert!(validate_exec_command("dd if=/dev/zero of=/dev/sda").is_err());
}
#[test]
fn test_expand_variables() {
let mut variables = HashMap::new();
variables.insert("h".to_string(), "example.com".to_string());
variables.insert("u".to_string(), "testuser".to_string());
variables.insert("l".to_string(), "localuser".to_string());
let command = "test -f /tmp/%h.lock";
let expanded = expand_variables(command, &variables);
assert_eq!(expanded, "test -f /tmp/example.com.lock");
let command = "echo %u@%h";
let expanded = expand_variables(command, &variables);
assert_eq!(expanded, "echo testuser@example.com");
}
#[test]
fn test_validate_exec_security_edge_cases() {
let long_cmd = "a".repeat(1024);
assert!(validate_exec_command(&long_cmd).is_ok());
let too_long_cmd = "a".repeat(1025);
assert!(validate_exec_command(&too_long_cmd).is_err());
assert!(validate_exec_command("echo \"hello").is_err());
assert!(validate_exec_command("echo 'hello").is_err());
assert!(validate_exec_command("echo \"hello'").is_err());
assert!(validate_exec_command("rm -rf /").is_err());
assert!(validate_exec_command("dd if=/dev/zero").is_err());
assert!(validate_exec_command("ls;rm file").is_err());
assert!(validate_exec_command("echo hello ; rm file").is_err());
}
#[test]
#[cfg(unix)]
fn test_exec_timeout() {
use std::time::Instant;
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let start = Instant::now();
let result = execute_match_command("sleep 10", &context).unwrap();
let duration = start.elapsed();
assert!(
!result,
"Long-running command should timeout and return false"
);
assert!(
duration.as_secs() <= EXEC_TIMEOUT_SECS + 1,
"Should timeout within {EXEC_TIMEOUT_SECS} seconds, took {duration:?}"
);
}
#[test]
#[cfg(unix)]
fn test_exec_nonexistent_command() {
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let result = execute_match_command("nonexistent_command_12345", &context).unwrap();
assert!(!result, "Nonexistent command should return false");
}
#[test]
#[cfg(unix)]
fn test_exec_exit_code_handling() {
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let result = execute_match_command("test -d /tmp", &context).unwrap();
assert!(result, "Successful command should return true");
let result = execute_match_command("test -f /nonexistent_file_12345", &context).unwrap();
assert!(!result, "Failed command should return false");
}
#[test]
#[cfg(windows)]
fn test_exec_disabled_on_windows() {
let context = MatchContext::new("example.com".to_string(), None).unwrap();
let result = execute_match_command("echo test", &context);
assert!(
result.is_err(),
"exec should be disabled on Windows for security"
);
}
#[test]
fn test_expand_variables_edge_cases() {
let mut variables = HashMap::new();
variables.insert("h".to_string(), "example.com".to_string());
let command = "test -f /tmp/%unknown";
let expanded = expand_variables(command, &variables);
assert_eq!(expanded, "test -f /tmp/%unknown");
let command = "echo 100%%";
let expanded = expand_variables(command, &variables);
assert!(expanded.contains("%"));
let command = "%h.example.com";
let expanded = expand_variables(command, &variables);
assert_eq!(expanded, "example.com.example.com");
let command = "prefix-%h";
let expanded = expand_variables(command, &variables);
assert_eq!(expanded, "prefix-example.com");
}
}