Skip to main content

reqwest/dns/
resolve.rs

1use hyper_util::client::legacy::connect::dns::Name as HyperName;
2use tower_service::Service;
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use crate::error::BoxError;
13
14/// Alias for an `Iterator` trait object over `SocketAddr`.
15pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Send>;
16
17/// Alias for the `Future` type returned by a DNS resolver.
18pub type Resolving = Pin<Box<dyn Future<Output = Result<Addrs, BoxError>> + Send>>;
19
20/// Trait for customizing DNS resolution in reqwest.
21pub trait Resolve: Send + Sync {
22    /// Performs DNS resolution on a `Name`.
23    /// The return type is a future containing an iterator of `SocketAddr`.
24    ///
25    /// It differs from `tower_service::Service<Name>` in several ways:
26    ///  * It is assumed that `resolve` will always be ready to poll.
27    ///  * It does not need a mutable reference to `self`.
28    ///  * Since trait objects cannot make use of associated types, it requires
29    ///    wrapping the returned `Future` and its contained `Iterator` with `Box`.
30    ///
31    /// Explicitly specified port in the URL will override any port in the resolved `SocketAddr`s.
32    /// Otherwise, port `0` will be replaced by the conventional port for the given scheme (e.g. 80 for http).
33    fn resolve(&self, name: Name) -> Resolving;
34}
35
36/// A name that must be resolved to addresses.
37#[derive(Debug, Clone)]
38pub struct Name(pub(super) HyperName);
39
40/// A more general trait implemented for types implementing `Resolve`.
41///
42/// Unnameable, only exported to aid seeing what implements this.
43pub trait IntoResolve {
44    #[doc(hidden)]
45    fn into_resolve(self) -> Arc<dyn Resolve>;
46}
47
48impl Name {
49    /// View the name as a string.
50    pub fn as_str(&self) -> &str {
51        self.0.as_str()
52    }
53}
54
55impl FromStr for Name {
56    type Err = sealed::InvalidNameError;
57
58    fn from_str(host: &str) -> Result<Self, Self::Err> {
59        HyperName::from_str(host)
60            .map(Name)
61            .map_err(|_| sealed::InvalidNameError { _ext: () })
62    }
63}
64
65#[derive(Clone)]
66pub(crate) struct DynResolver {
67    resolver: Arc<dyn Resolve>,
68}
69
70impl DynResolver {
71    pub(crate) fn new(resolver: Arc<dyn Resolve>) -> Self {
72        Self { resolver }
73    }
74
75    #[cfg(feature = "socks")]
76    pub(crate) fn gai() -> Self {
77        Self::new(Arc::new(super::gai::GaiResolver::new()))
78    }
79
80    /// Resolve an HTTP host and port, not just a domain name.
81    ///
82    /// This does the same thing that hyper-util's HttpConnector does, before
83    /// calling out to its underlying DNS resolver.
84    #[cfg(feature = "socks")]
85    pub(crate) async fn http_resolve(
86        &self,
87        target: &http::Uri,
88    ) -> Result<impl Iterator<Item = std::net::SocketAddr>, BoxError> {
89        let host = target.host().ok_or("missing host")?;
90        let port = target
91            .port_u16()
92            .unwrap_or_else(|| match target.scheme_str() {
93                Some("https") => 443,
94                Some("socks4") | Some("socks4a") | Some("socks5") | Some("socks5h") => 1080,
95                _ => 80,
96            });
97
98        let explicit_port = target.port().is_some();
99
100        let addrs = self.resolver.resolve(host.parse()?).await?;
101
102        Ok(addrs.map(move |mut addr| {
103            if explicit_port || addr.port() == 0 {
104                addr.set_port(port);
105            }
106            addr
107        }))
108    }
109}
110
111impl Service<HyperName> for DynResolver {
112    type Response = Addrs;
113    type Error = BoxError;
114    type Future = Resolving;
115
116    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        Poll::Ready(Ok(()))
118    }
119
120    fn call(&mut self, name: HyperName) -> Self::Future {
121        self.resolver.resolve(Name(name))
122    }
123}
124
125pub(crate) struct DnsResolverWithOverrides {
126    dns_resolver: Arc<dyn Resolve>,
127    overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
128}
129
130impl DnsResolverWithOverrides {
131    pub(crate) fn new(
132        dns_resolver: Arc<dyn Resolve>,
133        overrides: HashMap<String, Vec<SocketAddr>>,
134    ) -> Self {
135        DnsResolverWithOverrides {
136            dns_resolver,
137            overrides: Arc::new(overrides),
138        }
139    }
140}
141
142impl Resolve for DnsResolverWithOverrides {
143    fn resolve(&self, name: Name) -> Resolving {
144        match self.overrides.get(name.as_str()) {
145            Some(dest) => {
146                let addrs: Addrs = Box::new(dest.clone().into_iter());
147                Box::pin(std::future::ready(Ok(addrs)))
148            }
149            None => self.dns_resolver.resolve(name),
150        }
151    }
152}
153
154impl IntoResolve for Arc<dyn Resolve> {
155    fn into_resolve(self) -> Arc<dyn Resolve> {
156        self
157    }
158}
159
160impl<R> IntoResolve for Arc<R>
161where
162    R: Resolve + 'static,
163{
164    fn into_resolve(self) -> Arc<dyn Resolve> {
165        self
166    }
167}
168
169impl<R> IntoResolve for R
170where
171    R: Resolve + 'static,
172{
173    fn into_resolve(self) -> Arc<dyn Resolve> {
174        Arc::new(self)
175    }
176}
177
178/// Chains multiple resolvers: tries each in order, returning the first success.
179struct ChainedResolver {
180    resolvers: Vec<Arc<dyn Resolve>>,
181}
182
183impl Resolve for ChainedResolver {
184    fn resolve(&self, name: Name) -> Resolving {
185        let resolvers = self.resolvers.clone();
186        Box::pin(async move {
187            let mut last_err = None;
188            for resolver in &resolvers {
189                match resolver.resolve(name.clone()).await {
190                    Ok(addrs) => return Ok(addrs),
191                    Err(e) => last_err = Some(e),
192                }
193            }
194            Err(last_err.unwrap_or_else(|| "all DNS resolvers failed".into()))
195        })
196    }
197}
198
199impl<R: Resolve + 'static> IntoResolve for Vec<R> {
200    fn into_resolve(self) -> Arc<dyn Resolve> {
201        if self.len() == 1 {
202            return Arc::new(self.into_iter().next().unwrap());
203        }
204        Arc::new(ChainedResolver {
205            resolvers: self.into_iter().map(|r| Arc::new(r) as Arc<dyn Resolve>).collect(),
206        })
207    }
208}
209
210impl IntoResolve for Vec<Arc<dyn Resolve>> {
211    fn into_resolve(self) -> Arc<dyn Resolve> {
212        if self.len() == 1 {
213            return self.into_iter().next().unwrap();
214        }
215        Arc::new(ChainedResolver { resolvers: self })
216    }
217}
218
219mod sealed {
220    use std::fmt;
221
222    #[derive(Debug)]
223    pub struct InvalidNameError {
224        pub(super) _ext: (),
225    }
226
227    impl fmt::Display for InvalidNameError {
228        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229            f.write_str("invalid DNS name")
230        }
231    }
232
233    impl std::error::Error for InvalidNameError {}
234}