1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
35
36use thiserror::Error;
37
38#[derive(Debug, Error)]
40pub enum SsrfError {
41 #[error("invalid URL: {0}")]
42 InvalidUrl(String),
43
44 #[error("URL scheme '{0}' not allowed — only http/https")]
45 DisallowedScheme(String),
46
47 #[error("URL has no host component")]
48 MissingHost,
49
50 #[error("DNS resolution failed for '{host}': {source}")]
51 DnsResolutionFailed {
52 host: String,
53 #[source]
54 source: std::io::Error,
55 },
56
57 #[error("DNS resolution returned no addresses for '{0}'")]
58 NoAddressesResolved(String),
59
60 #[error(
61 "target '{host}' resolves to {ip} which is in a blocked range ({reason}); \
62 pointing the cloud runner at internal addresses is not allowed"
63 )]
64 BlockedAddress {
65 host: String,
66 ip: IpAddr,
67 reason: &'static str,
68 },
69}
70
71#[derive(Debug, Clone, Copy, Default)]
74pub struct Policy {
75 pub allow_loopback: bool,
79}
80
81impl Policy {
82 pub const fn strict() -> Self {
84 Self {
85 allow_loopback: false,
86 }
87 }
88
89 pub const fn for_test() -> Self {
92 Self {
93 allow_loopback: true,
94 }
95 }
96}
97
98pub async fn validate_target_url(url: &str, policy: Policy) -> Result<(), SsrfError> {
104 let parsed = url::Url::parse(url).map_err(|e| SsrfError::InvalidUrl(e.to_string()))?;
105
106 let scheme = parsed.scheme();
107 if scheme != "http" && scheme != "https" {
108 return Err(SsrfError::DisallowedScheme(scheme.to_string()));
109 }
110
111 let host = parsed.host_str().ok_or(SsrfError::MissingHost)?.to_string();
112 let port = parsed.port_or_known_default().unwrap_or(80);
113
114 if let Ok(ip) = host.parse::<IpAddr>() {
116 check_ip(&host, ip, policy)?;
117 return Ok(());
118 }
119
120 let lookup_target = format!("{}:{}", host, port);
125 let addrs: Vec<std::net::SocketAddr> = tokio::net::lookup_host(&lookup_target)
126 .await
127 .map_err(|source| SsrfError::DnsResolutionFailed {
128 host: host.clone(),
129 source,
130 })?
131 .collect();
132
133 if addrs.is_empty() {
134 return Err(SsrfError::NoAddressesResolved(host));
135 }
136
137 for addr in addrs {
138 check_ip(&host, addr.ip(), policy)?;
139 }
140
141 Ok(())
142}
143
144fn check_ip(host: &str, ip: IpAddr, policy: Policy) -> Result<(), SsrfError> {
145 if let Some(reason) = blocked_reason(ip, policy) {
146 return Err(SsrfError::BlockedAddress {
147 host: host.to_string(),
148 ip,
149 reason,
150 });
151 }
152 Ok(())
153}
154
155fn blocked_reason(ip: IpAddr, policy: Policy) -> Option<&'static str> {
156 match ip {
157 IpAddr::V4(v4) => blocked_reason_v4(v4, policy),
158 IpAddr::V6(v6) => blocked_reason_v6(v6, policy),
159 }
160}
161
162fn blocked_reason_v4(ip: Ipv4Addr, policy: Policy) -> Option<&'static str> {
163 if ip.is_loopback() {
164 if policy.allow_loopback {
165 return None;
166 }
167 return Some("IPv4 loopback (127.0.0.0/8)");
168 }
169 if ip.is_unspecified() {
170 return Some("IPv4 unspecified (0.0.0.0)");
171 }
172 if ip.is_broadcast() {
173 return Some("IPv4 broadcast");
174 }
175 if ip.is_link_local() {
176 return Some("IPv4 link-local (169.254.0.0/16, includes cloud metadata IP)");
177 }
178 if ip.is_private() {
179 return Some("IPv4 RFC1918 private (10/8, 172.16/12, 192.168/16)");
180 }
181 if ip.is_documentation() {
182 return Some("IPv4 documentation range (RFC5737)");
183 }
184 let octets = ip.octets();
188 if octets[0] == 100 && (64..=127).contains(&octets[1]) {
189 return Some("IPv4 CGNAT (100.64.0.0/10)");
190 }
191 if octets[0] == 198 && (octets[1] == 18 || octets[1] == 19) {
193 return Some("IPv4 benchmark (198.18.0.0/15)");
194 }
195 None
196}
197
198fn blocked_reason_v6(ip: Ipv6Addr, policy: Policy) -> Option<&'static str> {
199 if ip.is_loopback() {
200 if policy.allow_loopback {
201 return None;
202 }
203 return Some("IPv6 loopback (::1)");
204 }
205 if ip.is_unspecified() {
206 return Some("IPv6 unspecified (::)");
207 }
208 let segments = ip.segments();
209 if (segments[0] & 0xffc0) == 0xfe80 {
211 return Some("IPv6 link-local (fe80::/10)");
212 }
213 if (segments[0] & 0xfe00) == 0xfc00 {
215 return Some("IPv6 unique-local (fc00::/7)");
216 }
217 if let Some(v4) = ip.to_ipv4_mapped() {
220 return blocked_reason_v4(v4, policy);
221 }
222 None
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 fn assert_blocked(addr: &str, policy: Policy, fragment: &str) {
230 let ip: IpAddr = addr.parse().unwrap();
231 let reason =
232 blocked_reason(ip, policy).unwrap_or_else(|| panic!("expected {addr} to be blocked"));
233 assert!(
234 reason.contains(fragment),
235 "{addr} blocked but reason '{reason}' missing fragment '{fragment}'"
236 );
237 }
238
239 fn assert_allowed(addr: &str, policy: Policy) {
240 let ip: IpAddr = addr.parse().unwrap();
241 assert!(blocked_reason(ip, policy).is_none(), "{addr} unexpectedly blocked");
242 }
243
244 #[test]
245 fn blocks_loopback_v4_strict() {
246 assert_blocked("127.0.0.1", Policy::strict(), "loopback");
247 assert_blocked("127.255.255.254", Policy::strict(), "loopback");
248 }
249
250 #[test]
251 fn allows_loopback_v4_in_test_policy() {
252 assert_allowed("127.0.0.1", Policy::for_test());
253 }
254
255 #[test]
256 fn blocks_link_local_aws_metadata() {
257 assert_blocked("169.254.169.254", Policy::strict(), "link-local");
258 }
259
260 #[test]
261 fn blocks_rfc1918_ranges() {
262 assert_blocked("10.0.0.1", Policy::strict(), "RFC1918");
263 assert_blocked("172.16.0.1", Policy::strict(), "RFC1918");
264 assert_blocked("172.31.255.255", Policy::strict(), "RFC1918");
265 assert_blocked("192.168.0.1", Policy::strict(), "RFC1918");
266 }
267
268 #[test]
269 fn blocks_cgnat() {
270 assert_blocked("100.64.0.1", Policy::strict(), "CGNAT");
271 assert_blocked("100.127.255.255", Policy::strict(), "CGNAT");
272 }
273
274 #[test]
275 fn allows_ranges_outside_cgnat() {
276 assert_allowed("100.63.255.255", Policy::strict());
278 assert_allowed("100.128.0.1", Policy::strict());
279 }
280
281 #[test]
282 fn blocks_benchmark_range() {
283 assert_blocked("198.18.0.1", Policy::strict(), "benchmark");
284 assert_blocked("198.19.255.255", Policy::strict(), "benchmark");
285 }
286
287 #[test]
288 fn allows_public_v4() {
289 assert_allowed("8.8.8.8", Policy::strict());
290 assert_allowed("1.1.1.1", Policy::strict());
291 assert_allowed("142.250.190.78", Policy::strict()); }
293
294 #[test]
295 fn blocks_loopback_v6_strict() {
296 assert_blocked("::1", Policy::strict(), "loopback");
297 }
298
299 #[test]
300 fn blocks_link_local_v6() {
301 assert_blocked("fe80::1", Policy::strict(), "link-local");
302 assert_blocked("febf::1", Policy::strict(), "link-local");
303 }
304
305 #[test]
306 fn blocks_ula() {
307 assert_blocked("fc00::1", Policy::strict(), "unique-local");
308 assert_blocked("fd12:3456::1", Policy::strict(), "unique-local");
309 }
310
311 #[test]
312 fn blocks_ipv4_mapped_private() {
313 assert_blocked("::ffff:10.0.0.1", Policy::strict(), "RFC1918");
314 assert_blocked("::ffff:127.0.0.1", Policy::strict(), "loopback");
315 }
316
317 #[test]
318 fn allows_public_v6() {
319 assert_allowed("2606:4700:4700::1111", Policy::strict()); assert_allowed("2001:4860:4860::8888", Policy::strict()); }
322
323 #[tokio::test]
324 async fn validate_rejects_non_http_scheme() {
325 let err = validate_target_url("file:///etc/passwd", Policy::strict()).await.unwrap_err();
326 assert!(matches!(err, SsrfError::DisallowedScheme(s) if s == "file"));
327 }
328
329 #[tokio::test]
330 async fn validate_rejects_garbage_url() {
331 let err = validate_target_url("not a url", Policy::strict()).await.unwrap_err();
332 assert!(matches!(err, SsrfError::InvalidUrl(_)));
333 }
334
335 #[tokio::test]
336 async fn validate_rejects_literal_loopback() {
337 let err = validate_target_url("http://127.0.0.1/", Policy::strict()).await.unwrap_err();
338 assert!(matches!(err, SsrfError::BlockedAddress { .. }));
339 }
340
341 #[tokio::test]
342 async fn validate_rejects_literal_metadata_ip() {
343 let err = validate_target_url("http://169.254.169.254/latest/meta-data/", Policy::strict())
344 .await
345 .unwrap_err();
346 match err {
347 SsrfError::BlockedAddress { reason, .. } => assert!(reason.contains("link-local")),
348 other => panic!("expected BlockedAddress, got {other:?}"),
349 }
350 }
351
352 #[tokio::test]
353 async fn validate_rejects_literal_rfc1918() {
354 let err = validate_target_url("http://10.0.0.1/", Policy::strict()).await.unwrap_err();
355 assert!(matches!(err, SsrfError::BlockedAddress { .. }));
356 }
357
358 #[tokio::test]
359 async fn validate_allows_loopback_in_test_policy() {
360 validate_target_url("http://127.0.0.1:8080/", Policy::for_test()).await.unwrap();
361 }
362}