Skip to main content

harness_webfetch/
ssrf.rs

1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2use std::str::FromStr;
3
4use crate::types::WebFetchSessionConfig;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum BlockClass {
8    Loopback,
9    Private,
10    LinkLocal,
11    Metadata,
12    Reserved,
13}
14
15impl BlockClass {
16    pub fn as_str(&self) -> &'static str {
17        match self {
18            Self::Loopback => "loopback",
19            Self::Private => "private",
20            Self::LinkLocal => "link-local",
21            Self::Metadata => "metadata",
22            Self::Reserved => "reserved",
23        }
24    }
25}
26
27#[derive(Debug, Clone)]
28pub enum SsrfDecision {
29    Allowed,
30    Blocked { reason: String, hint: String },
31}
32
33/// Resolve the host and run the IP classifier against each result.
34/// Returns `Blocked` if any resolved IP lands in a range not opted into.
35pub async fn classify_host(host: &str, session: &WebFetchSessionConfig) -> SsrfDecision {
36    let addresses = match resolve_host(host).await {
37        Ok(addrs) => addrs,
38        Err(e) => {
39            return SsrfDecision::Blocked {
40                reason: format!("DNS resolution failed: {}", e),
41                hint: "Check that the hostname is reachable and correct.".to_string(),
42            };
43        }
44    };
45    if addresses.is_empty() {
46        return SsrfDecision::Blocked {
47            reason: "Hostname did not resolve to any address.".to_string(),
48            hint: "Check DNS or try a different host.".to_string(),
49        };
50    }
51    for addr in &addresses {
52        if let Some(class) = classify_ip(*addr) {
53            if !opted_in(class, session) {
54                return SsrfDecision::Blocked {
55                    reason: format!(
56                        "Host resolved to blocked IP range: {} ({})",
57                        addr,
58                        class.as_str()
59                    ),
60                    hint: hint_for(class).to_string(),
61                };
62            }
63        }
64    }
65    SsrfDecision::Allowed
66}
67
68/// Synchronous helper — reads the input string as an IP if possible,
69/// classifies. Used by tests + by the host-resolver path (each resolved
70/// address is fed in here).
71pub fn classify_ip(addr: IpAddr) -> Option<BlockClass> {
72    match addr {
73        IpAddr::V4(v4) => classify_v4(v4),
74        IpAddr::V6(v6) => classify_v6(v6),
75    }
76}
77
78fn classify_v4(addr: Ipv4Addr) -> Option<BlockClass> {
79    let [a, b, _, _] = addr.octets();
80    // Loopback 127.0.0.0/8
81    if a == 127 {
82        return Some(BlockClass::Loopback);
83    }
84    // Link-local / metadata 169.254.0.0/16
85    if a == 169 && b == 254 {
86        return Some(BlockClass::Metadata);
87    }
88    // RFC 1918 private
89    if a == 10 {
90        return Some(BlockClass::Private);
91    }
92    if a == 172 && (16..=31).contains(&b) {
93        return Some(BlockClass::Private);
94    }
95    if a == 192 && b == 168 {
96        return Some(BlockClass::Private);
97    }
98    // 0.0.0.0/8 "this network"
99    if a == 0 {
100        return Some(BlockClass::Reserved);
101    }
102    if addr == Ipv4Addr::BROADCAST {
103        return Some(BlockClass::Reserved);
104    }
105    // 100.64.0.0/10 CGNAT
106    if a == 100 && (64..=127).contains(&b) {
107        return Some(BlockClass::Private);
108    }
109    None
110}
111
112fn classify_v6(addr: Ipv6Addr) -> Option<BlockClass> {
113    if addr == Ipv6Addr::LOCALHOST {
114        return Some(BlockClass::Loopback);
115    }
116    if addr == Ipv6Addr::UNSPECIFIED {
117        return Some(BlockClass::Reserved);
118    }
119    let segments = addr.segments();
120    let first = segments[0];
121    // fe80::/10 link-local
122    if (first & 0xffc0) == 0xfe80 {
123        return Some(BlockClass::LinkLocal);
124    }
125    // fc00::/7 unique local
126    if (first & 0xfe00) == 0xfc00 {
127        return Some(BlockClass::Private);
128    }
129    // ::ffff:0:0/96 IPv4-mapped — classify the inner v4
130    if let Some(v4) = addr.to_ipv4_mapped() {
131        return classify_v4(v4);
132    }
133    None
134}
135
136fn opted_in(class: BlockClass, session: &WebFetchSessionConfig) -> bool {
137    match class {
138        BlockClass::Loopback => session.allow_loopback,
139        BlockClass::Private => session.allow_private_networks,
140        BlockClass::LinkLocal => session.allow_private_networks || session.allow_metadata,
141        BlockClass::Metadata => session.allow_metadata,
142        BlockClass::Reserved => false,
143    }
144}
145
146fn hint_for(class: BlockClass) -> &'static str {
147    match class {
148        BlockClass::Loopback => {
149            "Loopback is blocked by default. If you need localhost for a developer workload, the session must set allow_loopback: true."
150        }
151        BlockClass::Private => {
152            "Private IP ranges (RFC 1918) are blocked by default. Set session.allow_private_networks: true to enable."
153        }
154        BlockClass::LinkLocal => {
155            "Link-local addresses are blocked by default. Set session.allow_private_networks or session.allow_metadata as appropriate."
156        }
157        BlockClass::Metadata => {
158            "Cloud metadata endpoints (169.254.169.254) are blocked by default to prevent credential exfiltration. If this is intentional, set session.allow_metadata: true — but be aware of the security implications."
159        }
160        BlockClass::Reserved => {
161            "Reserved / special-purpose IP range (0.0.0.0/8, broadcast, etc.) — not a useful target."
162        }
163    }
164}
165
166/// Resolve a host to IPs via tokio's blocking lookup_host. Short-circuits
167/// if the host is already an IP literal.
168pub async fn resolve_host(host: &str) -> Result<Vec<IpAddr>, String> {
169    if let Ok(addr) = IpAddr::from_str(host) {
170        return Ok(vec![addr]);
171    }
172    // lookup_host takes "host:port" — use a dummy port.
173    let query = format!("{}:0", host);
174    let res = tokio::net::lookup_host(query).await;
175    match res {
176        Ok(iter) => Ok(iter.map(|sa| sa.ip()).collect()),
177        Err(e) => Err(e.to_string()),
178    }
179}