1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
9
10use regex::Regex;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub enum SsrfViolation {
20 BlockedIp { ip: String, range: String },
22 BlockedScheme { scheme: String },
24 BlockedHost { host: String },
26 DnsResolutionFailed { host: String, reason: String },
28 PrivateIp { ip: String },
30 BlockedPattern { pattern: String, url: String },
32 InvalidUrl { reason: String },
34}
35
36impl std::fmt::Display for SsrfViolation {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::BlockedIp { ip, range } => {
40 write!(f, "SSRF: IP {} falls within blocked range {}", ip, range)
41 }
42 Self::BlockedScheme { scheme } => {
43 write!(f, "SSRF: scheme '{}' is not allowed", scheme)
44 }
45 Self::BlockedHost { host } => {
46 write!(f, "SSRF: hostname '{}' is blocked", host)
47 }
48 Self::DnsResolutionFailed { host, reason } => {
49 write!(f, "SSRF: DNS resolution failed for '{}': {}", host, reason)
50 }
51 Self::PrivateIp { ip } => {
52 write!(f, "SSRF: resolved IP {} is in a private range", ip)
53 }
54 Self::BlockedPattern { pattern, url } => {
55 write!(
56 f,
57 "SSRF: URL '{}' matched blocked pattern '{}'",
58 url, pattern
59 )
60 }
61 Self::InvalidUrl { reason } => {
62 write!(f, "SSRF: invalid URL: {}", reason)
63 }
64 }
65 }
66}
67
68impl std::error::Error for SsrfViolation {}
69
70#[derive(Debug, Clone)]
76struct CidrRange {
77 label: String,
79 network: IpAddr,
81 prefix_len: u8,
83}
84
85impl CidrRange {
86 fn contains(&self, ip: &IpAddr) -> bool {
87 match (&self.network, ip) {
88 (IpAddr::V4(net), IpAddr::V4(addr)) => {
89 let net_bits = u32::from(*net);
90 let addr_bits = u32::from(*addr);
91 if self.prefix_len == 0 {
92 return true;
93 }
94 if self.prefix_len >= 32 {
95 return net_bits == addr_bits;
96 }
97 let mask = !((1u32 << (32 - self.prefix_len)) - 1);
98 (net_bits & mask) == (addr_bits & mask)
99 }
100 (IpAddr::V6(net), IpAddr::V6(addr)) => {
101 let net_bits = u128::from(*net);
102 let addr_bits = u128::from(*addr);
103 if self.prefix_len == 0 {
104 return true;
105 }
106 if self.prefix_len >= 128 {
107 return net_bits == addr_bits;
108 }
109 let mask = !((1u128 << (128 - self.prefix_len)) - 1);
110 (net_bits & mask) == (addr_bits & mask)
111 }
112 _ => false,
113 }
114 }
115}
116
117fn parse_cidr(s: &str) -> Option<CidrRange> {
118 let parts: Vec<&str> = s.split('/').collect();
119 if parts.len() != 2 {
120 return None;
121 }
122 let ip: IpAddr = parts[0].parse().ok()?;
123 let prefix_len: u8 = parts[1].parse().ok()?;
124 Some(CidrRange {
125 label: s.to_string(),
126 network: ip,
127 prefix_len,
128 })
129}
130
131#[derive(Debug, Clone)]
141pub struct SsrfProtector {
142 blocked_ranges: Vec<CidrRange>,
144 blocked_hosts: Vec<String>,
146 blocked_schemes: Vec<String>,
148 allowed_hosts: Vec<String>,
150 blocked_patterns: Vec<(String, Regex)>,
152 dns_check_enabled: bool,
154}
155
156impl Default for SsrfProtector {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl SsrfProtector {
163 pub fn new() -> Self {
165 let default_cidrs = [
166 "127.0.0.0/8",
167 "10.0.0.0/8",
168 "172.16.0.0/12",
169 "192.168.0.0/16",
170 "169.254.0.0/16",
171 "::1/128",
172 "fc00::/7",
173 "fe80::/10",
174 ];
175
176 let blocked_ranges: Vec<CidrRange> =
177 default_cidrs.iter().filter_map(|c| parse_cidr(c)).collect();
178
179 Self {
180 blocked_ranges,
181 blocked_hosts: vec![
182 "localhost".to_string(),
183 "metadata.google.internal".to_string(),
184 "169.254.169.254".to_string(),
185 ],
186 blocked_schemes: vec!["file".to_string(), "ftp".to_string(), "gopher".to_string()],
187 allowed_hosts: Vec::new(),
188 blocked_patterns: Vec::new(),
189 dns_check_enabled: true,
190 }
191 }
192
193 pub fn allow_host(&mut self, host: &str) {
195 self.allowed_hosts.push(host.to_lowercase());
196 }
197
198 pub fn add_blocked_pattern(&mut self, name: &str, pattern: &str) {
200 if let Ok(re) = Regex::new(pattern) {
201 self.blocked_patterns.push((name.to_string(), re));
202 }
203 }
204
205 pub fn add_blocked_range(&mut self, cidr: &str) {
207 if let Some(range) = parse_cidr(cidr) {
208 self.blocked_ranges.push(range);
209 }
210 }
211
212 pub fn block_host(&mut self, host: &str) {
214 self.blocked_hosts.push(host.to_lowercase());
215 }
216
217 pub fn set_dns_check(&mut self, enabled: bool) {
219 self.dns_check_enabled = enabled;
220 }
221
222 pub fn validate_url(&self, url: &str) -> Result<(), SsrfViolation> {
224 for (name, re) in &self.blocked_patterns {
226 if re.is_match(url) {
227 return Err(SsrfViolation::BlockedPattern {
228 pattern: name.clone(),
229 url: url.to_string(),
230 });
231 }
232 }
233
234 let scheme = extract_scheme(url)?;
236 if self.blocked_schemes.contains(&scheme.to_lowercase()) {
237 return Err(SsrfViolation::BlockedScheme { scheme });
238 }
239
240 let host = extract_host(url)?;
242 let host_lower = host.to_lowercase();
243
244 if self.blocked_hosts.contains(&host_lower) {
246 return Err(SsrfViolation::BlockedHost {
247 host: host.to_string(),
248 });
249 }
250
251 if self.allowed_hosts.contains(&host_lower) {
253 return Ok(());
254 }
255
256 if let Ok(ip) = host.parse::<IpAddr>() {
258 self.check_ip(&ip)?;
259 return Ok(());
260 }
261
262 if self.dns_check_enabled {
264 self.check_dns(&host)?;
265 }
266
267 Ok(())
268 }
269
270 fn check_ip(&self, ip: &IpAddr) -> Result<(), SsrfViolation> {
272 if is_private_ip(ip) {
274 return Err(SsrfViolation::PrivateIp { ip: ip.to_string() });
275 }
276
277 for range in &self.blocked_ranges {
279 if range.contains(ip) {
280 return Err(SsrfViolation::BlockedIp {
281 ip: ip.to_string(),
282 range: range.label.clone(),
283 });
284 }
285 }
286
287 Ok(())
288 }
289
290 fn check_dns(&self, host: &str) -> Result<(), SsrfViolation> {
292 let addr_str = format!("{}:80", host);
293 let addrs = addr_str
294 .to_socket_addrs()
295 .map_err(|e| SsrfViolation::DnsResolutionFailed {
296 host: host.to_string(),
297 reason: e.to_string(),
298 })?;
299
300 for addr in addrs {
301 self.check_ip(&addr.ip())?;
302 }
303
304 Ok(())
305 }
306}
307
308fn is_private_ip(ip: &IpAddr) -> bool {
314 match ip {
315 IpAddr::V4(v4) => {
316 v4.is_loopback()
317 || v4.is_private()
318 || v4.is_link_local()
319 || v4.is_broadcast()
320 || v4.is_unspecified()
321 || *v4 == Ipv4Addr::new(169, 254, 169, 254)
322 }
323 IpAddr::V6(v6) => {
324 v6.is_loopback()
325 || v6.is_unspecified()
326 || is_ipv6_unique_local(v6)
327 || is_ipv6_link_local(v6)
328 }
329 }
330}
331
332fn is_ipv6_unique_local(v6: &Ipv6Addr) -> bool {
333 let first_byte = v6.octets()[0];
335 (first_byte & 0xFE) == 0xFC
336}
337
338fn is_ipv6_link_local(v6: &Ipv6Addr) -> bool {
339 let octets = v6.octets();
341 octets[0] == 0xFE && (octets[1] & 0xC0) == 0x80
342}
343
344fn extract_scheme(url: &str) -> Result<String, SsrfViolation> {
346 if let Some(idx) = url.find("://") {
347 Ok(url[..idx].to_string())
348 } else {
349 Err(SsrfViolation::InvalidUrl {
350 reason: "missing scheme (no :// found)".into(),
351 })
352 }
353}
354
355fn extract_host(url: &str) -> Result<String, SsrfViolation> {
357 let after_scheme =
358 url.find("://")
359 .map(|i| &url[i + 3..])
360 .ok_or_else(|| SsrfViolation::InvalidUrl {
361 reason: "missing scheme".into(),
362 })?;
363
364 let after_userinfo = if let Some(at) = after_scheme.find('@') {
366 &after_scheme[at + 1..]
367 } else {
368 after_scheme
369 };
370
371 if after_userinfo.starts_with('[') {
373 if let Some(end) = after_userinfo.find(']') {
374 return Ok(after_userinfo[1..end].to_string());
375 }
376 return Err(SsrfViolation::InvalidUrl {
377 reason: "unclosed bracket in IPv6 address".into(),
378 });
379 }
380
381 let host = after_userinfo
383 .split([':', '/', '?', '#'])
384 .next()
385 .unwrap_or("");
386
387 if host.is_empty() {
388 return Err(SsrfViolation::InvalidUrl {
389 reason: "empty hostname".into(),
390 });
391 }
392
393 Ok(host.to_string())
394}
395
396#[cfg(test)]
401mod tests {
402 use super::*;
403
404 fn protector_no_dns() -> SsrfProtector {
405 let mut p = SsrfProtector::new();
406 p.set_dns_check(false);
407 p
408 }
409
410 #[test]
411 fn test_blocks_localhost() {
412 let p = protector_no_dns();
413 let result = p.validate_url("http://localhost/admin");
414 assert!(result.is_err());
415 assert!(matches!(
416 result.unwrap_err(),
417 SsrfViolation::BlockedHost { .. }
418 ));
419 }
420
421 #[test]
422 fn test_blocks_127_0_0_1() {
423 let p = protector_no_dns();
424 let result = p.validate_url("http://127.0.0.1/admin");
425 assert!(result.is_err());
426 match result.unwrap_err() {
427 SsrfViolation::PrivateIp { ip } | SsrfViolation::BlockedIp { ip, .. } => {
428 assert!(ip.starts_with("127."));
429 }
430 other => panic!("expected PrivateIp or BlockedIp, got {:?}", other),
431 }
432 }
433
434 #[test]
435 fn test_blocks_10_x_private_range() {
436 let p = protector_no_dns();
437 let result = p.validate_url("http://10.0.0.1/internal");
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn test_blocks_172_16_private_range() {
443 let p = protector_no_dns();
444 let result = p.validate_url("http://172.16.0.1/secret");
445 assert!(result.is_err());
446 }
447
448 #[test]
449 fn test_blocks_192_168_private_range() {
450 let p = protector_no_dns();
451 let result = p.validate_url("http://192.168.1.1/router");
452 assert!(result.is_err());
453 }
454
455 #[test]
456 fn test_blocks_link_local() {
457 let p = protector_no_dns();
458 let result = p.validate_url("http://169.254.169.254/latest/meta-data/");
459 assert!(result.is_err());
460 }
461
462 #[test]
463 fn test_blocks_ipv6_localhost() {
464 let p = protector_no_dns();
465 let result = p.validate_url("http://[::1]/admin");
466 assert!(result.is_err());
467 }
468
469 #[test]
470 fn test_blocks_file_scheme() {
471 let p = protector_no_dns();
472 let result = p.validate_url("file:///etc/passwd");
473 assert!(result.is_err());
474 assert!(matches!(
475 result.unwrap_err(),
476 SsrfViolation::BlockedScheme { .. }
477 ));
478 }
479
480 #[test]
481 fn test_blocks_ftp_scheme() {
482 let p = protector_no_dns();
483 let result = p.validate_url("ftp://internal-server/data");
484 assert!(result.is_err());
485 assert!(matches!(
486 result.unwrap_err(),
487 SsrfViolation::BlockedScheme { .. }
488 ));
489 }
490
491 #[test]
492 fn test_blocks_gopher_scheme() {
493 let p = protector_no_dns();
494 let result = p.validate_url("gopher://evil.com/1");
495 assert!(result.is_err());
496 assert!(matches!(
497 result.unwrap_err(),
498 SsrfViolation::BlockedScheme { .. }
499 ));
500 }
501
502 #[test]
503 fn test_allows_public_url() {
504 let p = protector_no_dns();
505 let result = p.validate_url("https://example.com/api");
506 assert!(result.is_ok());
507 }
508
509 #[test]
510 fn test_allows_explicit_allowed_host() {
511 let mut p = protector_no_dns();
512 p.allow_host("internal.mycompany.com");
513 let result = p.validate_url("http://internal.mycompany.com/api");
514 assert!(result.is_ok());
515 }
516
517 #[test]
518 fn test_blocks_custom_pattern() {
519 let mut p = protector_no_dns();
520 p.add_blocked_pattern("aws_metadata", r"169\.254\.169\.254");
521 let result = p.validate_url("http://169.254.169.254/latest/");
522 assert!(result.is_err());
523 }
524
525 #[test]
526 fn test_blocks_metadata_google_internal() {
527 let p = protector_no_dns();
528 let result = p.validate_url("http://metadata.google.internal/computeMetadata/v1/");
529 assert!(result.is_err());
530 }
531
532 #[test]
533 fn test_allows_public_ip() {
534 let p = protector_no_dns();
535 let result = p.validate_url("http://8.8.8.8/dns");
536 assert!(result.is_ok());
537 }
538
539 #[test]
540 fn test_invalid_url_no_scheme() {
541 let p = protector_no_dns();
542 let result = p.validate_url("just-a-hostname");
543 assert!(result.is_err());
544 assert!(matches!(
545 result.unwrap_err(),
546 SsrfViolation::InvalidUrl { .. }
547 ));
548 }
549
550 #[test]
551 fn test_cidr_range_contains() {
552 let range = parse_cidr("10.0.0.0/8").unwrap();
553 assert!(range.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
554 assert!(range.contains(&IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
555 assert!(!range.contains(&IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
556 }
557
558 #[test]
559 fn test_ipv6_unique_local_blocked() {
560 let p = protector_no_dns();
561 let result = p.validate_url("http://[fd00::1]/internal");
562 assert!(result.is_err());
563 }
564
565 #[test]
566 fn test_url_with_port() {
567 let p = protector_no_dns();
568 let result = p.validate_url("http://192.168.1.1:8080/api");
569 assert!(result.is_err());
570 }
571
572 #[test]
573 fn test_url_with_userinfo() {
574 let p = protector_no_dns();
575 let result = p.validate_url("http://admin:pass@10.0.0.1/secret");
576 assert!(result.is_err());
577 }
578}