hickory_resolver/name_server/
connection_provider.rs1use std::future::Future;
9use std::io;
10use std::marker::Unpin;
11#[cfg(feature = "__quic")]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
13use std::pin::Pin;
14#[cfg(feature = "__https")]
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18use crate::proto::runtime::Spawn;
19#[cfg(feature = "tokio")]
20use crate::proto::runtime::TokioRuntimeProvider;
21#[cfg(feature = "__tls")]
22use crate::proto::runtime::iocompat::AsyncIoStdAsTokio;
23use futures_util::future::FutureExt;
24use futures_util::ready;
25use futures_util::stream::{Stream, StreamExt};
26#[cfg(feature = "__tls")]
27use tokio_rustls::client::TlsStream as TokioTlsStream;
28
29use crate::config::{NameServerConfig, ResolverOpts};
30#[cfg(any(feature = "__h3", feature = "__https"))]
31use crate::proto;
32#[cfg(feature = "__https")]
33use crate::proto::h2::{HttpsClientConnect, HttpsClientStream};
34#[cfg(feature = "__h3")]
35use crate::proto::h3::{H3ClientConnect, H3ClientStream};
36#[cfg(feature = "__quic")]
37use crate::proto::quic::{QuicClientConnect, QuicClientStream};
38#[cfg(feature = "tokio")]
39#[allow(unused_imports)] use crate::proto::runtime::TokioTime;
41#[cfg(feature = "__tls")]
42use crate::proto::runtime::iocompat::AsyncIoTokioAsStd;
43use crate::proto::{
44 ProtoError,
45 runtime::RuntimeProvider,
46 tcp::TcpClientStream,
47 udp::{UdpClientConnect, UdpClientStream},
48 xfer::{
49 DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
50 DnsMultiplexerConnect, DnsRequest, DnsResponse, Protocol,
51 },
52};
53
54pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
57 type Conn: DnsHandle + Clone + Send + Sync + 'static;
59 type FutureConn: Future<Output = Result<Self::Conn, ProtoError>> + Send + 'static;
61 type RuntimeProvider: RuntimeProvider;
63
64 fn new_connection(
66 &self,
67 config: &NameServerConfig,
68 options: &ResolverOpts,
69 ) -> Result<Self::FutureConn, io::Error>;
70}
71
72#[cfg(feature = "__tls")]
73type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
75
76#[allow(clippy::large_enum_variant, clippy::type_complexity)]
78pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
79 Udp(DnsExchangeConnect<UdpClientConnect<R>, UdpClientStream<R>, R::Timer>),
80 Tcp(
81 DnsExchangeConnect<
82 DnsMultiplexerConnect<
83 Pin<Box<dyn Future<Output = Result<TcpClientStream<R::Tcp>, ProtoError>> + Send>>,
84 TcpClientStream<<R as RuntimeProvider>::Tcp>,
85 >,
86 DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>>,
87 R::Timer,
88 >,
89 ),
90 #[cfg(feature = "__tls")]
91 Tls(
92 DnsExchangeConnect<
93 DnsMultiplexerConnect<
94 Pin<
95 Box<
96 dyn Future<
97 Output = Result<
98 TlsClientStream<<R as RuntimeProvider>::Tcp>,
99 ProtoError,
100 >,
101 > + Send
102 + 'static,
103 >,
104 >,
105 TlsClientStream<<R as RuntimeProvider>::Tcp>,
106 >,
107 DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>>,
108 TokioTime,
109 >,
110 ),
111 #[cfg(all(feature = "__https", feature = "tokio"))]
112 Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
113 #[cfg(all(feature = "__quic", feature = "tokio"))]
114 Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
115 #[cfg(all(feature = "__h3", feature = "tokio"))]
116 H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
117}
118
119#[must_use = "futures do nothing unless polled"]
121pub struct ConnectionFuture<R: RuntimeProvider> {
122 pub(crate) connect: ConnectionConnect<R>,
123 pub(crate) spawner: R::Handle,
124}
125
126impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
127 type Output = Result<GenericConnection, ProtoError>;
128
129 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 Poll::Ready(Ok(match &mut self.connect {
131 ConnectionConnect::Udp(conn) => {
132 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
133 self.spawner.spawn_bg(bg);
134 GenericConnection(conn)
135 }
136 ConnectionConnect::Tcp(conn) => {
137 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
138 self.spawner.spawn_bg(bg);
139 GenericConnection(conn)
140 }
141 #[cfg(feature = "__tls")]
142 ConnectionConnect::Tls(conn) => {
143 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
144 self.spawner.spawn_bg(bg);
145 GenericConnection(conn)
146 }
147 #[cfg(feature = "__https")]
148 ConnectionConnect::Https(conn) => {
149 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
150 self.spawner.spawn_bg(bg);
151 GenericConnection(conn)
152 }
153 #[cfg(feature = "__quic")]
154 ConnectionConnect::Quic(conn) => {
155 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
156 self.spawner.spawn_bg(bg);
157 GenericConnection(conn)
158 }
159 #[cfg(feature = "__h3")]
160 ConnectionConnect::H3(conn) => {
161 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
162 self.spawner.spawn_bg(bg);
163 GenericConnection(conn)
164 }
165 }))
166 }
167}
168
169#[derive(Clone)]
171pub struct GenericConnection(DnsExchange);
172
173impl DnsHandle for GenericConnection {
174 type Response = ConnectionResponse;
175
176 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
177 ConnectionResponse(self.0.send(request))
178 }
179}
180
181#[cfg(feature = "tokio")]
183pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
184
185#[derive(Clone)]
187pub struct GenericConnector<P: RuntimeProvider> {
188 runtime_provider: P,
189}
190
191impl<P: RuntimeProvider> GenericConnector<P> {
192 pub fn new(runtime_provider: P) -> Self {
194 Self { runtime_provider }
195 }
196}
197
198impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
199 fn default() -> Self {
200 Self {
201 runtime_provider: P::default(),
202 }
203 }
204}
205
206impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
207 type Conn = GenericConnection;
208 type FutureConn = ConnectionFuture<P>;
209 type RuntimeProvider = P;
210
211 fn new_connection(
212 &self,
213 config: &NameServerConfig,
214 options: &ResolverOpts,
215 ) -> Result<Self::FutureConn, io::Error> {
216 let dns_connect = match (config.protocol, self.runtime_provider.quic_binder()) {
217 (Protocol::Udp, _) => {
218 let provider_handle = self.runtime_provider.clone();
219 let stream = UdpClientStream::builder(config.socket_addr, provider_handle)
220 .with_timeout(Some(options.timeout))
221 .with_os_port_selection(options.os_port_selection)
222 .avoid_local_ports(options.avoid_local_udp_ports.clone())
223 .with_bind_addr(config.bind_addr)
224 .build();
225 let exchange = DnsExchange::connect(stream);
226 ConnectionConnect::Udp(exchange)
227 }
228 (Protocol::Tcp, _) => {
229 let (future, handle) = TcpClientStream::new(
230 config.socket_addr,
231 config.bind_addr,
232 Some(options.timeout),
233 self.runtime_provider.clone(),
234 );
235
236 let dns_conn = DnsMultiplexer::with_timeout(future, handle, options.timeout, None);
238 let exchange = DnsExchange::connect(dns_conn);
239 ConnectionConnect::Tcp(exchange)
240 }
241 #[cfg(feature = "__tls")]
242 (Protocol::Tls, _) => {
243 let socket_addr = config.socket_addr;
244 let timeout = options.timeout;
245 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
246 let tcp_future = self.runtime_provider.connect_tcp(socket_addr, None, None);
247
248 let (stream, handle) = crate::tls::new_tls_stream_with_future(
249 tcp_future,
250 socket_addr,
251 tls_dns_name,
252 options.tls_config.clone(),
253 );
254
255 let dns_conn = DnsMultiplexer::with_timeout(stream, handle, timeout, None);
256 let exchange = DnsExchange::connect(dns_conn);
257 ConnectionConnect::Tls(exchange)
258 }
259 #[cfg(feature = "__https")]
260 (Protocol::Https, _) => {
261 let socket_addr = config.socket_addr;
262 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
263 let http_endpoint = config
264 .http_endpoint
265 .clone()
266 .unwrap_or_else(|| proto::http::DEFAULT_DNS_QUERY_PATH.to_owned());
267 let tcp_future = self.runtime_provider.connect_tcp(socket_addr, None, None);
268
269 let exchange = crate::h2::new_https_stream_with_future(
270 tcp_future,
271 socket_addr,
272 tls_dns_name,
273 http_endpoint,
274 Arc::new(options.tls_config.clone()),
275 );
276 ConnectionConnect::Https(exchange)
277 }
278 #[cfg(feature = "__quic")]
279 (Protocol::Quic, Some(binder)) => {
280 let socket_addr = config.socket_addr;
281 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
282 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
283 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
284 });
285 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
286 let client_config = options.tls_config.clone();
287 let socket = binder.bind_quic(bind_addr, socket_addr)?;
288
289 let exchange = crate::quic::new_quic_stream_with_future(
290 socket,
291 socket_addr,
292 tls_dns_name,
293 client_config,
294 );
295 ConnectionConnect::Quic(exchange)
296 }
297 #[cfg(feature = "__h3")]
298 (Protocol::H3, Some(binder)) => {
299 let socket_addr = config.socket_addr;
300 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
301 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
302 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
303 });
304 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
305 let http_endpoint = config
306 .http_endpoint
307 .clone()
308 .unwrap_or_else(|| proto::http::DEFAULT_DNS_QUERY_PATH.to_owned());
309 let client_config = options.tls_config.clone();
310 let socket = binder.bind_quic(bind_addr, socket_addr)?;
311
312 let exchange = crate::h3::new_h3_stream_with_future(
313 socket,
314 socket_addr,
315 tls_dns_name,
316 http_endpoint,
317 client_config,
318 );
319 ConnectionConnect::H3(exchange)
320 }
321 (protocol, _) => {
322 return Err(io::Error::new(
323 io::ErrorKind::InvalidInput,
324 format!("unsupported protocol: {protocol:?}"),
325 ));
326 }
327 };
328
329 Ok(ConnectionFuture::<P> {
330 connect: dns_connect,
331 spawner: self.runtime_provider.create_handle(),
332 })
333 }
334}
335
336#[must_use = "streams do nothing unless polled"]
338pub struct ConnectionResponse(DnsExchangeSend);
339
340impl Stream for ConnectionResponse {
341 type Item = Result<DnsResponse, ProtoError>;
342
343 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
344 Poll::Ready(ready!(self.0.poll_next_unpin(cx)))
345 }
346}