actix_tls/connect/
resolver.rs1use std::{
2 future::Future,
3 io,
4 net::SocketAddr,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll},
8 vec::IntoIter,
9};
10
11use actix_rt::task::{spawn_blocking, JoinHandle};
12use actix_service::{Service, ServiceFactory};
13use actix_utils::future::{ok, Ready};
14use futures_core::{future::LocalBoxFuture, ready};
15use tracing::trace;
16
17use super::{ConnectError, ConnectInfo, Host, Resolve};
18
19#[derive(Clone, Default)]
21pub struct Resolver {
22 resolver: ResolverService,
23}
24
25impl Resolver {
26 pub fn custom(resolver: impl Resolve + 'static) -> Self {
28 Self {
29 resolver: ResolverService::custom(resolver),
30 }
31 }
32
33 pub fn service(&self) -> ResolverService {
35 self.resolver.clone()
36 }
37}
38
39impl<R: Host> ServiceFactory<ConnectInfo<R>> for Resolver {
40 type Response = ConnectInfo<R>;
41 type Error = ConnectError;
42 type Config = ();
43 type Service = ResolverService;
44 type InitError = ();
45 type Future = Ready<Result<Self::Service, Self::InitError>>;
46
47 fn new_service(&self, _: ()) -> Self::Future {
48 ok(self.resolver.clone())
49 }
50}
51
52#[derive(Clone)]
53enum ResolverKind {
54 Default,
58
59 Custom(Rc<dyn Resolve>),
61}
62
63impl Default for ResolverKind {
64 fn default() -> Self {
65 Self::Default
66 }
67}
68
69#[derive(Clone, Default)]
71pub struct ResolverService {
72 kind: ResolverKind,
73}
74
75impl ResolverService {
76 pub fn custom(resolver: impl Resolve + 'static) -> Self {
78 Self {
79 kind: ResolverKind::Custom(Rc::new(resolver)),
80 }
81 }
82
83 fn default_lookup<R: Host>(
85 req: &ConnectInfo<R>,
86 ) -> JoinHandle<io::Result<IntoIter<SocketAddr>>> {
87 let host = format!("{}:{}", req.hostname(), req.port());
89
90 spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host))
93 }
94}
95
96impl<R: Host> Service<ConnectInfo<R>> for ResolverService {
97 type Response = ConnectInfo<R>;
98 type Error = ConnectError;
99 type Future = ResolverFut<R>;
100
101 actix_service::always_ready!();
102
103 fn call(&self, req: ConnectInfo<R>) -> Self::Future {
104 if req.addr.is_resolved() {
105 ResolverFut::Resolved(Some(req))
107 } else if let Ok(ip) = req.hostname().parse() {
108 let addr = SocketAddr::new(ip, req.port());
110 let req = req.set_addr(Some(addr));
111 ResolverFut::Resolved(Some(req))
112 } else {
113 trace!("DNS resolver: resolving host {:?}", req.hostname());
114
115 match &self.kind {
116 ResolverKind::Default => {
117 let fut = Self::default_lookup(&req);
118 ResolverFut::LookUp(fut, Some(req))
119 }
120
121 ResolverKind::Custom(resolver) => {
122 let resolver = Rc::clone(resolver);
123
124 ResolverFut::LookupCustom(Box::pin(async move {
125 let addrs = resolver
126 .lookup(req.hostname(), req.port())
127 .await
128 .map_err(ConnectError::Resolver)?;
129
130 let req = req.set_addrs(addrs);
131
132 if req.addr.is_unresolved() {
133 Err(ConnectError::NoRecords)
134 } else {
135 Ok(req)
136 }
137 }))
138 }
139 }
140 }
141 }
142}
143
144#[doc(hidden)]
146pub enum ResolverFut<R: Host> {
147 Resolved(Option<ConnectInfo<R>>),
148 LookUp(
149 JoinHandle<io::Result<IntoIter<SocketAddr>>>,
150 Option<ConnectInfo<R>>,
151 ),
152 LookupCustom(LocalBoxFuture<'static, Result<ConnectInfo<R>, ConnectError>>),
153}
154
155impl<R: Host> Future for ResolverFut<R> {
156 type Output = Result<ConnectInfo<R>, ConnectError>;
157
158 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 match self.get_mut() {
160 Self::Resolved(conn) => Poll::Ready(Ok(conn
161 .take()
162 .expect("ResolverFuture polled after finished"))),
163
164 Self::LookUp(fut, req) => {
165 let res = match ready!(Pin::new(fut).poll(cx)) {
166 Ok(Ok(res)) => Ok(res),
167 Ok(Err(err)) => Err(ConnectError::Resolver(Box::new(err))),
168 Err(err) => Err(ConnectError::Io(err.into())),
169 };
170
171 let req = req.take().unwrap();
172
173 let addrs = res.map_err(|err| {
174 trace!(
175 "DNS resolver: failed to resolve host {:?} err: {:?}",
176 req.hostname(),
177 err
178 );
179
180 err
181 })?;
182
183 let req = req.set_addrs(addrs);
184
185 trace!(
186 "DNS resolver: host {:?} resolved to {:?}",
187 req.hostname(),
188 req.addrs()
189 );
190
191 if req.addr.is_unresolved() {
192 Poll::Ready(Err(ConnectError::NoRecords))
193 } else {
194 Poll::Ready(Ok(req))
195 }
196 }
197
198 Self::LookupCustom(fut) => fut.as_mut().poll(cx),
199 }
200 }
201}