Skip to main content

punch_types/
ssrf.rs

1//! SSRF (Server-Side Request Forgery) protection engine.
2//!
3//! Guards the ring against fighters that try to reach out to internal
4//! network resources they have no business touching. The protector validates
5//! URLs before any HTTP request lands, blocking private IP ranges, dangerous
6//! schemes, and DNS rebinding attacks.
7
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
9
10use regex::Regex;
11use serde::{Deserialize, Serialize};
12
13// ---------------------------------------------------------------------------
14// SSRF violation
15// ---------------------------------------------------------------------------
16
17/// Describes how a URL violated SSRF protection rules.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub enum SsrfViolation {
20    /// The resolved IP address falls within a blocked CIDR range.
21    BlockedIp { ip: String, range: String },
22    /// The URL scheme is not allowed (e.g., `file://`, `ftp://`).
23    BlockedScheme { scheme: String },
24    /// The hostname is explicitly blocked.
25    BlockedHost { host: String },
26    /// DNS resolution failed for the hostname.
27    DnsResolutionFailed { host: String, reason: String },
28    /// The resolved IP is in a private/reserved range.
29    PrivateIp { ip: String },
30    /// The URL matched a custom blocked pattern.
31    BlockedPattern { pattern: String, url: String },
32    /// The URL could not be parsed.
33    InvalidUrl { reason: String },
34}
35
36impl std::fmt::Display for SsrfViolation {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::BlockedIp { ip, range } => {
40                write!(f, "SSRF: IP {} falls within blocked range {}", ip, range)
41            }
42            Self::BlockedScheme { scheme } => {
43                write!(f, "SSRF: scheme '{}' is not allowed", scheme)
44            }
45            Self::BlockedHost { host } => {
46                write!(f, "SSRF: hostname '{}' is blocked", host)
47            }
48            Self::DnsResolutionFailed { host, reason } => {
49                write!(f, "SSRF: DNS resolution failed for '{}': {}", host, reason)
50            }
51            Self::PrivateIp { ip } => {
52                write!(f, "SSRF: resolved IP {} is in a private range", ip)
53            }
54            Self::BlockedPattern { pattern, url } => {
55                write!(
56                    f,
57                    "SSRF: URL '{}' matched blocked pattern '{}'",
58                    url, pattern
59                )
60            }
61            Self::InvalidUrl { reason } => {
62                write!(f, "SSRF: invalid URL: {}", reason)
63            }
64        }
65    }
66}
67
68impl std::error::Error for SsrfViolation {}
69
70// ---------------------------------------------------------------------------
71// CIDR range (simple implementation)
72// ---------------------------------------------------------------------------
73
74/// A CIDR range for IP matching.
75#[derive(Debug, Clone)]
76struct CidrRange {
77    /// Human-readable description (e.g., "127.0.0.0/8").
78    label: String,
79    /// The network address.
80    network: IpAddr,
81    /// Prefix length in bits.
82    prefix_len: u8,
83}
84
85impl CidrRange {
86    fn contains(&self, ip: &IpAddr) -> bool {
87        match (&self.network, ip) {
88            (IpAddr::V4(net), IpAddr::V4(addr)) => {
89                let net_bits = u32::from(*net);
90                let addr_bits = u32::from(*addr);
91                if self.prefix_len == 0 {
92                    return true;
93                }
94                if self.prefix_len >= 32 {
95                    return net_bits == addr_bits;
96                }
97                let mask = !((1u32 << (32 - self.prefix_len)) - 1);
98                (net_bits & mask) == (addr_bits & mask)
99            }
100            (IpAddr::V6(net), IpAddr::V6(addr)) => {
101                let net_bits = u128::from(*net);
102                let addr_bits = u128::from(*addr);
103                if self.prefix_len == 0 {
104                    return true;
105                }
106                if self.prefix_len >= 128 {
107                    return net_bits == addr_bits;
108                }
109                let mask = !((1u128 << (128 - self.prefix_len)) - 1);
110                (net_bits & mask) == (addr_bits & mask)
111            }
112            _ => false,
113        }
114    }
115}
116
117fn parse_cidr(s: &str) -> Option<CidrRange> {
118    let parts: Vec<&str> = s.split('/').collect();
119    if parts.len() != 2 {
120        return None;
121    }
122    let ip: IpAddr = parts[0].parse().ok()?;
123    let prefix_len: u8 = parts[1].parse().ok()?;
124    Some(CidrRange {
125        label: s.to_string(),
126        network: ip,
127        prefix_len,
128    })
129}
130
131// ---------------------------------------------------------------------------
132// SsrfProtector
133// ---------------------------------------------------------------------------
134
135/// The SSRF protection engine — validates URLs before they leave the ring.
136///
137/// Blocks requests to private IP ranges, dangerous schemes, and specific
138/// hostnames. Supports allow-listing for trusted internal hosts and custom
139/// regex-based blocking patterns.
140#[derive(Debug, Clone)]
141pub struct SsrfProtector {
142    /// CIDR ranges to block.
143    blocked_ranges: Vec<CidrRange>,
144    /// Hostnames to block.
145    blocked_hosts: Vec<String>,
146    /// URL schemes to block.
147    blocked_schemes: Vec<String>,
148    /// Hostnames explicitly allowed (bypass IP checks).
149    allowed_hosts: Vec<String>,
150    /// Custom regex patterns to block.
151    blocked_patterns: Vec<(String, Regex)>,
152    /// Whether to perform DNS resolution checks.
153    dns_check_enabled: bool,
154}
155
156impl Default for SsrfProtector {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl SsrfProtector {
163    /// Create a new protector with default blocked ranges and schemes.
164    pub fn new() -> Self {
165        let default_cidrs = [
166            "127.0.0.0/8",
167            "10.0.0.0/8",
168            "172.16.0.0/12",
169            "192.168.0.0/16",
170            "169.254.0.0/16",
171            "::1/128",
172            "fc00::/7",
173            "fe80::/10",
174        ];
175
176        let blocked_ranges: Vec<CidrRange> =
177            default_cidrs.iter().filter_map(|c| parse_cidr(c)).collect();
178
179        Self {
180            blocked_ranges,
181            blocked_hosts: vec![
182                "localhost".to_string(),
183                "metadata.google.internal".to_string(),
184                "169.254.169.254".to_string(),
185            ],
186            blocked_schemes: vec!["file".to_string(), "ftp".to_string(), "gopher".to_string()],
187            allowed_hosts: Vec::new(),
188            blocked_patterns: Vec::new(),
189            dns_check_enabled: true,
190        }
191    }
192
193    /// Add a hostname to the allow-list (bypasses IP range checks).
194    pub fn allow_host(&mut self, host: &str) {
195        self.allowed_hosts.push(host.to_lowercase());
196    }
197
198    /// Add a custom regex pattern to block.
199    pub fn add_blocked_pattern(&mut self, name: &str, pattern: &str) {
200        if let Ok(re) = Regex::new(pattern) {
201            self.blocked_patterns.push((name.to_string(), re));
202        }
203    }
204
205    /// Add a custom CIDR range to block.
206    pub fn add_blocked_range(&mut self, cidr: &str) {
207        if let Some(range) = parse_cidr(cidr) {
208            self.blocked_ranges.push(range);
209        }
210    }
211
212    /// Block an additional hostname.
213    pub fn block_host(&mut self, host: &str) {
214        self.blocked_hosts.push(host.to_lowercase());
215    }
216
217    /// Enable or disable DNS resolution checks.
218    pub fn set_dns_check(&mut self, enabled: bool) {
219        self.dns_check_enabled = enabled;
220    }
221
222    /// Validate a URL, returning `Ok(())` if it is safe to request.
223    pub fn validate_url(&self, url: &str) -> Result<(), SsrfViolation> {
224        // Check custom blocked patterns first.
225        for (name, re) in &self.blocked_patterns {
226            if re.is_match(url) {
227                return Err(SsrfViolation::BlockedPattern {
228                    pattern: name.clone(),
229                    url: url.to_string(),
230                });
231            }
232        }
233
234        // Parse scheme.
235        let scheme = extract_scheme(url)?;
236        if self.blocked_schemes.contains(&scheme.to_lowercase()) {
237            return Err(SsrfViolation::BlockedScheme { scheme });
238        }
239
240        // Parse host.
241        let host = extract_host(url)?;
242        let host_lower = host.to_lowercase();
243
244        // Check blocked hosts.
245        if self.blocked_hosts.contains(&host_lower) {
246            return Err(SsrfViolation::BlockedHost {
247                host: host.to_string(),
248            });
249        }
250
251        // If the host is explicitly allowed, skip IP checks.
252        if self.allowed_hosts.contains(&host_lower) {
253            return Ok(());
254        }
255
256        // Check if the host is a literal IP address.
257        if let Ok(ip) = host.parse::<IpAddr>() {
258            self.check_ip(&ip)?;
259            return Ok(());
260        }
261
262        // DNS resolution check.
263        if self.dns_check_enabled {
264            self.check_dns(&host)?;
265        }
266
267        Ok(())
268    }
269
270    /// Check a resolved IP against blocked ranges.
271    fn check_ip(&self, ip: &IpAddr) -> Result<(), SsrfViolation> {
272        // Check if IP is private/reserved.
273        if is_private_ip(ip) {
274            return Err(SsrfViolation::PrivateIp { ip: ip.to_string() });
275        }
276
277        // Check explicit CIDR blocks.
278        for range in &self.blocked_ranges {
279            if range.contains(ip) {
280                return Err(SsrfViolation::BlockedIp {
281                    ip: ip.to_string(),
282                    range: range.label.clone(),
283                });
284            }
285        }
286
287        Ok(())
288    }
289
290    /// Resolve the hostname and verify all resolved IPs are safe.
291    fn check_dns(&self, host: &str) -> Result<(), SsrfViolation> {
292        let addr_str = format!("{}:80", host);
293        let addrs = addr_str
294            .to_socket_addrs()
295            .map_err(|e| SsrfViolation::DnsResolutionFailed {
296                host: host.to_string(),
297                reason: e.to_string(),
298            })?;
299
300        for addr in addrs {
301            self.check_ip(&addr.ip())?;
302        }
303
304        Ok(())
305    }
306}
307
308// ---------------------------------------------------------------------------
309// Helper functions
310// ---------------------------------------------------------------------------
311
312/// Check if an IP address is in a private/reserved range.
313fn is_private_ip(ip: &IpAddr) -> bool {
314    match ip {
315        IpAddr::V4(v4) => {
316            v4.is_loopback()
317                || v4.is_private()
318                || v4.is_link_local()
319                || v4.is_broadcast()
320                || v4.is_unspecified()
321                || *v4 == Ipv4Addr::new(169, 254, 169, 254)
322        }
323        IpAddr::V6(v6) => {
324            v6.is_loopback()
325                || v6.is_unspecified()
326                || is_ipv6_unique_local(v6)
327                || is_ipv6_link_local(v6)
328        }
329    }
330}
331
332fn is_ipv6_unique_local(v6: &Ipv6Addr) -> bool {
333    // fc00::/7
334    let first_byte = v6.octets()[0];
335    (first_byte & 0xFE) == 0xFC
336}
337
338fn is_ipv6_link_local(v6: &Ipv6Addr) -> bool {
339    // fe80::/10
340    let octets = v6.octets();
341    octets[0] == 0xFE && (octets[1] & 0xC0) == 0x80
342}
343
344/// Extract the scheme from a URL string.
345fn extract_scheme(url: &str) -> Result<String, SsrfViolation> {
346    if let Some(idx) = url.find("://") {
347        Ok(url[..idx].to_string())
348    } else {
349        Err(SsrfViolation::InvalidUrl {
350            reason: "missing scheme (no :// found)".into(),
351        })
352    }
353}
354
355/// Extract the hostname from a URL string.
356fn extract_host(url: &str) -> Result<String, SsrfViolation> {
357    let after_scheme =
358        url.find("://")
359            .map(|i| &url[i + 3..])
360            .ok_or_else(|| SsrfViolation::InvalidUrl {
361                reason: "missing scheme".into(),
362            })?;
363
364    // Strip userinfo (user:pass@).
365    let after_userinfo = if let Some(at) = after_scheme.find('@') {
366        &after_scheme[at + 1..]
367    } else {
368        after_scheme
369    };
370
371    // Handle IPv6 addresses in brackets.
372    if after_userinfo.starts_with('[') {
373        if let Some(end) = after_userinfo.find(']') {
374            return Ok(after_userinfo[1..end].to_string());
375        }
376        return Err(SsrfViolation::InvalidUrl {
377            reason: "unclosed bracket in IPv6 address".into(),
378        });
379    }
380
381    // Take everything before the first : or / or ? or #.
382    let host = after_userinfo
383        .split([':', '/', '?', '#'])
384        .next()
385        .unwrap_or("");
386
387    if host.is_empty() {
388        return Err(SsrfViolation::InvalidUrl {
389            reason: "empty hostname".into(),
390        });
391    }
392
393    Ok(host.to_string())
394}
395
396// ---------------------------------------------------------------------------
397// Tests
398// ---------------------------------------------------------------------------
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    fn protector_no_dns() -> SsrfProtector {
405        let mut p = SsrfProtector::new();
406        p.set_dns_check(false);
407        p
408    }
409
410    #[test]
411    fn test_blocks_localhost() {
412        let p = protector_no_dns();
413        let result = p.validate_url("http://localhost/admin");
414        assert!(result.is_err());
415        assert!(matches!(
416            result.unwrap_err(),
417            SsrfViolation::BlockedHost { .. }
418        ));
419    }
420
421    #[test]
422    fn test_blocks_127_0_0_1() {
423        let p = protector_no_dns();
424        let result = p.validate_url("http://127.0.0.1/admin");
425        assert!(result.is_err());
426        match result.unwrap_err() {
427            SsrfViolation::PrivateIp { ip } | SsrfViolation::BlockedIp { ip, .. } => {
428                assert!(ip.starts_with("127."));
429            }
430            other => panic!("expected PrivateIp or BlockedIp, got {:?}", other),
431        }
432    }
433
434    #[test]
435    fn test_blocks_10_x_private_range() {
436        let p = protector_no_dns();
437        let result = p.validate_url("http://10.0.0.1/internal");
438        assert!(result.is_err());
439    }
440
441    #[test]
442    fn test_blocks_172_16_private_range() {
443        let p = protector_no_dns();
444        let result = p.validate_url("http://172.16.0.1/secret");
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn test_blocks_192_168_private_range() {
450        let p = protector_no_dns();
451        let result = p.validate_url("http://192.168.1.1/router");
452        assert!(result.is_err());
453    }
454
455    #[test]
456    fn test_blocks_link_local() {
457        let p = protector_no_dns();
458        let result = p.validate_url("http://169.254.169.254/latest/meta-data/");
459        assert!(result.is_err());
460    }
461
462    #[test]
463    fn test_blocks_ipv6_localhost() {
464        let p = protector_no_dns();
465        let result = p.validate_url("http://[::1]/admin");
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_blocks_file_scheme() {
471        let p = protector_no_dns();
472        let result = p.validate_url("file:///etc/passwd");
473        assert!(result.is_err());
474        assert!(matches!(
475            result.unwrap_err(),
476            SsrfViolation::BlockedScheme { .. }
477        ));
478    }
479
480    #[test]
481    fn test_blocks_ftp_scheme() {
482        let p = protector_no_dns();
483        let result = p.validate_url("ftp://internal-server/data");
484        assert!(result.is_err());
485        assert!(matches!(
486            result.unwrap_err(),
487            SsrfViolation::BlockedScheme { .. }
488        ));
489    }
490
491    #[test]
492    fn test_blocks_gopher_scheme() {
493        let p = protector_no_dns();
494        let result = p.validate_url("gopher://evil.com/1");
495        assert!(result.is_err());
496        assert!(matches!(
497            result.unwrap_err(),
498            SsrfViolation::BlockedScheme { .. }
499        ));
500    }
501
502    #[test]
503    fn test_allows_public_url() {
504        let p = protector_no_dns();
505        let result = p.validate_url("https://example.com/api");
506        assert!(result.is_ok());
507    }
508
509    #[test]
510    fn test_allows_explicit_allowed_host() {
511        let mut p = protector_no_dns();
512        p.allow_host("internal.mycompany.com");
513        let result = p.validate_url("http://internal.mycompany.com/api");
514        assert!(result.is_ok());
515    }
516
517    #[test]
518    fn test_blocks_custom_pattern() {
519        let mut p = protector_no_dns();
520        p.add_blocked_pattern("aws_metadata", r"169\.254\.169\.254");
521        let result = p.validate_url("http://169.254.169.254/latest/");
522        assert!(result.is_err());
523    }
524
525    #[test]
526    fn test_blocks_metadata_google_internal() {
527        let p = protector_no_dns();
528        let result = p.validate_url("http://metadata.google.internal/computeMetadata/v1/");
529        assert!(result.is_err());
530    }
531
532    #[test]
533    fn test_allows_public_ip() {
534        let p = protector_no_dns();
535        let result = p.validate_url("http://8.8.8.8/dns");
536        assert!(result.is_ok());
537    }
538
539    #[test]
540    fn test_invalid_url_no_scheme() {
541        let p = protector_no_dns();
542        let result = p.validate_url("just-a-hostname");
543        assert!(result.is_err());
544        assert!(matches!(
545            result.unwrap_err(),
546            SsrfViolation::InvalidUrl { .. }
547        ));
548    }
549
550    #[test]
551    fn test_cidr_range_contains() {
552        let range = parse_cidr("10.0.0.0/8").unwrap();
553        assert!(range.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
554        assert!(range.contains(&IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
555        assert!(!range.contains(&IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
556    }
557
558    #[test]
559    fn test_ipv6_unique_local_blocked() {
560        let p = protector_no_dns();
561        let result = p.validate_url("http://[fd00::1]/internal");
562        assert!(result.is_err());
563    }
564
565    #[test]
566    fn test_url_with_port() {
567        let p = protector_no_dns();
568        let result = p.validate_url("http://192.168.1.1:8080/api");
569        assert!(result.is_err());
570    }
571
572    #[test]
573    fn test_url_with_userinfo() {
574        let p = protector_no_dns();
575        let result = p.validate_url("http://admin:pass@10.0.0.1/secret");
576        assert!(result.is_err());
577    }
578}