use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::time::Duration;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Default)]
pub struct PathAllowlist {
pub read: HashSet<PathBuf>,
pub write: HashSet<PathBuf>,
pub deny: HashSet<PathBuf>,
}
impl PathAllowlist {
pub fn none() -> Self {
Self::default()
}
pub fn all() -> Self {
Self {
read: [PathBuf::from("/")].into_iter().collect(),
write: [PathBuf::from("/")].into_iter().collect(),
deny: HashSet::new(),
}
}
pub fn allow_read(mut self, path: impl Into<PathBuf>) -> Self {
self.read.insert(path.into());
self
}
pub fn allow_write(mut self, path: impl Into<PathBuf>) -> Self {
self.write.insert(path.into());
self
}
pub fn allow_rw(self, path: impl Into<PathBuf>) -> Self {
let path = path.into();
self.allow_read(path.clone()).allow_write(path)
}
pub fn deny(mut self, path: impl Into<PathBuf>) -> Self {
self.deny.insert(path.into());
self
}
pub fn can_read(&self, path: &Path) -> bool {
if self.is_denied(path) {
return false;
}
self.read.iter().any(|allowed| path.starts_with(allowed))
}
pub fn can_write(&self, path: &Path) -> bool {
if self.is_denied(path) {
return false;
}
self.write.iter().any(|allowed| path.starts_with(allowed))
}
fn is_denied(&self, path: &Path) -> bool {
self.deny.iter().any(|denied| path.starts_with(denied))
}
pub fn check_read(&self, path: &Path) -> Result<()> {
if self.can_read(path) {
Ok(())
} else {
Err(Error::path_not_allowed(path.display().to_string()))
}
}
pub fn check_write(&self, path: &Path) -> Result<()> {
if self.can_write(path) {
Ok(())
} else {
Err(Error::path_not_allowed(path.display().to_string()))
}
}
}
#[derive(Debug, Clone, Default)]
pub struct HostAllowlist {
pub allowed: HashSet<String>,
pub denied: HashSet<String>,
}
impl HostAllowlist {
pub fn none() -> Self {
Self::default()
}
pub fn all() -> Self {
Self {
allowed: ["*".to_string()].into_iter().collect(),
denied: HashSet::new(),
}
}
pub fn allow(mut self, host: impl Into<String>) -> Self {
self.allowed.insert(host.into());
self
}
pub fn deny(mut self, host: impl Into<String>) -> Self {
self.denied.insert(host.into());
self
}
pub fn can_access(&self, host: &str) -> bool {
let host = host.to_lowercase();
for denied in &self.denied {
if Self::host_matches(&host, denied) {
return false;
}
}
for allowed in &self.allowed {
if Self::host_matches(&host, allowed) {
return true;
}
}
false
}
fn host_matches(host: &str, pattern: &str) -> bool {
let pattern = pattern.to_lowercase();
if pattern == "*" {
return true;
}
if pattern.starts_with("*.") {
let suffix = &pattern[1..];
host.ends_with(suffix) || host == &pattern[2..]
} else {
host == pattern
}
}
pub fn check(&self, host: &str) -> Result<()> {
if self.can_access(host) {
Ok(())
} else {
Err(Error::host_not_allowed(host))
}
}
}
#[derive(Debug, Clone)]
pub struct SafetyConfig {
pub paths: PathAllowlist,
pub hosts: HostAllowlist,
pub env_vars: Option<HashSet<String>>,
pub allow_process: bool,
pub allowed_commands: Option<HashSet<String>>,
pub default_timeout: Duration,
pub max_timeout: Duration,
}
impl Default for SafetyConfig {
fn default() -> Self {
Self {
paths: PathAllowlist::none(),
hosts: HostAllowlist::none(),
env_vars: Some(HashSet::new()),
allow_process: false,
allowed_commands: None,
default_timeout: Duration::from_secs(30),
max_timeout: Duration::from_secs(300),
}
}
}
impl SafetyConfig {
pub fn new() -> Self {
Self::default()
}
pub fn permissive() -> Self {
Self {
paths: PathAllowlist::all(),
hosts: HostAllowlist::all(),
env_vars: None,
allow_process: true,
allowed_commands: None,
default_timeout: Duration::from_secs(60),
max_timeout: Duration::from_secs(3600),
}
}
pub fn strict() -> Self {
Self {
paths: PathAllowlist::none(),
hosts: HostAllowlist::none(),
env_vars: Some(HashSet::new()),
allow_process: false,
allowed_commands: Some(HashSet::new()),
default_timeout: Duration::from_secs(10),
max_timeout: Duration::from_secs(30),
}
}
pub fn with_paths(mut self, paths: PathAllowlist) -> Self {
self.paths = paths;
self
}
pub fn with_hosts(mut self, hosts: HostAllowlist) -> Self {
self.hosts = hosts;
self
}
pub fn with_env_vars<I, S>(mut self, vars: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.env_vars = Some(vars.into_iter().map(Into::into).collect());
self
}
pub fn allow_all_env(mut self) -> Self {
self.env_vars = None;
self
}
pub fn with_allow_process(mut self, allow: bool) -> Self {
self.allow_process = allow;
self
}
pub fn with_allowed_commands<I, S>(mut self, commands: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_commands = Some(commands.into_iter().map(Into::into).collect());
self
}
pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = timeout;
self
}
pub fn with_max_timeout(mut self, timeout: Duration) -> Self {
self.max_timeout = timeout;
self
}
pub fn can_access_env(&self, name: &str) -> bool {
match &self.env_vars {
None => true,
Some(allowed) => allowed.contains(name),
}
}
pub fn check_env(&self, name: &str) -> Result<()> {
if self.can_access_env(name) {
Ok(())
} else {
Err(Error::not_permitted(format!(
"environment variable access denied: {}",
name
)))
}
}
pub fn can_execute(&self, command: &str) -> bool {
if !self.allow_process {
return false;
}
match &self.allowed_commands {
None => true,
Some(allowed) => allowed.contains(command),
}
}
pub fn check_execute(&self, command: &str) -> Result<()> {
if !self.allow_process {
return Err(Error::not_permitted("process execution not allowed"));
}
if let Some(ref allowed) = self.allowed_commands {
if !allowed.contains(command) {
return Err(Error::not_permitted(format!(
"command not allowed: {}",
command
)));
}
}
Ok(())
}
pub fn clamp_timeout(&self, timeout: Duration) -> Duration {
timeout.min(self.max_timeout)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_allowlist() {
let paths = PathAllowlist::none()
.allow_read("/tmp")
.allow_rw("/home/user/data")
.deny("/home/user/data/secret");
assert!(paths.can_read(Path::new("/tmp/file.txt")));
assert!(!paths.can_write(Path::new("/tmp/file.txt")));
assert!(paths.can_read(Path::new("/home/user/data/file.txt")));
assert!(paths.can_write(Path::new("/home/user/data/file.txt")));
assert!(!paths.can_read(Path::new("/home/user/data/secret/key")));
assert!(!paths.can_write(Path::new("/home/user/data/secret/key")));
assert!(!paths.can_read(Path::new("/etc/passwd")));
}
#[test]
fn test_host_allowlist() {
let hosts = HostAllowlist::none()
.allow("api.example.com")
.allow("*.trusted.org")
.deny("evil.trusted.org");
assert!(hosts.can_access("api.example.com"));
assert!(hosts.can_access("sub.trusted.org"));
assert!(hosts.can_access("trusted.org"));
assert!(!hosts.can_access("evil.trusted.org"));
assert!(!hosts.can_access("other.com"));
}
#[test]
fn test_safety_config() {
let config = SafetyConfig::new()
.with_env_vars(["PATH", "HOME"])
.with_allow_process(true)
.with_allowed_commands(["ls", "cat"]);
assert!(config.can_access_env("PATH"));
assert!(!config.can_access_env("SECRET"));
assert!(config.can_execute("ls"));
assert!(!config.can_execute("rm"));
}
#[test]
fn test_timeout_clamping() {
let config = SafetyConfig::new().with_max_timeout(Duration::from_secs(60));
assert_eq!(
config.clamp_timeout(Duration::from_secs(30)),
Duration::from_secs(30)
);
assert_eq!(
config.clamp_timeout(Duration::from_secs(120)),
Duration::from_secs(60)
);
}
}