use crate::security::config::{SecurityConfig, SecurityLevel};
use crate::security::error::SecurityError;
use crate::security::path::PathValidator;
use crate::security::rate_limit::ActionTracker;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SecurityPolicy {
config: SecurityConfig,
tracker: ActionTracker,
workspace_dir: PathBuf,
}
impl SecurityPolicy {
pub fn new(workspace_dir: PathBuf) -> Self {
Self {
config: SecurityConfig::default(),
tracker: ActionTracker::new(),
workspace_dir,
}
}
pub fn with_config(workspace_dir: PathBuf, config: SecurityConfig) -> Self {
Self {
config,
tracker: ActionTracker::new(),
workspace_dir,
}
}
pub fn from_level(workspace_dir: PathBuf, level: SecurityLevel) -> Self {
Self {
config: SecurityConfig::from_level(level),
tracker: ActionTracker::new(),
workspace_dir,
}
}
pub fn config(&self) -> &SecurityConfig {
&self.config
}
pub fn workspace_dir(&self) -> &Path {
&self.workspace_dir
}
pub fn is_read_only(&self) -> bool {
self.config.is_read_only()
}
pub fn has_shell_access(&self) -> bool {
matches!(self.config.level, SecurityLevel::Permissive)
}
pub fn is_path_allowed(&self, path: &str) -> Result<(), SecurityError> {
if PathValidator::contains_null_bytes(path) {
return Err(SecurityError::InvalidPathFormat {
reason: "Path contains null bytes".to_string(),
});
}
if PathValidator::contains_path_traversal(path) {
return Err(SecurityError::ForbiddenComponent {
component: "parent directory (..)".to_string(),
});
}
if PathValidator::contains_url_encoded_traversal(path) {
return Err(SecurityError::InvalidPathFormat {
reason: "URL-encoded path traversal detected".to_string(),
});
}
if PathValidator::starts_with_tilde(path) {
return Err(SecurityError::InvalidPathFormat {
reason: "Tilde expansion is not allowed".to_string(),
});
}
if self.config.workspace_only && PathValidator::is_absolute(path) {
return Err(SecurityError::InvalidPathFormat {
reason: "Absolute paths are not allowed in workspace-only mode".to_string(),
});
}
if let Some(prefix) =
PathValidator::matches_forbidden_prefix(path, &self.config.forbidden_paths)
{
return Err(SecurityError::PathNotAllowed {
path: format!("matches forbidden prefix: {}", prefix),
});
}
if let Some(ext) = PathValidator::get_extension(path) {
if PathValidator::is_extension_forbidden(&ext, &self.config.forbidden_extensions) {
return Err(SecurityError::ForbiddenExtension { ext });
}
}
Ok(())
}
pub fn resolve_path(&self, path: &str) -> PathBuf {
if Path::new(path).is_absolute() {
PathBuf::from(path)
} else {
self.workspace_dir.join(path)
}
}
pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool {
let resolved_canonical = if let Ok(c) = resolved.canonicalize() {
c
} else {
resolved.to_path_buf()
};
let workspace_canonical = if let Ok(c) = self.workspace_dir.canonicalize() {
c
} else {
self.workspace_dir.clone()
};
if resolved_canonical.starts_with(&workspace_canonical) {
return true;
}
PathValidator::is_within_allowed_roots(resolved, &self.config.allowed_roots)
}
pub async fn validate_path(&self, path: &str) -> Result<PathBuf, SecurityError> {
self.is_path_allowed(path)?;
let full_path = self.resolve_path(path);
let resolved = match tokio::fs::canonicalize(&full_path).await {
Ok(p) => p,
Err(_) => full_path, };
if self.config.workspace_only && !self.is_resolved_path_allowed(&resolved) {
return Err(SecurityError::PathEscapesWorkspace { resolved });
}
if !self.config.allow_symlinks {
if let Ok(meta) = tokio::fs::symlink_metadata(&resolved).await {
if meta.file_type().is_symlink() {
return Err(SecurityError::SymlinkNotAllowed { path: resolved });
}
}
}
Ok(resolved)
}
pub async fn validate_parent_directory(&self, path: &Path) -> Result<PathBuf, SecurityError> {
let Some(parent) = path.parent() else {
return Err(SecurityError::InvalidPathFormat {
reason: "Path has no parent directory".to_string(),
});
};
if let Err(e) = tokio::fs::create_dir_all(parent).await {
return Err(SecurityError::InvalidPathFormat {
reason: format!("Failed to create parent directories: {}", e),
});
}
let resolved_parent = tokio::fs::canonicalize(parent).await.map_err(|e| {
SecurityError::InvalidPathFormat {
reason: format!("Failed to resolve parent directory: {}", e),
}
})?;
if self.config.workspace_only && !self.is_resolved_path_allowed(&resolved_parent) {
return Err(SecurityError::PathEscapesWorkspace {
resolved: resolved_parent,
});
}
Ok(resolved_parent)
}
pub fn is_rate_limited(&self) -> bool {
self.tracker
.is_rate_limited(self.config.max_actions_per_hour)
}
pub fn record_action(&self) -> usize {
self.tracker.record()
}
pub fn try_record_action(&self) -> Result<(), SecurityError> {
if !self.tracker.try_record(self.config.max_actions_per_hour) {
let count = self.tracker.count();
return Err(SecurityError::RateLimitExceeded {
count,
max: self.config.max_actions_per_hour,
});
}
Ok(())
}
pub fn action_count(&self) -> usize {
self.tracker.count()
}
pub fn can_act(&self) -> Result<(), SecurityError> {
if self.is_read_only() {
return Err(SecurityError::ReadOnlyMode);
}
self.try_record_action()
}
pub fn check_file_size(&self, size: u64) -> Result<(), SecurityError> {
if self.config.max_file_size > 0 && size > self.config.max_file_size {
Err(SecurityError::FileTooLarge {
size,
max_size: self.config.max_file_size,
})
} else {
Ok(())
}
}
}
impl Default for SecurityPolicy {
fn default() -> Self {
Self::new(std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")))
}
}
pub type SharedSecurityPolicy = Arc<SecurityPolicy>;
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_policy() -> (SecurityPolicy, TempDir) {
let temp_dir = TempDir::new().unwrap();
let policy = SecurityPolicy::new(temp_dir.path().to_path_buf());
(policy, temp_dir)
}
#[test]
fn test_path_validation_null_bytes() {
let (policy, _temp) = create_test_policy();
assert!(policy.is_path_allowed("/path\0/file").is_err());
}
#[test]
fn test_path_validation_traversal() {
let (policy, _temp) = create_test_policy();
assert!(policy.is_path_allowed("../etc/passwd").is_err());
assert!(policy.is_path_allowed("/path/../file").is_err());
}
#[test]
fn test_rate_limiting() {
let (policy, _temp) = create_test_policy();
for _ in 0..policy.config.max_actions_per_hour {
assert!(policy.try_record_action().is_ok());
}
assert!(policy.try_record_action().is_err());
}
#[test]
fn test_read_only_mode() {
let temp_dir = TempDir::new().unwrap();
let policy =
SecurityPolicy::from_level(temp_dir.path().to_path_buf(), SecurityLevel::Paranoid);
assert!(policy.is_read_only());
assert!(policy.can_act().is_err());
}
#[tokio::test]
async fn test_validate_path() {
let temp_dir = TempDir::new().unwrap();
let policy = SecurityPolicy::new(temp_dir.path().to_path_buf());
let test_file = temp_dir.path().join("test.txt");
tokio::fs::write(&test_file, "test").await.unwrap();
let result = policy.validate_path("test.txt").await;
assert!(result.is_ok());
let result = policy.validate_path("../test.txt").await;
assert!(result.is_err());
}
}