1use std::net::{IpAddr, SocketAddr};
24use std::sync::Arc;
25
26use reqwest::dns::{Addrs, Name, Resolve, Resolving};
27use tokio::net::lookup_host;
28
29use crate::fetcher::ssrf::{SsrfError, SsrfLevel, validate_addresses};
30
31tokio::task_local! {
32 pub static SSRF_LEVEL: SsrfLevel;
38}
39
40#[derive(Debug)]
45pub struct DialBlocked(pub SsrfError);
46
47impl std::fmt::Display for DialBlocked {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(
50 f,
51 "ssrf policy blocked dial-time address resolution: {}",
52 self.0
53 )
54 }
55}
56
57impl std::error::Error for DialBlocked {
58 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
59 Some(&self.0)
60 }
61}
62
63#[derive(Default)]
66pub struct SsrfValidatingResolver;
67
68impl Resolve for SsrfValidatingResolver {
69 fn resolve(&self, name: Name) -> Resolving {
70 let host = name.as_str().to_string();
71 Box::pin(async move {
72 let target = format!("{host}:0");
75 let resolved: Vec<SocketAddr> = lookup_host(target.as_str())
76 .await
77 .map_err(Box::<dyn std::error::Error + Send + Sync>::from)?
78 .collect();
79
80 if let Ok(level) = SSRF_LEVEL.try_with(|l| *l) {
81 let ips: Vec<IpAddr> = resolved.iter().map(|s| s.ip()).collect();
82 if let Err(e) = validate_addresses(&ips, level) {
83 return Err(
84 Box::new(DialBlocked(e)) as Box<dyn std::error::Error + Send + Sync>
85 );
86 }
87 }
88
89 let iter: Addrs = Box::new(resolved.into_iter());
90 Ok(iter)
91 })
92 }
93}
94
95pub fn shared_resolver() -> Arc<SsrfValidatingResolver> {
97 Arc::new(SsrfValidatingResolver)
98}
99
100pub fn dial_blocked_cause<'a>(
107 err: &'a (dyn std::error::Error + 'static),
108) -> Option<&'a DialBlocked> {
109 let mut current: Option<&(dyn std::error::Error + 'static)> = Some(err);
110 while let Some(e) = current {
111 if let Some(blocked) = e.downcast_ref::<DialBlocked>() {
112 return Some(blocked);
113 }
114 current = e.source();
115 }
116 None
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use std::net::Ipv4Addr;
123
124 #[tokio::test]
125 async fn resolver_passes_through_when_no_context_set() {
126 let r = SsrfValidatingResolver;
129 let name: Name = "localhost".parse().unwrap();
130 let result = r.resolve(name).await;
131 assert!(result.is_ok());
133 }
134
135 #[tokio::test]
136 async fn resolver_blocks_loopback_under_strict() {
137 let r = SsrfValidatingResolver;
138 let name: Name = "localhost".parse().unwrap();
139 let result = SSRF_LEVEL
140 .scope(SsrfLevel::Strict, async { r.resolve(name).await })
141 .await;
142 let Err(err) = result else {
143 panic!("strict must reject loopback");
144 };
145 assert!(
146 dial_blocked_cause(&*err).is_some(),
147 "expected DialBlocked in source chain, got: {err}",
148 );
149 }
150
151 #[tokio::test]
152 async fn resolver_allows_loopback_under_loopback_level() {
153 let r = SsrfValidatingResolver;
154 let name: Name = "localhost".parse().unwrap();
155 let result = SSRF_LEVEL
156 .scope(SsrfLevel::Loopback, async { r.resolve(name).await })
157 .await;
158 assert!(result.is_ok(), "loopback level must accept localhost");
159 }
160
161 #[test]
162 fn dial_blocked_walks_source_chain() {
163 let inner = SsrfError::Address {
164 address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
165 level: SsrfLevel::Strict,
166 reason: "loopback IPv4",
167 };
168 let dial = DialBlocked(inner);
169 #[derive(Debug)]
171 struct Wrap(DialBlocked);
172 impl std::fmt::Display for Wrap {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(f, "wrap")
175 }
176 }
177 impl std::error::Error for Wrap {
178 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
179 Some(&self.0)
180 }
181 }
182 let wrapped = Wrap(dial);
183 assert!(dial_blocked_cause(&wrapped).is_some());
184 }
185}