hickory_resolver/name_server/
connection_provider.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::future::Future;
9use std::io;
10use std::marker::Unpin;
11#[cfg(feature = "__quic")]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
13use std::pin::Pin;
14#[cfg(feature = "__https")]
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18use crate::proto::runtime::Spawn;
19#[cfg(feature = "tokio")]
20use crate::proto::runtime::TokioRuntimeProvider;
21#[cfg(feature = "__tls")]
22use crate::proto::runtime::iocompat::AsyncIoStdAsTokio;
23use futures_util::future::FutureExt;
24use futures_util::ready;
25use futures_util::stream::{Stream, StreamExt};
26#[cfg(feature = "__tls")]
27use tokio_rustls::client::TlsStream as TokioTlsStream;
28
29use crate::config::{NameServerConfig, ResolverOpts};
30#[cfg(any(feature = "__h3", feature = "__https"))]
31use crate::proto;
32#[cfg(feature = "__https")]
33use crate::proto::h2::{HttpsClientConnect, HttpsClientStream};
34#[cfg(feature = "__h3")]
35use crate::proto::h3::{H3ClientConnect, H3ClientStream};
36#[cfg(feature = "__quic")]
37use crate::proto::quic::{QuicClientConnect, QuicClientStream};
38#[cfg(feature = "tokio")]
39#[allow(unused_imports)] // Complicated cfg for which protocols are enabled
40use crate::proto::runtime::TokioTime;
41#[cfg(feature = "__tls")]
42use crate::proto::runtime::iocompat::AsyncIoTokioAsStd;
43use crate::proto::{
44    ProtoError,
45    runtime::RuntimeProvider,
46    tcp::TcpClientStream,
47    udp::{UdpClientConnect, UdpClientStream},
48    xfer::{
49        DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
50        DnsMultiplexerConnect, DnsRequest, DnsResponse, Protocol,
51    },
52};
53
54/// Create `DnsHandle` with the help of `RuntimeProvider`.
55/// This trait is designed for customization.
56pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
57    /// The handle to the connection for sending DNS requests.
58    type Conn: DnsHandle + Clone + Send + Sync + 'static;
59    /// Ths future is responsible for spawning any background tasks as necessary.
60    type FutureConn: Future<Output = Result<Self::Conn, ProtoError>> + Send + 'static;
61    /// Provider that handles the underlying I/O and timing.
62    type RuntimeProvider: RuntimeProvider;
63
64    /// Create a new connection.
65    fn new_connection(
66        &self,
67        config: &NameServerConfig,
68        options: &ResolverOpts,
69    ) -> Result<Self::FutureConn, io::Error>;
70}
71
72#[cfg(feature = "__tls")]
73/// Predefined type for TLS client stream
74type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
75
76/// The variants of all supported connections for the Resolver
77#[allow(clippy::large_enum_variant, clippy::type_complexity)]
78pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
79    Udp(DnsExchangeConnect<UdpClientConnect<R>, UdpClientStream<R>, R::Timer>),
80    Tcp(
81        DnsExchangeConnect<
82            DnsMultiplexerConnect<
83                Pin<Box<dyn Future<Output = Result<TcpClientStream<R::Tcp>, ProtoError>> + Send>>,
84                TcpClientStream<<R as RuntimeProvider>::Tcp>,
85            >,
86            DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>>,
87            R::Timer,
88        >,
89    ),
90    #[cfg(feature = "__tls")]
91    Tls(
92        DnsExchangeConnect<
93            DnsMultiplexerConnect<
94                Pin<
95                    Box<
96                        dyn Future<
97                                Output = Result<
98                                    TlsClientStream<<R as RuntimeProvider>::Tcp>,
99                                    ProtoError,
100                                >,
101                            > + Send
102                            + 'static,
103                    >,
104                >,
105                TlsClientStream<<R as RuntimeProvider>::Tcp>,
106            >,
107            DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>>,
108            TokioTime,
109        >,
110    ),
111    #[cfg(all(feature = "__https", feature = "tokio"))]
112    Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
113    #[cfg(all(feature = "__quic", feature = "tokio"))]
114    Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
115    #[cfg(all(feature = "__h3", feature = "tokio"))]
116    H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
117}
118
119/// Resolves to a new Connection
120#[must_use = "futures do nothing unless polled"]
121pub struct ConnectionFuture<R: RuntimeProvider> {
122    pub(crate) connect: ConnectionConnect<R>,
123    pub(crate) spawner: R::Handle,
124}
125
126impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
127    type Output = Result<GenericConnection, ProtoError>;
128
129    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130        Poll::Ready(Ok(match &mut self.connect {
131            ConnectionConnect::Udp(conn) => {
132                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
133                self.spawner.spawn_bg(bg);
134                GenericConnection(conn)
135            }
136            ConnectionConnect::Tcp(conn) => {
137                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
138                self.spawner.spawn_bg(bg);
139                GenericConnection(conn)
140            }
141            #[cfg(feature = "__tls")]
142            ConnectionConnect::Tls(conn) => {
143                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
144                self.spawner.spawn_bg(bg);
145                GenericConnection(conn)
146            }
147            #[cfg(feature = "__https")]
148            ConnectionConnect::Https(conn) => {
149                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
150                self.spawner.spawn_bg(bg);
151                GenericConnection(conn)
152            }
153            #[cfg(feature = "__quic")]
154            ConnectionConnect::Quic(conn) => {
155                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
156                self.spawner.spawn_bg(bg);
157                GenericConnection(conn)
158            }
159            #[cfg(feature = "__h3")]
160            ConnectionConnect::H3(conn) => {
161                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
162                self.spawner.spawn_bg(bg);
163                GenericConnection(conn)
164            }
165        }))
166    }
167}
168
169/// A connected DNS handle
170#[derive(Clone)]
171pub struct GenericConnection(DnsExchange);
172
173impl DnsHandle for GenericConnection {
174    type Response = ConnectionResponse;
175
176    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
177        ConnectionResponse(self.0.send(request))
178    }
179}
180
181/// Default ConnectionProvider with `GenericConnection`.
182#[cfg(feature = "tokio")]
183pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
184
185/// Default connector for `GenericConnection`
186#[derive(Clone)]
187pub struct GenericConnector<P: RuntimeProvider> {
188    runtime_provider: P,
189}
190
191impl<P: RuntimeProvider> GenericConnector<P> {
192    /// Create a new instance.
193    pub fn new(runtime_provider: P) -> Self {
194        Self { runtime_provider }
195    }
196}
197
198impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
199    fn default() -> Self {
200        Self {
201            runtime_provider: P::default(),
202        }
203    }
204}
205
206impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
207    type Conn = GenericConnection;
208    type FutureConn = ConnectionFuture<P>;
209    type RuntimeProvider = P;
210
211    fn new_connection(
212        &self,
213        config: &NameServerConfig,
214        options: &ResolverOpts,
215    ) -> Result<Self::FutureConn, io::Error> {
216        let dns_connect = match (config.protocol, self.runtime_provider.quic_binder()) {
217            (Protocol::Udp, _) => {
218                let provider_handle = self.runtime_provider.clone();
219                let stream = UdpClientStream::builder(config.socket_addr, provider_handle)
220                    .with_timeout(Some(options.timeout))
221                    .with_os_port_selection(options.os_port_selection)
222                    .avoid_local_ports(options.avoid_local_udp_ports.clone())
223                    .with_bind_addr(config.bind_addr)
224                    .build();
225                let exchange = DnsExchange::connect(stream);
226                ConnectionConnect::Udp(exchange)
227            }
228            (Protocol::Tcp, _) => {
229                let (future, handle) = TcpClientStream::new(
230                    config.socket_addr,
231                    config.bind_addr,
232                    Some(options.timeout),
233                    self.runtime_provider.clone(),
234                );
235
236                // TODO: need config for Signer...
237                let dns_conn = DnsMultiplexer::with_timeout(future, handle, options.timeout, None);
238                let exchange = DnsExchange::connect(dns_conn);
239                ConnectionConnect::Tcp(exchange)
240            }
241            #[cfg(feature = "__tls")]
242            (Protocol::Tls, _) => {
243                let socket_addr = config.socket_addr;
244                let timeout = options.timeout;
245                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
246                let tcp_future = self.runtime_provider.connect_tcp(socket_addr, None, None);
247
248                let (stream, handle) = crate::tls::new_tls_stream_with_future(
249                    tcp_future,
250                    socket_addr,
251                    tls_dns_name,
252                    options.tls_config.clone(),
253                );
254
255                let dns_conn = DnsMultiplexer::with_timeout(stream, handle, timeout, None);
256                let exchange = DnsExchange::connect(dns_conn);
257                ConnectionConnect::Tls(exchange)
258            }
259            #[cfg(feature = "__https")]
260            (Protocol::Https, _) => {
261                let socket_addr = config.socket_addr;
262                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
263                let http_endpoint = config
264                    .http_endpoint
265                    .clone()
266                    .unwrap_or_else(|| proto::http::DEFAULT_DNS_QUERY_PATH.to_owned());
267                let tcp_future = self.runtime_provider.connect_tcp(socket_addr, None, None);
268
269                let exchange = crate::h2::new_https_stream_with_future(
270                    tcp_future,
271                    socket_addr,
272                    tls_dns_name,
273                    http_endpoint,
274                    Arc::new(options.tls_config.clone()),
275                );
276                ConnectionConnect::Https(exchange)
277            }
278            #[cfg(feature = "__quic")]
279            (Protocol::Quic, Some(binder)) => {
280                let socket_addr = config.socket_addr;
281                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
282                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
283                    SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
284                });
285                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
286                let client_config = options.tls_config.clone();
287                let socket = binder.bind_quic(bind_addr, socket_addr)?;
288
289                let exchange = crate::quic::new_quic_stream_with_future(
290                    socket,
291                    socket_addr,
292                    tls_dns_name,
293                    client_config,
294                );
295                ConnectionConnect::Quic(exchange)
296            }
297            #[cfg(feature = "__h3")]
298            (Protocol::H3, Some(binder)) => {
299                let socket_addr = config.socket_addr;
300                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
301                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
302                    SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
303                });
304                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
305                let http_endpoint = config
306                    .http_endpoint
307                    .clone()
308                    .unwrap_or_else(|| proto::http::DEFAULT_DNS_QUERY_PATH.to_owned());
309                let client_config = options.tls_config.clone();
310                let socket = binder.bind_quic(bind_addr, socket_addr)?;
311
312                let exchange = crate::h3::new_h3_stream_with_future(
313                    socket,
314                    socket_addr,
315                    tls_dns_name,
316                    http_endpoint,
317                    client_config,
318                );
319                ConnectionConnect::H3(exchange)
320            }
321            (protocol, _) => {
322                return Err(io::Error::new(
323                    io::ErrorKind::InvalidInput,
324                    format!("unsupported protocol: {protocol:?}"),
325                ));
326            }
327        };
328
329        Ok(ConnectionFuture::<P> {
330            connect: dns_connect,
331            spawner: self.runtime_provider.create_handle(),
332        })
333    }
334}
335
336/// A stream of response to a DNS request.
337#[must_use = "streams do nothing unless polled"]
338pub struct ConnectionResponse(DnsExchangeSend);
339
340impl Stream for ConnectionResponse {
341    type Item = Result<DnsResponse, ProtoError>;
342
343    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
344        Poll::Ready(ready!(self.0.poll_next_unpin(cx)))
345    }
346}