Skip to main content

rover/fetcher/
dns.rs

1//! DNS resolution with dial-time SSRF enforcement.
2//!
3//! Closes the TOCTOU window between Rover's pre-flight `validate_addresses`
4//! and reqwest's internal dial-time resolution (`docs/security.md` §"DNS
5//! rebinding"). A malicious authoritative resolver can return a public IP
6//! to our pre-flight lookup and a loopback/RFC1918 address to reqwest's
7//! later lookup; the connection then targets the unsafe IP even though the
8//! request was authorised.
9//!
10//! Fix: install a custom [`reqwest::dns::Resolve`] on the shared client that
11//! re-runs the same address validator at the moment of dial. The per-request
12//! `SsrfLevel` is carried via the [`SSRF_LEVEL`] task-local, populated by
13//! [`crate::fetcher::fetch`] (and any other module that needs validated
14//! outbound DNS — see `robots.rs`, `headless/intercept.rs`).
15//!
16//! Requests issued without setting `SSRF_LEVEL` fall through to a plain
17//! `tokio::net::lookup_host` with no policy check. This is intentional: the
18//! resolver is shared by every consumer of the client (including test code
19//! and the cloud captioner/summariser paths that don't go through our SSRF
20//! gate), and silently rejecting their lookups would be surprising. Every
21//! caller that should be policed sets the task-local explicitly.
22
23use std::net::{IpAddr, SocketAddr};
24use std::sync::Arc;
25
26use reqwest::dns::{Addrs, Name, Resolve, Resolving};
27use tokio::net::lookup_host;
28
29use crate::fetcher::ssrf::{SsrfError, SsrfLevel, validate_addresses};
30
31tokio::task_local! {
32    /// Per-request SSRF level consulted by [`SsrfValidatingResolver`].
33    ///
34    /// Set via `SSRF_LEVEL.scope(level, fut).await` around any call into
35    /// `reqwest::Client` that should be policed (so redirects to a new host
36    /// inside a single request are also covered).
37    pub static SSRF_LEVEL: SsrfLevel;
38}
39
40/// Wrapper error returned by the resolver when SSRF policy rejects an
41/// address at dial time. Carried inside `reqwest::Error`'s source chain so
42/// the retry classifier can promote it to a fatal failure rather than
43/// looping on what looks like a transient connect error.
44#[derive(Debug)]
45pub struct DialBlocked(pub SsrfError);
46
47impl std::fmt::Display for DialBlocked {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(
50            f,
51            "ssrf policy blocked dial-time address resolution: {}",
52            self.0
53        )
54    }
55}
56
57impl std::error::Error for DialBlocked {
58    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
59        Some(&self.0)
60    }
61}
62
63/// A `reqwest::dns::Resolve` implementation that enforces the SSRF address
64/// policy on every resolution.
65#[derive(Default)]
66pub struct SsrfValidatingResolver;
67
68impl Resolve for SsrfValidatingResolver {
69    fn resolve(&self, name: Name) -> Resolving {
70        let host = name.as_str().to_string();
71        Box::pin(async move {
72            // reqwest sets the real destination port on the SocketAddr after
73            // it gets back from the resolver, so port 0 here is fine.
74            let target = format!("{host}:0");
75            let resolved: Vec<SocketAddr> = lookup_host(target.as_str())
76                .await
77                .map_err(Box::<dyn std::error::Error + Send + Sync>::from)?
78                .collect();
79
80            if let Ok(level) = SSRF_LEVEL.try_with(|l| *l) {
81                let ips: Vec<IpAddr> = resolved.iter().map(|s| s.ip()).collect();
82                if let Err(e) = validate_addresses(&ips, level) {
83                    return Err(
84                        Box::new(DialBlocked(e)) as Box<dyn std::error::Error + Send + Sync>
85                    );
86                }
87            }
88
89            let iter: Addrs = Box::new(resolved.into_iter());
90            Ok(iter)
91        })
92    }
93}
94
95/// Convenience: an `Arc` wrapper suitable for `ClientBuilder::dns_resolver`.
96pub fn shared_resolver() -> Arc<SsrfValidatingResolver> {
97    Arc::new(SsrfValidatingResolver)
98}
99
100/// Walk a `reqwest::Error`'s source chain looking for a [`DialBlocked`].
101///
102/// reqwest wraps resolver errors in its own `reqwest::Error` (typically with
103/// `is_connect() == true`), so callers that want to distinguish "SSRF blocked
104/// the dial" from "the server is down" need to inspect the chain. Used by
105/// `retry.rs` to keep retries from re-trying a forbidden destination.
106pub fn dial_blocked_cause<'a>(
107    err: &'a (dyn std::error::Error + 'static),
108) -> Option<&'a DialBlocked> {
109    let mut current: Option<&(dyn std::error::Error + 'static)> = Some(err);
110    while let Some(e) = current {
111        if let Some(blocked) = e.downcast_ref::<DialBlocked>() {
112            return Some(blocked);
113        }
114        current = e.source();
115    }
116    None
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use std::net::Ipv4Addr;
123
124    #[tokio::test]
125    async fn resolver_passes_through_when_no_context_set() {
126        // No SSRF_LEVEL scope active → the resolver should not consult the
127        // policy at all. Use a name that resolves locally on every platform.
128        let r = SsrfValidatingResolver;
129        let name: Name = "localhost".parse().unwrap();
130        let result = r.resolve(name).await;
131        // localhost should resolve; we don't care which addresses.
132        assert!(result.is_ok());
133    }
134
135    #[tokio::test]
136    async fn resolver_blocks_loopback_under_strict() {
137        let r = SsrfValidatingResolver;
138        let name: Name = "localhost".parse().unwrap();
139        let result = SSRF_LEVEL
140            .scope(SsrfLevel::Strict, async { r.resolve(name).await })
141            .await;
142        let Err(err) = result else {
143            panic!("strict must reject loopback");
144        };
145        assert!(
146            dial_blocked_cause(&*err).is_some(),
147            "expected DialBlocked in source chain, got: {err}",
148        );
149    }
150
151    #[tokio::test]
152    async fn resolver_allows_loopback_under_loopback_level() {
153        let r = SsrfValidatingResolver;
154        let name: Name = "localhost".parse().unwrap();
155        let result = SSRF_LEVEL
156            .scope(SsrfLevel::Loopback, async { r.resolve(name).await })
157            .await;
158        assert!(result.is_ok(), "loopback level must accept localhost");
159    }
160
161    #[test]
162    fn dial_blocked_walks_source_chain() {
163        let inner = SsrfError::Address {
164            address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
165            level: SsrfLevel::Strict,
166            reason: "loopback IPv4",
167        };
168        let dial = DialBlocked(inner);
169        // Wrap in another error layer to exercise the chain walk.
170        #[derive(Debug)]
171        struct Wrap(DialBlocked);
172        impl std::fmt::Display for Wrap {
173            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174                write!(f, "wrap")
175            }
176        }
177        impl std::error::Error for Wrap {
178            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
179                Some(&self.0)
180            }
181        }
182        let wrapped = Wrap(dial);
183        assert!(dial_blocked_cause(&wrapped).is_some());
184    }
185}