1use anyhow::{Context, Result, bail};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
5use url::Url;
6
7const DEFAULT_BLOCKED_HOSTS: &[&str] = &[
9 "localhost",
10 "127.0.0.1",
11 "0.0.0.0",
12 "::1",
13 "[::1]",
14 "169.254.169.254", "metadata.google.internal", "metadata.goog", ];
18
19#[derive(Clone, Debug)]
37pub struct UrlValidator {
38 allowed_domains: Option<Vec<String>>,
40 blocked_hosts: Vec<String>,
42 allow_private_ips: bool,
44 max_redirects: usize,
46 require_https: bool,
48}
49
50impl Default for UrlValidator {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl UrlValidator {
57 #[must_use]
59 pub fn new() -> Self {
60 Self {
61 allowed_domains: None,
62 blocked_hosts: DEFAULT_BLOCKED_HOSTS
63 .iter()
64 .map(|&s| s.to_string())
65 .collect(),
66 allow_private_ips: false,
67 max_redirects: 3,
68 require_https: true,
69 }
70 }
71
72 #[must_use]
74 pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
75 self.allowed_domains = Some(domains);
76 self
77 }
78
79 #[must_use]
81 pub fn with_blocked_hosts(mut self, hosts: Vec<String>) -> Self {
82 self.blocked_hosts.extend(hosts);
83 self
84 }
85
86 #[must_use]
88 pub const fn with_allow_private_ips(mut self, allow: bool) -> Self {
89 self.allow_private_ips = allow;
90 self
91 }
92
93 #[must_use]
95 pub const fn with_max_redirects(mut self, max: usize) -> Self {
96 self.max_redirects = max;
97 self
98 }
99
100 #[must_use]
102 pub const fn with_allow_http(mut self) -> Self {
103 self.require_https = false;
104 self
105 }
106
107 #[must_use]
109 pub const fn max_redirects(&self) -> usize {
110 self.max_redirects
111 }
112
113 pub fn validate(&self, url_str: &str) -> Result<Url> {
125 let url = Url::parse(url_str).context("Invalid URL format")?;
126
127 match url.scheme() {
129 "https" => {}
130 "http" => {
131 if self.require_https {
132 bail!("HTTPS required, but HTTP URL provided");
133 }
134 }
135 scheme => bail!("Unsupported URL scheme: {scheme}"),
136 }
137
138 let host = url.host_str().context("URL must have a host")?;
140
141 if self.blocked_hosts.iter().any(|blocked| {
143 host.eq_ignore_ascii_case(blocked) || host.ends_with(&format!(".{blocked}"))
144 }) {
145 bail!("Access to host '{host}' is blocked");
146 }
147
148 if let Some(ref allowed) = self.allowed_domains {
150 let is_allowed = allowed.iter().any(|domain| {
151 host.eq_ignore_ascii_case(domain) || host.ends_with(&format!(".{domain}"))
152 });
153 if !is_allowed {
154 bail!("Host '{host}' is not in the allowed domains list");
155 }
156 }
157
158 self.validate_resolved_ip(host)?;
160
161 Ok(url)
162 }
163
164 fn validate_resolved_ip(&self, host: &str) -> Result<()> {
170 let addrs: Vec<_> = format!("{host}:80")
172 .to_socket_addrs()
173 .map(Iterator::collect)
174 .unwrap_or_default();
175
176 if addrs.is_empty() {
177 bail!("Could not resolve host '{host}' — blocking unresolvable URLs for safety");
178 }
179
180 for addr in addrs {
181 let ip = addr.ip();
182 if !self.allow_private_ips && is_private_ip(&ip) {
183 bail!("Access to private IP address {ip} is blocked");
184 }
185 if is_loopback(&ip) {
186 bail!("Access to loopback address {ip} is blocked");
187 }
188 if is_link_local(&ip) {
189 bail!("Access to link-local address {ip} is blocked");
190 }
191 }
192
193 Ok(())
194 }
195}
196
197fn is_private_ip(ip: &IpAddr) -> bool {
202 match ip {
203 IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
204 IpAddr::V6(ipv6) => {
205 if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
207 return is_private_ipv4(mapped_v4);
208 }
209 is_private_ipv6(ipv6)
210 }
211 }
212}
213
214fn is_private_ipv4(ip: Ipv4Addr) -> bool {
216 let octets = ip.octets();
217
218 if octets[0] == 10 {
220 return true;
221 }
222
223 if octets[0] == 172 && (16..=31).contains(&octets[1]) {
225 return true;
226 }
227
228 if octets[0] == 192 && octets[1] == 168 {
230 return true;
231 }
232
233 if octets[0] == 100 && (64..=127).contains(&octets[1]) {
235 return true;
236 }
237
238 false
239}
240
241const fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
243 let segments = ip.segments();
245 (segments[0] & 0xfe00) == 0xfc00
246}
247
248const fn is_loopback(ip: &IpAddr) -> bool {
252 match ip {
253 IpAddr::V4(ipv4) => ipv4.is_loopback(),
254 IpAddr::V6(ipv6) => {
255 if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
256 return mapped_v4.is_loopback();
257 }
258 ipv6.is_loopback()
259 }
260 }
261}
262
263const fn is_link_local(ip: &IpAddr) -> bool {
267 match ip {
268 IpAddr::V4(ipv4) => ipv4.is_link_local(),
269 IpAddr::V6(ipv6) => {
270 if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
271 return mapped_v4.is_link_local();
272 }
273 let segments = ipv6.segments();
275 (segments[0] & 0xffc0) == 0xfe80
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_valid_https_url() {
286 let validator = UrlValidator::new();
287 assert!(validator.validate("https://example.com").is_ok());
288 assert!(validator.validate("https://example.com/path").is_ok());
289 }
290
291 #[test]
292 fn test_http_blocked_by_default() {
293 let validator = UrlValidator::new();
294 let result = validator.validate("http://example.com");
295 assert!(result.is_err());
296 assert!(result.unwrap_err().to_string().contains("HTTPS required"));
297 }
298
299 #[test]
300 fn test_http_allowed_with_flag() {
301 let validator = UrlValidator::new().with_allow_http();
302 assert!(validator.validate("http://example.com").is_ok());
303 }
304
305 #[test]
306 fn test_localhost_blocked() {
307 let validator = UrlValidator::new().with_allow_http();
308 assert!(validator.validate("http://localhost").is_err());
309 assert!(validator.validate("http://127.0.0.1").is_err());
310 assert!(validator.validate("http://[::1]").is_err());
311 }
312
313 #[test]
314 fn test_metadata_endpoints_blocked() {
315 let validator = UrlValidator::new().with_allow_http();
316 assert!(validator.validate("http://169.254.169.254").is_err());
317 assert!(
318 validator
319 .validate("http://metadata.google.internal")
320 .is_err()
321 );
322 }
323
324 #[test]
325 fn test_invalid_url() {
326 let validator = UrlValidator::new();
327 assert!(validator.validate("not-a-url").is_err());
328 assert!(validator.validate("").is_err());
329 assert!(validator.validate("ftp://example.com").is_err());
330 }
331
332 #[test]
333 fn test_allowed_domains() {
334 let validator = UrlValidator::new().with_allowed_domains(vec!["example.com".to_string()]);
335
336 assert!(validator.validate("https://example.com").is_ok());
337
338 let result = validator.validate("https://other.com");
339 assert!(result.is_err());
340 assert!(
341 result
342 .unwrap_err()
343 .to_string()
344 .contains("not in the allowed domains")
345 );
346 }
347
348 #[test]
349 fn test_blocked_hosts() {
350 let validator = UrlValidator::new().with_blocked_hosts(vec!["blocked.com".to_string()]);
351
352 let result = validator.validate("https://blocked.com");
353 assert!(result.is_err());
354 assert!(result.unwrap_err().to_string().contains("blocked"));
355 }
356
357 #[test]
358 fn test_is_private_ipv4() {
359 assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
361 assert!(is_private_ipv4(Ipv4Addr::new(10, 255, 255, 255)));
362 assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
363 assert!(is_private_ipv4(Ipv4Addr::new(172, 31, 255, 255)));
364 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 0, 1)));
365 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 255, 255)));
366
367 assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
369 assert!(!is_private_ipv4(Ipv4Addr::new(1, 1, 1, 1)));
370 assert!(!is_private_ipv4(Ipv4Addr::new(172, 15, 0, 1)));
371 assert!(!is_private_ipv4(Ipv4Addr::new(172, 32, 0, 1)));
372 }
373
374 #[test]
375 fn test_max_redirects() {
376 let validator = UrlValidator::new().with_max_redirects(5);
377 assert_eq!(validator.max_redirects(), 5);
378 }
379
380 #[test]
381 fn test_default_validator() {
382 let validator = UrlValidator::default();
383 assert!(!validator.allow_private_ips);
384 assert!(validator.require_https);
385 assert_eq!(validator.max_redirects, 3);
386 }
387
388 #[test]
389 fn test_unresolvable_host_blocked() {
390 let validator = UrlValidator::new();
391 let result = validator.validate("https://this-domain-does-not-exist-xyz123.example");
392 assert!(result.is_err());
393 let err_msg = result.unwrap_err().to_string();
394 assert!(
395 err_msg.contains("Could not resolve host"),
396 "Expected DNS resolution failure, got: {err_msg}"
397 );
398 }
399
400 #[test]
401 fn test_ipv4_mapped_ipv6_private_detected() {
402 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001));
404 assert!(is_private_ip(&ip));
405 }
406
407 #[test]
408 fn test_ipv4_mapped_ipv6_loopback_detected() {
409 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001));
411 assert!(is_loopback(&ip));
412 }
413
414 #[test]
415 fn test_ipv4_mapped_ipv6_link_local_detected() {
416 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xa9fe, 0xa9fe));
418 assert!(is_link_local(&ip));
419 }
420
421 #[test]
422 fn test_regular_ipv6_private_still_detected() {
423 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1));
425 assert!(is_private_ip(&ip));
426 }
427
428 #[test]
429 fn test_ipv4_mapped_ipv6_public_not_flagged() {
430 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0808, 0x0808));
432 assert!(!is_private_ip(&ip));
433 assert!(!is_loopback(&ip));
434 assert!(!is_link_local(&ip));
435 }
436}