1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10use crate::error::ShieldError;
11use crate::ir::tool_surface::PermissionType;
12use crate::ir::{ArgumentSource, ScanTarget};
13
14const CURRENT_SCHEMA_VERSION: u32 = 1;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EgressPolicy {
19 pub schema_version: u32,
21 pub domains: DomainPolicy,
23 #[serde(default)]
25 pub networks: NetworkPolicy,
26 #[serde(default)]
28 pub rate_limits: RateLimitPolicy,
29 #[serde(default)]
31 pub audit: AuditPolicy,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct DomainPolicy {
37 #[serde(default)]
39 pub allow: Vec<String>,
40 #[serde(default)]
42 pub deny: Vec<String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct NetworkPolicy {
48 #[serde(default = "default_true")]
50 pub block_private: bool,
51 #[serde(default = "default_true")]
53 pub block_link_local: bool,
54 #[serde(default = "default_true")]
56 pub block_localhost: bool,
57 #[serde(default = "default_true")]
59 pub block_metadata: bool,
60}
61
62fn default_true() -> bool {
63 true
64}
65
66impl Default for NetworkPolicy {
67 fn default() -> Self {
68 Self {
69 block_private: true,
70 block_link_local: true,
71 block_localhost: true,
72 block_metadata: true,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct RateLimitPolicy {
80 #[serde(default = "default_rate_limit")]
82 pub max_requests_per_minute: u32,
83 #[serde(default)]
85 pub per_domain: HashMap<String, u32>,
86}
87
88fn default_rate_limit() -> u32 {
89 60
90}
91
92impl Default for RateLimitPolicy {
93 fn default() -> Self {
94 Self {
95 max_requests_per_minute: default_rate_limit(),
96 per_domain: HashMap::new(),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct AuditPolicy {
104 #[serde(default)]
106 pub log_path: Option<PathBuf>,
107 #[serde(default = "default_log_format")]
109 pub log_format: String,
110 #[serde(default)]
112 pub log_allowed: bool,
113}
114
115fn default_log_format() -> String {
116 "json".to_string()
117}
118
119impl Default for AuditPolicy {
120 fn default() -> Self {
121 Self {
122 log_path: None,
123 log_format: default_log_format(),
124 log_allowed: false,
125 }
126 }
127}
128
129impl EgressPolicy {
130 pub fn load(path: &Path) -> Result<Self, ShieldError> {
132 let content = std::fs::read_to_string(path).map_err(ShieldError::Io)?;
133 let policy: Self = toml::from_str(&content)?;
134 if policy.schema_version > CURRENT_SCHEMA_VERSION {
135 return Err(ShieldError::Config(format!(
136 "Egress policy schema version {} is newer than supported version {}",
137 policy.schema_version, CURRENT_SCHEMA_VERSION
138 )));
139 }
140 Ok(policy)
141 }
142
143 pub fn save(&self, path: &Path) -> Result<(), ShieldError> {
145 let content = toml::to_string_pretty(self)?;
146 std::fs::write(path, content).map_err(ShieldError::Io)?;
147 Ok(())
148 }
149
150 pub fn is_domain_allowed(&self, domain: &str) -> bool {
155 if self
157 .domains
158 .deny
159 .iter()
160 .any(|pattern| domain_matches(domain, pattern))
161 {
162 return false;
163 }
164 if self.domains.allow.is_empty() {
166 return true;
167 }
168 self.domains
170 .allow
171 .iter()
172 .any(|pattern| domain_matches(domain, pattern))
173 }
174
175 pub fn is_ip_blocked(&self, ip: &str) -> bool {
177 if self.networks.block_localhost && is_localhost(ip) {
178 return true;
179 }
180 if self.networks.block_private && is_private_ip(ip) {
181 return true;
182 }
183 if self.networks.block_link_local && is_link_local(ip) {
184 return true;
185 }
186 if self.networks.block_metadata && is_metadata_ip(ip) {
187 return true;
188 }
189 false
190 }
191
192 pub fn rate_limit_for(&self, domain: &str) -> u32 {
196 self.rate_limits
197 .per_domain
198 .get(domain)
199 .copied()
200 .unwrap_or(self.rate_limits.max_requests_per_minute)
201 }
202
203 pub fn from_scan_targets(targets: &[ScanTarget]) -> Self {
212 let mut domains = std::collections::HashSet::new();
213
214 for target in targets {
215 for net_op in &target.execution.network_operations {
217 if let ArgumentSource::Literal(ref url) = net_op.url_arg {
218 if let Some(domain) = extract_domain(url) {
219 domains.insert(domain);
220 }
221 }
222 }
223
224 for tool in &target.tools {
226 for perm in &tool.declared_permissions {
227 if matches!(perm.permission_type, PermissionType::NetworkAccess) {
228 if let Some(ref scope) = perm.target {
229 if let Some(domain) = extract_domain(scope) {
230 domains.insert(domain);
231 }
232 }
233 }
234 }
235 }
236 }
237
238 let mut allow: Vec<String> = domains.into_iter().collect();
239 allow.sort();
240
241 EgressPolicy {
242 schema_version: CURRENT_SCHEMA_VERSION,
243 domains: DomainPolicy {
244 allow,
245 deny: vec![],
246 },
247 networks: NetworkPolicy::default(),
248 rate_limits: RateLimitPolicy::default(),
249 audit: AuditPolicy::default(),
250 }
251 }
252
253 pub fn merge_override(&self, operator: &EgressPolicy) -> EgressPolicy {
265 let allow = if operator.domains.allow.is_empty() {
267 self.domains.allow.clone()
269 } else if self.domains.allow.is_empty() {
270 operator.domains.allow.clone()
272 } else {
273 self.domains
275 .allow
276 .iter()
277 .filter(|d| {
278 operator
279 .domains
280 .allow
281 .iter()
282 .any(|o| domain_matches(d, o) || domain_matches(o, d))
283 })
284 .cloned()
285 .collect()
286 };
287
288 let mut deny = self.domains.deny.clone();
290 for d in &operator.domains.deny {
291 if !deny.contains(d) {
292 deny.push(d.clone());
293 }
294 }
295
296 let global_min = self
298 .rate_limits
299 .max_requests_per_minute
300 .min(operator.rate_limits.max_requests_per_minute);
301
302 let mut per_domain = self.rate_limits.per_domain.clone();
303 for (domain, &op_rate) in &operator.rate_limits.per_domain {
304 let entry = per_domain
305 .entry(domain.clone())
306 .or_insert(self.rate_limits.max_requests_per_minute);
307 *entry = (*entry).min(op_rate);
308 }
309
310 EgressPolicy {
311 schema_version: self.schema_version,
312 domains: DomainPolicy { allow, deny },
313 networks: NetworkPolicy {
314 block_private: self.networks.block_private || operator.networks.block_private,
315 block_link_local: self.networks.block_link_local
316 || operator.networks.block_link_local,
317 block_localhost: self.networks.block_localhost || operator.networks.block_localhost,
318 block_metadata: self.networks.block_metadata || operator.networks.block_metadata,
319 },
320 rate_limits: RateLimitPolicy {
321 max_requests_per_minute: global_min,
322 per_domain,
323 },
324 audit: operator.audit.clone(),
325 }
326 }
327
328 pub fn starter_toml() -> &'static str {
330 r#"# AgentShield Egress Policy
331# See: https://github.com/limaronaldo/agentshield
332
333schema_version = 1
334
335[domains]
336# Allowed domain patterns (glob-style)
337allow = ["*.example.com", "api.github.com"]
338# Explicitly denied (takes precedence over allow)
339deny = []
340
341[networks]
342block_private = true # 10.x, 172.16-31.x, 192.168.x
343block_link_local = true # 169.254.x
344block_localhost = true # 127.x, ::1
345block_metadata = true # 169.254.169.254, metadata.google.internal
346
347[rate_limits]
348max_requests_per_minute = 60
349
350[audit]
351# log_path = "agentshield-audit.jsonl"
352log_format = "json"
353log_allowed = false
354"#
355 }
356}
357
358pub fn extract_domain(url_or_domain: &str) -> Option<String> {
364 let rest = if let Some(r) = url_or_domain.strip_prefix("https://") {
366 r
367 } else if let Some(r) = url_or_domain.strip_prefix("http://") {
368 r
369 } else {
370 if url_or_domain.contains('.') && !url_or_domain.contains('/') {
372 return Some(url_or_domain.to_string());
373 }
374 return None;
375 };
376
377 let host = rest.split('/').next()?;
379 let host = host.split(':').next()?;
381
382 if host.is_empty() {
383 return None;
384 }
385 Some(host.to_string())
386}
387
388fn domain_matches(domain: &str, pattern: &str) -> bool {
393 if let Some(suffix) = pattern.strip_prefix('*') {
394 domain.ends_with(suffix) || domain == &suffix[1..]
396 } else {
397 domain == pattern
398 }
399}
400
401fn is_localhost(ip: &str) -> bool {
402 ip.starts_with("127.") || ip == "::1" || ip == "localhost"
403}
404
405fn is_private_ip(ip: &str) -> bool {
406 ip.starts_with("10.")
407 || (ip.starts_with("172.") && is_172_private(ip))
408 || ip.starts_with("192.168.")
409 || ip.starts_with("fd") }
411
412fn is_172_private(ip: &str) -> bool {
413 if let Some(second_octet) = ip
414 .strip_prefix("172.")
415 .and_then(|rest| rest.split('.').next())
416 {
417 if let Ok(n) = second_octet.parse::<u8>() {
418 return (16..=31).contains(&n);
419 }
420 }
421 false
422}
423
424fn is_link_local(ip: &str) -> bool {
425 ip.starts_with("169.254.") || ip.starts_with("fe80:")
426}
427
428fn is_metadata_ip(ip: &str) -> bool {
429 ip == "169.254.169.254"
430 || ip.contains("metadata.google.internal")
431 || ip == "100.100.100.200" || ip == "169.254.170.2" }
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use tempfile::TempDir;
439
440 fn sample_policy() -> EgressPolicy {
441 EgressPolicy {
442 schema_version: 1,
443 domains: DomainPolicy {
444 allow: vec!["*.example.com".into(), "api.github.com".into()],
445 deny: vec!["evil.example.com".into()],
446 },
447 networks: NetworkPolicy::default(),
448 rate_limits: RateLimitPolicy {
449 max_requests_per_minute: 60,
450 per_domain: {
451 let mut m = HashMap::new();
452 m.insert("api.github.com".into(), 30);
453 m
454 },
455 },
456 audit: AuditPolicy::default(),
457 }
458 }
459
460 #[test]
461 fn test_load_and_save_roundtrip() {
462 let tmp = TempDir::new().unwrap();
463 let path = tmp.path().join("egress.toml");
464
465 let original = sample_policy();
466 original.save(&path).unwrap();
467
468 let loaded = EgressPolicy::load(&path).unwrap();
469
470 assert_eq!(loaded.schema_version, original.schema_version);
471 assert_eq!(loaded.domains.allow, original.domains.allow);
472 assert_eq!(loaded.domains.deny, original.domains.deny);
473 assert_eq!(
474 loaded.networks.block_private,
475 original.networks.block_private
476 );
477 assert_eq!(
478 loaded.networks.block_localhost,
479 original.networks.block_localhost
480 );
481 assert_eq!(
482 loaded.networks.block_link_local,
483 original.networks.block_link_local
484 );
485 assert_eq!(
486 loaded.networks.block_metadata,
487 original.networks.block_metadata
488 );
489 assert_eq!(
490 loaded.rate_limits.max_requests_per_minute,
491 original.rate_limits.max_requests_per_minute
492 );
493 assert_eq!(
494 loaded.rate_limits.per_domain,
495 original.rate_limits.per_domain
496 );
497 assert_eq!(loaded.audit.log_format, original.audit.log_format);
498 assert_eq!(loaded.audit.log_allowed, original.audit.log_allowed);
499 assert_eq!(loaded.audit.log_path, original.audit.log_path);
500 }
501
502 #[test]
503 fn test_domain_allowed() {
504 let policy = sample_policy();
505
506 assert!(policy.is_domain_allowed("api.github.com"));
508 assert!(policy.is_domain_allowed("sub.example.com"));
510 assert!(policy.is_domain_allowed("example.com"));
512 assert!(!policy.is_domain_allowed("random.org"));
514 }
515
516 #[test]
517 fn test_domain_denied_takes_precedence() {
518 let policy = sample_policy();
519
520 assert!(
522 !policy.is_domain_allowed("evil.example.com"),
523 "deny should take precedence over allow"
524 );
525 }
526
527 #[test]
528 fn test_empty_allow_list_allows_all() {
529 let policy = EgressPolicy {
530 schema_version: 1,
531 domains: DomainPolicy {
532 allow: vec![],
533 deny: vec!["blocked.com".into()],
534 },
535 networks: NetworkPolicy::default(),
536 rate_limits: RateLimitPolicy::default(),
537 audit: AuditPolicy::default(),
538 };
539
540 assert!(policy.is_domain_allowed("anything.com"));
541 assert!(policy.is_domain_allowed("whatever.org"));
542 assert!(
543 !policy.is_domain_allowed("blocked.com"),
544 "deny should still block even with empty allow"
545 );
546 }
547
548 #[test]
549 fn test_ip_blocking() {
550 let policy = sample_policy();
551
552 assert!(policy.is_ip_blocked("127.0.0.1"));
554 assert!(policy.is_ip_blocked("127.0.0.2"));
555 assert!(policy.is_ip_blocked("::1"));
556 assert!(policy.is_ip_blocked("localhost"));
557
558 assert!(policy.is_ip_blocked("10.0.0.1"));
560 assert!(policy.is_ip_blocked("172.16.0.1"));
561 assert!(policy.is_ip_blocked("172.31.255.255"));
562 assert!(policy.is_ip_blocked("192.168.1.1"));
563
564 assert!(!policy.is_ip_blocked("172.15.0.1"));
566 assert!(!policy.is_ip_blocked("172.32.0.1"));
567
568 assert!(policy.is_ip_blocked("169.254.1.1"));
570 assert!(policy.is_ip_blocked("fe80::1"));
571
572 assert!(policy.is_ip_blocked("169.254.169.254"));
574 assert!(policy.is_ip_blocked("metadata.google.internal"));
575 assert!(policy.is_ip_blocked("100.100.100.200"));
576 assert!(policy.is_ip_blocked("169.254.170.2"));
577
578 assert!(!policy.is_ip_blocked("8.8.8.8"));
580 assert!(!policy.is_ip_blocked("1.1.1.1"));
581 }
582
583 #[test]
584 fn test_rate_limit_per_domain() {
585 let policy = sample_policy();
586 assert_eq!(policy.rate_limit_for("api.github.com"), 30);
587 }
588
589 #[test]
590 fn test_rate_limit_default() {
591 let policy = sample_policy();
592 assert_eq!(policy.rate_limit_for("unknown.com"), 60);
593 }
594
595 #[test]
596 fn test_future_schema_rejected() {
597 let tmp = TempDir::new().unwrap();
598 let path = tmp.path().join("future.toml");
599
600 let content = r#"
601schema_version = 99
602
603[domains]
604allow = []
605deny = []
606"#;
607 std::fs::write(&path, content).unwrap();
608
609 let result = EgressPolicy::load(&path);
610 assert!(result.is_err());
611
612 let err_msg = result.unwrap_err().to_string();
613 assert!(
614 err_msg.contains("99") && err_msg.contains("newer"),
615 "Error should mention unsupported schema version, got: {err_msg}"
616 );
617 }
618
619 #[test]
620 fn test_starter_toml_parses() {
621 let toml_str = EgressPolicy::starter_toml();
622 let policy: EgressPolicy =
623 toml::from_str(toml_str).expect("starter_toml() should produce valid TOML");
624 assert_eq!(policy.schema_version, 1);
625 assert!(!policy.domains.allow.is_empty());
626 assert!(policy.networks.block_private);
627 assert!(policy.networks.block_metadata);
628 assert_eq!(policy.rate_limits.max_requests_per_minute, 60);
629 assert_eq!(policy.audit.log_format, "json");
630 }
631
632 #[test]
635 fn test_extract_domain_from_url() {
636 assert_eq!(
638 extract_domain("https://api.example.com/v1/items"),
639 Some("api.example.com".into())
640 );
641 assert_eq!(
642 extract_domain("http://api.example.com:8080/path"),
643 Some("api.example.com".into())
644 );
645 assert_eq!(
646 extract_domain("https://api.github.com"),
647 Some("api.github.com".into())
648 );
649 assert_eq!(
651 extract_domain("api.example.com"),
652 Some("api.example.com".into())
653 );
654 assert_eq!(extract_domain("localhost"), None);
656 assert_eq!(extract_domain("/some/path"), None);
658 assert_eq!(extract_domain(""), None);
660 }
661
662 #[test]
663 fn test_from_scan_targets_extracts_domains() {
664 use crate::ir::execution_surface::{ExecutionSurface, NetworkOperation};
665 use crate::ir::tool_surface::{DeclaredPermission, PermissionType, ToolSurface};
666 use crate::ir::{
667 ArgumentSource, DataSurface, DependencySurface, Framework, ProvenanceSurface,
668 ScanTarget, SourceLocation,
669 };
670 use std::path::PathBuf;
671
672 let make_loc = || SourceLocation {
673 file: PathBuf::from("server.py"),
674 line: 1,
675 column: 0,
676 end_line: None,
677 end_column: None,
678 };
679
680 let target = ScanTarget {
681 name: "test-server".into(),
682 framework: Framework::Mcp,
683 root_path: PathBuf::from("/tmp/test"),
684 tools: vec![ToolSurface {
685 name: "fetch_data".into(),
686 description: None,
687 input_schema: None,
688 output_schema: None,
689 declared_permissions: vec![DeclaredPermission {
690 permission_type: PermissionType::NetworkAccess,
691 target: Some("https://api.stripe.com/v1".into()),
692 description: None,
693 }],
694 defined_at: None,
695 }],
696 execution: ExecutionSurface {
697 network_operations: vec![
698 NetworkOperation {
699 function: "requests.get".into(),
700 url_arg: ArgumentSource::Literal("https://api.openai.com/v1/chat".into()),
701 method: Some("GET".into()),
702 sends_data: false,
703 location: make_loc(),
704 },
705 NetworkOperation {
706 function: "requests.post".into(),
707 url_arg: ArgumentSource::Parameter { name: "url".into() },
709 method: Some("POST".into()),
710 sends_data: true,
711 location: make_loc(),
712 },
713 ],
714 ..ExecutionSurface::default()
715 },
716 data: DataSurface::default(),
717 dependencies: DependencySurface::default(),
718 provenance: ProvenanceSurface::default(),
719 source_files: vec![],
720 };
721
722 let policy = EgressPolicy::from_scan_targets(&[target]);
723
724 assert_eq!(policy.schema_version, 1);
726 assert!(policy.domains.deny.is_empty());
728 assert!(
730 policy.domains.allow.contains(&"api.openai.com".to_string()),
731 "Expected api.openai.com in allow list, got: {:?}",
732 policy.domains.allow
733 );
734 assert!(
735 policy.domains.allow.contains(&"api.stripe.com".to_string()),
736 "Expected api.stripe.com in allow list, got: {:?}",
737 policy.domains.allow
738 );
739 assert_eq!(
741 policy.domains.allow,
742 {
743 let mut sorted = policy.domains.allow.clone();
744 sorted.sort();
745 sorted
746 },
747 "Allow list should be sorted"
748 );
749 assert!(policy.networks.block_private);
751 assert!(policy.networks.block_localhost);
752 assert!(policy.networks.block_link_local);
753 assert!(policy.networks.block_metadata);
754 assert_eq!(policy.rate_limits.max_requests_per_minute, 60);
756 }
757
758 fn base_policy() -> EgressPolicy {
761 EgressPolicy {
762 schema_version: 1,
763 domains: DomainPolicy {
764 allow: vec![
765 "api.example.com".into(),
766 "api.github.com".into(),
767 "api.openai.com".into(),
768 ],
769 deny: vec!["evil.com".into()],
770 },
771 networks: NetworkPolicy {
772 block_private: false,
773 block_link_local: true,
774 block_localhost: true,
775 block_metadata: false,
776 },
777 rate_limits: RateLimitPolicy {
778 max_requests_per_minute: 60,
779 per_domain: {
780 let mut m = HashMap::new();
781 m.insert("api.openai.com".into(), 20);
782 m
783 },
784 },
785 audit: AuditPolicy {
786 log_path: Some(PathBuf::from("/tmp/base-audit.jsonl")),
787 log_format: "json".into(),
788 log_allowed: false,
789 },
790 }
791 }
792
793 #[test]
794 fn test_merge_deny_union() {
795 let base = base_policy();
796 let operator = EgressPolicy {
797 schema_version: 1,
798 domains: DomainPolicy {
799 allow: vec![],
800 deny: vec!["extra-bad.com".into()],
801 },
802 networks: NetworkPolicy::default(),
803 rate_limits: RateLimitPolicy::default(),
804 audit: AuditPolicy::default(),
805 };
806
807 let merged = base.merge_override(&operator);
808
809 assert!(
810 merged.domains.deny.contains(&"evil.com".to_string()),
811 "base deny entry must be preserved"
812 );
813 assert!(
814 merged.domains.deny.contains(&"extra-bad.com".to_string()),
815 "operator deny entry must be added"
816 );
817 assert_eq!(merged.domains.deny.len(), 2);
818 }
819
820 #[test]
821 fn test_merge_allow_intersection() {
822 let base = base_policy();
823 let operator = EgressPolicy {
824 schema_version: 1,
825 domains: DomainPolicy {
826 allow: vec![
828 "api.github.com".into(),
829 "api.openai.com".into(),
830 "api.stripe.com".into(),
831 ],
832 deny: vec![],
833 },
834 networks: NetworkPolicy::default(),
835 rate_limits: RateLimitPolicy::default(),
836 audit: AuditPolicy::default(),
837 };
838
839 let merged = base.merge_override(&operator);
840
841 assert!(
842 merged.domains.allow.contains(&"api.github.com".to_string()),
843 "intersection: api.github.com must be in result"
844 );
845 assert!(
846 merged.domains.allow.contains(&"api.openai.com".to_string()),
847 "intersection: api.openai.com must be in result"
848 );
849 assert!(
850 !merged
851 .domains
852 .allow
853 .contains(&"api.example.com".to_string()),
854 "api.example.com not in operator allow → must be excluded"
855 );
856 assert!(
857 !merged.domains.allow.contains(&"api.stripe.com".to_string()),
858 "api.stripe.com not in base allow → must be excluded"
859 );
860 }
861
862 #[test]
863 fn test_merge_rate_limits_min() {
864 let base = base_policy(); let operator = EgressPolicy {
866 schema_version: 1,
867 domains: DomainPolicy {
868 allow: vec![],
869 deny: vec![],
870 },
871 networks: NetworkPolicy::default(),
872 rate_limits: RateLimitPolicy {
873 max_requests_per_minute: 30,
874 per_domain: {
875 let mut m = HashMap::new();
876 m.insert("api.openai.com".into(), 10);
877 m.insert("api.github.com".into(), 5);
878 m
879 },
880 },
881 audit: AuditPolicy::default(),
882 };
883
884 let merged = base.merge_override(&operator);
885
886 assert_eq!(
887 merged.rate_limits.max_requests_per_minute, 30,
888 "global rate: min(60, 30) = 30"
889 );
890 assert_eq!(
891 merged.rate_limits.per_domain["api.openai.com"], 10,
892 "per-domain rate: min(20, 10) = 10"
893 );
894 assert_eq!(
895 merged.rate_limits.per_domain["api.github.com"], 5,
896 "operator-only per-domain: min(60, 5) = 5"
897 );
898 }
899
900 #[test]
901 fn test_merge_network_blocks_or() {
902 let base = base_policy(); let operator = EgressPolicy {
904 schema_version: 1,
905 domains: DomainPolicy {
906 allow: vec![],
907 deny: vec![],
908 },
909 networks: NetworkPolicy {
910 block_private: true,
911 block_link_local: false,
912 block_localhost: false,
913 block_metadata: true,
914 },
915 rate_limits: RateLimitPolicy::default(),
916 audit: AuditPolicy::default(),
917 };
918
919 let merged = base.merge_override(&operator);
920
921 assert!(merged.networks.block_private, "false || true = true");
922 assert!(
923 merged.networks.block_link_local,
924 "true || false = true (base had it)"
925 );
926 assert!(
927 merged.networks.block_localhost,
928 "true || false = true (base had it)"
929 );
930 assert!(merged.networks.block_metadata, "false || true = true");
931 }
932
933 #[test]
934 fn test_merge_empty_override_allow_keeps_base() {
935 let base = base_policy(); let operator = EgressPolicy {
937 schema_version: 1,
938 domains: DomainPolicy {
939 allow: vec![], deny: vec![],
941 },
942 networks: NetworkPolicy::default(),
943 rate_limits: RateLimitPolicy::default(),
944 audit: AuditPolicy::default(),
945 };
946
947 let merged = base.merge_override(&operator);
948
949 assert_eq!(
950 merged.domains.allow, base.domains.allow,
951 "empty operator allow must not restrict base allow list"
952 );
953 }
954
955 #[test]
956 fn test_merge_audit_override_wins() {
957 let base = base_policy(); let operator = EgressPolicy {
959 schema_version: 1,
960 domains: DomainPolicy {
961 allow: vec![],
962 deny: vec![],
963 },
964 networks: NetworkPolicy::default(),
965 rate_limits: RateLimitPolicy::default(),
966 audit: AuditPolicy {
967 log_path: Some(PathBuf::from("/var/log/agentshield/operator.jsonl")),
968 log_format: "text".into(),
969 log_allowed: true,
970 },
971 };
972
973 let merged = base.merge_override(&operator);
974
975 assert_eq!(
976 merged.audit.log_path,
977 Some(PathBuf::from("/var/log/agentshield/operator.jsonl")),
978 "operator audit log_path must win"
979 );
980 assert_eq!(
981 merged.audit.log_format, "text",
982 "operator audit log_format must win"
983 );
984 assert!(
985 merged.audit.log_allowed,
986 "operator audit log_allowed must win"
987 );
988 }
989
990 #[test]
991 fn test_emit_egress_policy_integration() {
992 use crate::{scan, ScanOptions};
995 use std::path::Path;
996
997 let opts = ScanOptions::default();
998 let report = scan(Path::new("tests/fixtures/mcp_servers/vuln_ssrf"), &opts)
999 .expect("scan should succeed");
1000
1001 let policy = EgressPolicy::from_scan_targets(&report.targets);
1002
1003 let tmp = TempDir::new().unwrap();
1005 let policy_path = tmp.path().join("agentshield.egress.toml");
1006 policy.save(&policy_path).unwrap();
1007
1008 let loaded = EgressPolicy::load(&policy_path).unwrap();
1009 assert_eq!(loaded.schema_version, 1);
1010 assert!(loaded.networks.block_private);
1011 assert!(loaded.networks.block_metadata);
1012 assert!(loaded.domains.deny.is_empty());
1014 }
1015}