1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
14use std::time::Duration;
15use url::Url;
16
17const DNS_LOOKUP_TIMEOUT: Duration = Duration::from_secs(5);
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum UrlValidationError {
24 InvalidUrl(String),
26 DisallowedScheme(String),
28 MissingHostname,
30 BlockedHost(String),
32}
33
34impl std::fmt::Display for UrlValidationError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::InvalidUrl(msg) => write!(f, "Invalid URL: {msg}"),
38 Self::DisallowedScheme(scheme) => {
39 write!(f, "Disallowed URL scheme: {scheme} (must be http or https)")
40 }
41 Self::MissingHostname => write!(f, "URL must have a hostname"),
42 Self::BlockedHost(host) => {
43 write!(f, "Blocked host: {host} (private/internal address)")
44 }
45 }
46 }
47}
48
49impl std::error::Error for UrlValidationError {}
50
51pub fn validate_safe_url(raw_url: &str) -> Result<Url, UrlValidationError> {
57 let url = Url::parse(raw_url).map_err(|e| UrlValidationError::InvalidUrl(e.to_string()))?;
58
59 match url.scheme() {
61 "http" | "https" => {}
62 other => return Err(UrlValidationError::DisallowedScheme(other.to_string())),
63 }
64
65 let host = url.host_str().ok_or(UrlValidationError::MissingHostname)?;
67
68 if is_blocked_host(host) {
70 return Err(UrlValidationError::BlockedHost(host.to_string()));
71 }
72
73 Ok(url)
74}
75
76pub async fn validate_url_dns_pinned(
90 raw_url: &str,
91) -> Result<(Url, Vec<SocketAddr>), UrlValidationError> {
92 validate_url_with_resolver(raw_url, default_dns_resolve).await
93}
94
95async fn validate_url_with_resolver<R, F>(
97 raw_url: &str,
98 resolve: R,
99) -> Result<(Url, Vec<SocketAddr>), UrlValidationError>
100where
101 R: Fn(String, u16) -> F,
102 F: std::future::Future<Output = Result<Vec<SocketAddr>, std::io::Error>>,
103{
104 let url = validate_safe_url(raw_url)?;
106 let host = url
107 .host_str()
108 .ok_or(UrlValidationError::MissingHostname)?
109 .to_string();
110
111 let bare = host
113 .strip_prefix('[')
114 .and_then(|s| s.strip_suffix(']'))
115 .unwrap_or(&host);
116 if bare.parse::<IpAddr>().is_ok() {
117 return Ok((url, Vec::new()));
118 }
119
120 let port = url.port_or_known_default().unwrap_or(443);
122 let addrs = resolve(host.clone(), port)
123 .await
124 .map_err(|_| UrlValidationError::BlockedHost(host.clone()))?;
125
126 if addrs.is_empty() {
127 return Err(UrlValidationError::BlockedHost(host.clone()));
128 }
129
130 for addr in &addrs {
131 if is_blocked_ip(addr.ip()) {
132 tracing::warn!(
133 host = %host,
134 resolved_ip = %addr.ip(),
135 "DNS rebinding check blocked: hostname resolves to private address"
136 );
137 return Err(UrlValidationError::BlockedHost(format!(
138 "{host} resolves to blocked address {}",
139 addr.ip()
140 )));
141 }
142 }
143
144 Ok((url, addrs))
145}
146
147async fn default_dns_resolve(host: String, port: u16) -> Result<Vec<SocketAddr>, std::io::Error> {
150 tokio::time::timeout(
151 DNS_LOOKUP_TIMEOUT,
152 tokio::net::lookup_host(format!("{host}:{port}")),
153 )
154 .await
155 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "DNS lookup timed out"))?
156 .map(|iter| iter.collect())
157}
158
159fn is_blocked_host(host: &str) -> bool {
161 let host_lower = host.to_lowercase();
162
163 if host_lower == "localhost"
165 || host_lower == "localhost."
166 || host_lower.ends_with(".localhost")
167 || host_lower.ends_with(".localhost.")
168 {
169 return true;
170 }
171
172 let bare = host_lower
174 .strip_prefix('[')
175 .and_then(|s| s.strip_suffix(']'))
176 .unwrap_or(&host_lower);
177
178 if let Ok(ip) = bare.parse::<IpAddr>() {
180 return is_blocked_ip(ip);
181 }
182
183 if host_lower == "metadata.google.internal" || host_lower == "metadata.google.internal." {
185 return true;
186 }
187
188 false
189}
190
191pub fn is_blocked_ip(ip: IpAddr) -> bool {
198 match ip {
199 IpAddr::V4(v4) => is_blocked_ipv4(v4),
200 IpAddr::V6(v6) => is_blocked_ipv6(v6),
201 }
202}
203
204fn is_blocked_ipv4(ip: Ipv4Addr) -> bool {
205 let octets = ip.octets();
206
207 if octets[0] == 127 {
209 return true;
210 }
211
212 if ip.is_unspecified() {
214 return true;
215 }
216
217 if octets[0] == 10 {
219 return true;
220 }
221
222 if octets[0] == 172 && (16..=31).contains(&octets[1]) {
224 return true;
225 }
226
227 if octets[0] == 192 && octets[1] == 168 {
229 return true;
230 }
231
232 if octets[0] == 169 && octets[1] == 254 {
234 return true;
235 }
236
237 if octets[0] == 100 && (64..=127).contains(&octets[1]) {
239 return true;
240 }
241
242 if (octets[0] == 192 && octets[1] == 0 && octets[2] == 2)
244 || (octets[0] == 198 && octets[1] == 51 && octets[2] == 100)
245 || (octets[0] == 203 && octets[1] == 0 && octets[2] == 113)
246 {
247 return true;
248 }
249
250 false
251}
252
253fn is_blocked_ipv6(ip: Ipv6Addr) -> bool {
254 if ip.is_loopback() {
256 return true;
257 }
258
259 if ip.is_unspecified() {
261 return true;
262 }
263
264 let segments = ip.segments();
266 if segments[0] & 0xffc0 == 0xfe80 {
267 return true;
268 }
269
270 if segments[0] & 0xfe00 == 0xfc00 {
272 return true;
273 }
274
275 if let Some(v4) = ip.to_ipv4_mapped() {
277 return is_blocked_ipv4(v4);
278 }
279
280 false
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
290 fn accepts_https_public_url() {
291 assert!(validate_safe_url("https://mcp.example.com/v1/mcp").is_ok());
292 }
293
294 #[test]
295 fn accepts_http_public_url() {
296 assert!(validate_safe_url("http://mcp.example.com/v1/mcp").is_ok());
297 }
298
299 #[test]
300 fn accepts_url_with_port() {
301 assert!(validate_safe_url("https://mcp.example.com:8443/v1/mcp").is_ok());
302 }
303
304 #[test]
305 fn accepts_url_with_path_and_query() {
306 assert!(validate_safe_url("https://api.example.com/mcp?key=val").is_ok());
307 }
308
309 #[test]
312 fn rejects_ftp_scheme() {
313 let err = validate_safe_url("ftp://evil.com/file").unwrap_err();
314 assert!(matches!(err, UrlValidationError::DisallowedScheme(_)));
315 }
316
317 #[test]
318 fn rejects_file_scheme() {
319 let err = validate_safe_url("file:///etc/passwd").unwrap_err();
320 assert!(matches!(err, UrlValidationError::DisallowedScheme(_)));
321 }
322
323 #[test]
324 fn rejects_javascript_scheme() {
325 let err = validate_safe_url("javascript:alert(1)").unwrap_err();
326 assert!(
328 matches!(err, UrlValidationError::DisallowedScheme(_))
329 || matches!(err, UrlValidationError::MissingHostname)
330 );
331 }
332
333 #[test]
334 fn rejects_data_scheme() {
335 let err = validate_safe_url("data:text/plain,hello").unwrap_err();
336 assert!(
337 matches!(err, UrlValidationError::DisallowedScheme(_))
338 || matches!(err, UrlValidationError::MissingHostname)
339 );
340 }
341
342 #[test]
345 fn rejects_empty_string() {
346 assert!(validate_safe_url("").is_err());
347 }
348
349 #[test]
350 fn rejects_not_a_url() {
351 assert!(validate_safe_url("not a url").is_err());
352 }
353
354 #[test]
357 fn rejects_localhost() {
358 let err = validate_safe_url("http://localhost/path").unwrap_err();
359 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
360 }
361
362 #[test]
363 fn rejects_localhost_with_port() {
364 let err = validate_safe_url("http://localhost:8080/path").unwrap_err();
365 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
366 }
367
368 #[test]
369 fn rejects_subdomain_of_localhost() {
370 let err = validate_safe_url("http://foo.localhost/path").unwrap_err();
371 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
372 }
373
374 #[test]
377 fn rejects_127_0_0_1() {
378 let err = validate_safe_url("http://127.0.0.1/path").unwrap_err();
379 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
380 }
381
382 #[test]
383 fn rejects_127_x_x_x() {
384 let err = validate_safe_url("http://127.255.0.1/path").unwrap_err();
385 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
386 }
387
388 #[test]
389 fn rejects_ipv6_loopback() {
390 let err = validate_safe_url("http://[::1]/path").unwrap_err();
391 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
392 }
393
394 #[test]
397 fn rejects_10_x() {
398 let err = validate_safe_url("http://10.0.0.1/path").unwrap_err();
399 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
400 }
401
402 #[test]
403 fn rejects_172_16_x() {
404 let err = validate_safe_url("http://172.16.0.1/path").unwrap_err();
405 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
406 }
407
408 #[test]
409 fn rejects_172_31_x() {
410 let err = validate_safe_url("http://172.31.255.255/path").unwrap_err();
411 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
412 }
413
414 #[test]
415 fn accepts_172_32_x() {
416 assert!(validate_safe_url("http://172.32.0.1/path").is_ok());
418 }
419
420 #[test]
421 fn rejects_192_168_x() {
422 let err = validate_safe_url("http://192.168.1.1/path").unwrap_err();
423 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
424 }
425
426 #[test]
429 fn rejects_link_local() {
430 let err = validate_safe_url("http://169.254.1.1/path").unwrap_err();
431 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
432 }
433
434 #[test]
435 fn rejects_cloud_metadata_ip() {
436 let err = validate_safe_url("http://169.254.169.254/latest/meta-data/").unwrap_err();
437 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
438 }
439
440 #[test]
441 fn rejects_gce_metadata_hostname() {
442 let err =
443 validate_safe_url("http://metadata.google.internal/computeMetadata/v1/").unwrap_err();
444 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
445 }
446
447 #[test]
450 fn rejects_unspecified_v4() {
451 let err = validate_safe_url("http://0.0.0.0/path").unwrap_err();
452 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
453 }
454
455 #[test]
458 fn rejects_ipv6_unspecified() {
459 let err = validate_safe_url("http://[::]/path").unwrap_err();
460 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
461 }
462
463 #[test]
464 fn rejects_ipv6_link_local() {
465 let err = validate_safe_url("http://[fe80::1]/path").unwrap_err();
466 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
467 }
468
469 #[test]
470 fn rejects_ipv6_unique_local() {
471 let err = validate_safe_url("http://[fd00::1]/path").unwrap_err();
472 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
473 }
474
475 #[test]
476 fn rejects_ipv4_mapped_ipv6_private() {
477 let err = validate_safe_url("http://[::ffff:127.0.0.1]/path").unwrap_err();
478 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
479 }
480
481 #[test]
482 fn rejects_ipv4_mapped_ipv6_metadata() {
483 let err =
484 validate_safe_url("http://[::ffff:169.254.169.254]/latest/meta-data/").unwrap_err();
485 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
486 }
487
488 #[test]
491 fn rejects_cgnat() {
492 let err = validate_safe_url("http://100.64.0.1/path").unwrap_err();
493 assert!(matches!(err, UrlValidationError::BlockedHost(_)));
494 }
495
496 #[test]
499 fn error_display_messages() {
500 assert!(
501 UrlValidationError::BlockedHost("localhost".into())
502 .to_string()
503 .contains("private/internal")
504 );
505 assert!(
506 UrlValidationError::DisallowedScheme("ftp".into())
507 .to_string()
508 .contains("http or https")
509 );
510 }
511
512 #[tokio::test]
515 async fn dns_pinned_rejects_private_ip_literal() {
516 let result = validate_url_dns_pinned("http://10.0.0.1/mcp").await;
517 assert!(matches!(result, Err(UrlValidationError::BlockedHost(_))));
518 }
519
520 #[tokio::test]
521 async fn dns_pinned_rejects_loopback_ip_literal() {
522 let result = validate_url_dns_pinned("http://127.0.0.1/mcp").await;
523 assert!(matches!(result, Err(UrlValidationError::BlockedHost(_))));
524 }
525
526 #[tokio::test]
527 async fn dns_pinned_rejects_metadata_ip_literal() {
528 let result = validate_url_dns_pinned("http://169.254.169.254/latest/meta-data/").await;
529 assert!(matches!(result, Err(UrlValidationError::BlockedHost(_))));
530 }
531
532 #[tokio::test]
533 async fn dns_pinned_rejects_localhost_hostname() {
534 let result = validate_url_dns_pinned("http://localhost:8080/mcp").await;
536 assert!(matches!(result, Err(UrlValidationError::BlockedHost(_))));
537 }
538
539 #[tokio::test]
540 async fn dns_pinned_rejects_bad_scheme() {
541 let result = validate_url_dns_pinned("ftp://example.com/mcp").await;
542 assert!(matches!(
543 result,
544 Err(UrlValidationError::DisallowedScheme(_))
545 ));
546 }
547
548 async fn private_ip_resolver(
552 _host: String,
553 _port: u16,
554 ) -> Result<Vec<SocketAddr>, std::io::Error> {
555 Ok(vec!["10.0.0.1:80".parse().unwrap()])
556 }
557
558 async fn public_ip_resolver(
560 _host: String,
561 _port: u16,
562 ) -> Result<Vec<SocketAddr>, std::io::Error> {
563 Ok(vec!["1.1.1.1:443".parse().unwrap()])
564 }
565
566 async fn failing_resolver(
568 _host: String,
569 _port: u16,
570 ) -> Result<Vec<SocketAddr>, std::io::Error> {
571 Err(std::io::Error::new(
572 std::io::ErrorKind::TimedOut,
573 "DNS lookup timed out",
574 ))
575 }
576
577 async fn empty_resolver(_host: String, _port: u16) -> Result<Vec<SocketAddr>, std::io::Error> {
579 Ok(vec![])
580 }
581
582 #[tokio::test]
583 async fn dns_resolver_blocks_hostname_resolving_to_private_ip() {
584 let result =
586 validate_url_with_resolver("http://evil.example.com/mcp", private_ip_resolver).await;
587 assert!(
588 matches!(result, Err(UrlValidationError::BlockedHost(_))),
589 "expected BlockedHost, got {result:?}"
590 );
591 }
592
593 #[tokio::test]
594 async fn dns_resolver_allows_hostname_resolving_to_public_ip() {
595 let (url, addrs) =
596 validate_url_with_resolver("https://mcp.example.com/v1/mcp", public_ip_resolver)
597 .await
598 .expect("should succeed");
599 assert_eq!(url.host_str(), Some("mcp.example.com"));
600 assert_eq!(addrs.len(), 1);
602 }
603
604 #[tokio::test]
605 async fn dns_resolver_blocks_on_lookup_failure() {
606 let result = validate_url_with_resolver("http://example.com/mcp", failing_resolver).await;
607 assert!(matches!(result, Err(UrlValidationError::BlockedHost(_))));
608 }
609
610 #[tokio::test]
611 async fn dns_resolver_blocks_empty_response() {
612 let result = validate_url_with_resolver("http://example.com/mcp", empty_resolver).await;
613 assert!(matches!(result, Err(UrlValidationError::BlockedHost(_))));
614 }
615
616 #[tokio::test]
617 async fn dns_resolver_returns_addrs_for_connection_pinning() {
618 let (_url, addrs) =
619 validate_url_with_resolver("https://mcp.example.com/v1/mcp", public_ip_resolver)
620 .await
621 .unwrap();
622 assert!(!addrs.is_empty());
623 }
624}