1use futures::{prelude::*, channel::oneshot, future::BoxFuture};
37use libp2p_core::{
38 Transport,
39 multiaddr::{Protocol, Multiaddr},
40 transport::{TransportError, ListenerEvent}
41};
42use log::{error, debug, trace};
43use std::{error, fmt, io, net::ToSocketAddrs};
44
45#[derive(Clone)]
53pub struct DnsConfig<T> {
54 inner: T,
56 thread_pool: futures::executor::ThreadPool,
58}
59
60impl<T> DnsConfig<T> {
61 pub fn new(inner: T) -> Result<DnsConfig<T>, io::Error> {
63 DnsConfig::with_resolve_threads(inner, 1)
64 }
65
66 pub fn with_resolve_threads(inner: T, num_threads: usize) -> Result<DnsConfig<T>, io::Error> {
68 let thread_pool = futures::executor::ThreadPool::builder()
69 .pool_size(num_threads)
70 .name_prefix("libp2p-dns-")
71 .create()?;
72
73 trace!("Created a DNS thread pool");
74
75 Ok(DnsConfig {
76 inner,
77 thread_pool,
78 })
79 }
80}
81
82impl<T> fmt::Debug for DnsConfig<T>
83where
84 T: fmt::Debug,
85{
86 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
87 fmt.debug_tuple("DnsConfig").field(&self.inner).finish()
88 }
89}
90
91impl<T> Transport for DnsConfig<T>
92where
93 T: Transport + Send + 'static,
94 T::Error: Send,
95 T::Dial: Send
96{
97 type Output = T::Output;
98 type Error = DnsErr<T::Error>;
99 type Listener = stream::MapErr<
100 stream::MapOk<T::Listener,
101 fn(ListenerEvent<T::ListenerUpgrade, T::Error>) -> ListenerEvent<Self::ListenerUpgrade, Self::Error>>,
102 fn(T::Error) -> Self::Error>;
103 type ListenerUpgrade = future::MapErr<T::ListenerUpgrade, fn(T::Error) -> Self::Error>;
104 type Dial = future::Either<
105 future::MapErr<T::Dial, fn(T::Error) -> Self::Error>,
106 BoxFuture<'static, Result<Self::Output, Self::Error>>
107 >;
108
109 fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
110 let listener = self.inner.listen_on(addr).map_err(|err| err.map(DnsErr::Underlying))?;
111 let listener = listener
112 .map_ok::<_, fn(_) -> _>(|event| {
113 event
114 .map(|upgr| {
115 upgr.map_err::<_, fn(_) -> _>(DnsErr::Underlying)
116 })
117 .map_err(DnsErr::Underlying)
118 })
119 .map_err::<_, fn(_) -> _>(DnsErr::Underlying);
120 Ok(listener)
121 }
122
123 fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
124 let contains_dns = addr.iter().any(|cmp| match cmp {
127 Protocol::Dns(_) => true,
128 Protocol::Dns4(_) => true,
129 Protocol::Dns6(_) => true,
130 _ => false,
131 });
132
133 if !contains_dns {
134 trace!("Pass-through address without DNS: {}", addr);
135 let inner_dial = self.inner.dial(addr)
136 .map_err(|err| err.map(DnsErr::Underlying))?;
137 return Ok(inner_dial.map_err::<_, fn(_) -> _>(DnsErr::Underlying).left_future());
138 }
139
140 trace!("Dialing address with DNS: {}", addr);
141 let resolve_futs = addr.iter()
142 .map(|cmp| match cmp {
143 Protocol::Dns(ref name) | Protocol::Dns4(ref name) | Protocol::Dns6(ref name) => {
144 let name = name.to_string();
145 let to_resolve = format!("{}:0", name);
146 let (tx, rx) = oneshot::channel();
147 self.thread_pool.spawn_ok(async {
148 let to_resolve = to_resolve;
149 let _ = tx.send(match to_resolve[..].to_socket_addrs() {
150 Ok(list) => Ok(list.map(|s| s.ip()).collect::<Vec<_>>()),
151 Err(e) => Err(e),
152 });
153 });
154
155 let (dns4, dns6) = match cmp {
156 Protocol::Dns(_) => (true, true),
157 Protocol::Dns4(_) => (true, false),
158 Protocol::Dns6(_) => (false, true),
159 _ => unreachable!(),
160 };
161
162 async move {
163 let list = rx.await
164 .map_err(|_| {
165 error!("DNS resolver crashed");
166 DnsErr::ResolveFail(name.clone())
167 })?
168 .map_err(|err| DnsErr::ResolveError {
169 domain_name: name.clone(),
170 error: err,
171 })?;
172
173 list.into_iter()
174 .filter_map(|addr| {
175 if (dns4 && addr.is_ipv4()) || (dns6 && addr.is_ipv6()) {
176 Some(Protocol::from(addr))
177 } else {
178 None
179 }
180 })
181 .next()
182 .ok_or_else(|| DnsErr::ResolveFail(name))
183 }.left_future()
184 },
185 cmp => future::ready(Ok(cmp.acquire())).right_future()
186 })
187 .collect::<stream::FuturesOrdered<_>>();
188
189 let future = resolve_futs.collect::<Vec<_>>()
190 .then(move |outcome| async move {
191 let outcome = outcome.into_iter().collect::<Result<Vec<_>, _>>()?;
192 let outcome = outcome.into_iter().collect::<Multiaddr>();
193 debug!("DNS resolution outcome: {} => {}", addr, outcome);
194
195 match self.inner.dial(outcome) {
196 Ok(d) => d.await.map_err(DnsErr::Underlying),
197 Err(TransportError::MultiaddrNotSupported(_addr)) =>
198 Err(DnsErr::MultiaddrNotSupported),
199 Err(TransportError::Other(err)) => Err(DnsErr::Underlying(err))
200 }
201 });
202
203 Ok(future.boxed().right_future())
204 }
205
206 fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
207 self.inner.address_translation(server, observed)
208 }
209}
210
211#[derive(Debug)]
213pub enum DnsErr<TErr> {
214 Underlying(TErr),
216 ResolveFail(String),
218 ResolveError {
220 domain_name: String,
221 error: io::Error,
222 },
223 MultiaddrNotSupported,
225}
226
227impl<TErr> fmt::Display for DnsErr<TErr>
228where TErr: fmt::Display
229{
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 match self {
232 DnsErr::Underlying(err) => write!(f, "{}", err),
233 DnsErr::ResolveFail(addr) => write!(f, "Failed to resolve DNS address: {:?}", addr),
234 DnsErr::ResolveError { domain_name, error } => {
235 write!(f, "Failed to resolve DNS address: {:?}; {:?}", domain_name, error)
236 },
237 DnsErr::MultiaddrNotSupported => write!(f, "Resolve multiaddr not supported"),
238 }
239 }
240}
241
242impl<TErr> error::Error for DnsErr<TErr>
243where TErr: error::Error + 'static
244{
245 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
246 match self {
247 DnsErr::Underlying(err) => Some(err),
248 DnsErr::ResolveFail(_) => None,
249 DnsErr::ResolveError { error, .. } => Some(error),
250 DnsErr::MultiaddrNotSupported => None,
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::DnsConfig;
258 use futures::{future::BoxFuture, prelude::*, stream::BoxStream};
259 use libp2p_core::{
260 Transport,
261 multiaddr::{Protocol, Multiaddr},
262 transport::ListenerEvent,
263 transport::TransportError,
264 };
265
266 #[test]
267 fn basic_resolve() {
268 #[derive(Clone)]
269 struct CustomTransport;
270
271 impl Transport for CustomTransport {
272 type Output = ();
273 type Error = std::io::Error;
274 type Listener = BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
275 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
276 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
277
278 fn listen_on(self, _: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
279 unreachable!()
280 }
281
282 fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
283 let addr = addr.iter().collect::<Vec<_>>();
284 assert_eq!(addr.len(), 2);
285 match addr[1] {
286 Protocol::Tcp(_) => (),
287 _ => panic!(),
288 };
289 match addr[0] {
290 Protocol::Ip4(_) => (),
291 Protocol::Ip6(_) => (),
292 _ => panic!(),
293 };
294 Ok(Box::pin(future::ready(Ok(()))))
295 }
296
297 fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option<Multiaddr> {
298 None
299 }
300 }
301
302 futures::executor::block_on(async move {
303 let transport = DnsConfig::new(CustomTransport).unwrap();
304
305 let _ = transport
306 .clone()
307 .dial("/dns4/example.com/tcp/20000".parse().unwrap())
308 .unwrap()
309 .await
310 .unwrap();
311
312 let _ = transport
313 .clone()
314 .dial("/dns6/example.com/tcp/20000".parse().unwrap())
315 .unwrap()
316 .await
317 .unwrap();
318
319 let _ = transport
320 .dial("/ip4/1.2.3.4/tcp/20000".parse().unwrap())
321 .unwrap()
322 .await
323 .unwrap();
324 });
325 }
326}