1use anyhow::anyhow;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
3use url::Url;
4
5#[derive(Debug, Clone)]
10pub struct UrlAccessPolicy {
11 pub block_private_ip: bool,
12 pub allow_loopback: bool,
13 pub allow_cidrs: Vec<CidrRange>,
14 pub allow_domains: Vec<String>,
15 pub enforce_domain_allowlist: bool,
16 pub domain_allowlist: Vec<String>,
17 pub domain_blocklist: Vec<String>,
18 pub approved_domains: Vec<String>,
19 pub require_first_visit_approval: bool,
22}
23
24impl Default for UrlAccessPolicy {
25 fn default() -> Self {
26 Self {
27 block_private_ip: true,
28 allow_loopback: false,
29 allow_cidrs: Vec::new(),
30 allow_domains: Vec::new(),
31 enforce_domain_allowlist: false,
32 domain_allowlist: Vec::new(),
33 domain_blocklist: Vec::new(),
34 approved_domains: Vec::new(),
35 require_first_visit_approval: false,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct CidrRange {
43 pub network: IpAddr,
44 pub prefix_len: u8,
45}
46
47impl CidrRange {
48 pub fn parse(s: &str) -> anyhow::Result<Self> {
50 let parts: Vec<&str> = s.split('/').collect();
51 if parts.len() != 2 {
52 return Err(anyhow!("invalid CIDR notation: {s}"));
53 }
54 let network: IpAddr = parts[0]
55 .parse()
56 .map_err(|_| anyhow!("invalid IP in CIDR: {}", parts[0]))?;
57 let prefix_len: u8 = parts[1]
58 .parse()
59 .map_err(|_| anyhow!("invalid prefix length in CIDR: {}", parts[1]))?;
60 let max_prefix = match network {
61 IpAddr::V4(_) => 32,
62 IpAddr::V6(_) => 128,
63 };
64 if prefix_len > max_prefix {
65 return Err(anyhow!(
66 "prefix length {prefix_len} exceeds maximum {max_prefix}"
67 ));
68 }
69 Ok(Self {
70 network,
71 prefix_len,
72 })
73 }
74
75 pub fn contains(&self, ip: &IpAddr) -> bool {
77 match (&self.network, ip) {
78 (IpAddr::V4(net), IpAddr::V4(addr)) => {
79 let net_bits = u32::from(*net);
80 let addr_bits = u32::from(*addr);
81 if self.prefix_len == 0 {
82 return true;
83 }
84 let mask = u32::MAX << (32 - self.prefix_len);
85 (net_bits & mask) == (addr_bits & mask)
86 }
87 (IpAddr::V6(net), IpAddr::V6(addr)) => {
88 let net_bits = u128::from(*net);
89 let addr_bits = u128::from(*addr);
90 if self.prefix_len == 0 {
91 return true;
92 }
93 let mask = u128::MAX << (128 - self.prefix_len);
94 (net_bits & mask) == (addr_bits & mask)
95 }
96 _ => false, }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum UrlPolicyResult {
104 Allowed,
106 RequiresApproval { domain: String },
108 Blocked { reason: String },
110}
111
112pub fn enforce_url_policy(url: &Url, policy: &UrlAccessPolicy) -> UrlPolicyResult {
117 let host = match url.host_str() {
118 Some(h) => h.to_lowercase(),
119 None => {
120 return UrlPolicyResult::Blocked {
121 reason: "URL has no host".to_string(),
122 }
123 }
124 };
125
126 if is_domain_blocked(&host, &policy.domain_blocklist) {
128 return UrlPolicyResult::Blocked {
129 reason: format!("domain `{host}` is in the blocklist"),
130 };
131 }
132
133 if is_domain_allowed(&host, &policy.allow_domains) {
135 return UrlPolicyResult::Allowed;
136 }
137
138 if is_domain_allowed(&host, &policy.approved_domains) {
140 return UrlPolicyResult::Allowed;
141 }
142
143 if policy.block_private_ip {
145 match check_private_ip(&host, policy) {
146 PrivateIpResult::NotPrivate => {}
147 PrivateIpResult::AllowedLoopback => {}
148 PrivateIpResult::AllowedByCidr => {}
149 PrivateIpResult::Blocked(reason) => {
150 return UrlPolicyResult::Blocked { reason };
151 }
152 PrivateIpResult::DnsRebindingRisk(reason) => {
153 return UrlPolicyResult::Blocked { reason };
154 }
155 }
156 }
157
158 if policy.enforce_domain_allowlist && !is_domain_allowed(&host, &policy.domain_allowlist) {
160 return UrlPolicyResult::Blocked {
161 reason: format!("domain `{host}` is not in the allowlist"),
162 };
163 }
164
165 if policy.require_first_visit_approval {
167 return UrlPolicyResult::RequiresApproval {
168 domain: host.to_string(),
169 };
170 }
171
172 UrlPolicyResult::Allowed
173}
174
175fn is_domain_allowed(host: &str, domains: &[String]) -> bool {
178 domains.iter().any(|d| {
179 let d_lower = d.to_lowercase();
180 host == d_lower || host.ends_with(&format!(".{d_lower}"))
181 })
182}
183
184fn is_domain_blocked(host: &str, blocklist: &[String]) -> bool {
185 is_domain_allowed(host, blocklist)
186}
187
188enum PrivateIpResult {
189 NotPrivate,
190 AllowedLoopback,
191 AllowedByCidr,
192 Blocked(String),
193 DnsRebindingRisk(String),
194}
195
196fn check_private_ip(host: &str, policy: &UrlAccessPolicy) -> PrivateIpResult {
197 if let Ok(ip) = host.parse::<IpAddr>() {
199 return check_ip_address(&ip, policy);
200 }
201
202 let socket_addr = format!("{host}:80");
204 match socket_addr.to_socket_addrs() {
205 Ok(addrs) => {
206 for addr in addrs {
207 let ip = addr.ip();
208 match check_ip_address(&ip, policy) {
209 PrivateIpResult::NotPrivate
210 | PrivateIpResult::AllowedLoopback
211 | PrivateIpResult::AllowedByCidr => continue,
212 PrivateIpResult::Blocked(_) => {
213 return PrivateIpResult::DnsRebindingRisk(format!(
214 "domain `{host}` resolves to private IP {ip}; possible DNS rebinding"
215 ));
216 }
217 PrivateIpResult::DnsRebindingRisk(r) => {
218 return PrivateIpResult::DnsRebindingRisk(r)
219 }
220 }
221 }
222 PrivateIpResult::NotPrivate
223 }
224 Err(_) => {
225 PrivateIpResult::NotPrivate
228 }
229 }
230}
231
232fn check_ip_address(ip: &IpAddr, policy: &UrlAccessPolicy) -> PrivateIpResult {
233 for cidr in &policy.allow_cidrs {
235 if cidr.contains(ip) {
236 return PrivateIpResult::AllowedByCidr;
237 }
238 }
239
240 if ip.is_loopback() {
241 if policy.allow_loopback {
242 return PrivateIpResult::AllowedLoopback;
243 }
244 return PrivateIpResult::Blocked(format!("loopback address {ip} is blocked"));
245 }
246
247 if is_private_ip(ip) {
248 return PrivateIpResult::Blocked(format!("private IP {ip} is blocked"));
249 }
250
251 PrivateIpResult::NotPrivate
252}
253
254fn is_private_ip(ip: &IpAddr) -> bool {
256 match ip {
257 IpAddr::V4(v4) => is_private_ipv4(v4),
258 IpAddr::V6(v6) => is_private_ipv6(v6),
259 }
260}
261
262fn is_private_ipv4(ip: &Ipv4Addr) -> bool {
263 let octets = ip.octets();
264 if octets[0] == 10 {
266 return true;
267 }
268 if octets[0] == 172 && (16..=31).contains(&octets[1]) {
270 return true;
271 }
272 if octets[0] == 192 && octets[1] == 168 {
274 return true;
275 }
276 if octets[0] == 169 && octets[1] == 254 {
278 return true;
279 }
280 if octets[0] == 100 && (64..=127).contains(&octets[1]) {
282 return true;
283 }
284 if octets[0] == 0 {
286 return true;
287 }
288 if octets[0] >= 240 {
290 return true;
291 }
292 false
293}
294
295fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
296 let segments = ip.segments();
297 if ip.is_unspecified() {
300 return true;
301 }
302 if (segments[0] & 0xfe00) == 0xfc00 {
304 return true;
305 }
306 if (segments[0] & 0xffc0) == 0xfe80 {
308 return true;
309 }
310 if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff {
312 let v4 = Ipv4Addr::new(
313 (segments[6] >> 8) as u8,
314 segments[6] as u8,
315 (segments[7] >> 8) as u8,
316 segments[7] as u8,
317 );
318 return is_private_ipv4(&v4);
319 }
320 false
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn cidr_parse_valid() {
329 let cidr = CidrRange::parse("10.0.0.0/8").unwrap();
330 assert_eq!(cidr.prefix_len, 8);
331 }
332
333 #[test]
334 fn cidr_parse_invalid() {
335 assert!(CidrRange::parse("not-a-cidr").is_err());
336 assert!(CidrRange::parse("10.0.0.0/33").is_err());
337 }
338
339 #[test]
340 fn cidr_contains_ipv4() {
341 let cidr = CidrRange::parse("192.168.1.0/24").unwrap();
342 assert!(cidr.contains(&"192.168.1.100".parse().unwrap()));
343 assert!(!cidr.contains(&"192.168.2.1".parse().unwrap()));
344 }
345
346 #[test]
347 fn cidr_contains_ipv6() {
348 let cidr = CidrRange::parse("fc00::/7").unwrap();
349 assert!(cidr.contains(&"fd12::1".parse().unwrap()));
350 assert!(!cidr.contains(&"2001:db8::1".parse().unwrap()));
351 }
352
353 #[test]
354 fn private_ipv4_ranges() {
355 assert!(is_private_ip(&"10.0.0.1".parse().unwrap()));
356 assert!(is_private_ip(&"172.16.0.1".parse().unwrap()));
357 assert!(is_private_ip(&"172.31.255.255".parse().unwrap()));
358 assert!(is_private_ip(&"192.168.0.1".parse().unwrap()));
359 assert!(is_private_ip(&"169.254.1.1".parse().unwrap()));
360 assert!(is_private_ip(&"100.64.0.1".parse().unwrap()));
361 assert!(is_private_ip(&"0.0.0.0".parse().unwrap()));
362 assert!(!is_private_ip(&"8.8.8.8".parse().unwrap()));
363 assert!(!is_private_ip(&"1.1.1.1".parse().unwrap()));
364 }
365
366 #[test]
367 fn private_ipv6_ranges() {
368 assert!(is_private_ip(&"fc00::1".parse().unwrap()));
369 assert!(is_private_ip(&"fd12:3456::1".parse().unwrap()));
370 assert!(is_private_ip(&"fe80::1".parse().unwrap()));
371 assert!(is_private_ip(&"::".parse().unwrap()));
372 assert!(!is_private_ip(&"2001:db8::1".parse().unwrap()));
373 }
374
375 #[test]
376 fn ipv4_mapped_ipv6_private() {
377 assert!(is_private_ip(&"::ffff:192.168.1.1".parse().unwrap()));
379 assert!(!is_private_ip(&"::ffff:8.8.8.8".parse().unwrap()));
381 }
382
383 #[test]
384 fn policy_blocks_private_ip_literal() {
385 let policy = UrlAccessPolicy::default();
386 let url = Url::parse("http://192.168.1.1/api").unwrap();
387 let result = enforce_url_policy(&url, &policy);
388 assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
389 }
390
391 #[test]
392 fn policy_allows_public_ip() {
393 let policy = UrlAccessPolicy::default();
394 let url = Url::parse("https://8.8.8.8/dns-query").unwrap();
395 let result = enforce_url_policy(&url, &policy);
396 assert_eq!(result, UrlPolicyResult::Allowed);
397 }
398
399 #[test]
400 fn policy_blocks_loopback_by_default() {
401 let policy = UrlAccessPolicy::default();
402 let url = Url::parse("http://127.0.0.1:8080").unwrap();
403 let result = enforce_url_policy(&url, &policy);
404 assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
405 }
406
407 #[test]
408 fn policy_allows_loopback_when_configured() {
409 let policy = UrlAccessPolicy {
410 allow_loopback: true,
411 ..Default::default()
412 };
413 let url = Url::parse("http://127.0.0.1:8080").unwrap();
414 let result = enforce_url_policy(&url, &policy);
415 assert_eq!(result, UrlPolicyResult::Allowed);
416 }
417
418 #[test]
419 fn policy_allow_cidrs_exempts_private_ip() {
420 let policy = UrlAccessPolicy {
421 allow_cidrs: vec![CidrRange::parse("10.0.0.0/8").unwrap()],
422 ..Default::default()
423 };
424 let url = Url::parse("http://10.1.2.3/api").unwrap();
425 let result = enforce_url_policy(&url, &policy);
426 assert_eq!(result, UrlPolicyResult::Allowed);
427 }
428
429 #[test]
430 fn policy_domain_blocklist() {
431 let policy = UrlAccessPolicy {
432 domain_blocklist: vec!["evil.com".to_string()],
433 ..Default::default()
434 };
435 let url = Url::parse("https://evil.com/phish").unwrap();
436 let result = enforce_url_policy(&url, &policy);
437 assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
438 }
439
440 #[test]
441 fn policy_domain_blocklist_subdomain() {
442 let policy = UrlAccessPolicy {
443 domain_blocklist: vec!["evil.com".to_string()],
444 ..Default::default()
445 };
446 let url = Url::parse("https://api.evil.com/data").unwrap();
447 let result = enforce_url_policy(&url, &policy);
448 assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
449 }
450
451 #[test]
452 fn policy_domain_allowlist_enforced() {
453 let policy = UrlAccessPolicy {
454 enforce_domain_allowlist: true,
455 domain_allowlist: vec!["api.example.com".to_string()],
456 ..Default::default()
457 };
458
459 let allowed = Url::parse("https://api.example.com/v1").unwrap();
460 assert_eq!(
461 enforce_url_policy(&allowed, &policy),
462 UrlPolicyResult::Allowed
463 );
464
465 let blocked = Url::parse("https://other.com/v1").unwrap();
466 assert!(matches!(
467 enforce_url_policy(&blocked, &policy),
468 UrlPolicyResult::Blocked { .. }
469 ));
470 }
471
472 #[test]
473 fn policy_allow_domains_bypass_private_ip_check() {
474 let policy = UrlAccessPolicy {
475 allow_domains: vec!["internal.corp".to_string()],
476 ..Default::default()
477 };
478 let url = Url::parse("http://internal.corp/api").unwrap();
480 let result = enforce_url_policy(&url, &policy);
481 assert_eq!(result, UrlPolicyResult::Allowed);
482 }
483
484 #[test]
485 fn policy_no_host_blocked() {
486 let policy = UrlAccessPolicy::default();
487 let url = Url::parse("file:///etc/passwd").unwrap();
488 let result = enforce_url_policy(&url, &policy);
489 assert!(matches!(result, UrlPolicyResult::Blocked { .. }));
490 }
491
492 #[test]
493 fn policy_approved_domains_allowed() {
494 let policy = UrlAccessPolicy {
495 approved_domains: vec!["trusted.io".to_string()],
496 ..Default::default()
497 };
498 let url = Url::parse("https://trusted.io/data").unwrap();
499 assert_eq!(enforce_url_policy(&url, &policy), UrlPolicyResult::Allowed);
500 }
501
502 #[test]
503 fn default_policy_allows_public_domains() {
504 let policy = UrlAccessPolicy::default();
505 let url = Url::parse("https://api.github.com/repos").unwrap();
506 let result = enforce_url_policy(&url, &policy);
507 assert_eq!(result, UrlPolicyResult::Allowed);
508 }
509
510 #[test]
511 fn domain_matching_case_insensitive() {
512 let policy = UrlAccessPolicy {
513 domain_blocklist: vec!["Evil.Com".to_string()],
514 ..Default::default()
515 };
516 let url = Url::parse("https://evil.com/path").unwrap();
517 assert!(matches!(
518 enforce_url_policy(&url, &policy),
519 UrlPolicyResult::Blocked { .. }
520 ));
521 }
522}