claude_agent/security/sandbox/
network.rs1use std::collections::HashSet;
4
5use crate::config::NetworkSandboxSettings;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum DomainCheck {
9 Allowed,
10 Blocked,
11}
12
13#[derive(Debug, Clone)]
14pub struct NetworkSandbox {
15 allowed_domains: HashSet<String>,
16 blocked_domains: HashSet<String>,
17 permissive: bool,
18}
19
20impl NetworkSandbox {
21 pub fn new() -> Self {
22 Self {
23 allowed_domains: default_allowed_domains(),
24 blocked_domains: HashSet::new(),
25 permissive: false,
26 }
27 }
28
29 pub fn from_settings(settings: &NetworkSandboxSettings) -> Self {
30 let mut allowed = default_allowed_domains();
31 allowed.extend(settings.allowed_domains.iter().cloned());
32
33 Self {
34 allowed_domains: allowed,
35 blocked_domains: settings.blocked_domains.clone(),
36 permissive: false,
37 }
38 }
39
40 pub fn permissive() -> Self {
41 Self {
42 allowed_domains: HashSet::new(),
43 blocked_domains: HashSet::new(),
44 permissive: true,
45 }
46 }
47
48 pub fn with_allowed_domains(mut self, domains: impl IntoIterator<Item = String>) -> Self {
49 self.allowed_domains.extend(domains);
50 self
51 }
52
53 pub fn with_blocked_domains(mut self, domains: impl IntoIterator<Item = String>) -> Self {
54 self.blocked_domains.extend(domains);
55 self
56 }
57
58 pub fn check(&self, domain: &str) -> DomainCheck {
59 if self.permissive {
60 return DomainCheck::Allowed;
61 }
62
63 let normalized = normalize_domain(domain);
64
65 if self.is_blocked(&normalized) {
66 return DomainCheck::Blocked;
67 }
68
69 if self.is_allowed(&normalized) {
70 return DomainCheck::Allowed;
71 }
72
73 DomainCheck::Blocked
75 }
76
77 fn is_blocked(&self, domain: &str) -> bool {
78 self.blocked_domains.contains(domain)
79 || self
80 .blocked_domains
81 .iter()
82 .any(|pattern| matches_domain_pattern(pattern, domain))
83 }
84
85 fn is_allowed(&self, domain: &str) -> bool {
86 if self.allowed_domains.is_empty() {
87 return true;
88 }
89
90 self.allowed_domains.contains(domain)
91 || self
92 .allowed_domains
93 .iter()
94 .any(|pattern| matches_domain_pattern(pattern, domain))
95 }
96
97 pub fn allowed_domains(&self) -> &HashSet<String> {
98 &self.allowed_domains
99 }
100
101 pub fn blocked_domains(&self) -> &HashSet<String> {
102 &self.blocked_domains
103 }
104}
105
106impl Default for NetworkSandbox {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112fn default_allowed_domains() -> HashSet<String> {
113 [
114 "api.anthropic.com",
115 "claude.ai",
116 "statsig.anthropic.com",
117 "sentry.io",
118 "localhost",
119 "127.0.0.1",
120 "::1",
121 ]
122 .into_iter()
123 .map(String::from)
124 .collect()
125}
126
127fn normalize_domain(domain: &str) -> String {
128 domain
129 .trim()
130 .to_lowercase()
131 .trim_start_matches("http://")
132 .trim_start_matches("https://")
133 .split('/')
134 .next()
135 .unwrap_or(domain)
136 .split(':')
137 .next()
138 .unwrap_or(domain)
139 .to_string()
140}
141
142fn matches_domain_pattern(pattern: &str, domain: &str) -> bool {
143 if pattern.starts_with("*.") {
144 let suffix = &pattern[1..];
145 domain.ends_with(suffix) || domain == &pattern[2..]
146 } else if pattern.starts_with('.') {
147 domain.ends_with(pattern) || domain == &pattern[1..]
148 } else {
149 pattern == domain
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_default_allowed() {
159 let sandbox = NetworkSandbox::new();
160 assert_eq!(sandbox.check("api.anthropic.com"), DomainCheck::Allowed);
161 assert_eq!(sandbox.check("claude.ai"), DomainCheck::Allowed);
162 assert_eq!(sandbox.check("localhost"), DomainCheck::Allowed);
163 }
164
165 #[test]
166 fn test_unknown_domain_blocked() {
167 let sandbox = NetworkSandbox::new();
168 assert_eq!(sandbox.check("unknown.com"), DomainCheck::Blocked);
169 }
170
171 #[test]
172 fn test_blocked_domain() {
173 let sandbox = NetworkSandbox::new().with_blocked_domains(vec!["evil.com".into()]);
174 assert_eq!(sandbox.check("evil.com"), DomainCheck::Blocked);
175 }
176
177 #[test]
178 fn test_wildcard_allowed() {
179 let sandbox = NetworkSandbox::new().with_allowed_domains(vec!["*.example.com".into()]);
180 assert_eq!(sandbox.check("sub.example.com"), DomainCheck::Allowed);
181 assert_eq!(sandbox.check("example.com"), DomainCheck::Allowed);
182 }
183
184 #[test]
185 fn test_wildcard_blocked() {
186 let sandbox = NetworkSandbox::new().with_blocked_domains(vec!["*.malware.com".into()]);
187 assert_eq!(sandbox.check("sub.malware.com"), DomainCheck::Blocked);
188 }
189
190 #[test]
191 fn test_normalize_domain() {
192 assert_eq!(normalize_domain("https://example.com/path"), "example.com");
193 assert_eq!(normalize_domain("example.com:8080"), "example.com");
194 assert_eq!(normalize_domain("EXAMPLE.COM"), "example.com");
195 }
196
197 #[test]
198 fn test_permissive() {
199 let sandbox = NetworkSandbox::permissive();
200 assert_eq!(sandbox.check("anything.com"), DomainCheck::Allowed);
201 }
202
203 #[test]
204 fn test_blocked_takes_precedence() {
205 let sandbox = NetworkSandbox::new()
206 .with_allowed_domains(vec!["example.com".into()])
207 .with_blocked_domains(vec!["example.com".into()]);
208 assert_eq!(sandbox.check("example.com"), DomainCheck::Blocked);
209 }
210}