use std::collections::HashSet;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use regex::Regex;
use thiserror::Error;
use anyhow::{Result, Context};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceConfig {
pub agent: AgentConfig,
pub permissions: PermissionsConfig,
pub logging: Option<LoggingConfig>,
pub alerts: Option<AlertsConfig>,
pub enterprise: Option<EnterpriseConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub name: String,
pub version: String,
pub tenant: Option<String>,
pub compliance_level: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionsConfig {
pub network: NetworkConfig,
pub filesystem: FilesystemConfig,
pub env: EnvConfig,
pub tools: ToolsConfig,
pub model: ModelConfig,
pub content: ContentConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkConfig {
#[serde(default)]
pub default_deny: bool,
pub allow: Vec<String>,
pub block: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FilesystemConfig {
#[serde(default)]
pub default_deny: bool,
pub allow_read: Vec<String>,
pub allow_write: Vec<String>,
pub block_all: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvConfig {
#[serde(default)]
pub default_deny: bool,
pub allow: Vec<String>,
pub block: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolsConfig {
pub enforce_whitelist: bool,
#[serde(default)]
pub audit_all_calls: bool,
pub blocked_patterns: Vec<String>,
pub timeout_seconds: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub allowed_models: Vec<String>,
pub max_temperature: f64,
pub max_tokens: u32,
#[serde(default)]
pub audit_api_calls: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentConfig {
pub blocked_input_patterns: Vec<String>,
pub blocked_output_patterns: Vec<String>,
pub max_response_length: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: String,
#[serde(default)]
pub audit_all_requests: bool,
#[serde(default)]
pub audit_all_tool_calls: bool,
#[serde(default)]
pub log_network_requests: bool,
#[serde(default)]
pub log_file_access: bool,
pub retention_days: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlertsConfig {
#[serde(default)]
pub real_time_monitoring: bool,
#[serde(default)]
pub security_violation_alerts: bool,
#[serde(default)]
pub compliance_breach_alerts: bool,
#[serde(default)]
pub performance_anomaly_alerts: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnterpriseConfig {
pub tenant_id: Option<String>,
pub department: Option<String>,
pub cost_center: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceViolation {
pub timestamp: u64,
pub category: String,
pub message: String,
pub details: serde_json::Value,
pub agent_id: String,
}
#[derive(Error, Debug)]
pub enum ComplianceError {
#[error("Network compliance violation: {0}")]
Network(String),
#[error("Filesystem compliance violation: {0}")]
Filesystem(String),
#[error("Environment variable compliance violation: {0}")]
Environment(String),
#[error("Tool usage compliance violation: {0}")]
Tool(String),
#[error("Content policy compliance violation: {0}")]
Content(String),
#[error("Model usage compliance violation: {0}")]
Model(String),
#[error("Integrity compliance violation: {0}")]
Integrity(String),
#[error("Configuration error: {0}")]
Configuration(String),
#[error("Compliance check failed: {0}")]
CheckFailed(String),
}
#[derive(Debug, Clone)]
pub struct ComplianceMonitor {
config: ComplianceConfig,
logger: Option<Arc<crate::logger::SekuireLogger>>,
network_allow_set: Arc<HashSet<String>>,
network_block_set: Arc<HashSet<String>>,
files_read_set: Arc<HashSet<String>>,
files_write_set: Arc<HashSet<String>>,
files_block_patterns: Arc<Vec<String>>,
env_allow_set: Arc<HashSet<String>>,
env_block_set: Arc<HashSet<String>>,
allowed_tools_set: Arc<HashSet<String>>,
tool_block_patterns: Arc<Vec<Regex>>,
input_block_patterns: Arc<Vec<Regex>>,
output_block_patterns: Arc<Vec<Regex>>,
allowed_models_set: Arc<HashSet<String>>,
violations: Arc<Mutex<Vec<ComplianceViolation>>>,
network_cache: Arc<RwLock<std::collections::HashMap<String, bool>>>,
content_cache: Arc<RwLock<std::collections::HashMap<String, bool>>>,
max_cache_size: usize,
}
impl ComplianceMonitor {
pub fn new<P: AsRef<Path>>(config_path: P) -> Result<Self> {
let config = Self::load_config(&config_path)
.context("Failed to load compliance configuration")?;
Self::from_config(config)
}
pub fn new_with_logger<P: AsRef<Path>>(
config_path: P,
logger: Arc<crate::logger::SekuireLogger>,
) -> Result<Self> {
let config = Self::load_config(&config_path)
.context("Failed to load compliance configuration")?;
let mut monitor = Self::from_config(config)?;
monitor.logger = Some(logger);
Ok(monitor)
}
pub fn from_config(config: ComplianceConfig) -> Result<Self> {
let mut tool_patterns = Vec::new();
for pattern in &config.permissions.tools.blocked_patterns {
let regex = Regex::new(&format!("(?i){}", regex::escape(pattern)))
.context("Failed to compile tool block pattern")?;
tool_patterns.push(regex);
}
let mut input_patterns = Vec::new();
for pattern in &config.permissions.content.blocked_input_patterns {
let regex = Regex::new(&format!("(?i){}", pattern))
.context("Failed to compile input block pattern")?;
input_patterns.push(regex);
}
let mut output_patterns = Vec::new();
for pattern in &config.permissions.content.blocked_output_patterns {
let regex = Regex::new(&format!("(?i){}", pattern))
.context("Failed to compile output block pattern")?;
output_patterns.push(regex);
}
let allowed_tools = if Path::new("tools.json").exists() {
let content = fs::read_to_string("tools.json")
.context("Failed to read tools.json")?;
let tools: Vec<serde_json::Value> = serde_json::from_str(&content)
.context("Failed to parse tools.json")?;
tools
.iter()
.filter_map(|tool| tool.get("name"))
.filter_map(|name| name.as_str())
.map(|name| name.to_string())
.collect()
} else {
HashSet::new()
};
let monitor = Self {
config: config.clone(),
logger: None,
network_allow_set: Arc::new(config.permissions.network.allow.into_iter().collect()),
network_block_set: Arc::new(config.permissions.network.block.unwrap_or_default().into_iter().collect()),
files_read_set: Arc::new(config.permissions.filesystem.allow_read.into_iter().collect()),
files_write_set: Arc::new(config.permissions.filesystem.allow_write.into_iter().collect()),
files_block_patterns: Arc::new(config.permissions.filesystem.block_all.unwrap_or_default()),
env_allow_set: Arc::new(config.permissions.env.allow.into_iter().collect()),
env_block_set: Arc::new(config.permissions.env.block.unwrap_or_default().into_iter().collect()),
allowed_tools_set: Arc::new(allowed_tools),
tool_block_patterns: Arc::new(tool_patterns),
input_block_patterns: Arc::new(input_patterns),
output_block_patterns: Arc::new(output_patterns),
allowed_models_set: Arc::new(config.permissions.model.allowed_models.into_iter().collect()),
violations: Arc::new(Mutex::new(Vec::new())),
network_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
content_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
max_cache_size: 10000,
};
Ok(monitor)
}
fn load_config<P: AsRef<Path>>(config_path: P) -> Result<ComplianceConfig> {
let content = fs::read_to_string(&config_path)
.context("Failed to read configuration file")?;
serde_yaml::from_str(&content)
.context("Failed to parse YAML configuration")
}
fn initialize_regex_patterns(&self) -> Result<()> {
Ok(())
}
fn load_allowed_tools(&self) -> Result<()> {
Ok(())
}
pub fn check_network_access(&self, url: &str, method: &str) -> Result<()> {
{
let cache = self.network_cache.read().unwrap();
if let Some(allowed) = cache.get(url) {
return if *allowed { Ok(()) } else {
Err(ComplianceError::Network(format!("Access to {} is blocked", url)))
};
}
}
for blocked in self.network_block_set.iter() {
if url.starts_with(blocked) {
let violation = self.create_violation(
"network",
format!("Network access to {} blocked by blocklist rule: {}", url, blocked),
serde_json::json!({
"url": url,
"method": method,
"rule": blocked
}),
);
self.log_violation(violation);
{
let mut cache = self.network_cache.write().unwrap();
self.set_cache_value(&mut *cache, url.to_string(), false);
}
return Err(ComplianceError::Network(violation.message));
}
}
if !self.network_allow_set.is_empty() {
let allowed = self.network_allow_set.iter().any(|allowed| url.starts_with(allowed));
if !allowed {
let violation = self.create_violation(
"network",
format!("Network access to {} not in allowlist", url),
serde_json::json!({
"url": url,
"method": method
}),
);
self.log_violation(violation);
{
let mut cache = self.network_cache.write().unwrap();
self.set_cache_value(&mut *cache, url.to_string(), false);
}
return Err(ComplianceError::Network(violation.message));
}
}
{
let mut cache = self.network_cache.write().unwrap();
self.set_cache_value(&mut *cache, url.to_string(), true);
}
if self.config.permissions.network.log_network_requests {
if let Some(logger) = &self.logger {
logger.log_event(
crate::logger::EventType::NetworkAccess,
crate::logger::Severity::Info,
serde_json::json!({
"url": url,
"method": method,
"status": "allowed",
}),
);
}
}
Ok(())
}
pub fn check_file_access(&self, file_path: &str, mode: &str) -> Result<()> {
let resolved_path = fs::canonicalize(file_path)
.unwrap_or_else(|_| PathBuf::from(file_path));
let path_str = resolved_path.to_string_lossy();
for blocked_pattern in self.files_block_patterns.iter() {
if path_str.contains(blocked_pattern) || path_str.starts_with(blocked_pattern) {
let violation = self.create_violation(
"filesystem",
format!("File access blocked by pattern: {}", blocked_pattern),
serde_json::json!({
"path": path_str,
"mode": mode,
"pattern": blocked_pattern
}),
);
self.log_violation(violation);
return Err(ComplianceError::Filesystem(violation.message));
}
}
match mode {
"read" => {
if !self.files_read_set.is_empty() {
let allowed = self.files_read_set.iter().any(|allowed| {
path_str.ends_with(allowed) || path_str == *allowed
});
if !allowed {
let violation = self.create_violation(
"filesystem",
format!("File read access not allowed: {}", path_str),
serde_json::json!({
"path": path_str,
"mode": mode
}),
);
self.log_violation(violation);
return Err(ComplianceError::Filesystem(violation.message));
}
}
}
"write" => {
if !self.files_write_set.is_empty() {
let allowed = self.files_write_set.iter().any(|allowed| path_str.starts_with(allowed));
if !allowed {
let violation = self.create_violation(
"filesystem",
format!("File write access not allowed: {}", path_str),
serde_json::json!({
"path": path_str,
"mode": mode
}),
);
self.log_violation(violation);
return Err(ComplianceError::Filesystem(violation.message));
}
}
}
_ => {
return Err(ComplianceError::Filesystem(format!("Unsupported access mode: {}", mode)));
}
}
if self.config.permissions.network.log_network_requests {
if let Some(logger) = &self.logger {
logger.log_event(
crate::logger::EventType::FileAccess,
crate::logger::Severity::Info,
serde_json::json!({
"path": path_str,
"mode": mode,
"status": "allowed",
}),
);
}
}
Ok(())
}
pub fn safe_read_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
let path_str = path.as_ref().to_string_lossy();
self.check_file_access(&path_str, "read")?;
fs::read_to_string(path)
.context("Failed to read file")
}
pub fn safe_write_file<P: AsRef<Path>, C: AsRef<[u8]>>(&self, path: P, contents: C) -> Result<()> {
let path_str = path.as_ref().to_string_lossy();
self.check_file_access(&path_str, "write")?;
fs::write(path, contents)
.context("Failed to write file")
}
pub fn check_env_access(&self, var_name: &str) -> Result<()> {
let var_lower = var_name.to_lowercase();
for blocked in self.env_block_set.iter() {
if var_lower.contains(&blocked.to_lowercase()) {
let violation = self.create_violation(
"env",
format!("Environment variable access blocked: {}", var_name),
serde_json::json!({
"var_name": var_name
}),
);
self.log_violation(violation);
return Err(ComplianceError::Environment(violation.message));
}
}
if !self.env_allow_set.is_empty() && !self.env_allow_set.contains(var_name) {
let violation = self.create_violation(
"env",
format!("Environment variable {} not in allowlist", var_name),
serde_json::json!({
"var_name": var_name
}),
);
self.log_violation(violation);
return Err(ComplianceError::Environment(violation.message));
}
Ok(())
}
pub fn safe_get_env(&self, var_name: &str) -> Option<String> {
self.check_env_access(var_name).ok()?;
std::env::var(var_name).ok()
}
pub fn check_tool_usage(&self, tool_name: &str, code_snippet: Option<&str>) -> Result<()> {
if let Some(code) = code_snippet {
for pattern in self.tool_block_patterns.iter() {
if pattern.is_match(code) {
let violation = self.create_violation(
"tools",
format!("Tool usage contains blocked pattern: {}", pattern.as_str()),
serde_json::json!({
"tool": tool_name,
"pattern": pattern.as_str()
}),
);
self.log_violation(violation);
return Err(ComplianceError::Tool(violation.message));
}
}
}
if self.config.permissions.tools.enforce_whitelist && !self.allowed_tools_set.contains(tool_name) {
let violation = self.create_violation(
"tools",
format!("Tool {} not in approved tools.json whitelist", tool_name),
serde_json::json!({
"tool": tool_name,
"allowed_tools": self.allowed_tools_set.iter().cloned().collect::<Vec<_>>()
}),
);
self.log_violation(violation);
return Err(ComplianceError::Tool(violation.message));
}
if self.config.permissions.tools.audit_all_tool_calls {
if let Some(logger) = &self.logger {
logger.log_event(
crate::logger::EventType::ToolExecution,
crate::logger::Severity::Info,
serde_json::json!({
"tool": tool_name,
"status": "allowed",
"has_code_snippet": code_snippet.is_some(),
}),
);
}
}
Ok(())
}
pub fn check_input_content(&self, content: &str) -> Result<()> {
self.check_content(content, &self.input_block_patterns, "input")
}
pub fn check_output_content(&self, content: &str) -> Result<()> {
self.check_content(content, &self.output_block_patterns, "output")
}
fn check_content(&self, content: &str, patterns: &[Regex], content_type: &str) -> Result<()> {
let cache_key = format!("{}:{}", content_type, &content[..content.len().min(100)]);
{
let cache = self.content_cache.read().unwrap();
if let Some(allowed) = cache.get(&cache_key) {
return if *allowed { Ok(()) } else {
Err(ComplianceError::Content(format!("Content check failed for {}", content_type)))
};
}
}
for pattern in patterns {
if pattern.is_match(content) {
let violation = self.create_violation(
"content",
format!("{} contains blocked pattern: {}", content_type, pattern.as_str()),
serde_json::json!({
"pattern": pattern.as_str(),
"content_length": content.len(),
"content_type": content_type
}),
);
self.log_violation(violation);
{
let mut cache = self.content_cache.write().unwrap();
self.set_cache_value(&mut *cache, cache_key, false);
}
return Err(ComplianceError::Content(violation.message));
}
}
if content_type == "output" {
if let Some(max_length) = self.config.permissions.content.max_response_length {
if content.len() > max_length {
let violation = self.create_violation(
"content",
format!("{} length {} exceeds maximum {}", content_type, content.len(), max_length),
serde_json::json!({
"content_length": content.len(),
"max_length": max_length,
"content_type": content_type
}),
);
self.log_violation(violation);
return Err(ComplianceError::Content(violation.message));
}
}
}
{
let mut cache = self.content_cache.write().unwrap();
self.set_cache_value(&mut *cache, cache_key, true);
}
Ok(())
}
pub fn check_model_usage(&self, model: &str, temperature: Option<f64>, max_tokens: Option<u32>) -> Result<()> {
if !self.allowed_models_set.is_empty() && !self.allowed_models_set.contains(model) {
let violation = self.create_violation(
"model",
format!("Model {} not in allowlist", model),
serde_json::json!({
"model": model,
"allowed_models": self.allowed_models_set.iter().cloned().collect::<Vec<_>>()
}),
);
self.log_violation(violation);
return Err(ComplianceError::Model(violation.message));
}
if let Some(temp) = temperature {
if temp > self.config.permissions.model.max_temperature {
let violation = self.create_violation(
"model",
format!("Temperature {} exceeds maximum {}", temp, self.config.permissions.model.max_temperature),
serde_json::json!({
"model": model,
"temperature": temp
}),
);
self.log_violation(violation);
return Err(ComplianceError::Model(violation.message));
}
}
if let Some(tokens) = max_tokens {
if tokens > self.config.permissions.model.max_tokens {
let violation = self.create_violation(
"model",
format!("Max tokens {} exceeds maximum {}", tokens, self.config.permissions.model.max_tokens),
serde_json::json!({
"model": model,
"max_tokens": tokens
}),
);
self.log_violation(violation);
return Err(ComplianceError::Model(violation.message));
}
}
Ok(())
}
pub fn hash_content(&self, content: &str) -> Result<String> {
use blake3::Hasher;
let mut hasher = Hasher::new();
hasher.update(content.as_bytes());
Ok(format!("{:x}", hasher.finalize()))
}
pub fn check_file_integrity(&self, file_path: &str, expected_hash: &str) -> Result<()> {
let content = fs::read_to_string(file_path)
.context("Failed to read file for integrity check")?;
let actual_hash = self.hash_content(&content)?;
if actual_hash != expected_hash {
let violation = self.create_violation(
"integrity",
format!("File integrity check failed: {}", file_path),
serde_json::json!({
"file": file_path,
"expected": expected_hash,
"actual": actual_hash
}),
);
self.log_violation(violation);
return Err(ComplianceError::Integrity(violation.message));
}
Ok(())
}
fn create_violation(&self, category: &str, message: String, details: serde_json::Value) -> ComplianceViolation {
ComplianceViolation {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
category: category.to_string(),
message,
details,
agent_id: self.config.agent.name.clone(),
}
}
fn log_violation(&self, violation: ComplianceViolation) {
{
let mut violations = self.violations.lock().unwrap();
violations.push(violation.clone());
if violations.len() > 10000 {
violations.drain(0..5000);
}
}
eprintln!("⚠️ Compliance Violation [{}]: {}", violation.category, violation.message);
if let Some(logger) = &self.logger {
logger.log_event(
crate::logger::EventType::PolicyViolation,
crate::logger::Severity::Error,
serde_json::json!({
"category": violation.category,
"message": violation.message.clone(),
"details": violation.details,
"agent_id": violation.agent_id,
"timestamp": violation.timestamp,
}),
);
}
if let Some(alerts) = &self.config.alerts {
match violation.category.as_str() {
"network" if alerts.security_violation_alerts => self.send_security_alert(&violation),
"filesystem" if alerts.security_violation_alerts => self.send_security_alert(&violation),
"tools" if alerts.compliance_breach_alerts => self.send_security_alert(&violation),
"content" if alerts.compliance_breach_alerts => self.send_security_alert(&violation),
_ => {}
}
}
}
fn send_security_alert(&self, violation: &ComplianceViolation) {
println!("🚨 Security Alert: {} - {}", violation.category, violation.message);
}
pub fn get_compliance_status(&self) -> serde_json::Value {
let violations = self.violations.lock().unwrap();
let network_cache = self.network_cache.read().unwrap();
let content_cache = self.content_cache.read().unwrap();
let last_violation = violations.last().cloned();
serde_json::json!({
"config_loaded": true,
"violations_count": violations.len(),
"last_violation": last_violation,
"rules_enforced": [
"network_access",
"file_access",
"tool_usage",
"content_filtering",
"env_var_access",
"model_usage"
],
"security_stats": {
"network_rules": self.network_allow_set.len() + self.network_block_set.len(),
"allowed_tools": self.allowed_tools_set.len(),
"content_patterns": self.input_block_patterns.len() + self.output_block_patterns.len(),
"cache_size": network_cache.len() + content_cache.len()
}
})
}
pub fn get_violations(&self, category: Option<&str>, limit: Option<usize>) -> Vec<ComplianceViolation> {
let violations = self.violations.lock().unwrap();
let filtered = if let Some(cat) = category {
violations.iter().filter(|v| v.category == cat).cloned().collect()
} else {
violations.clone()
};
if let Some(limit) = limit {
filtered.into_iter().rev().take(limit).collect()
} else {
filtered
}
}
pub fn clear_violations(&self) {
self.violations.lock().unwrap().clear();
self.network_cache.write().unwrap().clear();
self.content_cache.write().unwrap().clear();
}
fn set_cache_value<K, V>(&self, cache: &mut std::collections::HashMap<K, V>, key: K, value: V) {
if cache.len() >= self.max_cache_size {
let keys_to_delete: Vec<K> = cache.keys().take(self.max_cache_size / 5).cloned().collect();
for key in keys_to_delete {
cache.remove(&key);
}
}
cache.insert(key, value);
}
pub fn pre_warm_caches(&self) {
let common_domains = [
"https://api.openai.com",
"https://api.anthropic.com",
"https://registry.sekuire.ai",
];
for domain in &common_domains {
let _ = self.check_network_access(domain, "GET");
}
}
}
#[macro_export]
macro_rules! enforce_compliance {
($monitor:expr, $check:expr, $error_type:ident, $message:expr) => {
if let Err(e) = $check {
return Err($crate::compliance::ComplianceError::$error_type(format!("{}: {}", $message, e)));
}
};
}
#[macro_export]
macro_rules! enforce_network {
($monitor:expr, $url:expr, $method:expr) => {
$crate::enforce_compliance!($monitor, $monitor.check_network_access($url, $method), Network, "Network access denied")
};
}
#[macro_export]
macro_rules! enforce_file_access {
($monitor:expr, $path:expr, $mode:expr) => {
$crate::enforce_compliance!($monitor, $monitor.check_file_access($path, $mode), Filesystem, "File access denied")
};
}
#[macro_export]
macro_rules! enforce_tool_usage {
($monitor:expr, $tool:expr, $code:expr) => {
$crate::enforce_compliance!($monitor, $monitor.check_tool_usage($tool, $code), Tool, "Tool usage denied")
};
}
#[macro_export]
macro_rules! enforce_content {
($monitor:expr, $content:expr, $direction:expr) => {
let check_result = match $direction {
"input" => $monitor.check_input_content($content),
"output" => $monitor.check_output_content($content),
_ => return Err($crate::compliance::ComplianceError::Content("Invalid content direction".to_string())),
};
$crate::enforce_compliance!($monitor, check_result, Content, "Content policy violation")
};
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
#[test]
fn test_config_loading() {
let config_content = r#"
agent:
name: "test-agent"
version: "1.0.0"
permissions:
network:
allow: ["api.openai.com"]
block: ["github.com"]
filesystem:
allow_read: ["config/"]
allow_write: ["logs/"]
env:
allow: ["OPENAI_API_KEY"]
tools:
enforce_whitelist: true
blocked_patterns: ["os.system", "eval("]
model:
allowed_models: ["gpt-4"]
max_temperature: 0.7
max_tokens: 1000
content:
blocked_input_patterns: ["password"]
blocked_output_patterns: ["sudo su"]
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(config_content.as_bytes()).unwrap();
let monitor = ComplianceMonitor::new(temp_file.path()).unwrap();
assert_eq!(monitor.config.agent.name, "test-agent");
assert_eq!(monitor.network_allow_set.len(), 1);
assert!(monitor.network_allow_set.contains("api.openai.com"));
}
#[test]
fn test_network_compliance() {
let config = ComplianceConfig {
agent: AgentConfig {
name: "test".to_string(),
version: "1.0.0".to_string(),
tenant: None,
compliance_level: None,
},
permissions: PermissionsConfig {
network: NetworkConfig {
default_deny: false,
allow: vec!["api.openai.com".to_string()],
block: Some(vec!["github.com".to_string()]),
},
filesystem: FilesystemConfig {
default_deny: false,
allow_read: vec![],
allow_write: vec![],
block_all: None,
},
env: EnvConfig {
default_deny: false,
allow: vec![],
block: None,
},
tools: ToolsConfig {
enforce_whitelist: true,
audit_all_calls: false,
blocked_patterns: vec![],
timeout_seconds: None,
},
model: ModelConfig {
allowed_models: vec![],
max_temperature: 1.0,
max_tokens: 1000,
audit_api_calls: false,
},
content: ContentConfig {
blocked_input_patterns: vec![],
blocked_output_patterns: vec![],
max_response_length: None,
},
},
logging: None,
alerts: None,
enterprise: None,
};
let monitor = ComplianceMonitor::from_config(config).unwrap();
assert!(monitor.check_network_access("https://api.openai.com/v1/chat", "POST").is_ok());
assert!(monitor.check_network_access("https://github.com/repo", "GET").is_err());
assert!(monitor.check_network_access("https://example.com", "GET").is_err());
}
#[test]
fn test_content_compliance() {
let config = ComplianceConfig {
agent: AgentConfig {
name: "test".to_string(),
version: "1.0.0".to_string(),
tenant: None,
compliance_level: None,
},
permissions: PermissionsConfig {
network: NetworkConfig {
default_deny: false,
allow: vec![],
block: None,
},
filesystem: FilesystemConfig {
default_deny: false,
allow_read: vec![],
allow_write: vec![],
block_all: None,
},
env: EnvConfig {
default_deny: false,
allow: vec![],
block: None,
},
tools: ToolsConfig {
enforce_whitelist: true,
audit_all_calls: false,
blocked_patterns: vec![],
timeout_seconds: None,
},
model: ModelConfig {
allowed_models: vec![],
max_temperature: 1.0,
max_tokens: 1000,
audit_api_calls: false,
},
content: ContentConfig {
blocked_input_patterns: vec!["password".to_string(), "secret".to_string()],
blocked_output_patterns: vec!["sudo su".to_string()],
max_response_length: None,
},
},
logging: None,
alerts: None,
enterprise: None,
};
let monitor = ComplianceMonitor::from_config(config).unwrap();
assert!(monitor.check_input_content("What's the admin password?").is_err());
assert!(monitor.check_input_content("Hello, how can I help you?").is_ok());
assert!(monitor.check_output_content("Run: sudo su to get root").is_err());
assert!(monitor.check_output_content("Here's the information you requested").is_ok());
}
}