use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
pub fn validate_glob_pattern(pattern: &str) -> Result<()> {
if pattern.contains("**") {
anyhow::bail!("Recursive glob patterns (**) are not allowed for security reasons");
}
let wildcard_count = pattern.chars().filter(|&c| c == '*').count();
if wildcard_count > 5 {
anyhow::bail!("Too many wildcards in pattern '{pattern}'. Maximum 5 wildcards allowed.");
}
if (pattern == "*" || pattern == "/*") && !pattern.contains("ssh") {
anyhow::bail!("Pattern '{pattern}' is too broad and could match system files");
}
if pattern.len() > 512 {
anyhow::bail!("Pattern is too long (max 512 characters)");
}
Ok(())
}
#[cfg(not(test))]
pub fn is_path_allowed(path: &Path) -> bool {
let allowed_prefixes = [
dirs::home_dir().unwrap_or_else(|| PathBuf::from("/")),
PathBuf::from("/etc/ssh"),
PathBuf::from("/usr/local/etc/ssh"),
std::env::temp_dir(), ];
allowed_prefixes
.iter()
.any(|prefix| path.starts_with(prefix))
}
pub fn validate_include_path(path: &Path) -> Result<()> {
if !path.exists() {
return Ok(());
}
let metadata = std::fs::symlink_metadata(path)
.with_context(|| format!("Failed to get metadata for {}", path.display()))?;
if metadata.is_symlink() {
anyhow::bail!(
"Include path {} is a symbolic link. Symlinks are not allowed for security reasons.",
path.display()
);
}
if !metadata.is_file() {
anyhow::bail!("Include path is not a regular file: {}", path.display());
}
let canonical = path
.canonicalize()
.with_context(|| format!("Failed to canonicalize {}", path.display()))?;
let path_str = canonical.to_string_lossy();
if path_str.contains("../") || path_str.contains("..\\") {
anyhow::bail!(
"Include path {} contains directory traversal sequences",
path.display()
);
}
let safe_prefixes = [
dirs::home_dir().unwrap_or_else(|| PathBuf::from("/")),
PathBuf::from("/etc/ssh"),
PathBuf::from("/usr/local/etc/ssh"),
std::env::temp_dir(), ];
let is_safe = safe_prefixes
.iter()
.any(|prefix| canonical.starts_with(prefix));
if !is_safe {
tracing::warn!(
"Include path {} is outside of standard SSH config directories. This may be a security risk.",
canonical.display()
);
}
#[cfg(all(unix, not(test)))]
{
use std::os::unix::fs::PermissionsExt;
let permissions = metadata.permissions();
let mode = permissions.mode();
if mode & 0o002 != 0 {
anyhow::bail!(
"SSH config file {} is world-writable. This is a security vulnerability.",
path.display()
);
}
if mode & 0o020 != 0 {
tracing::warn!(
"SSH config file {} is group-writable. This is a potential security risk.",
path.display()
);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_glob_pattern_security() {
let result = validate_glob_pattern("config.d/**/*.conf");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Recursive glob"));
let result = validate_glob_pattern("a*/b*/c*/d*/e*/f*");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Too many wildcards")
);
let long_pattern = "a".repeat(600);
let result = validate_glob_pattern(&long_pattern);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("too long"));
let result = validate_glob_pattern("/*");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("too broad"));
assert!(validate_glob_pattern("~/.ssh/config.d/*.conf").is_ok());
assert!(validate_glob_pattern("/etc/ssh/*.conf").is_ok());
assert!(validate_glob_pattern("config.d/[0-9][0-9]-*.conf").is_ok());
assert!(validate_glob_pattern("../../../etc/passwd").is_ok());
}
}