Skip to main content

agentzero_core/common/
url_policy.rs

1use anyhow::anyhow;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
3use url::Url;
4
5/// URL access policy enforcement.
6///
7/// Checks URLs against private IP blocking, domain allowlists/blocklists,
8/// CIDR ranges, and DNS rebinding protection before allowing network access.
9#[derive(Debug, Clone)]
10pub struct UrlAccessPolicy {
11    pub block_private_ip: bool,
12    pub allow_loopback: bool,
13    pub allow_cidrs: Vec<CidrRange>,
14    pub allow_domains: Vec<String>,
15    pub enforce_domain_allowlist: bool,
16    pub domain_allowlist: Vec<String>,
17    pub domain_blocklist: Vec<String>,
18    pub approved_domains: Vec<String>,
19    /// When true, domains not in `approved_domains` or `allow_domains` require
20    /// explicit first-visit approval from the user before access is granted.
21    pub require_first_visit_approval: bool,
22}
23
24impl Default for UrlAccessPolicy {
25    fn default() -> Self {
26        Self {
27            block_private_ip: true,
28            allow_loopback: false,
29            allow_cidrs: Vec::new(),
30            allow_domains: Vec::new(),
31            enforce_domain_allowlist: false,
32            domain_allowlist: Vec::new(),
33            domain_blocklist: Vec::new(),
34            approved_domains: Vec::new(),
35            require_first_visit_approval: false,
36        }
37    }
38}
39
40/// A parsed CIDR range for IP matching.
41#[derive(Debug, Clone)]
42pub struct CidrRange {
43    pub network: IpAddr,
44    pub prefix_len: u8,
45}
46
47impl CidrRange {
48    /// Parse a CIDR string like "10.0.0.0/8" or "::1/128".
49    pub fn parse(s: &str) -> anyhow::Result<Self> {
50        let parts: Vec<&str> = s.split('/').collect();
51        if parts.len() != 2 {
52            return Err(anyhow!("invalid CIDR notation: {s}"));
53        }
54        let network: IpAddr = parts[0]
55            .parse()
56            .map_err(|_| anyhow!("invalid IP in CIDR: {}", parts[0]))?;
57        let prefix_len: u8 = parts[1]
58            .parse()
59            .map_err(|_| anyhow!("invalid prefix length in CIDR: {}", parts[1]))?;
60        let max_prefix = match network {
61            IpAddr::V4(_) => 32,
62            IpAddr::V6(_) => 128,
63        };
64        if prefix_len > max_prefix {
65            return Err(anyhow!(
66                "prefix length {prefix_len} exceeds maximum {max_prefix}"
67            ));
68        }
69        Ok(Self {
70            network,
71            prefix_len,
72        })
73    }
74
75    /// Check if an IP address falls within this CIDR range.
76    pub fn contains(&self, ip: &IpAddr) -> bool {
77        match (&self.network, ip) {
78            (IpAddr::V4(net), IpAddr::V4(addr)) => {
79                let net_bits = u32::from(*net);
80                let addr_bits = u32::from(*addr);
81                if self.prefix_len == 0 {
82                    return true;
83                }
84                let mask = u32::MAX << (32 - self.prefix_len);
85                (net_bits & mask) == (addr_bits & mask)
86            }
87            (IpAddr::V6(net), IpAddr::V6(addr)) => {
88                let net_bits = u128::from(*net);
89                let addr_bits = u128::from(*addr);
90                if self.prefix_len == 0 {
91                    return true;
92                }
93                let mask = u128::MAX << (128 - self.prefix_len);
94                (net_bits & mask) == (addr_bits & mask)
95            }
96            _ => false, // v4 vs v6 mismatch
97        }
98    }
99}
100
101/// Result of enforcing a URL access policy.
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum UrlPolicyResult {
104    /// URL is allowed.
105    Allowed,
106    /// URL requires first-visit approval from the user.
107    RequiresApproval { domain: String },
108    /// URL is blocked by policy.
109    Blocked { reason: String },
110}
111
112/// Enforce URL access policy on a parsed URL.
113///
114/// This checks: domain blocklist, private IP blocking, domain allowlist,
115/// and first-visit approval requirements.
116pub fn enforce_url_policy(url: &Url, policy: &UrlAccessPolicy) -> UrlPolicyResult {
117    let host = match url.host_str() {
118        Some(h) => h.to_lowercase(),
119        None => {
120            return UrlPolicyResult::Blocked {
121                reason: "URL has no host".to_string(),
122            }
123        }
124    };
125
126    // 1. Check domain blocklist first (always applies)
127    if is_domain_blocked(&host, &policy.domain_blocklist) {
128        return UrlPolicyResult::Blocked {
129            reason: format!("domain `{host}` is in the blocklist"),
130        };
131    }
132
133    // 2. Check explicit allow_domains (bypass all other checks)
134    if is_domain_allowed(&host, &policy.allow_domains) {
135        return UrlPolicyResult::Allowed;
136    }
137
138    // 3. Check approved_domains
139    if is_domain_allowed(&host, &policy.approved_domains) {
140        return UrlPolicyResult::Allowed;
141    }
142
143    // 4. Check private IP blocking
144    if policy.block_private_ip {
145        match check_private_ip(&host, policy) {
146            PrivateIpResult::NotPrivate => {}
147            PrivateIpResult::AllowedLoopback => {}
148            PrivateIpResult::AllowedByCidr => {}
149            PrivateIpResult::Blocked(reason) => {
150                return UrlPolicyResult::Blocked { reason };
151            }
152            PrivateIpResult::DnsRebindingRisk(reason) => {
153                return UrlPolicyResult::Blocked { reason };
154            }
155        }
156    }
157
158    // 5. Check domain allowlist enforcement
159    if policy.enforce_domain_allowlist && !is_domain_allowed(&host, &policy.domain_allowlist) {
160        return UrlPolicyResult::Blocked {
161            reason: format!("domain `{host}` is not in the allowlist"),
162        };
163    }
164
165    // 6. Check first-visit approval requirement
166    if policy.require_first_visit_approval {
167        return UrlPolicyResult::RequiresApproval {
168            domain: host.to_string(),
169        };
170    }
171
172    UrlPolicyResult::Allowed
173}
174
175/// Check if a domain matches any entry in a domain list.
176/// Supports exact match and subdomain matching (e.g., "api.example.com" matches "example.com").
177fn is_domain_allowed(host: &str, domains: &[String]) -> bool {
178    domains.iter().any(|d| {
179        let d_lower = d.to_lowercase();
180        host == d_lower || host.ends_with(&format!(".{d_lower}"))
181    })
182}
183
184fn is_domain_blocked(host: &str, blocklist: &[String]) -> bool {
185    is_domain_allowed(host, blocklist)
186}
187
188enum PrivateIpResult {
189    NotPrivate,
190    AllowedLoopback,
191    AllowedByCidr,
192    Blocked(String),
193    DnsRebindingRisk(String),
194}
195
196fn check_private_ip(host: &str, policy: &UrlAccessPolicy) -> PrivateIpResult {
197    // Try to parse as IP literal first
198    if let Ok(ip) = host.parse::<IpAddr>() {
199        return check_ip_address(&ip, policy);
200    }
201
202    // Resolve domain to IP addresses for DNS rebinding protection
203    let socket_addr = format!("{host}:80");
204    match socket_addr.to_socket_addrs() {
205        Ok(addrs) => {
206            for addr in addrs {
207                let ip = addr.ip();
208                match check_ip_address(&ip, policy) {
209                    PrivateIpResult::NotPrivate
210                    | PrivateIpResult::AllowedLoopback
211                    | PrivateIpResult::AllowedByCidr => continue,
212                    PrivateIpResult::Blocked(_) => {
213                        return PrivateIpResult::DnsRebindingRisk(format!(
214                            "domain `{host}` resolves to private IP {ip}; possible DNS rebinding"
215                        ));
216                    }
217                    PrivateIpResult::DnsRebindingRisk(r) => {
218                        return PrivateIpResult::DnsRebindingRisk(r)
219                    }
220                }
221            }
222            PrivateIpResult::NotPrivate
223        }
224        Err(_) => {
225            // DNS resolution failed — not necessarily a security issue,
226            // let the HTTP client handle connectivity errors
227            PrivateIpResult::NotPrivate
228        }
229    }
230}
231
232fn check_ip_address(ip: &IpAddr, policy: &UrlAccessPolicy) -> PrivateIpResult {
233    // Check if IP is in explicitly allowed CIDRs first
234    for cidr in &policy.allow_cidrs {
235        if cidr.contains(ip) {
236            return PrivateIpResult::AllowedByCidr;
237        }
238    }
239
240    if ip.is_loopback() {
241        if policy.allow_loopback {
242            return PrivateIpResult::AllowedLoopback;
243        }
244        return PrivateIpResult::Blocked(format!("loopback address {ip} is blocked"));
245    }
246
247    if is_private_ip(ip) {
248        return PrivateIpResult::Blocked(format!("private IP {ip} is blocked"));
249    }
250
251    PrivateIpResult::NotPrivate
252}
253
254/// Check if an IP address is in a private/reserved range.
255fn is_private_ip(ip: &IpAddr) -> bool {
256    match ip {
257        IpAddr::V4(v4) => is_private_ipv4(v4),
258        IpAddr::V6(v6) => is_private_ipv6(v6),
259    }
260}
261
262fn is_private_ipv4(ip: &Ipv4Addr) -> bool {
263    let octets = ip.octets();
264    // 10.0.0.0/8
265    if octets[0] == 10 {
266        return true;
267    }
268    // 172.16.0.0/12
269    if octets[0] == 172 && (16..=31).contains(&octets[1]) {
270        return true;
271    }
272    // 192.168.0.0/16
273    if octets[0] == 192 && octets[1] == 168 {
274        return true;
275    }
276    // 169.254.0.0/16 (link-local)
277    if octets[0] == 169 && octets[1] == 254 {
278        return true;
279    }
280    // 100.64.0.0/10 (carrier-grade NAT)
281    if octets[0] == 100 && (64..=127).contains(&octets[1]) {
282        return true;
283    }
284    // 0.0.0.0/8
285    if octets[0] == 0 {
286        return true;
287    }
288    // 240.0.0.0/4 (reserved)
289    if octets[0] >= 240 {
290        return true;
291    }
292    false
293}
294
295fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
296    let segments = ip.segments();
297    // ::1 (loopback — handled separately)
298    // :: (unspecified)
299    if ip.is_unspecified() {
300        return true;
301    }
302    // fc00::/7 (unique local)
303    if (segments[0] & 0xfe00) == 0xfc00 {
304        return true;
305    }
306    // fe80::/10 (link-local)
307    if (segments[0] & 0xffc0) == 0xfe80 {
308        return true;
309    }
310    // ::ffff:0:0/96 (IPv4-mapped — check the embedded v4 address)
311    if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff {
312        let v4 = Ipv4Addr::new(
313            (segments[6] >> 8) as u8,
314            segments[6] as u8,
315            (segments[7] >> 8) as u8,
316            segments[7] as u8,
317        );
318        return is_private_ipv4(&v4);
319    }
320    false
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn cidr_parse_valid() {
329        let cidr = CidrRange::parse("10.0.0.0/8").unwrap();
330        assert_eq!(cidr.prefix_len, 8);
331    }
332
333    #[test]
334    fn cidr_parse_invalid() {
335        assert!(CidrRange::parse("not-a-cidr").is_err());
336        assert!(CidrRange::parse("10.0.0.0/33").is_err());
337    }
338
339    #[test]
340    fn cidr_contains_ipv4() {
341        let cidr = CidrRange::parse("192.168.1.0/24").unwrap();
342        assert!(cidr.contains(&"192.168.1.100".parse().unwrap()));
343        assert!(!cidr.contains(&"192.168.2.1".parse().unwrap()));
344    }
345
346    #[test]
347    fn cidr_contains_ipv6() {
348        let cidr = CidrRange::parse("fc00::/7").unwrap();
349        assert!(cidr.contains(&"fd12::1".parse().unwrap()));
350        assert!(!cidr.contains(&"2001:db8::1".parse().unwrap()));
351    }
352
353    #[test]
354    fn private_ipv4_ranges() {
355        assert!(is_private_ip(&"10.0.0.1".parse().unwrap()));
356        assert!(is_private_ip(&"172.16.0.1".parse().unwrap()));
357        assert!(is_private_ip(&"172.31.255.255".parse().unwrap()));
358        assert!(is_private_ip(&"192.168.0.1".parse().unwrap()));
359        assert!(is_private_ip(&"169.254.1.1".parse().unwrap()));
360        assert!(is_private_ip(&"100.64.0.1".parse().unwrap()));
361        assert!(is_private_ip(&"0.0.0.0".parse().unwrap()));
362        assert!(!is_private_ip(&"8.8.8.8".parse().unwrap()));
363        assert!(!is_private_ip(&"1.1.1.1".parse().unwrap()));
364    }
365
366    #[test]
367    fn private_ipv6_ranges() {
368        assert!(is_private_ip(&"fc00::1".parse().unwrap()));
369        assert!(is_private_ip(&"fd12:3456::1".parse().unwrap()));
370        assert!(is_private_ip(&"fe80::1".parse().unwrap()));
371        assert!(is_private_ip(&"::".parse().unwrap()));
372        assert!(!is_private_ip(&"2001:db8::1".parse().unwrap()));
373    }
374
375    #[test]
376    fn ipv4_mapped_ipv6_private() {
377        // ::ffff:192.168.1.1
378        assert!(is_private_ip(&"::ffff:192.168.1.1".parse().unwrap()));
379        // ::ffff:8.8.8.8
380        assert!(!is_private_ip(&"::ffff:8.8.8.8".parse().unwrap()));
381    }
382
383    #[test]
384    fn policy_blocks_private_ip_literal() {
385        let policy = UrlAccessPolicy::default();
386        let url = Url::parse("http://192.168.1.1/api").unwrap();
387        let result = enforce_url_policy(&url, &policy);
388        assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
389    }
390
391    #[test]
392    fn policy_allows_public_ip() {
393        let policy = UrlAccessPolicy::default();
394        let url = Url::parse("https://8.8.8.8/dns-query").unwrap();
395        let result = enforce_url_policy(&url, &policy);
396        assert_eq!(result, UrlPolicyResult::Allowed);
397    }
398
399    #[test]
400    fn policy_blocks_loopback_by_default() {
401        let policy = UrlAccessPolicy::default();
402        let url = Url::parse("http://127.0.0.1:8080").unwrap();
403        let result = enforce_url_policy(&url, &policy);
404        assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
405    }
406
407    #[test]
408    fn policy_allows_loopback_when_configured() {
409        let policy = UrlAccessPolicy {
410            allow_loopback: true,
411            ..Default::default()
412        };
413        let url = Url::parse("http://127.0.0.1:8080").unwrap();
414        let result = enforce_url_policy(&url, &policy);
415        assert_eq!(result, UrlPolicyResult::Allowed);
416    }
417
418    #[test]
419    fn policy_allow_cidrs_exempts_private_ip() {
420        let policy = UrlAccessPolicy {
421            allow_cidrs: vec![CidrRange::parse("10.0.0.0/8").unwrap()],
422            ..Default::default()
423        };
424        let url = Url::parse("http://10.1.2.3/api").unwrap();
425        let result = enforce_url_policy(&url, &policy);
426        assert_eq!(result, UrlPolicyResult::Allowed);
427    }
428
429    #[test]
430    fn policy_domain_blocklist() {
431        let policy = UrlAccessPolicy {
432            domain_blocklist: vec!["evil.com".to_string()],
433            ..Default::default()
434        };
435        let url = Url::parse("https://evil.com/phish").unwrap();
436        let result = enforce_url_policy(&url, &policy);
437        assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
438    }
439
440    #[test]
441    fn policy_domain_blocklist_subdomain() {
442        let policy = UrlAccessPolicy {
443            domain_blocklist: vec!["evil.com".to_string()],
444            ..Default::default()
445        };
446        let url = Url::parse("https://api.evil.com/data").unwrap();
447        let result = enforce_url_policy(&url, &policy);
448        assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
449    }
450
451    #[test]
452    fn policy_domain_allowlist_enforced() {
453        let policy = UrlAccessPolicy {
454            enforce_domain_allowlist: true,
455            domain_allowlist: vec!["api.example.com".to_string()],
456            ..Default::default()
457        };
458
459        let allowed = Url::parse("https://api.example.com/v1").unwrap();
460        assert_eq!(
461            enforce_url_policy(&allowed, &policy),
462            UrlPolicyResult::Allowed
463        );
464
465        let blocked = Url::parse("https://other.com/v1").unwrap();
466        assert!(matches!(
467            enforce_url_policy(&blocked, &policy),
468            UrlPolicyResult::Blocked { .. }
469        ));
470    }
471
472    #[test]
473    fn policy_allow_domains_bypass_private_ip_check() {
474        let policy = UrlAccessPolicy {
475            allow_domains: vec!["internal.corp".to_string()],
476            ..Default::default()
477        };
478        // Even though this might resolve to a private IP, allow_domains bypasses
479        let url = Url::parse("http://internal.corp/api").unwrap();
480        let result = enforce_url_policy(&url, &policy);
481        assert_eq!(result, UrlPolicyResult::Allowed);
482    }
483
484    #[test]
485    fn policy_no_host_blocked() {
486        let policy = UrlAccessPolicy::default();
487        let url = Url::parse("file:///etc/passwd").unwrap();
488        let result = enforce_url_policy(&url, &policy);
489        assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
490    }
491
492    #[test]
493    fn policy_approved_domains_allowed() {
494        let policy = UrlAccessPolicy {
495            approved_domains: vec!["trusted.io".to_string()],
496            ..Default::default()
497        };
498        let url = Url::parse("https://trusted.io/data").unwrap();
499        assert_eq!(enforce_url_policy(&url, &policy), UrlPolicyResult::Allowed);
500    }
501
502    #[test]
503    fn default_policy_allows_public_domains() {
504        let policy = UrlAccessPolicy::default();
505        let url = Url::parse("https://api.github.com/repos").unwrap();
506        let result = enforce_url_policy(&url, &policy);
507        assert_eq!(result, UrlPolicyResult::Allowed);
508    }
509
510    #[test]
511    fn domain_matching_case_insensitive() {
512        let policy = UrlAccessPolicy {
513            domain_blocklist: vec!["Evil.Com".to_string()],
514            ..Default::default()
515        };
516        let url = Url::parse("https://evil.com/path").unwrap();
517        assert!(matches!(
518            enforce_url_policy(&url, &policy),
519            UrlPolicyResult::Blocked { .. }
520        ));
521    }
522}