1use anyhow::{Context, Result, bail};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
5use url::Url;
6
7const DEFAULT_BLOCKED_HOSTS: &[&str] = &[
9 "localhost",
10 "127.0.0.1",
11 "0.0.0.0",
12 "::1",
13 "[::1]",
14 "169.254.169.254", "metadata.google.internal", "metadata.goog", ];
18
19#[derive(Clone, Debug)]
37pub struct UrlValidator {
38 allowed_domains: Option<Vec<String>>,
40 blocked_hosts: Vec<String>,
42 allow_private_ips: bool,
44 max_redirects: usize,
46 require_https: bool,
48}
49
50#[derive(Clone, Debug)]
60pub struct ValidatedUrl {
61 pub url: Url,
63 pub addresses: Vec<SocketAddr>,
66}
67
68impl Default for UrlValidator {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl UrlValidator {
75 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 allowed_domains: None,
80 blocked_hosts: DEFAULT_BLOCKED_HOSTS
81 .iter()
82 .map(|&s| s.to_string())
83 .collect(),
84 allow_private_ips: false,
85 max_redirects: 3,
86 require_https: true,
87 }
88 }
89
90 #[must_use]
92 pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
93 self.allowed_domains = Some(domains);
94 self
95 }
96
97 #[must_use]
99 pub fn with_blocked_hosts(mut self, hosts: Vec<String>) -> Self {
100 self.blocked_hosts.extend(hosts);
101 self
102 }
103
104 #[must_use]
106 pub const fn with_allow_private_ips(mut self, allow: bool) -> Self {
107 self.allow_private_ips = allow;
108 self
109 }
110
111 #[must_use]
113 pub const fn with_max_redirects(mut self, max: usize) -> Self {
114 self.max_redirects = max;
115 self
116 }
117
118 #[must_use]
120 pub const fn with_allow_http(mut self) -> Self {
121 self.require_https = false;
122 self
123 }
124
125 #[must_use]
127 pub const fn max_redirects(&self) -> usize {
128 self.max_redirects
129 }
130
131 pub async fn validate(&self, url_str: &str) -> Result<ValidatedUrl> {
148 let url = Url::parse(url_str).context("Invalid URL format")?;
149
150 match url.scheme() {
152 "https" => {}
153 "http" => {
154 if self.require_https {
155 bail!("HTTPS required, but HTTP URL provided");
156 }
157 }
158 scheme => bail!("Unsupported URL scheme: {scheme}"),
159 }
160
161 let host = url.host_str().context("URL must have a host")?;
163
164 if self.blocked_hosts.iter().any(|blocked| {
166 host.eq_ignore_ascii_case(blocked) || host.ends_with(&format!(".{blocked}"))
167 }) {
168 bail!("Access to host '{host}' is blocked");
169 }
170
171 if let Some(ref allowed) = self.allowed_domains {
173 let is_allowed = allowed.iter().any(|domain| {
174 host.eq_ignore_ascii_case(domain) || host.ends_with(&format!(".{domain}"))
175 });
176 if !is_allowed {
177 bail!("Host '{host}' is not in the allowed domains list");
178 }
179 }
180
181 let port = url.port_or_known_default().unwrap_or(443);
184 let addresses = self.validate_resolved_ip(host, port).await?;
185
186 Ok(ValidatedUrl { url, addresses })
187 }
188
189 async fn validate_resolved_ip(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
196 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
199 .await
200 .map(Iterator::collect)
201 .unwrap_or_default();
202
203 if addrs.is_empty() {
204 bail!("Could not resolve host '{host}' — blocking unresolvable URLs for safety");
205 }
206
207 for addr in &addrs {
208 let ip = addr.ip();
209 if !self.allow_private_ips && is_private_ip(&ip) {
210 bail!("Access to private IP address {ip} is blocked");
211 }
212 if is_loopback(&ip) {
213 bail!("Access to loopback address {ip} is blocked");
214 }
215 if is_link_local(&ip) {
216 bail!("Access to link-local address {ip} is blocked");
217 }
218 }
219
220 Ok(addrs)
221 }
222}
223
224fn is_private_ip(ip: &IpAddr) -> bool {
229 match ip {
230 IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
231 IpAddr::V6(ipv6) => {
232 if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
234 return is_private_ipv4(mapped_v4);
235 }
236 is_private_ipv6(ipv6)
237 }
238 }
239}
240
241fn is_private_ipv4(ip: Ipv4Addr) -> bool {
243 let octets = ip.octets();
244
245 if octets[0] == 10 {
247 return true;
248 }
249
250 if octets[0] == 172 && (16..=31).contains(&octets[1]) {
252 return true;
253 }
254
255 if octets[0] == 192 && octets[1] == 168 {
257 return true;
258 }
259
260 if octets[0] == 100 && (64..=127).contains(&octets[1]) {
262 return true;
263 }
264
265 false
266}
267
268const fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
270 let segments = ip.segments();
272 (segments[0] & 0xfe00) == 0xfc00
273}
274
275const fn is_loopback(ip: &IpAddr) -> bool {
279 match ip {
280 IpAddr::V4(ipv4) => ipv4.is_loopback(),
281 IpAddr::V6(ipv6) => {
282 if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
283 return mapped_v4.is_loopback();
284 }
285 ipv6.is_loopback()
286 }
287 }
288}
289
290const fn is_link_local(ip: &IpAddr) -> bool {
294 match ip {
295 IpAddr::V4(ipv4) => ipv4.is_link_local(),
296 IpAddr::V6(ipv6) => {
297 if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
298 return mapped_v4.is_link_local();
299 }
300 let segments = ipv6.segments();
302 (segments[0] & 0xffc0) == 0xfe80
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[tokio::test]
312 async fn test_valid_https_url() {
313 let validator = UrlValidator::new();
314 assert!(validator.validate("https://example.com").await.is_ok());
315 assert!(validator.validate("https://example.com/path").await.is_ok());
316 }
317
318 #[tokio::test]
319 async fn test_validate_returns_vetted_addresses() -> Result<()> {
320 let validator = UrlValidator::new();
323 let validated = validator
324 .validate("https://example.com")
325 .await
326 .context("example.com should validate")?;
327 assert!(
328 !validated.addresses.is_empty(),
329 "validation must return the vetted IP addresses for pinning"
330 );
331 assert!(validated.addresses.iter().all(|a| a.port() == 443));
333 Ok(())
334 }
335
336 #[tokio::test]
337 async fn test_http_blocked_by_default() {
338 let validator = UrlValidator::new();
339 let result = validator.validate("http://example.com").await;
340 assert!(result.is_err());
341 assert!(result.unwrap_err().to_string().contains("HTTPS required"));
342 }
343
344 #[tokio::test]
345 async fn test_http_allowed_with_flag() {
346 let validator = UrlValidator::new().with_allow_http();
347 assert!(validator.validate("http://example.com").await.is_ok());
348 }
349
350 #[tokio::test]
351 async fn test_localhost_blocked() {
352 let validator = UrlValidator::new().with_allow_http();
353 assert!(validator.validate("http://localhost").await.is_err());
354 assert!(validator.validate("http://127.0.0.1").await.is_err());
355 assert!(validator.validate("http://[::1]").await.is_err());
356 }
357
358 #[tokio::test]
359 async fn test_metadata_endpoints_blocked() {
360 let validator = UrlValidator::new().with_allow_http();
361 assert!(validator.validate("http://169.254.169.254").await.is_err());
362 assert!(
363 validator
364 .validate("http://metadata.google.internal")
365 .await
366 .is_err()
367 );
368 }
369
370 #[tokio::test]
371 async fn test_invalid_url() {
372 let validator = UrlValidator::new();
373 assert!(validator.validate("not-a-url").await.is_err());
374 assert!(validator.validate("").await.is_err());
375 assert!(validator.validate("ftp://example.com").await.is_err());
376 }
377
378 #[tokio::test]
379 async fn test_allowed_domains() {
380 let validator = UrlValidator::new().with_allowed_domains(vec!["example.com".to_string()]);
381
382 assert!(validator.validate("https://example.com").await.is_ok());
383
384 let result = validator.validate("https://other.com").await;
385 assert!(result.is_err());
386 assert!(
387 result
388 .unwrap_err()
389 .to_string()
390 .contains("not in the allowed domains")
391 );
392 }
393
394 #[tokio::test]
395 async fn test_blocked_hosts() {
396 let validator = UrlValidator::new().with_blocked_hosts(vec!["blocked.com".to_string()]);
397
398 let result = validator.validate("https://blocked.com").await;
399 assert!(result.is_err());
400 assert!(result.unwrap_err().to_string().contains("blocked"));
401 }
402
403 #[test]
404 fn test_is_private_ipv4() {
405 assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
407 assert!(is_private_ipv4(Ipv4Addr::new(10, 255, 255, 255)));
408 assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
409 assert!(is_private_ipv4(Ipv4Addr::new(172, 31, 255, 255)));
410 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 0, 1)));
411 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 255, 255)));
412
413 assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
415 assert!(!is_private_ipv4(Ipv4Addr::new(1, 1, 1, 1)));
416 assert!(!is_private_ipv4(Ipv4Addr::new(172, 15, 0, 1)));
417 assert!(!is_private_ipv4(Ipv4Addr::new(172, 32, 0, 1)));
418 }
419
420 #[test]
421 fn test_max_redirects() {
422 let validator = UrlValidator::new().with_max_redirects(5);
423 assert_eq!(validator.max_redirects(), 5);
424 }
425
426 #[test]
427 fn test_default_validator() {
428 let validator = UrlValidator::default();
429 assert!(!validator.allow_private_ips);
430 assert!(validator.require_https);
431 assert_eq!(validator.max_redirects, 3);
432 }
433
434 #[tokio::test]
435 async fn test_unresolvable_host_blocked() {
436 let validator = UrlValidator::new();
437 let result = validator
438 .validate("https://this-domain-does-not-exist-xyz123.example")
439 .await;
440 assert!(result.is_err());
441 let err_msg = result.unwrap_err().to_string();
442 assert!(
443 err_msg.contains("Could not resolve host"),
444 "Expected DNS resolution failure, got: {err_msg}"
445 );
446 }
447
448 #[test]
449 fn test_ipv4_mapped_ipv6_private_detected() {
450 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001));
452 assert!(is_private_ip(&ip));
453 }
454
455 #[test]
456 fn test_ipv4_mapped_ipv6_loopback_detected() {
457 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001));
459 assert!(is_loopback(&ip));
460 }
461
462 #[test]
463 fn test_ipv4_mapped_ipv6_link_local_detected() {
464 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xa9fe, 0xa9fe));
466 assert!(is_link_local(&ip));
467 }
468
469 #[test]
470 fn test_regular_ipv6_private_still_detected() {
471 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1));
473 assert!(is_private_ip(&ip));
474 }
475
476 #[test]
477 fn test_ipv4_mapped_ipv6_public_not_flagged() {
478 let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0808, 0x0808));
480 assert!(!is_private_ip(&ip));
481 assert!(!is_loopback(&ip));
482 assert!(!is_link_local(&ip));
483 }
484}