actix_tls/connect/
resolver.rs

1use std::{
2    future::Future,
3    io,
4    net::SocketAddr,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8    vec::IntoIter,
9};
10
11use actix_rt::task::{spawn_blocking, JoinHandle};
12use actix_service::{Service, ServiceFactory};
13use actix_utils::future::{ok, Ready};
14use futures_core::{future::LocalBoxFuture, ready};
15use tracing::trace;
16
17use super::{ConnectError, ConnectInfo, Host, Resolve};
18
19/// DNS resolver service factory.
20#[derive(Clone, Default)]
21pub struct Resolver {
22    resolver: ResolverService,
23}
24
25impl Resolver {
26    /// Constructs a new resolver factory with a custom resolver.
27    pub fn custom(resolver: impl Resolve + 'static) -> Self {
28        Self {
29            resolver: ResolverService::custom(resolver),
30        }
31    }
32
33    /// Returns a new resolver service.
34    pub fn service(&self) -> ResolverService {
35        self.resolver.clone()
36    }
37}
38
39impl<R: Host> ServiceFactory<ConnectInfo<R>> for Resolver {
40    type Response = ConnectInfo<R>;
41    type Error = ConnectError;
42    type Config = ();
43    type Service = ResolverService;
44    type InitError = ();
45    type Future = Ready<Result<Self::Service, Self::InitError>>;
46
47    fn new_service(&self, _: ()) -> Self::Future {
48        ok(self.resolver.clone())
49    }
50}
51
52#[derive(Clone)]
53enum ResolverKind {
54    /// Built-in DNS resolver.
55    ///
56    /// See [`std::net::ToSocketAddrs`] trait.
57    Default,
58
59    /// Custom, user-provided DNS resolver.
60    Custom(Rc<dyn Resolve>),
61}
62
63impl Default for ResolverKind {
64    fn default() -> Self {
65        Self::Default
66    }
67}
68
69/// DNS resolver service.
70#[derive(Clone, Default)]
71pub struct ResolverService {
72    kind: ResolverKind,
73}
74
75impl ResolverService {
76    /// Constructor for custom Resolve trait object and use it as resolver.
77    pub fn custom(resolver: impl Resolve + 'static) -> Self {
78        Self {
79            kind: ResolverKind::Custom(Rc::new(resolver)),
80        }
81    }
82
83    /// Resolve DNS with default resolver.
84    fn default_lookup<R: Host>(
85        req: &ConnectInfo<R>,
86    ) -> JoinHandle<io::Result<IntoIter<SocketAddr>>> {
87        // reconstruct host; concatenate hostname and port together
88        let host = format!("{}:{}", req.hostname(), req.port());
89
90        // run blocking DNS lookup in thread pool since DNS lookups can take upwards of seconds on
91        // some platforms if conditions are poor and OS-level cache is not populated
92        spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host))
93    }
94}
95
96impl<R: Host> Service<ConnectInfo<R>> for ResolverService {
97    type Response = ConnectInfo<R>;
98    type Error = ConnectError;
99    type Future = ResolverFut<R>;
100
101    actix_service::always_ready!();
102
103    fn call(&self, req: ConnectInfo<R>) -> Self::Future {
104        if req.addr.is_resolved() {
105            // socket address(es) already resolved; return existing connection request
106            ResolverFut::Resolved(Some(req))
107        } else if let Ok(ip) = req.hostname().parse() {
108            // request hostname is valid ip address; add address to request and return
109            let addr = SocketAddr::new(ip, req.port());
110            let req = req.set_addr(Some(addr));
111            ResolverFut::Resolved(Some(req))
112        } else {
113            trace!("DNS resolver: resolving host {:?}", req.hostname());
114
115            match &self.kind {
116                ResolverKind::Default => {
117                    let fut = Self::default_lookup(&req);
118                    ResolverFut::LookUp(fut, Some(req))
119                }
120
121                ResolverKind::Custom(resolver) => {
122                    let resolver = Rc::clone(resolver);
123
124                    ResolverFut::LookupCustom(Box::pin(async move {
125                        let addrs = resolver
126                            .lookup(req.hostname(), req.port())
127                            .await
128                            .map_err(ConnectError::Resolver)?;
129
130                        let req = req.set_addrs(addrs);
131
132                        if req.addr.is_unresolved() {
133                            Err(ConnectError::NoRecords)
134                        } else {
135                            Ok(req)
136                        }
137                    }))
138                }
139            }
140        }
141    }
142}
143
144/// Future for resolver service.
145#[doc(hidden)]
146pub enum ResolverFut<R: Host> {
147    Resolved(Option<ConnectInfo<R>>),
148    LookUp(
149        JoinHandle<io::Result<IntoIter<SocketAddr>>>,
150        Option<ConnectInfo<R>>,
151    ),
152    LookupCustom(LocalBoxFuture<'static, Result<ConnectInfo<R>, ConnectError>>),
153}
154
155impl<R: Host> Future for ResolverFut<R> {
156    type Output = Result<ConnectInfo<R>, ConnectError>;
157
158    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        match self.get_mut() {
160            Self::Resolved(conn) => Poll::Ready(Ok(conn
161                .take()
162                .expect("ResolverFuture polled after finished"))),
163
164            Self::LookUp(fut, req) => {
165                let res = match ready!(Pin::new(fut).poll(cx)) {
166                    Ok(Ok(res)) => Ok(res),
167                    Ok(Err(err)) => Err(ConnectError::Resolver(Box::new(err))),
168                    Err(err) => Err(ConnectError::Io(err.into())),
169                };
170
171                let req = req.take().unwrap();
172
173                let addrs = res.map_err(|err| {
174                    trace!(
175                        "DNS resolver: failed to resolve host {:?} err: {:?}",
176                        req.hostname(),
177                        err
178                    );
179
180                    err
181                })?;
182
183                let req = req.set_addrs(addrs);
184
185                trace!(
186                    "DNS resolver: host {:?} resolved to {:?}",
187                    req.hostname(),
188                    req.addrs()
189                );
190
191                if req.addr.is_unresolved() {
192                    Poll::Ready(Err(ConnectError::NoRecords))
193                } else {
194                    Poll::Ready(Ok(req))
195                }
196            }
197
198            Self::LookupCustom(fut) => fut.as_mut().poll(cx),
199        }
200    }
201}