Skip to main content

do_memory_mcp/sandbox/
network.rs

1//! Network access control for sandboxed code
2//!
3//! Implements network restrictions with:
4//! - Domain whitelist enforcement
5//! - HTTPS-only mode
6//! - Request rate limiting
7//! - IP address validation
8
9use anyhow::{Result, bail};
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
11use tracing::{debug, warn};
12
13/// Network access restrictions configuration
14#[derive(Debug, Clone)]
15pub struct NetworkRestrictions {
16    /// Block all network access
17    pub block_all: bool,
18    /// Allowed domains (empty = deny all if block_all is false)
19    pub allowed_domains: Vec<String>,
20    /// Allowed IP addresses
21    pub allowed_ips: Vec<IpAddr>,
22    /// Require HTTPS only (no HTTP)
23    pub https_only: bool,
24    /// Block private IP ranges (RFC1918)
25    pub block_private_ips: bool,
26    /// Block localhost
27    pub block_localhost: bool,
28    /// Maximum number of requests
29    pub max_requests: usize,
30}
31
32impl Default for NetworkRestrictions {
33    fn default() -> Self {
34        Self {
35            block_all: true,
36            allowed_domains: vec![],
37            allowed_ips: vec![],
38            https_only: true,
39            block_private_ips: true,
40            block_localhost: true,
41            max_requests: 0,
42        }
43    }
44}
45
46impl NetworkRestrictions {
47    /// Create a deny-all configuration
48    pub fn deny_all() -> Self {
49        Self {
50            block_all: true,
51            ..Default::default()
52        }
53    }
54
55    /// Create an allow-specific configuration
56    pub fn allow_domains(domains: Vec<String>) -> Self {
57        Self {
58            block_all: false,
59            allowed_domains: domains,
60            https_only: true,
61            block_private_ips: true,
62            block_localhost: true,
63            max_requests: 10,
64            ..Default::default()
65        }
66    }
67
68    /// Validate a URL for network access
69    ///
70    /// # Security
71    ///
72    /// This method checks:
73    /// 1. If network access is allowed at all
74    /// 2. Protocol (HTTP vs HTTPS)
75    /// 3. Domain/hostname against whitelist
76    /// 4. IP address restrictions
77    pub fn validate_url(&self, url: &str) -> Result<()> {
78        // Check if all network access is blocked
79        if self.block_all {
80            bail!(NetworkSecurityError::NetworkAccessDenied {
81                reason: "All network access is blocked".to_string()
82            });
83        }
84
85        // Parse URL
86        let parsed = url::Url::parse(url).map_err(|e| NetworkSecurityError::InvalidUrl {
87            url: url.to_string(),
88            reason: e.to_string(),
89        })?;
90
91        // Validate scheme
92        self.validate_scheme(&parsed)?;
93
94        // Validate host
95        if let Some(host) = parsed.host_str() {
96            self.validate_host(host)?;
97        } else {
98            bail!(NetworkSecurityError::InvalidUrl {
99                url: url.to_string(),
100                reason: "No host specified".to_string()
101            });
102        }
103
104        debug!("URL validated: {}", url);
105        Ok(())
106    }
107
108    /// Validate a domain name for network access
109    pub fn validate_domain(&self, domain: &str) -> Result<()> {
110        if self.block_all {
111            bail!(NetworkSecurityError::NetworkAccessDenied {
112                reason: "All network access is blocked".to_string()
113            });
114        }
115
116        // Check if domain is in whitelist
117        if !self.is_domain_allowed(domain) {
118            warn!("Domain access denied: {} (not in whitelist)", domain);
119            bail!(NetworkSecurityError::DomainNotInWhitelist {
120                domain: domain.to_string(),
121                allowed_domains: self.allowed_domains.clone()
122            });
123        }
124
125        // Check for localhost
126        if self.block_localhost && is_localhost(domain) {
127            bail!(NetworkSecurityError::LocalhostAccessDenied {
128                domain: domain.to_string()
129            });
130        }
131
132        debug!("Domain validated: {}", domain);
133        Ok(())
134    }
135
136    /// Validate URL scheme (HTTP/HTTPS)
137    fn validate_scheme(&self, url: &url::Url) -> Result<()> {
138        let scheme = url.scheme();
139
140        match scheme {
141            "https" => Ok(()),
142            "http" => {
143                if self.https_only {
144                    bail!(NetworkSecurityError::HttpNotAllowed {
145                        url: url.to_string()
146                    });
147                }
148                Ok(())
149            }
150            _ => bail!(NetworkSecurityError::UnsupportedProtocol {
151                protocol: scheme.to_string(),
152                url: url.to_string()
153            }),
154        }
155    }
156
157    /// Validate host (domain or IP)
158    fn validate_host(&self, host: &str) -> Result<()> {
159        // Try parsing as IP address first
160        if let Ok(ip) = host.parse::<IpAddr>() {
161            return self.validate_ip(&ip);
162        }
163
164        // Otherwise treat as domain name
165        self.validate_domain(host)
166    }
167
168    /// Validate IP address
169    fn validate_ip(&self, ip: &IpAddr) -> Result<()> {
170        // Check if IP is in allowed list
171        if !self.allowed_ips.is_empty() && !self.allowed_ips.contains(ip) {
172            bail!(NetworkSecurityError::IpNotInWhitelist {
173                ip: ip.to_string(),
174                allowed_ips: self.allowed_ips.iter().map(|i| i.to_string()).collect()
175            });
176        }
177
178        // Check for localhost
179        if self.block_localhost && is_localhost_ip(ip) {
180            bail!(NetworkSecurityError::LocalhostAccessDenied {
181                domain: ip.to_string()
182            });
183        }
184
185        // Check for private IPs
186        if self.block_private_ips && is_private_ip(ip) {
187            bail!(NetworkSecurityError::PrivateIpAccessDenied { ip: ip.to_string() });
188        }
189
190        Ok(())
191    }
192
193    /// Check if a domain is in the allowed list
194    fn is_domain_allowed(&self, domain: &str) -> bool {
195        if self.allowed_domains.is_empty() {
196            // No whitelist = deny all
197            return false;
198        }
199
200        // Exact match
201        if self.allowed_domains.contains(&domain.to_string()) {
202            return true;
203        }
204
205        // Subdomain match (e.g., "api.example.com" matches "example.com")
206        for allowed in &self.allowed_domains {
207            if domain.ends_with(&format!(".{}", allowed)) {
208                return true;
209            }
210        }
211
212        false
213    }
214}
215
216/// Check if a domain name is localhost
217fn is_localhost(domain: &str) -> bool {
218    matches!(
219        domain.to_lowercase().as_str(),
220        "localhost" | "localhost.localdomain" | "127.0.0.1" | "::1" | "0.0.0.0"
221    )
222}
223
224/// Check if an IP address is localhost
225fn is_localhost_ip(ip: &IpAddr) -> bool {
226    match ip {
227        IpAddr::V4(ipv4) => ipv4.is_loopback(),
228        IpAddr::V6(ipv6) => ipv6.is_loopback(),
229    }
230}
231
232/// Check if an IP address is private (RFC1918)
233fn is_private_ip(ip: &IpAddr) -> bool {
234    match ip {
235        IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
236        IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
237    }
238}
239
240/// Check if IPv4 address is private
241fn is_private_ipv4(ip: &Ipv4Addr) -> bool {
242    // RFC1918 private ranges:
243    // 10.0.0.0/8
244    // 172.16.0.0/12
245    // 192.168.0.0/16
246    ip.is_private()
247        || ip.is_loopback()
248        || ip.is_link_local()
249        || ip.is_broadcast()
250        || ip.is_documentation()
251}
252
253/// Check if IPv6 address is private
254fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
255    // Check loopback (::1)
256    if ip.is_loopback() {
257        return true;
258    }
259
260    // Check unique local addresses (fc00::/7)
261    // This is equivalent to is_unique_local() which requires MSRV 1.84.0
262    let segments = ip.segments();
263    if (segments[0] & 0xfe00) == 0xfc00 {
264        return true;
265    }
266
267    // Check multicast
268    if ip.is_multicast() {
269        return true;
270    }
271
272    false
273}
274
275/// Network security errors
276#[derive(Debug, thiserror::Error)]
277pub enum NetworkSecurityError {
278    #[error("Network access denied: {reason}")]
279    NetworkAccessDenied { reason: String },
280
281    #[error("Invalid URL: {url} - {reason}")]
282    InvalidUrl { url: String, reason: String },
283
284    #[error("HTTP not allowed (HTTPS only): {url}")]
285    HttpNotAllowed { url: String },
286
287    #[error("Unsupported protocol: {protocol} in URL: {url}")]
288    UnsupportedProtocol { protocol: String, url: String },
289
290    #[error("Domain not in whitelist: {domain} (allowed: {allowed_domains:?})")]
291    DomainNotInWhitelist {
292        domain: String,
293        allowed_domains: Vec<String>,
294    },
295
296    #[error("IP not in whitelist: {ip} (allowed: {allowed_ips:?})")]
297    IpNotInWhitelist {
298        ip: String,
299        allowed_ips: Vec<String>,
300    },
301
302    #[error("Localhost access denied: {domain}")]
303    LocalhostAccessDenied { domain: String },
304
305    #[error("Private IP access denied: {ip}")]
306    PrivateIpAccessDenied { ip: String },
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_deny_all() {
315        let restrictions = NetworkRestrictions::deny_all();
316        let result = restrictions.validate_url("https://example.com");
317        assert!(result.is_err());
318    }
319
320    #[test]
321    fn test_https_only() {
322        let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
323        assert!(restrictions.validate_url("https://example.com").is_ok());
324        assert!(restrictions.validate_url("http://example.com").is_err());
325    }
326
327    #[test]
328    fn test_domain_whitelist() {
329        let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
330
331        assert!(restrictions.validate_url("https://example.com").is_ok());
332        assert!(restrictions.validate_url("https://api.example.com").is_ok());
333        assert!(restrictions.validate_url("https://evil.com").is_err());
334    }
335
336    #[test]
337    fn test_localhost_blocking() {
338        let mut restrictions = NetworkRestrictions::allow_domains(vec!["localhost".to_string()]);
339        restrictions.block_localhost = true;
340
341        assert!(restrictions.validate_domain("localhost").is_err());
342        assert!(restrictions.validate_domain("127.0.0.1").is_err());
343    }
344
345    #[test]
346    fn test_private_ip_blocking() {
347        let restrictions = NetworkRestrictions {
348            block_all: false,
349            block_private_ips: true,
350            ..Default::default()
351        };
352
353        let private_ips = vec![
354            "10.0.0.1",
355            "172.16.0.1",
356            "192.168.1.1",
357            "127.0.0.1",
358            "169.254.1.1",
359        ];
360
361        for ip in private_ips {
362            let addr: IpAddr = ip.parse().unwrap();
363            assert!(restrictions.validate_ip(&addr).is_err());
364        }
365    }
366
367    #[test]
368    fn test_public_ip_allowed() {
369        let restrictions = NetworkRestrictions {
370            block_all: false,
371            allowed_ips: vec!["8.8.8.8".parse().unwrap()],
372            block_private_ips: true,
373            ..Default::default()
374        };
375
376        let ip: IpAddr = "8.8.8.8".parse().unwrap();
377        assert!(restrictions.validate_ip(&ip).is_ok());
378    }
379
380    #[test]
381    fn test_is_localhost() {
382        assert!(is_localhost("localhost"));
383        assert!(is_localhost("LOCALHOST"));
384        assert!(is_localhost("127.0.0.1"));
385        assert!(!is_localhost("example.com"));
386    }
387
388    #[test]
389    fn test_is_private_ipv4() {
390        let private = Ipv4Addr::new(192, 168, 1, 1);
391        assert!(is_private_ipv4(&private));
392
393        let public = Ipv4Addr::new(8, 8, 8, 8);
394        assert!(!is_private_ipv4(&public));
395    }
396
397    #[test]
398    fn test_subdomain_matching() {
399        let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
400
401        assert!(restrictions.is_domain_allowed("example.com"));
402        assert!(restrictions.is_domain_allowed("api.example.com"));
403        assert!(restrictions.is_domain_allowed("foo.bar.example.com"));
404        assert!(!restrictions.is_domain_allowed("examplecom"));
405        assert!(!restrictions.is_domain_allowed("evil.com"));
406    }
407}