do_memory_mcp/sandbox/
network.rs1use anyhow::{Result, bail};
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
11use tracing::{debug, warn};
12
13#[derive(Debug, Clone)]
15pub struct NetworkRestrictions {
16 pub block_all: bool,
18 pub allowed_domains: Vec<String>,
20 pub allowed_ips: Vec<IpAddr>,
22 pub https_only: bool,
24 pub block_private_ips: bool,
26 pub block_localhost: bool,
28 pub max_requests: usize,
30}
31
32impl Default for NetworkRestrictions {
33 fn default() -> Self {
34 Self {
35 block_all: true,
36 allowed_domains: vec![],
37 allowed_ips: vec![],
38 https_only: true,
39 block_private_ips: true,
40 block_localhost: true,
41 max_requests: 0,
42 }
43 }
44}
45
46impl NetworkRestrictions {
47 pub fn deny_all() -> Self {
49 Self {
50 block_all: true,
51 ..Default::default()
52 }
53 }
54
55 pub fn allow_domains(domains: Vec<String>) -> Self {
57 Self {
58 block_all: false,
59 allowed_domains: domains,
60 https_only: true,
61 block_private_ips: true,
62 block_localhost: true,
63 max_requests: 10,
64 ..Default::default()
65 }
66 }
67
68 pub fn validate_url(&self, url: &str) -> Result<()> {
78 if self.block_all {
80 bail!(NetworkSecurityError::NetworkAccessDenied {
81 reason: "All network access is blocked".to_string()
82 });
83 }
84
85 let parsed = url::Url::parse(url).map_err(|e| NetworkSecurityError::InvalidUrl {
87 url: url.to_string(),
88 reason: e.to_string(),
89 })?;
90
91 self.validate_scheme(&parsed)?;
93
94 if let Some(host) = parsed.host_str() {
96 self.validate_host(host)?;
97 } else {
98 bail!(NetworkSecurityError::InvalidUrl {
99 url: url.to_string(),
100 reason: "No host specified".to_string()
101 });
102 }
103
104 debug!("URL validated: {}", url);
105 Ok(())
106 }
107
108 pub fn validate_domain(&self, domain: &str) -> Result<()> {
110 if self.block_all {
111 bail!(NetworkSecurityError::NetworkAccessDenied {
112 reason: "All network access is blocked".to_string()
113 });
114 }
115
116 if !self.is_domain_allowed(domain) {
118 warn!("Domain access denied: {} (not in whitelist)", domain);
119 bail!(NetworkSecurityError::DomainNotInWhitelist {
120 domain: domain.to_string(),
121 allowed_domains: self.allowed_domains.clone()
122 });
123 }
124
125 if self.block_localhost && is_localhost(domain) {
127 bail!(NetworkSecurityError::LocalhostAccessDenied {
128 domain: domain.to_string()
129 });
130 }
131
132 debug!("Domain validated: {}", domain);
133 Ok(())
134 }
135
136 fn validate_scheme(&self, url: &url::Url) -> Result<()> {
138 let scheme = url.scheme();
139
140 match scheme {
141 "https" => Ok(()),
142 "http" => {
143 if self.https_only {
144 bail!(NetworkSecurityError::HttpNotAllowed {
145 url: url.to_string()
146 });
147 }
148 Ok(())
149 }
150 _ => bail!(NetworkSecurityError::UnsupportedProtocol {
151 protocol: scheme.to_string(),
152 url: url.to_string()
153 }),
154 }
155 }
156
157 fn validate_host(&self, host: &str) -> Result<()> {
159 if let Ok(ip) = host.parse::<IpAddr>() {
161 return self.validate_ip(&ip);
162 }
163
164 self.validate_domain(host)
166 }
167
168 fn validate_ip(&self, ip: &IpAddr) -> Result<()> {
170 if !self.allowed_ips.is_empty() && !self.allowed_ips.contains(ip) {
172 bail!(NetworkSecurityError::IpNotInWhitelist {
173 ip: ip.to_string(),
174 allowed_ips: self.allowed_ips.iter().map(|i| i.to_string()).collect()
175 });
176 }
177
178 if self.block_localhost && is_localhost_ip(ip) {
180 bail!(NetworkSecurityError::LocalhostAccessDenied {
181 domain: ip.to_string()
182 });
183 }
184
185 if self.block_private_ips && is_private_ip(ip) {
187 bail!(NetworkSecurityError::PrivateIpAccessDenied { ip: ip.to_string() });
188 }
189
190 Ok(())
191 }
192
193 fn is_domain_allowed(&self, domain: &str) -> bool {
195 if self.allowed_domains.is_empty() {
196 return false;
198 }
199
200 if self.allowed_domains.contains(&domain.to_string()) {
202 return true;
203 }
204
205 for allowed in &self.allowed_domains {
207 if domain.ends_with(&format!(".{}", allowed)) {
208 return true;
209 }
210 }
211
212 false
213 }
214}
215
216fn is_localhost(domain: &str) -> bool {
218 matches!(
219 domain.to_lowercase().as_str(),
220 "localhost" | "localhost.localdomain" | "127.0.0.1" | "::1" | "0.0.0.0"
221 )
222}
223
224fn is_localhost_ip(ip: &IpAddr) -> bool {
226 match ip {
227 IpAddr::V4(ipv4) => ipv4.is_loopback(),
228 IpAddr::V6(ipv6) => ipv6.is_loopback(),
229 }
230}
231
232fn is_private_ip(ip: &IpAddr) -> bool {
234 match ip {
235 IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
236 IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
237 }
238}
239
240fn is_private_ipv4(ip: &Ipv4Addr) -> bool {
242 ip.is_private()
247 || ip.is_loopback()
248 || ip.is_link_local()
249 || ip.is_broadcast()
250 || ip.is_documentation()
251}
252
253fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
255 if ip.is_loopback() {
257 return true;
258 }
259
260 let segments = ip.segments();
263 if (segments[0] & 0xfe00) == 0xfc00 {
264 return true;
265 }
266
267 if ip.is_multicast() {
269 return true;
270 }
271
272 false
273}
274
275#[derive(Debug, thiserror::Error)]
277pub enum NetworkSecurityError {
278 #[error("Network access denied: {reason}")]
279 NetworkAccessDenied { reason: String },
280
281 #[error("Invalid URL: {url} - {reason}")]
282 InvalidUrl { url: String, reason: String },
283
284 #[error("HTTP not allowed (HTTPS only): {url}")]
285 HttpNotAllowed { url: String },
286
287 #[error("Unsupported protocol: {protocol} in URL: {url}")]
288 UnsupportedProtocol { protocol: String, url: String },
289
290 #[error("Domain not in whitelist: {domain} (allowed: {allowed_domains:?})")]
291 DomainNotInWhitelist {
292 domain: String,
293 allowed_domains: Vec<String>,
294 },
295
296 #[error("IP not in whitelist: {ip} (allowed: {allowed_ips:?})")]
297 IpNotInWhitelist {
298 ip: String,
299 allowed_ips: Vec<String>,
300 },
301
302 #[error("Localhost access denied: {domain}")]
303 LocalhostAccessDenied { domain: String },
304
305 #[error("Private IP access denied: {ip}")]
306 PrivateIpAccessDenied { ip: String },
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_deny_all() {
315 let restrictions = NetworkRestrictions::deny_all();
316 let result = restrictions.validate_url("https://example.com");
317 assert!(result.is_err());
318 }
319
320 #[test]
321 fn test_https_only() {
322 let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
323 assert!(restrictions.validate_url("https://example.com").is_ok());
324 assert!(restrictions.validate_url("http://example.com").is_err());
325 }
326
327 #[test]
328 fn test_domain_whitelist() {
329 let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
330
331 assert!(restrictions.validate_url("https://example.com").is_ok());
332 assert!(restrictions.validate_url("https://api.example.com").is_ok());
333 assert!(restrictions.validate_url("https://evil.com").is_err());
334 }
335
336 #[test]
337 fn test_localhost_blocking() {
338 let mut restrictions = NetworkRestrictions::allow_domains(vec!["localhost".to_string()]);
339 restrictions.block_localhost = true;
340
341 assert!(restrictions.validate_domain("localhost").is_err());
342 assert!(restrictions.validate_domain("127.0.0.1").is_err());
343 }
344
345 #[test]
346 fn test_private_ip_blocking() {
347 let restrictions = NetworkRestrictions {
348 block_all: false,
349 block_private_ips: true,
350 ..Default::default()
351 };
352
353 let private_ips = vec![
354 "10.0.0.1",
355 "172.16.0.1",
356 "192.168.1.1",
357 "127.0.0.1",
358 "169.254.1.1",
359 ];
360
361 for ip in private_ips {
362 let addr: IpAddr = ip.parse().unwrap();
363 assert!(restrictions.validate_ip(&addr).is_err());
364 }
365 }
366
367 #[test]
368 fn test_public_ip_allowed() {
369 let restrictions = NetworkRestrictions {
370 block_all: false,
371 allowed_ips: vec!["8.8.8.8".parse().unwrap()],
372 block_private_ips: true,
373 ..Default::default()
374 };
375
376 let ip: IpAddr = "8.8.8.8".parse().unwrap();
377 assert!(restrictions.validate_ip(&ip).is_ok());
378 }
379
380 #[test]
381 fn test_is_localhost() {
382 assert!(is_localhost("localhost"));
383 assert!(is_localhost("LOCALHOST"));
384 assert!(is_localhost("127.0.0.1"));
385 assert!(!is_localhost("example.com"));
386 }
387
388 #[test]
389 fn test_is_private_ipv4() {
390 let private = Ipv4Addr::new(192, 168, 1, 1);
391 assert!(is_private_ipv4(&private));
392
393 let public = Ipv4Addr::new(8, 8, 8, 8);
394 assert!(!is_private_ipv4(&public));
395 }
396
397 #[test]
398 fn test_subdomain_matching() {
399 let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
400
401 assert!(restrictions.is_domain_allowed("example.com"));
402 assert!(restrictions.is_domain_allowed("api.example.com"));
403 assert!(restrictions.is_domain_allowed("foo.bar.example.com"));
404 assert!(!restrictions.is_domain_allowed("examplecom"));
405 assert!(!restrictions.is_domain_allowed("evil.com"));
406 }
407}