use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use crate::error::ShieldError;
use crate::ir::tool_surface::PermissionType;
use crate::ir::{ArgumentSource, ScanTarget};
const CURRENT_SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EgressPolicy {
pub schema_version: u32,
pub domains: DomainPolicy,
#[serde(default)]
pub networks: NetworkPolicy,
#[serde(default)]
pub rate_limits: RateLimitPolicy,
#[serde(default)]
pub audit: AuditPolicy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainPolicy {
#[serde(default)]
pub allow: Vec<String>,
#[serde(default)]
pub deny: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkPolicy {
#[serde(default = "default_true")]
pub block_private: bool,
#[serde(default = "default_true")]
pub block_link_local: bool,
#[serde(default = "default_true")]
pub block_localhost: bool,
#[serde(default = "default_true")]
pub block_metadata: bool,
}
fn default_true() -> bool {
true
}
impl Default for NetworkPolicy {
fn default() -> Self {
Self {
block_private: true,
block_link_local: true,
block_localhost: true,
block_metadata: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitPolicy {
#[serde(default = "default_rate_limit")]
pub max_requests_per_minute: u32,
#[serde(default)]
pub per_domain: HashMap<String, u32>,
}
fn default_rate_limit() -> u32 {
60
}
impl Default for RateLimitPolicy {
fn default() -> Self {
Self {
max_requests_per_minute: default_rate_limit(),
per_domain: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditPolicy {
#[serde(default)]
pub log_path: Option<PathBuf>,
#[serde(default = "default_log_format")]
pub log_format: String,
#[serde(default)]
pub log_allowed: bool,
}
fn default_log_format() -> String {
"json".to_string()
}
impl Default for AuditPolicy {
fn default() -> Self {
Self {
log_path: None,
log_format: default_log_format(),
log_allowed: false,
}
}
}
impl EgressPolicy {
pub fn load(path: &Path) -> Result<Self, ShieldError> {
let content = std::fs::read_to_string(path).map_err(ShieldError::Io)?;
let policy: Self = toml::from_str(&content)?;
if policy.schema_version > CURRENT_SCHEMA_VERSION {
return Err(ShieldError::Config(format!(
"Egress policy schema version {} is newer than supported version {}",
policy.schema_version, CURRENT_SCHEMA_VERSION
)));
}
Ok(policy)
}
pub fn save(&self, path: &Path) -> Result<(), ShieldError> {
let content = toml::to_string_pretty(self)?;
std::fs::write(path, content).map_err(ShieldError::Io)?;
Ok(())
}
pub fn is_domain_allowed(&self, domain: &str) -> bool {
if self
.domains
.deny
.iter()
.any(|pattern| domain_matches(domain, pattern))
{
return false;
}
if self.domains.allow.is_empty() {
return true;
}
self.domains
.allow
.iter()
.any(|pattern| domain_matches(domain, pattern))
}
pub fn is_ip_blocked(&self, ip: &str) -> bool {
if self.networks.block_localhost && is_localhost(ip) {
return true;
}
if self.networks.block_private && is_private_ip(ip) {
return true;
}
if self.networks.block_link_local && is_link_local(ip) {
return true;
}
if self.networks.block_metadata && is_metadata_ip(ip) {
return true;
}
false
}
pub fn rate_limit_for(&self, domain: &str) -> u32 {
self.rate_limits
.per_domain
.get(domain)
.copied()
.unwrap_or(self.rate_limits.max_requests_per_minute)
}
pub fn from_scan_targets(targets: &[ScanTarget]) -> Self {
let mut domains = std::collections::HashSet::new();
for target in targets {
for net_op in &target.execution.network_operations {
if let ArgumentSource::Literal(ref url) = net_op.url_arg {
if let Some(domain) = extract_domain(url) {
domains.insert(domain);
}
}
}
for tool in &target.tools {
for perm in &tool.declared_permissions {
if matches!(perm.permission_type, PermissionType::NetworkAccess) {
if let Some(ref scope) = perm.target {
if let Some(domain) = extract_domain(scope) {
domains.insert(domain);
}
}
}
}
}
}
let mut allow: Vec<String> = domains.into_iter().collect();
allow.sort();
EgressPolicy {
schema_version: CURRENT_SCHEMA_VERSION,
domains: DomainPolicy {
allow,
deny: vec![],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy::default(),
}
}
pub fn merge_override(&self, operator: &EgressPolicy) -> EgressPolicy {
let allow = if operator.domains.allow.is_empty() {
self.domains.allow.clone()
} else if self.domains.allow.is_empty() {
operator.domains.allow.clone()
} else {
self.domains
.allow
.iter()
.filter(|d| {
operator
.domains
.allow
.iter()
.any(|o| domain_matches(d, o) || domain_matches(o, d))
})
.cloned()
.collect()
};
let mut deny = self.domains.deny.clone();
for d in &operator.domains.deny {
if !deny.contains(d) {
deny.push(d.clone());
}
}
let global_min = self
.rate_limits
.max_requests_per_minute
.min(operator.rate_limits.max_requests_per_minute);
let mut per_domain = self.rate_limits.per_domain.clone();
for (domain, &op_rate) in &operator.rate_limits.per_domain {
let entry = per_domain
.entry(domain.clone())
.or_insert(self.rate_limits.max_requests_per_minute);
*entry = (*entry).min(op_rate);
}
EgressPolicy {
schema_version: self.schema_version,
domains: DomainPolicy { allow, deny },
networks: NetworkPolicy {
block_private: self.networks.block_private || operator.networks.block_private,
block_link_local: self.networks.block_link_local
|| operator.networks.block_link_local,
block_localhost: self.networks.block_localhost || operator.networks.block_localhost,
block_metadata: self.networks.block_metadata || operator.networks.block_metadata,
},
rate_limits: RateLimitPolicy {
max_requests_per_minute: global_min,
per_domain,
},
audit: operator.audit.clone(),
}
}
pub fn starter_toml() -> &'static str {
r#"# AgentShield Egress Policy
# See: https://github.com/limaronaldo/agentshield
schema_version = 1
[domains]
# Allowed domain patterns (glob-style)
allow = ["*.example.com", "api.github.com"]
# Explicitly denied (takes precedence over allow)
deny = []
[networks]
block_private = true # 10.x, 172.16-31.x, 192.168.x
block_link_local = true # 169.254.x
block_localhost = true # 127.x, ::1
block_metadata = true # 169.254.169.254, metadata.google.internal
[rate_limits]
max_requests_per_minute = 60
[audit]
# log_path = "agentshield-audit.jsonl"
log_format = "json"
log_allowed = false
"#
}
}
pub fn extract_domain(url_or_domain: &str) -> Option<String> {
let rest = if let Some(r) = url_or_domain.strip_prefix("https://") {
r
} else if let Some(r) = url_or_domain.strip_prefix("http://") {
r
} else {
if url_or_domain.contains('.') && !url_or_domain.contains('/') {
return Some(url_or_domain.to_string());
}
return None;
};
let host = rest.split('/').next()?;
let host = host.split(':').next()?;
if host.is_empty() {
return None;
}
Some(host.to_string())
}
fn domain_matches(domain: &str, pattern: &str) -> bool {
if let Some(suffix) = pattern.strip_prefix('*') {
domain.ends_with(suffix) || domain == &suffix[1..]
} else {
domain == pattern
}
}
fn is_localhost(ip: &str) -> bool {
ip.starts_with("127.") || ip == "::1" || ip == "localhost"
}
fn is_private_ip(ip: &str) -> bool {
ip.starts_with("10.")
|| (ip.starts_with("172.") && is_172_private(ip))
|| ip.starts_with("192.168.")
|| ip.starts_with("fd") }
fn is_172_private(ip: &str) -> bool {
if let Some(second_octet) = ip
.strip_prefix("172.")
.and_then(|rest| rest.split('.').next())
{
if let Ok(n) = second_octet.parse::<u8>() {
return (16..=31).contains(&n);
}
}
false
}
fn is_link_local(ip: &str) -> bool {
ip.starts_with("169.254.") || ip.starts_with("fe80:")
}
fn is_metadata_ip(ip: &str) -> bool {
ip == "169.254.169.254"
|| ip.contains("metadata.google.internal")
|| ip == "100.100.100.200" || ip == "169.254.170.2" }
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn sample_policy() -> EgressPolicy {
EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec!["*.example.com".into(), "api.github.com".into()],
deny: vec!["evil.example.com".into()],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy {
max_requests_per_minute: 60,
per_domain: {
let mut m = HashMap::new();
m.insert("api.github.com".into(), 30);
m
},
},
audit: AuditPolicy::default(),
}
}
#[test]
fn test_load_and_save_roundtrip() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("egress.toml");
let original = sample_policy();
original.save(&path).unwrap();
let loaded = EgressPolicy::load(&path).unwrap();
assert_eq!(loaded.schema_version, original.schema_version);
assert_eq!(loaded.domains.allow, original.domains.allow);
assert_eq!(loaded.domains.deny, original.domains.deny);
assert_eq!(
loaded.networks.block_private,
original.networks.block_private
);
assert_eq!(
loaded.networks.block_localhost,
original.networks.block_localhost
);
assert_eq!(
loaded.networks.block_link_local,
original.networks.block_link_local
);
assert_eq!(
loaded.networks.block_metadata,
original.networks.block_metadata
);
assert_eq!(
loaded.rate_limits.max_requests_per_minute,
original.rate_limits.max_requests_per_minute
);
assert_eq!(
loaded.rate_limits.per_domain,
original.rate_limits.per_domain
);
assert_eq!(loaded.audit.log_format, original.audit.log_format);
assert_eq!(loaded.audit.log_allowed, original.audit.log_allowed);
assert_eq!(loaded.audit.log_path, original.audit.log_path);
}
#[test]
fn test_domain_allowed() {
let policy = sample_policy();
assert!(policy.is_domain_allowed("api.github.com"));
assert!(policy.is_domain_allowed("sub.example.com"));
assert!(policy.is_domain_allowed("example.com"));
assert!(!policy.is_domain_allowed("random.org"));
}
#[test]
fn test_domain_denied_takes_precedence() {
let policy = sample_policy();
assert!(
!policy.is_domain_allowed("evil.example.com"),
"deny should take precedence over allow"
);
}
#[test]
fn test_empty_allow_list_allows_all() {
let policy = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![],
deny: vec!["blocked.com".into()],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy::default(),
};
assert!(policy.is_domain_allowed("anything.com"));
assert!(policy.is_domain_allowed("whatever.org"));
assert!(
!policy.is_domain_allowed("blocked.com"),
"deny should still block even with empty allow"
);
}
#[test]
fn test_ip_blocking() {
let policy = sample_policy();
assert!(policy.is_ip_blocked("127.0.0.1"));
assert!(policy.is_ip_blocked("127.0.0.2"));
assert!(policy.is_ip_blocked("::1"));
assert!(policy.is_ip_blocked("localhost"));
assert!(policy.is_ip_blocked("10.0.0.1"));
assert!(policy.is_ip_blocked("172.16.0.1"));
assert!(policy.is_ip_blocked("172.31.255.255"));
assert!(policy.is_ip_blocked("192.168.1.1"));
assert!(!policy.is_ip_blocked("172.15.0.1"));
assert!(!policy.is_ip_blocked("172.32.0.1"));
assert!(policy.is_ip_blocked("169.254.1.1"));
assert!(policy.is_ip_blocked("fe80::1"));
assert!(policy.is_ip_blocked("169.254.169.254"));
assert!(policy.is_ip_blocked("metadata.google.internal"));
assert!(policy.is_ip_blocked("100.100.100.200"));
assert!(policy.is_ip_blocked("169.254.170.2"));
assert!(!policy.is_ip_blocked("8.8.8.8"));
assert!(!policy.is_ip_blocked("1.1.1.1"));
}
#[test]
fn test_rate_limit_per_domain() {
let policy = sample_policy();
assert_eq!(policy.rate_limit_for("api.github.com"), 30);
}
#[test]
fn test_rate_limit_default() {
let policy = sample_policy();
assert_eq!(policy.rate_limit_for("unknown.com"), 60);
}
#[test]
fn test_future_schema_rejected() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("future.toml");
let content = r#"
schema_version = 99
[domains]
allow = []
deny = []
"#;
std::fs::write(&path, content).unwrap();
let result = EgressPolicy::load(&path);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("99") && err_msg.contains("newer"),
"Error should mention unsupported schema version, got: {err_msg}"
);
}
#[test]
fn test_starter_toml_parses() {
let toml_str = EgressPolicy::starter_toml();
let policy: EgressPolicy =
toml::from_str(toml_str).expect("starter_toml() should produce valid TOML");
assert_eq!(policy.schema_version, 1);
assert!(!policy.domains.allow.is_empty());
assert!(policy.networks.block_private);
assert!(policy.networks.block_metadata);
assert_eq!(policy.rate_limits.max_requests_per_minute, 60);
assert_eq!(policy.audit.log_format, "json");
}
#[test]
fn test_extract_domain_from_url() {
assert_eq!(
extract_domain("https://api.example.com/v1/items"),
Some("api.example.com".into())
);
assert_eq!(
extract_domain("http://api.example.com:8080/path"),
Some("api.example.com".into())
);
assert_eq!(
extract_domain("https://api.github.com"),
Some("api.github.com".into())
);
assert_eq!(
extract_domain("api.example.com"),
Some("api.example.com".into())
);
assert_eq!(extract_domain("localhost"), None);
assert_eq!(extract_domain("/some/path"), None);
assert_eq!(extract_domain(""), None);
}
#[test]
fn test_from_scan_targets_extracts_domains() {
use crate::ir::execution_surface::{ExecutionSurface, NetworkOperation};
use crate::ir::tool_surface::{DeclaredPermission, PermissionType, ToolSurface};
use crate::ir::{
ArgumentSource, DataSurface, DependencySurface, Framework, ProvenanceSurface,
ScanTarget, SourceLocation,
};
use std::path::PathBuf;
let make_loc = || SourceLocation {
file: PathBuf::from("server.py"),
line: 1,
column: 0,
end_line: None,
end_column: None,
};
let target = ScanTarget {
name: "test-server".into(),
framework: Framework::Mcp,
root_path: PathBuf::from("/tmp/test"),
tools: vec![ToolSurface {
name: "fetch_data".into(),
description: None,
input_schema: None,
output_schema: None,
declared_permissions: vec![DeclaredPermission {
permission_type: PermissionType::NetworkAccess,
target: Some("https://api.stripe.com/v1".into()),
description: None,
}],
defined_at: None,
}],
execution: ExecutionSurface {
network_operations: vec![
NetworkOperation {
function: "requests.get".into(),
url_arg: ArgumentSource::Literal("https://api.openai.com/v1/chat".into()),
method: Some("GET".into()),
sends_data: false,
location: make_loc(),
},
NetworkOperation {
function: "requests.post".into(),
url_arg: ArgumentSource::Parameter { name: "url".into() },
method: Some("POST".into()),
sends_data: true,
location: make_loc(),
},
],
..ExecutionSurface::default()
},
data: DataSurface::default(),
dependencies: DependencySurface::default(),
provenance: ProvenanceSurface::default(),
source_files: vec![],
};
let policy = EgressPolicy::from_scan_targets(&[target]);
assert_eq!(policy.schema_version, 1);
assert!(policy.domains.deny.is_empty());
assert!(
policy.domains.allow.contains(&"api.openai.com".to_string()),
"Expected api.openai.com in allow list, got: {:?}",
policy.domains.allow
);
assert!(
policy.domains.allow.contains(&"api.stripe.com".to_string()),
"Expected api.stripe.com in allow list, got: {:?}",
policy.domains.allow
);
assert_eq!(
policy.domains.allow,
{
let mut sorted = policy.domains.allow.clone();
sorted.sort();
sorted
},
"Allow list should be sorted"
);
assert!(policy.networks.block_private);
assert!(policy.networks.block_localhost);
assert!(policy.networks.block_link_local);
assert!(policy.networks.block_metadata);
assert_eq!(policy.rate_limits.max_requests_per_minute, 60);
}
fn base_policy() -> EgressPolicy {
EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![
"api.example.com".into(),
"api.github.com".into(),
"api.openai.com".into(),
],
deny: vec!["evil.com".into()],
},
networks: NetworkPolicy {
block_private: false,
block_link_local: true,
block_localhost: true,
block_metadata: false,
},
rate_limits: RateLimitPolicy {
max_requests_per_minute: 60,
per_domain: {
let mut m = HashMap::new();
m.insert("api.openai.com".into(), 20);
m
},
},
audit: AuditPolicy {
log_path: Some(PathBuf::from("/tmp/base-audit.jsonl")),
log_format: "json".into(),
log_allowed: false,
},
}
}
#[test]
fn test_merge_deny_union() {
let base = base_policy();
let operator = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![],
deny: vec!["extra-bad.com".into()],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy::default(),
};
let merged = base.merge_override(&operator);
assert!(
merged.domains.deny.contains(&"evil.com".to_string()),
"base deny entry must be preserved"
);
assert!(
merged.domains.deny.contains(&"extra-bad.com".to_string()),
"operator deny entry must be added"
);
assert_eq!(merged.domains.deny.len(), 2);
}
#[test]
fn test_merge_allow_intersection() {
let base = base_policy();
let operator = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![
"api.github.com".into(),
"api.openai.com".into(),
"api.stripe.com".into(),
],
deny: vec![],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy::default(),
};
let merged = base.merge_override(&operator);
assert!(
merged.domains.allow.contains(&"api.github.com".to_string()),
"intersection: api.github.com must be in result"
);
assert!(
merged.domains.allow.contains(&"api.openai.com".to_string()),
"intersection: api.openai.com must be in result"
);
assert!(
!merged
.domains
.allow
.contains(&"api.example.com".to_string()),
"api.example.com not in operator allow → must be excluded"
);
assert!(
!merged.domains.allow.contains(&"api.stripe.com".to_string()),
"api.stripe.com not in base allow → must be excluded"
);
}
#[test]
fn test_merge_rate_limits_min() {
let base = base_policy(); let operator = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![],
deny: vec![],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy {
max_requests_per_minute: 30,
per_domain: {
let mut m = HashMap::new();
m.insert("api.openai.com".into(), 10);
m.insert("api.github.com".into(), 5);
m
},
},
audit: AuditPolicy::default(),
};
let merged = base.merge_override(&operator);
assert_eq!(
merged.rate_limits.max_requests_per_minute, 30,
"global rate: min(60, 30) = 30"
);
assert_eq!(
merged.rate_limits.per_domain["api.openai.com"], 10,
"per-domain rate: min(20, 10) = 10"
);
assert_eq!(
merged.rate_limits.per_domain["api.github.com"], 5,
"operator-only per-domain: min(60, 5) = 5"
);
}
#[test]
fn test_merge_network_blocks_or() {
let base = base_policy(); let operator = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![],
deny: vec![],
},
networks: NetworkPolicy {
block_private: true,
block_link_local: false,
block_localhost: false,
block_metadata: true,
},
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy::default(),
};
let merged = base.merge_override(&operator);
assert!(merged.networks.block_private, "false || true = true");
assert!(
merged.networks.block_link_local,
"true || false = true (base had it)"
);
assert!(
merged.networks.block_localhost,
"true || false = true (base had it)"
);
assert!(merged.networks.block_metadata, "false || true = true");
}
#[test]
fn test_merge_empty_override_allow_keeps_base() {
let base = base_policy(); let operator = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![], deny: vec![],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy::default(),
};
let merged = base.merge_override(&operator);
assert_eq!(
merged.domains.allow, base.domains.allow,
"empty operator allow must not restrict base allow list"
);
}
#[test]
fn test_merge_audit_override_wins() {
let base = base_policy(); let operator = EgressPolicy {
schema_version: 1,
domains: DomainPolicy {
allow: vec![],
deny: vec![],
},
networks: NetworkPolicy::default(),
rate_limits: RateLimitPolicy::default(),
audit: AuditPolicy {
log_path: Some(PathBuf::from("/var/log/agentshield/operator.jsonl")),
log_format: "text".into(),
log_allowed: true,
},
};
let merged = base.merge_override(&operator);
assert_eq!(
merged.audit.log_path,
Some(PathBuf::from("/var/log/agentshield/operator.jsonl")),
"operator audit log_path must win"
);
assert_eq!(
merged.audit.log_format, "text",
"operator audit log_format must win"
);
assert!(
merged.audit.log_allowed,
"operator audit log_allowed must win"
);
}
#[test]
fn test_emit_egress_policy_integration() {
use crate::{scan, ScanOptions};
use std::path::Path;
let opts = ScanOptions::default();
let report = scan(Path::new("tests/fixtures/mcp_servers/vuln_ssrf"), &opts)
.expect("scan should succeed");
let policy = EgressPolicy::from_scan_targets(&report.targets);
let tmp = TempDir::new().unwrap();
let policy_path = tmp.path().join("agentshield.egress.toml");
policy.save(&policy_path).unwrap();
let loaded = EgressPolicy::load(&policy_path).unwrap();
assert_eq!(loaded.schema_version, 1);
assert!(loaded.networks.block_private);
assert!(loaded.networks.block_metadata);
assert!(loaded.domains.deny.is_empty());
}
}