hickory_resolver/name_server/
connection_provider.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::io;
9use std::marker::Unpin;
10use std::net::SocketAddr;
11#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::{Future, FutureExt};
18use futures_util::ready;
19use futures_util::stream::{Stream, StreamExt};
20#[cfg(feature = "tokio-runtime")]
21use tokio::net::TcpStream as TokioTcpStream;
22#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
23use tokio_native_tls::TlsStream as TokioTlsStream;
24#[cfg(all(
25    feature = "dns-over-openssl",
26    not(feature = "dns-over-rustls"),
27    not(feature = "dns-over-native-tls")
28))]
29use tokio_openssl::SslStream as TokioTlsStream;
30#[cfg(feature = "dns-over-rustls")]
31use tokio_rustls::client::TlsStream as TokioTlsStream;
32
33use crate::config::{NameServerConfig, Protocol, ResolverOpts};
34#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
35use hickory_proto::udp::QuicLocalAddr;
36#[cfg(feature = "dns-over-https")]
37use proto::h2::{HttpsClientConnect, HttpsClientStream};
38#[cfg(feature = "dns-over-h3")]
39use proto::h3::{H3ClientConnect, H3ClientStream};
40#[cfg(feature = "dns-over-quic")]
41use proto::quic::{QuicClientConnect, QuicClientStream};
42use proto::tcp::DnsTcpStream;
43use proto::udp::DnsUdpSocket;
44use proto::{
45    self,
46    error::ProtoError,
47    op::NoopMessageFinalizer,
48    tcp::TcpClientConnect,
49    tcp::TcpClientStream,
50    udp::UdpClientConnect,
51    udp::UdpClientStream,
52    xfer::{
53        DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
54        DnsMultiplexerConnect, DnsRequest, DnsResponse,
55    },
56    Time,
57};
58#[cfg(feature = "tokio-runtime")]
59use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
60
61use crate::error::ResolveError;
62
63/// RuntimeProvider defines which async runtime that handles IO and timers.
64pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
65    /// Handle to the executor;
66    type Handle: Clone + Send + Spawn + Sync + Unpin;
67
68    /// Timer
69    type Timer: Time + Send + Unpin;
70
71    #[cfg(not(any(feature = "dns-over-quic", feature = "dns-over-h3")))]
72    /// UdpSocket
73    type Udp: DnsUdpSocket + Send;
74    #[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
75    /// UdpSocket, where `QuicLocalAddr` is for `quinn` crate.
76    type Udp: DnsUdpSocket + QuicLocalAddr + Send;
77
78    /// TcpStream
79    type Tcp: DnsTcpStream;
80
81    /// Create a runtime handle
82    fn create_handle(&self) -> Self::Handle;
83
84    /// Create a TCP connection with custom configuration.
85    fn connect_tcp(
86        &self,
87        server_addr: SocketAddr,
88    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
89
90    /// Create a UDP socket bound to `local_addr`. The returned value should **not** be connected to `server_addr`.
91    /// *Notice: the future should be ready once returned at best effort. Otherwise UDP DNS may need much more retries.*
92    fn bind_udp(
93        &self,
94        local_addr: SocketAddr,
95        server_addr: SocketAddr,
96    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
97}
98
99/// Create `DnsHandle` with the help of `RuntimeProvider`.
100/// This trait is designed for customization.
101pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
102    /// The handle to the connect for sending DNS requests.
103    type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
104    /// Ths future is responsible for spawning any background tasks as necessary.
105    type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
106    /// Provider that handles the underlying I/O and timing.
107    type RuntimeProvider: RuntimeProvider;
108
109    /// Create a new connection.
110    fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
111        -> Self::FutureConn;
112}
113
114/// A type defines the Handle which can spawn future.
115pub trait Spawn {
116    /// Spawn a future in the background
117    fn spawn_bg<F>(&mut self, future: F)
118    where
119        F: Future<Output = Result<(), ProtoError>> + Send + 'static;
120}
121
122#[cfg(feature = "dns-over-tls")]
123/// Predefined type for TLS client stream
124type TlsClientStream<S> =
125    TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
126
127/// The variants of all supported connections for the Resolver
128#[allow(clippy::large_enum_variant, clippy::type_complexity)]
129pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
130    Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
131    Tcp(
132        DnsExchangeConnect<
133            DnsMultiplexerConnect<
134                TcpClientConnect<<R as RuntimeProvider>::Tcp>,
135                TcpClientStream<<R as RuntimeProvider>::Tcp>,
136                NoopMessageFinalizer,
137            >,
138            DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
139            R::Timer,
140        >,
141    ),
142    #[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
143    Tls(
144        DnsExchangeConnect<
145            DnsMultiplexerConnect<
146                Pin<
147                    Box<
148                        dyn Future<
149                                Output = Result<
150                                    TlsClientStream<<R as RuntimeProvider>::Tcp>,
151                                    ProtoError,
152                                >,
153                            > + Send
154                            + 'static,
155                    >,
156                >,
157                TlsClientStream<<R as RuntimeProvider>::Tcp>,
158                NoopMessageFinalizer,
159            >,
160            DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
161            TokioTime,
162        >,
163    ),
164    #[cfg(all(feature = "dns-over-https", feature = "tokio-runtime"))]
165    Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
166    #[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
167    Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
168    #[cfg(all(feature = "dns-over-h3", feature = "tokio-runtime"))]
169    H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
170}
171
172/// Resolves to a new Connection
173#[must_use = "futures do nothing unless polled"]
174pub struct ConnectionFuture<R: RuntimeProvider> {
175    pub(crate) connect: ConnectionConnect<R>,
176    pub(crate) spawner: R::Handle,
177}
178
179impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
180    type Output = Result<GenericConnection, ResolveError>;
181
182    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183        Poll::Ready(Ok(match &mut self.connect {
184            ConnectionConnect::Udp(ref mut conn) => {
185                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
186                self.spawner.spawn_bg(bg);
187                GenericConnection(conn)
188            }
189            ConnectionConnect::Tcp(ref mut conn) => {
190                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
191                self.spawner.spawn_bg(bg);
192                GenericConnection(conn)
193            }
194            #[cfg(feature = "dns-over-tls")]
195            ConnectionConnect::Tls(ref mut conn) => {
196                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
197                self.spawner.spawn_bg(bg);
198                GenericConnection(conn)
199            }
200            #[cfg(feature = "dns-over-https")]
201            ConnectionConnect::Https(ref mut conn) => {
202                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
203                self.spawner.spawn_bg(bg);
204                GenericConnection(conn)
205            }
206            #[cfg(feature = "dns-over-quic")]
207            ConnectionConnect::Quic(ref mut conn) => {
208                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
209                self.spawner.spawn_bg(bg);
210                GenericConnection(conn)
211            }
212            #[cfg(feature = "dns-over-h3")]
213            ConnectionConnect::H3(ref mut conn) => {
214                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
215                self.spawner.spawn_bg(bg);
216                GenericConnection(conn)
217            }
218        }))
219    }
220}
221
222/// A connected DNS handle
223#[derive(Clone)]
224pub struct GenericConnection(DnsExchange);
225
226impl DnsHandle for GenericConnection {
227    type Response = ConnectionResponse;
228    type Error = ResolveError;
229
230    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
231        ConnectionResponse(self.0.send(request))
232    }
233}
234
235/// Default connector for `GenericConnection`
236#[derive(Clone)]
237pub struct GenericConnector<P: RuntimeProvider> {
238    runtime_provider: P,
239}
240
241impl<P: RuntimeProvider> GenericConnector<P> {
242    /// Create a new instance.
243    pub fn new(runtime_provider: P) -> Self {
244        Self { runtime_provider }
245    }
246}
247
248impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
249    fn default() -> Self {
250        Self {
251            runtime_provider: P::default(),
252        }
253    }
254}
255
256impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
257    type Conn = GenericConnection;
258    type FutureConn = ConnectionFuture<P>;
259    type RuntimeProvider = P;
260
261    fn new_connection(
262        &self,
263        config: &NameServerConfig,
264        options: &ResolverOpts,
265    ) -> Self::FutureConn {
266        let dns_connect = match config.protocol {
267            Protocol::Udp => {
268                let provider_handle = self.runtime_provider.clone();
269                let closure = move |local_addr: SocketAddr, server_addr: SocketAddr| {
270                    provider_handle.bind_udp(local_addr, server_addr)
271                };
272                let stream = UdpClientStream::with_creator(
273                    config.socket_addr,
274                    None,
275                    options.timeout,
276                    Arc::new(closure),
277                );
278                let exchange = DnsExchange::connect(stream);
279                ConnectionConnect::Udp(exchange)
280            }
281            Protocol::Tcp => {
282                let socket_addr = config.socket_addr;
283                let timeout = options.timeout;
284                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
285
286                let (stream, handle) =
287                    TcpClientStream::with_future(tcp_future, socket_addr, timeout);
288                // TODO: need config for Signer...
289                let dns_conn = DnsMultiplexer::with_timeout(
290                    stream,
291                    handle,
292                    timeout,
293                    NoopMessageFinalizer::new(),
294                );
295
296                let exchange = DnsExchange::connect(dns_conn);
297                ConnectionConnect::Tcp(exchange)
298            }
299            #[cfg(feature = "dns-over-tls")]
300            Protocol::Tls => {
301                let socket_addr = config.socket_addr;
302                let timeout = options.timeout;
303                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
304                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
305
306                #[cfg(feature = "dns-over-rustls")]
307                let client_config = config.tls_config.clone();
308
309                #[cfg(feature = "dns-over-rustls")]
310                let (stream, handle) = {
311                    crate::tls::new_tls_stream_with_future(
312                        tcp_future,
313                        socket_addr,
314                        tls_dns_name,
315                        client_config,
316                    )
317                };
318                #[cfg(not(feature = "dns-over-rustls"))]
319                let (stream, handle) = {
320                    crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
321                };
322
323                let dns_conn = DnsMultiplexer::with_timeout(
324                    stream,
325                    handle,
326                    timeout,
327                    NoopMessageFinalizer::new(),
328                );
329
330                let exchange = DnsExchange::connect(dns_conn);
331                ConnectionConnect::Tls(exchange)
332            }
333            #[cfg(feature = "dns-over-https")]
334            Protocol::Https => {
335                let socket_addr = config.socket_addr;
336                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
337                #[cfg(feature = "dns-over-rustls")]
338                let client_config = config.tls_config.clone();
339                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
340
341                let exchange = crate::h2::new_https_stream_with_future(
342                    tcp_future,
343                    socket_addr,
344                    tls_dns_name,
345                    client_config,
346                );
347                ConnectionConnect::Https(exchange)
348            }
349            #[cfg(feature = "dns-over-quic")]
350            Protocol::Quic => {
351                let socket_addr = config.socket_addr;
352                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
353                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
354                    SocketAddr::V6(_) => {
355                        SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
356                    }
357                });
358                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
359                #[cfg(feature = "dns-over-rustls")]
360                let client_config = config.tls_config.clone();
361                let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
362
363                let exchange = crate::quic::new_quic_stream_with_future(
364                    udp_future,
365                    socket_addr,
366                    tls_dns_name,
367                    client_config,
368                );
369                ConnectionConnect::Quic(exchange)
370            }
371            #[cfg(feature = "dns-over-h3")]
372            Protocol::H3 => {
373                let socket_addr = config.socket_addr;
374                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
375                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
376                    SocketAddr::V6(_) => {
377                        SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
378                    }
379                });
380                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
381                let client_config = config.tls_config.clone();
382                let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
383
384                let exchange = crate::h3::new_h3_stream_with_future(
385                    udp_future,
386                    socket_addr,
387                    tls_dns_name,
388                    client_config,
389                );
390                ConnectionConnect::H3(exchange)
391            }
392        };
393
394        ConnectionFuture::<P> {
395            connect: dns_connect,
396            spawner: self.runtime_provider.create_handle(),
397        }
398    }
399}
400
401/// A stream of response to a DNS request.
402#[must_use = "steam do nothing unless polled"]
403pub struct ConnectionResponse(DnsExchangeSend);
404
405impl Stream for ConnectionResponse {
406    type Item = Result<DnsResponse, ResolveError>;
407
408    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
409        Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
410    }
411}
412
413#[cfg(feature = "tokio-runtime")]
414#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
415#[allow(unreachable_pub)]
416pub mod tokio_runtime {
417    use super::*;
418    use std::sync::{Arc, Mutex};
419    use tokio::net::UdpSocket as TokioUdpSocket;
420    use tokio::task::JoinSet;
421
422    /// A handle to the Tokio runtime
423    #[derive(Clone, Default)]
424    pub struct TokioHandle {
425        join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
426    }
427
428    impl Spawn for TokioHandle {
429        fn spawn_bg<F>(&mut self, future: F)
430        where
431            F: Future<Output = Result<(), ProtoError>> + Send + 'static,
432        {
433            let mut join_set = self.join_set.lock().unwrap();
434            join_set.spawn(future);
435            reap_tasks(&mut join_set);
436        }
437    }
438
439    /// The Tokio Runtime for async execution
440    #[derive(Clone, Default)]
441    pub struct TokioRuntimeProvider(TokioHandle);
442
443    impl TokioRuntimeProvider {
444        /// Create a Tokio runtime
445        pub fn new() -> Self {
446            Self::default()
447        }
448    }
449
450    impl RuntimeProvider for TokioRuntimeProvider {
451        type Handle = TokioHandle;
452        type Timer = TokioTime;
453        type Udp = TokioUdpSocket;
454        type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
455
456        fn create_handle(&self) -> Self::Handle {
457            self.0.clone()
458        }
459
460        fn connect_tcp(
461            &self,
462            server_addr: SocketAddr,
463        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
464            Box::pin(async move {
465                TokioTcpStream::connect(server_addr)
466                    .await
467                    .map(AsyncIoTokioAsStd)
468            })
469        }
470
471        fn bind_udp(
472            &self,
473            local_addr: SocketAddr,
474            _server_addr: SocketAddr,
475        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
476            Box::pin(tokio::net::UdpSocket::bind(local_addr))
477        }
478    }
479
480    /// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
481    fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
482        while FutureExt::now_or_never(join_set.join_next())
483            .flatten()
484            .is_some()
485        {}
486    }
487
488    /// Default ConnectionProvider with `GenericConnection`.
489    pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
490}