Skip to main content

agentshield/egress/
policy.rs

1//! Egress policy schema and validation.
2//!
3//! Parses `agentshield.egress.toml` files that define which domains,
4//! IPs, and rate limits are enforced by the `wrap` command proxy.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10use crate::error::ShieldError;
11use crate::ir::tool_surface::PermissionType;
12use crate::ir::{ArgumentSource, ScanTarget};
13
14const CURRENT_SCHEMA_VERSION: u32 = 1;
15
16/// Top-level egress policy loaded from `agentshield.egress.toml`.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EgressPolicy {
19    /// Schema version for forward compatibility checks.
20    pub schema_version: u32,
21    /// Domain allow/deny rules.
22    pub domains: DomainPolicy,
23    /// Network-level IP blocking rules.
24    #[serde(default)]
25    pub networks: NetworkPolicy,
26    /// Rate limiting configuration.
27    #[serde(default)]
28    pub rate_limits: RateLimitPolicy,
29    /// Audit logging configuration.
30    #[serde(default)]
31    pub audit: AuditPolicy,
32}
33
34/// Domain-level allow/deny policy using glob-style patterns.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct DomainPolicy {
37    /// Allowed domain patterns (glob-style: `"*.example.com"`, `"api.github.com"`).
38    #[serde(default)]
39    pub allow: Vec<String>,
40    /// Explicitly denied domain patterns (takes precedence over allow).
41    #[serde(default)]
42    pub deny: Vec<String>,
43}
44
45/// Network-level IP range blocking policy.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct NetworkPolicy {
48    /// Block private IP ranges (10.x, 172.16-31.x, 192.168.x). Default: true.
49    #[serde(default = "default_true")]
50    pub block_private: bool,
51    /// Block link-local addresses (169.254.x). Default: true.
52    #[serde(default = "default_true")]
53    pub block_link_local: bool,
54    /// Block localhost (127.x, ::1). Default: true.
55    #[serde(default = "default_true")]
56    pub block_localhost: bool,
57    /// Block cloud metadata endpoints (169.254.169.254, etc.). Default: true.
58    #[serde(default = "default_true")]
59    pub block_metadata: bool,
60}
61
62fn default_true() -> bool {
63    true
64}
65
66impl Default for NetworkPolicy {
67    fn default() -> Self {
68        Self {
69            block_private: true,
70            block_link_local: true,
71            block_localhost: true,
72            block_metadata: true,
73        }
74    }
75}
76
77/// Rate limiting configuration for outbound requests.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct RateLimitPolicy {
80    /// Maximum requests per minute per domain. 0 = unlimited.
81    #[serde(default = "default_rate_limit")]
82    pub max_requests_per_minute: u32,
83    /// Per-domain overrides (domain string -> requests per minute).
84    #[serde(default)]
85    pub per_domain: HashMap<String, u32>,
86}
87
88fn default_rate_limit() -> u32 {
89    60
90}
91
92impl Default for RateLimitPolicy {
93    fn default() -> Self {
94        Self {
95            max_requests_per_minute: default_rate_limit(),
96            per_domain: HashMap::new(),
97        }
98    }
99}
100
101/// Audit logging configuration for egress events.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct AuditPolicy {
104    /// Path to write audit log.
105    #[serde(default)]
106    pub log_path: Option<PathBuf>,
107    /// Log format: `"json"` or `"text"`.
108    #[serde(default = "default_log_format")]
109    pub log_format: String,
110    /// Log allowed requests too (not just blocked). Default: false.
111    #[serde(default)]
112    pub log_allowed: bool,
113}
114
115fn default_log_format() -> String {
116    "json".to_string()
117}
118
119impl Default for AuditPolicy {
120    fn default() -> Self {
121        Self {
122            log_path: None,
123            log_format: default_log_format(),
124            log_allowed: false,
125        }
126    }
127}
128
129impl EgressPolicy {
130    /// Load an egress policy from a TOML file.
131    pub fn load(path: &Path) -> Result<Self, ShieldError> {
132        let content = std::fs::read_to_string(path).map_err(ShieldError::Io)?;
133        let policy: Self = toml::from_str(&content)?;
134        if policy.schema_version > CURRENT_SCHEMA_VERSION {
135            return Err(ShieldError::Config(format!(
136                "Egress policy schema version {} is newer than supported version {}",
137                policy.schema_version, CURRENT_SCHEMA_VERSION
138            )));
139        }
140        Ok(policy)
141    }
142
143    /// Save an egress policy to a TOML file.
144    pub fn save(&self, path: &Path) -> Result<(), ShieldError> {
145        let content = toml::to_string_pretty(self)?;
146        std::fs::write(path, content).map_err(ShieldError::Io)?;
147        Ok(())
148    }
149
150    /// Check if a domain is allowed by this policy.
151    ///
152    /// Deny rules take precedence over allow rules. If the allow list is
153    /// empty, all domains not explicitly denied are allowed.
154    pub fn is_domain_allowed(&self, domain: &str) -> bool {
155        // Deny takes precedence
156        if self
157            .domains
158            .deny
159            .iter()
160            .any(|pattern| domain_matches(domain, pattern))
161        {
162            return false;
163        }
164        // If allow list is empty, allow all (that aren't denied)
165        if self.domains.allow.is_empty() {
166            return true;
167        }
168        // Must match at least one allow pattern
169        self.domains
170            .allow
171            .iter()
172            .any(|pattern| domain_matches(domain, pattern))
173    }
174
175    /// Check if an IP address is blocked by network policy.
176    pub fn is_ip_blocked(&self, ip: &str) -> bool {
177        if self.networks.block_localhost && is_localhost(ip) {
178            return true;
179        }
180        if self.networks.block_private && is_private_ip(ip) {
181            return true;
182        }
183        if self.networks.block_link_local && is_link_local(ip) {
184            return true;
185        }
186        if self.networks.block_metadata && is_metadata_ip(ip) {
187            return true;
188        }
189        false
190    }
191
192    /// Get rate limit for a domain (requests per minute).
193    ///
194    /// Returns the per-domain override if one exists, otherwise the global default.
195    pub fn rate_limit_for(&self, domain: &str) -> u32 {
196        self.rate_limits
197            .per_domain
198            .get(domain)
199            .copied()
200            .unwrap_or(self.rate_limits.max_requests_per_minute)
201    }
202
203    /// Build a starter egress policy by analyzing all `ScanTarget`s.
204    ///
205    /// Extracts domains from:
206    /// - Literal URL arguments in `NetworkOperation` entries
207    /// - `NetworkAccess` declared permissions with a scope/target URL or domain
208    ///
209    /// The resulting policy allows all discovered domains and uses safe defaults
210    /// for network-level blocking and rate limiting.
211    pub fn from_scan_targets(targets: &[ScanTarget]) -> Self {
212        let mut domains = std::collections::HashSet::new();
213
214        for target in targets {
215            // Extract domains from network operations with literal URLs
216            for net_op in &target.execution.network_operations {
217                if let ArgumentSource::Literal(ref url) = net_op.url_arg {
218                    if let Some(domain) = extract_domain(url) {
219                        domains.insert(domain);
220                    }
221                }
222            }
223
224            // Extract domains from tool declared permissions (NetworkAccess)
225            for tool in &target.tools {
226                for perm in &tool.declared_permissions {
227                    if matches!(perm.permission_type, PermissionType::NetworkAccess) {
228                        if let Some(ref scope) = perm.target {
229                            if let Some(domain) = extract_domain(scope) {
230                                domains.insert(domain);
231                            }
232                        }
233                    }
234                }
235            }
236        }
237
238        let mut allow: Vec<String> = domains.into_iter().collect();
239        allow.sort();
240
241        EgressPolicy {
242            schema_version: CURRENT_SCHEMA_VERSION,
243            domains: DomainPolicy {
244                allow,
245                deny: vec![],
246            },
247            networks: NetworkPolicy::default(),
248            rate_limits: RateLimitPolicy::default(),
249            audit: AuditPolicy::default(),
250        }
251    }
252
253    /// Merge with an operator override policy. The override can only restrict, never expand.
254    ///
255    /// Merge rules:
256    /// - `domains.allow` = intersection(self.allow, override.allow)
257    ///   If override.allow is empty, self.allow is kept (empty means "no restriction").
258    ///   If self.allow is empty (allow all), operator's allow list becomes the effective list.
259    /// - `domains.deny` = union(self.deny, override.deny)
260    /// - `networks`: if either policy blocks a range, it is blocked in the result
261    /// - `rate_limits.max_requests_per_minute` = min(self, override)
262    /// - `rate_limits.per_domain`: min rate per domain; missing entries inherit the global min
263    /// - `audit`: operator override wins (operator controls where logs go)
264    pub fn merge_override(&self, operator: &EgressPolicy) -> EgressPolicy {
265        // Allow list: intersection when both are non-empty; operator restricts further
266        let allow = if operator.domains.allow.is_empty() {
267            // Empty override allow = "no additional restriction on allow"
268            self.domains.allow.clone()
269        } else if self.domains.allow.is_empty() {
270            // Self allows all; operator restricts to its list
271            operator.domains.allow.clone()
272        } else {
273            // Both have allow lists: intersection (only domains in BOTH lists)
274            self.domains
275                .allow
276                .iter()
277                .filter(|d| {
278                    operator
279                        .domains
280                        .allow
281                        .iter()
282                        .any(|o| domain_matches(d, o) || domain_matches(o, d))
283                })
284                .cloned()
285                .collect()
286        };
287
288        // Deny list: union (operator can only add more denials)
289        let mut deny = self.domains.deny.clone();
290        for d in &operator.domains.deny {
291            if !deny.contains(d) {
292                deny.push(d.clone());
293            }
294        }
295
296        // Rate limits: take the minimum (more restrictive wins)
297        let global_min = self
298            .rate_limits
299            .max_requests_per_minute
300            .min(operator.rate_limits.max_requests_per_minute);
301
302        let mut per_domain = self.rate_limits.per_domain.clone();
303        for (domain, &op_rate) in &operator.rate_limits.per_domain {
304            let entry = per_domain
305                .entry(domain.clone())
306                .or_insert(self.rate_limits.max_requests_per_minute);
307            *entry = (*entry).min(op_rate);
308        }
309
310        EgressPolicy {
311            schema_version: self.schema_version,
312            domains: DomainPolicy { allow, deny },
313            networks: NetworkPolicy {
314                block_private: self.networks.block_private || operator.networks.block_private,
315                block_link_local: self.networks.block_link_local
316                    || operator.networks.block_link_local,
317                block_localhost: self.networks.block_localhost || operator.networks.block_localhost,
318                block_metadata: self.networks.block_metadata || operator.networks.block_metadata,
319            },
320            rate_limits: RateLimitPolicy {
321                max_requests_per_minute: global_min,
322                per_domain,
323            },
324            audit: operator.audit.clone(),
325        }
326    }
327
328    /// Generate a starter policy TOML string for `agentshield init --egress`.
329    pub fn starter_toml() -> &'static str {
330        r#"# AgentShield Egress Policy
331# See: https://github.com/limaronaldo/agentshield
332
333schema_version = 1
334
335[domains]
336# Allowed domain patterns (glob-style)
337allow = ["*.example.com", "api.github.com"]
338# Explicitly denied (takes precedence over allow)
339deny = []
340
341[networks]
342block_private = true      # 10.x, 172.16-31.x, 192.168.x
343block_link_local = true   # 169.254.x
344block_localhost = true     # 127.x, ::1
345block_metadata = true     # 169.254.169.254, metadata.google.internal
346
347[rate_limits]
348max_requests_per_minute = 60
349
350[audit]
351# log_path = "agentshield-audit.jsonl"
352log_format = "json"
353log_allowed = false
354"#
355    }
356}
357
358/// Extract the hostname from a URL string or bare domain.
359///
360/// Handles `http://`, `https://` URLs (strips scheme, path, port) and bare
361/// domain names (e.g., `"api.example.com"`). Returns `None` for strings that
362/// cannot be mapped to a useful hostname (e.g., paths, IP-like without dot).
363pub fn extract_domain(url_or_domain: &str) -> Option<String> {
364    // Try stripping http:// or https://
365    let rest = if let Some(r) = url_or_domain.strip_prefix("https://") {
366        r
367    } else if let Some(r) = url_or_domain.strip_prefix("http://") {
368        r
369    } else {
370        // Bare domain: must contain a dot and no slashes
371        if url_or_domain.contains('.') && !url_or_domain.contains('/') {
372            return Some(url_or_domain.to_string());
373        }
374        return None;
375    };
376
377    // Take the host portion (before first '/')
378    let host = rest.split('/').next()?;
379    // Strip port if present
380    let host = host.split(':').next()?;
381
382    if host.is_empty() {
383        return None;
384    }
385    Some(host.to_string())
386}
387
388/// Simple glob matching for domain patterns.
389///
390/// Supports `*.example.com` (matches `sub.example.com` and `example.com`)
391/// and exact matches like `api.github.com`.
392fn domain_matches(domain: &str, pattern: &str) -> bool {
393    if let Some(suffix) = pattern.strip_prefix('*') {
394        // "*.example.com" matches "sub.example.com" and "example.com"
395        domain.ends_with(suffix) || domain == &suffix[1..]
396    } else {
397        domain == pattern
398    }
399}
400
401fn is_localhost(ip: &str) -> bool {
402    ip.starts_with("127.") || ip == "::1" || ip == "localhost"
403}
404
405fn is_private_ip(ip: &str) -> bool {
406    ip.starts_with("10.")
407        || (ip.starts_with("172.") && is_172_private(ip))
408        || ip.starts_with("192.168.")
409        || ip.starts_with("fd") // IPv6 ULA
410}
411
412fn is_172_private(ip: &str) -> bool {
413    if let Some(second_octet) = ip
414        .strip_prefix("172.")
415        .and_then(|rest| rest.split('.').next())
416    {
417        if let Ok(n) = second_octet.parse::<u8>() {
418            return (16..=31).contains(&n);
419        }
420    }
421    false
422}
423
424fn is_link_local(ip: &str) -> bool {
425    ip.starts_with("169.254.") || ip.starts_with("fe80:")
426}
427
428fn is_metadata_ip(ip: &str) -> bool {
429    ip == "169.254.169.254"
430        || ip.contains("metadata.google.internal")
431        || ip == "100.100.100.200" // Alibaba Cloud
432        || ip == "169.254.170.2" // AWS ECS
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use tempfile::TempDir;
439
440    fn sample_policy() -> EgressPolicy {
441        EgressPolicy {
442            schema_version: 1,
443            domains: DomainPolicy {
444                allow: vec!["*.example.com".into(), "api.github.com".into()],
445                deny: vec!["evil.example.com".into()],
446            },
447            networks: NetworkPolicy::default(),
448            rate_limits: RateLimitPolicy {
449                max_requests_per_minute: 60,
450                per_domain: {
451                    let mut m = HashMap::new();
452                    m.insert("api.github.com".into(), 30);
453                    m
454                },
455            },
456            audit: AuditPolicy::default(),
457        }
458    }
459
460    #[test]
461    fn test_load_and_save_roundtrip() {
462        let tmp = TempDir::new().unwrap();
463        let path = tmp.path().join("egress.toml");
464
465        let original = sample_policy();
466        original.save(&path).unwrap();
467
468        let loaded = EgressPolicy::load(&path).unwrap();
469
470        assert_eq!(loaded.schema_version, original.schema_version);
471        assert_eq!(loaded.domains.allow, original.domains.allow);
472        assert_eq!(loaded.domains.deny, original.domains.deny);
473        assert_eq!(
474            loaded.networks.block_private,
475            original.networks.block_private
476        );
477        assert_eq!(
478            loaded.networks.block_localhost,
479            original.networks.block_localhost
480        );
481        assert_eq!(
482            loaded.networks.block_link_local,
483            original.networks.block_link_local
484        );
485        assert_eq!(
486            loaded.networks.block_metadata,
487            original.networks.block_metadata
488        );
489        assert_eq!(
490            loaded.rate_limits.max_requests_per_minute,
491            original.rate_limits.max_requests_per_minute
492        );
493        assert_eq!(
494            loaded.rate_limits.per_domain,
495            original.rate_limits.per_domain
496        );
497        assert_eq!(loaded.audit.log_format, original.audit.log_format);
498        assert_eq!(loaded.audit.log_allowed, original.audit.log_allowed);
499        assert_eq!(loaded.audit.log_path, original.audit.log_path);
500    }
501
502    #[test]
503    fn test_domain_allowed() {
504        let policy = sample_policy();
505
506        // Exact match
507        assert!(policy.is_domain_allowed("api.github.com"));
508        // Glob match
509        assert!(policy.is_domain_allowed("sub.example.com"));
510        // Base domain matches *.example.com
511        assert!(policy.is_domain_allowed("example.com"));
512        // Not in allow list
513        assert!(!policy.is_domain_allowed("random.org"));
514    }
515
516    #[test]
517    fn test_domain_denied_takes_precedence() {
518        let policy = sample_policy();
519
520        // evil.example.com matches *.example.com (allow) but also deny list
521        assert!(
522            !policy.is_domain_allowed("evil.example.com"),
523            "deny should take precedence over allow"
524        );
525    }
526
527    #[test]
528    fn test_empty_allow_list_allows_all() {
529        let policy = EgressPolicy {
530            schema_version: 1,
531            domains: DomainPolicy {
532                allow: vec![],
533                deny: vec!["blocked.com".into()],
534            },
535            networks: NetworkPolicy::default(),
536            rate_limits: RateLimitPolicy::default(),
537            audit: AuditPolicy::default(),
538        };
539
540        assert!(policy.is_domain_allowed("anything.com"));
541        assert!(policy.is_domain_allowed("whatever.org"));
542        assert!(
543            !policy.is_domain_allowed("blocked.com"),
544            "deny should still block even with empty allow"
545        );
546    }
547
548    #[test]
549    fn test_ip_blocking() {
550        let policy = sample_policy();
551
552        // Localhost
553        assert!(policy.is_ip_blocked("127.0.0.1"));
554        assert!(policy.is_ip_blocked("127.0.0.2"));
555        assert!(policy.is_ip_blocked("::1"));
556        assert!(policy.is_ip_blocked("localhost"));
557
558        // Private ranges
559        assert!(policy.is_ip_blocked("10.0.0.1"));
560        assert!(policy.is_ip_blocked("172.16.0.1"));
561        assert!(policy.is_ip_blocked("172.31.255.255"));
562        assert!(policy.is_ip_blocked("192.168.1.1"));
563
564        // Not private (172.15.x is outside the private range)
565        assert!(!policy.is_ip_blocked("172.15.0.1"));
566        assert!(!policy.is_ip_blocked("172.32.0.1"));
567
568        // Link-local
569        assert!(policy.is_ip_blocked("169.254.1.1"));
570        assert!(policy.is_ip_blocked("fe80::1"));
571
572        // Metadata endpoints
573        assert!(policy.is_ip_blocked("169.254.169.254"));
574        assert!(policy.is_ip_blocked("metadata.google.internal"));
575        assert!(policy.is_ip_blocked("100.100.100.200"));
576        assert!(policy.is_ip_blocked("169.254.170.2"));
577
578        // Public IP should not be blocked
579        assert!(!policy.is_ip_blocked("8.8.8.8"));
580        assert!(!policy.is_ip_blocked("1.1.1.1"));
581    }
582
583    #[test]
584    fn test_rate_limit_per_domain() {
585        let policy = sample_policy();
586        assert_eq!(policy.rate_limit_for("api.github.com"), 30);
587    }
588
589    #[test]
590    fn test_rate_limit_default() {
591        let policy = sample_policy();
592        assert_eq!(policy.rate_limit_for("unknown.com"), 60);
593    }
594
595    #[test]
596    fn test_future_schema_rejected() {
597        let tmp = TempDir::new().unwrap();
598        let path = tmp.path().join("future.toml");
599
600        let content = r#"
601schema_version = 99
602
603[domains]
604allow = []
605deny = []
606"#;
607        std::fs::write(&path, content).unwrap();
608
609        let result = EgressPolicy::load(&path);
610        assert!(result.is_err());
611
612        let err_msg = result.unwrap_err().to_string();
613        assert!(
614            err_msg.contains("99") && err_msg.contains("newer"),
615            "Error should mention unsupported schema version, got: {err_msg}"
616        );
617    }
618
619    #[test]
620    fn test_starter_toml_parses() {
621        let toml_str = EgressPolicy::starter_toml();
622        let policy: EgressPolicy =
623            toml::from_str(toml_str).expect("starter_toml() should produce valid TOML");
624        assert_eq!(policy.schema_version, 1);
625        assert!(!policy.domains.allow.is_empty());
626        assert!(policy.networks.block_private);
627        assert!(policy.networks.block_metadata);
628        assert_eq!(policy.rate_limits.max_requests_per_minute, 60);
629        assert_eq!(policy.audit.log_format, "json");
630    }
631
632    // ── from_scan_targets tests ─────────────────────────────────────────────
633
634    #[test]
635    fn test_extract_domain_from_url() {
636        // Full URLs with scheme
637        assert_eq!(
638            extract_domain("https://api.example.com/v1/items"),
639            Some("api.example.com".into())
640        );
641        assert_eq!(
642            extract_domain("http://api.example.com:8080/path"),
643            Some("api.example.com".into())
644        );
645        assert_eq!(
646            extract_domain("https://api.github.com"),
647            Some("api.github.com".into())
648        );
649        // Bare domain
650        assert_eq!(
651            extract_domain("api.example.com"),
652            Some("api.example.com".into())
653        );
654        // No dot, no scheme → None
655        assert_eq!(extract_domain("localhost"), None);
656        // Path without scheme → None (has slash)
657        assert_eq!(extract_domain("/some/path"), None);
658        // Empty string → None
659        assert_eq!(extract_domain(""), None);
660    }
661
662    #[test]
663    fn test_from_scan_targets_extracts_domains() {
664        use crate::ir::execution_surface::{ExecutionSurface, NetworkOperation};
665        use crate::ir::tool_surface::{DeclaredPermission, PermissionType, ToolSurface};
666        use crate::ir::{
667            ArgumentSource, DataSurface, DependencySurface, Framework, ProvenanceSurface,
668            ScanTarget, SourceLocation,
669        };
670        use std::path::PathBuf;
671
672        let make_loc = || SourceLocation {
673            file: PathBuf::from("server.py"),
674            line: 1,
675            column: 0,
676            end_line: None,
677            end_column: None,
678        };
679
680        let target = ScanTarget {
681            name: "test-server".into(),
682            framework: Framework::Mcp,
683            root_path: PathBuf::from("/tmp/test"),
684            tools: vec![ToolSurface {
685                name: "fetch_data".into(),
686                description: None,
687                input_schema: None,
688                output_schema: None,
689                declared_permissions: vec![DeclaredPermission {
690                    permission_type: PermissionType::NetworkAccess,
691                    target: Some("https://api.stripe.com/v1".into()),
692                    description: None,
693                }],
694                defined_at: None,
695            }],
696            execution: ExecutionSurface {
697                network_operations: vec![
698                    NetworkOperation {
699                        function: "requests.get".into(),
700                        url_arg: ArgumentSource::Literal("https://api.openai.com/v1/chat".into()),
701                        method: Some("GET".into()),
702                        sends_data: false,
703                        location: make_loc(),
704                    },
705                    NetworkOperation {
706                        function: "requests.post".into(),
707                        // Non-literal: should be skipped
708                        url_arg: ArgumentSource::Parameter { name: "url".into() },
709                        method: Some("POST".into()),
710                        sends_data: true,
711                        location: make_loc(),
712                    },
713                ],
714                ..ExecutionSurface::default()
715            },
716            data: DataSurface::default(),
717            dependencies: DependencySurface::default(),
718            provenance: ProvenanceSurface::default(),
719            source_files: vec![],
720        };
721
722        let policy = EgressPolicy::from_scan_targets(&[target]);
723
724        // schema_version must be 1
725        assert_eq!(policy.schema_version, 1);
726        // deny list must be empty (starter policy)
727        assert!(policy.domains.deny.is_empty());
728        // allow list must contain both discovered domains, sorted
729        assert!(
730            policy.domains.allow.contains(&"api.openai.com".to_string()),
731            "Expected api.openai.com in allow list, got: {:?}",
732            policy.domains.allow
733        );
734        assert!(
735            policy.domains.allow.contains(&"api.stripe.com".to_string()),
736            "Expected api.stripe.com in allow list, got: {:?}",
737            policy.domains.allow
738        );
739        // Allow list should be sorted
740        assert_eq!(
741            policy.domains.allow,
742            {
743                let mut sorted = policy.domains.allow.clone();
744                sorted.sort();
745                sorted
746            },
747            "Allow list should be sorted"
748        );
749        // Network defaults must be secure
750        assert!(policy.networks.block_private);
751        assert!(policy.networks.block_localhost);
752        assert!(policy.networks.block_link_local);
753        assert!(policy.networks.block_metadata);
754        // Rate limit default
755        assert_eq!(policy.rate_limits.max_requests_per_minute, 60);
756    }
757
758    // ── merge_override tests ─────────────────────────────────────────────────
759
760    fn base_policy() -> EgressPolicy {
761        EgressPolicy {
762            schema_version: 1,
763            domains: DomainPolicy {
764                allow: vec![
765                    "api.example.com".into(),
766                    "api.github.com".into(),
767                    "api.openai.com".into(),
768                ],
769                deny: vec!["evil.com".into()],
770            },
771            networks: NetworkPolicy {
772                block_private: false,
773                block_link_local: true,
774                block_localhost: true,
775                block_metadata: false,
776            },
777            rate_limits: RateLimitPolicy {
778                max_requests_per_minute: 60,
779                per_domain: {
780                    let mut m = HashMap::new();
781                    m.insert("api.openai.com".into(), 20);
782                    m
783                },
784            },
785            audit: AuditPolicy {
786                log_path: Some(PathBuf::from("/tmp/base-audit.jsonl")),
787                log_format: "json".into(),
788                log_allowed: false,
789            },
790        }
791    }
792
793    #[test]
794    fn test_merge_deny_union() {
795        let base = base_policy();
796        let operator = EgressPolicy {
797            schema_version: 1,
798            domains: DomainPolicy {
799                allow: vec![],
800                deny: vec!["extra-bad.com".into()],
801            },
802            networks: NetworkPolicy::default(),
803            rate_limits: RateLimitPolicy::default(),
804            audit: AuditPolicy::default(),
805        };
806
807        let merged = base.merge_override(&operator);
808
809        assert!(
810            merged.domains.deny.contains(&"evil.com".to_string()),
811            "base deny entry must be preserved"
812        );
813        assert!(
814            merged.domains.deny.contains(&"extra-bad.com".to_string()),
815            "operator deny entry must be added"
816        );
817        assert_eq!(merged.domains.deny.len(), 2);
818    }
819
820    #[test]
821    fn test_merge_allow_intersection() {
822        let base = base_policy();
823        let operator = EgressPolicy {
824            schema_version: 1,
825            domains: DomainPolicy {
826                // B and C overlap with base; D is operator-only (not in base → excluded)
827                allow: vec![
828                    "api.github.com".into(),
829                    "api.openai.com".into(),
830                    "api.stripe.com".into(),
831                ],
832                deny: vec![],
833            },
834            networks: NetworkPolicy::default(),
835            rate_limits: RateLimitPolicy::default(),
836            audit: AuditPolicy::default(),
837        };
838
839        let merged = base.merge_override(&operator);
840
841        assert!(
842            merged.domains.allow.contains(&"api.github.com".to_string()),
843            "intersection: api.github.com must be in result"
844        );
845        assert!(
846            merged.domains.allow.contains(&"api.openai.com".to_string()),
847            "intersection: api.openai.com must be in result"
848        );
849        assert!(
850            !merged
851                .domains
852                .allow
853                .contains(&"api.example.com".to_string()),
854            "api.example.com not in operator allow → must be excluded"
855        );
856        assert!(
857            !merged.domains.allow.contains(&"api.stripe.com".to_string()),
858            "api.stripe.com not in base allow → must be excluded"
859        );
860    }
861
862    #[test]
863    fn test_merge_rate_limits_min() {
864        let base = base_policy(); // global = 60, openai = 20
865        let operator = EgressPolicy {
866            schema_version: 1,
867            domains: DomainPolicy {
868                allow: vec![],
869                deny: vec![],
870            },
871            networks: NetworkPolicy::default(),
872            rate_limits: RateLimitPolicy {
873                max_requests_per_minute: 30,
874                per_domain: {
875                    let mut m = HashMap::new();
876                    m.insert("api.openai.com".into(), 10);
877                    m.insert("api.github.com".into(), 5);
878                    m
879                },
880            },
881            audit: AuditPolicy::default(),
882        };
883
884        let merged = base.merge_override(&operator);
885
886        assert_eq!(
887            merged.rate_limits.max_requests_per_minute, 30,
888            "global rate: min(60, 30) = 30"
889        );
890        assert_eq!(
891            merged.rate_limits.per_domain["api.openai.com"], 10,
892            "per-domain rate: min(20, 10) = 10"
893        );
894        assert_eq!(
895            merged.rate_limits.per_domain["api.github.com"], 5,
896            "operator-only per-domain: min(60, 5) = 5"
897        );
898    }
899
900    #[test]
901    fn test_merge_network_blocks_or() {
902        let base = base_policy(); // block_private=false, block_metadata=false
903        let operator = EgressPolicy {
904            schema_version: 1,
905            domains: DomainPolicy {
906                allow: vec![],
907                deny: vec![],
908            },
909            networks: NetworkPolicy {
910                block_private: true,
911                block_link_local: false,
912                block_localhost: false,
913                block_metadata: true,
914            },
915            rate_limits: RateLimitPolicy::default(),
916            audit: AuditPolicy::default(),
917        };
918
919        let merged = base.merge_override(&operator);
920
921        assert!(merged.networks.block_private, "false || true = true");
922        assert!(
923            merged.networks.block_link_local,
924            "true || false = true (base had it)"
925        );
926        assert!(
927            merged.networks.block_localhost,
928            "true || false = true (base had it)"
929        );
930        assert!(merged.networks.block_metadata, "false || true = true");
931    }
932
933    #[test]
934    fn test_merge_empty_override_allow_keeps_base() {
935        let base = base_policy(); // has 3 allow entries
936        let operator = EgressPolicy {
937            schema_version: 1,
938            domains: DomainPolicy {
939                allow: vec![], // empty = no restriction on allow
940                deny: vec![],
941            },
942            networks: NetworkPolicy::default(),
943            rate_limits: RateLimitPolicy::default(),
944            audit: AuditPolicy::default(),
945        };
946
947        let merged = base.merge_override(&operator);
948
949        assert_eq!(
950            merged.domains.allow, base.domains.allow,
951            "empty operator allow must not restrict base allow list"
952        );
953    }
954
955    #[test]
956    fn test_merge_audit_override_wins() {
957        let base = base_policy(); // log_path = /tmp/base-audit.jsonl
958        let operator = EgressPolicy {
959            schema_version: 1,
960            domains: DomainPolicy {
961                allow: vec![],
962                deny: vec![],
963            },
964            networks: NetworkPolicy::default(),
965            rate_limits: RateLimitPolicy::default(),
966            audit: AuditPolicy {
967                log_path: Some(PathBuf::from("/var/log/agentshield/operator.jsonl")),
968                log_format: "text".into(),
969                log_allowed: true,
970            },
971        };
972
973        let merged = base.merge_override(&operator);
974
975        assert_eq!(
976            merged.audit.log_path,
977            Some(PathBuf::from("/var/log/agentshield/operator.jsonl")),
978            "operator audit log_path must win"
979        );
980        assert_eq!(
981            merged.audit.log_format, "text",
982            "operator audit log_format must win"
983        );
984        assert!(
985            merged.audit.log_allowed,
986            "operator audit log_allowed must win"
987        );
988    }
989
990    #[test]
991    fn test_emit_egress_policy_integration() {
992        // Scan the vuln_ssrf fixture — it has literal HTTP requests.
993        // Build the policy from targets embedded in the report and round-trip it.
994        use crate::{scan, ScanOptions};
995        use std::path::Path;
996
997        let opts = ScanOptions::default();
998        let report = scan(Path::new("tests/fixtures/mcp_servers/vuln_ssrf"), &opts)
999            .expect("scan should succeed");
1000
1001        let policy = EgressPolicy::from_scan_targets(&report.targets);
1002
1003        // Policy must be valid (round-trip save/load)
1004        let tmp = TempDir::new().unwrap();
1005        let policy_path = tmp.path().join("agentshield.egress.toml");
1006        policy.save(&policy_path).unwrap();
1007
1008        let loaded = EgressPolicy::load(&policy_path).unwrap();
1009        assert_eq!(loaded.schema_version, 1);
1010        assert!(loaded.networks.block_private);
1011        assert!(loaded.networks.block_metadata);
1012        // deny list must be empty in a generated starter policy
1013        assert!(loaded.domains.deny.is_empty());
1014    }
1015}