1use std::io;
9use std::marker::Unpin;
10use std::net::SocketAddr;
11#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::{Future, FutureExt};
18use futures_util::ready;
19use futures_util::stream::{Stream, StreamExt};
20#[cfg(feature = "tokio-runtime")]
21use tokio::net::TcpStream as TokioTcpStream;
22#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
23use tokio_native_tls::TlsStream as TokioTlsStream;
24#[cfg(all(
25 feature = "dns-over-openssl",
26 not(feature = "dns-over-rustls"),
27 not(feature = "dns-over-native-tls")
28))]
29use tokio_openssl::SslStream as TokioTlsStream;
30#[cfg(feature = "dns-over-rustls")]
31use tokio_rustls::client::TlsStream as TokioTlsStream;
32
33use crate::config::{NameServerConfig, Protocol, ResolverOpts};
34#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
35use hickory_proto::udp::QuicLocalAddr;
36#[cfg(feature = "dns-over-https")]
37use proto::h2::{HttpsClientConnect, HttpsClientStream};
38#[cfg(feature = "dns-over-h3")]
39use proto::h3::{H3ClientConnect, H3ClientStream};
40#[cfg(feature = "dns-over-quic")]
41use proto::quic::{QuicClientConnect, QuicClientStream};
42use proto::tcp::DnsTcpStream;
43use proto::udp::DnsUdpSocket;
44use proto::{
45 self,
46 error::ProtoError,
47 op::NoopMessageFinalizer,
48 tcp::TcpClientConnect,
49 tcp::TcpClientStream,
50 udp::UdpClientConnect,
51 udp::UdpClientStream,
52 xfer::{
53 DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
54 DnsMultiplexerConnect, DnsRequest, DnsResponse,
55 },
56 Time,
57};
58#[cfg(feature = "tokio-runtime")]
59use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
60
61use crate::error::ResolveError;
62
63pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
65 type Handle: Clone + Send + Spawn + Sync + Unpin;
67
68 type Timer: Time + Send + Unpin;
70
71 #[cfg(not(any(feature = "dns-over-quic", feature = "dns-over-h3")))]
72 type Udp: DnsUdpSocket + Send;
74 #[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
75 type Udp: DnsUdpSocket + QuicLocalAddr + Send;
77
78 type Tcp: DnsTcpStream;
80
81 fn create_handle(&self) -> Self::Handle;
83
84 fn connect_tcp(
86 &self,
87 server_addr: SocketAddr,
88 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
89
90 fn bind_udp(
93 &self,
94 local_addr: SocketAddr,
95 server_addr: SocketAddr,
96 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
97}
98
99pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
102 type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
104 type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
106 type RuntimeProvider: RuntimeProvider;
108
109 fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
111 -> Self::FutureConn;
112}
113
114pub trait Spawn {
116 fn spawn_bg<F>(&mut self, future: F)
118 where
119 F: Future<Output = Result<(), ProtoError>> + Send + 'static;
120}
121
122#[cfg(feature = "dns-over-tls")]
123type TlsClientStream<S> =
125 TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
126
127#[allow(clippy::large_enum_variant, clippy::type_complexity)]
129pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
130 Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
131 Tcp(
132 DnsExchangeConnect<
133 DnsMultiplexerConnect<
134 TcpClientConnect<<R as RuntimeProvider>::Tcp>,
135 TcpClientStream<<R as RuntimeProvider>::Tcp>,
136 NoopMessageFinalizer,
137 >,
138 DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
139 R::Timer,
140 >,
141 ),
142 #[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
143 Tls(
144 DnsExchangeConnect<
145 DnsMultiplexerConnect<
146 Pin<
147 Box<
148 dyn Future<
149 Output = Result<
150 TlsClientStream<<R as RuntimeProvider>::Tcp>,
151 ProtoError,
152 >,
153 > + Send
154 + 'static,
155 >,
156 >,
157 TlsClientStream<<R as RuntimeProvider>::Tcp>,
158 NoopMessageFinalizer,
159 >,
160 DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
161 TokioTime,
162 >,
163 ),
164 #[cfg(all(feature = "dns-over-https", feature = "tokio-runtime"))]
165 Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
166 #[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
167 Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
168 #[cfg(all(feature = "dns-over-h3", feature = "tokio-runtime"))]
169 H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
170}
171
172#[must_use = "futures do nothing unless polled"]
174pub struct ConnectionFuture<R: RuntimeProvider> {
175 pub(crate) connect: ConnectionConnect<R>,
176 pub(crate) spawner: R::Handle,
177}
178
179impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
180 type Output = Result<GenericConnection, ResolveError>;
181
182 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183 Poll::Ready(Ok(match &mut self.connect {
184 ConnectionConnect::Udp(ref mut conn) => {
185 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
186 self.spawner.spawn_bg(bg);
187 GenericConnection(conn)
188 }
189 ConnectionConnect::Tcp(ref mut conn) => {
190 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
191 self.spawner.spawn_bg(bg);
192 GenericConnection(conn)
193 }
194 #[cfg(feature = "dns-over-tls")]
195 ConnectionConnect::Tls(ref mut conn) => {
196 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
197 self.spawner.spawn_bg(bg);
198 GenericConnection(conn)
199 }
200 #[cfg(feature = "dns-over-https")]
201 ConnectionConnect::Https(ref mut conn) => {
202 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
203 self.spawner.spawn_bg(bg);
204 GenericConnection(conn)
205 }
206 #[cfg(feature = "dns-over-quic")]
207 ConnectionConnect::Quic(ref mut conn) => {
208 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
209 self.spawner.spawn_bg(bg);
210 GenericConnection(conn)
211 }
212 #[cfg(feature = "dns-over-h3")]
213 ConnectionConnect::H3(ref mut conn) => {
214 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
215 self.spawner.spawn_bg(bg);
216 GenericConnection(conn)
217 }
218 }))
219 }
220}
221
222#[derive(Clone)]
224pub struct GenericConnection(DnsExchange);
225
226impl DnsHandle for GenericConnection {
227 type Response = ConnectionResponse;
228 type Error = ResolveError;
229
230 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
231 ConnectionResponse(self.0.send(request))
232 }
233}
234
235#[derive(Clone)]
237pub struct GenericConnector<P: RuntimeProvider> {
238 runtime_provider: P,
239}
240
241impl<P: RuntimeProvider> GenericConnector<P> {
242 pub fn new(runtime_provider: P) -> Self {
244 Self { runtime_provider }
245 }
246}
247
248impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
249 fn default() -> Self {
250 Self {
251 runtime_provider: P::default(),
252 }
253 }
254}
255
256impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
257 type Conn = GenericConnection;
258 type FutureConn = ConnectionFuture<P>;
259 type RuntimeProvider = P;
260
261 fn new_connection(
262 &self,
263 config: &NameServerConfig,
264 options: &ResolverOpts,
265 ) -> Self::FutureConn {
266 let dns_connect = match config.protocol {
267 Protocol::Udp => {
268 let provider_handle = self.runtime_provider.clone();
269 let closure = move |local_addr: SocketAddr, server_addr: SocketAddr| {
270 provider_handle.bind_udp(local_addr, server_addr)
271 };
272 let stream = UdpClientStream::with_creator(
273 config.socket_addr,
274 None,
275 options.timeout,
276 Arc::new(closure),
277 );
278 let exchange = DnsExchange::connect(stream);
279 ConnectionConnect::Udp(exchange)
280 }
281 Protocol::Tcp => {
282 let socket_addr = config.socket_addr;
283 let timeout = options.timeout;
284 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
285
286 let (stream, handle) =
287 TcpClientStream::with_future(tcp_future, socket_addr, timeout);
288 let dns_conn = DnsMultiplexer::with_timeout(
290 stream,
291 handle,
292 timeout,
293 NoopMessageFinalizer::new(),
294 );
295
296 let exchange = DnsExchange::connect(dns_conn);
297 ConnectionConnect::Tcp(exchange)
298 }
299 #[cfg(feature = "dns-over-tls")]
300 Protocol::Tls => {
301 let socket_addr = config.socket_addr;
302 let timeout = options.timeout;
303 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
304 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
305
306 #[cfg(feature = "dns-over-rustls")]
307 let client_config = config.tls_config.clone();
308
309 #[cfg(feature = "dns-over-rustls")]
310 let (stream, handle) = {
311 crate::tls::new_tls_stream_with_future(
312 tcp_future,
313 socket_addr,
314 tls_dns_name,
315 client_config,
316 )
317 };
318 #[cfg(not(feature = "dns-over-rustls"))]
319 let (stream, handle) = {
320 crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
321 };
322
323 let dns_conn = DnsMultiplexer::with_timeout(
324 stream,
325 handle,
326 timeout,
327 NoopMessageFinalizer::new(),
328 );
329
330 let exchange = DnsExchange::connect(dns_conn);
331 ConnectionConnect::Tls(exchange)
332 }
333 #[cfg(feature = "dns-over-https")]
334 Protocol::Https => {
335 let socket_addr = config.socket_addr;
336 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
337 #[cfg(feature = "dns-over-rustls")]
338 let client_config = config.tls_config.clone();
339 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
340
341 let exchange = crate::h2::new_https_stream_with_future(
342 tcp_future,
343 socket_addr,
344 tls_dns_name,
345 client_config,
346 );
347 ConnectionConnect::Https(exchange)
348 }
349 #[cfg(feature = "dns-over-quic")]
350 Protocol::Quic => {
351 let socket_addr = config.socket_addr;
352 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
353 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
354 SocketAddr::V6(_) => {
355 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
356 }
357 });
358 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
359 #[cfg(feature = "dns-over-rustls")]
360 let client_config = config.tls_config.clone();
361 let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
362
363 let exchange = crate::quic::new_quic_stream_with_future(
364 udp_future,
365 socket_addr,
366 tls_dns_name,
367 client_config,
368 );
369 ConnectionConnect::Quic(exchange)
370 }
371 #[cfg(feature = "dns-over-h3")]
372 Protocol::H3 => {
373 let socket_addr = config.socket_addr;
374 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
375 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
376 SocketAddr::V6(_) => {
377 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
378 }
379 });
380 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
381 let client_config = config.tls_config.clone();
382 let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
383
384 let exchange = crate::h3::new_h3_stream_with_future(
385 udp_future,
386 socket_addr,
387 tls_dns_name,
388 client_config,
389 );
390 ConnectionConnect::H3(exchange)
391 }
392 };
393
394 ConnectionFuture::<P> {
395 connect: dns_connect,
396 spawner: self.runtime_provider.create_handle(),
397 }
398 }
399}
400
401#[must_use = "steam do nothing unless polled"]
403pub struct ConnectionResponse(DnsExchangeSend);
404
405impl Stream for ConnectionResponse {
406 type Item = Result<DnsResponse, ResolveError>;
407
408 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
409 Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
410 }
411}
412
413#[cfg(feature = "tokio-runtime")]
414#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
415#[allow(unreachable_pub)]
416pub mod tokio_runtime {
417 use super::*;
418 use std::sync::{Arc, Mutex};
419 use tokio::net::UdpSocket as TokioUdpSocket;
420 use tokio::task::JoinSet;
421
422 #[derive(Clone, Default)]
424 pub struct TokioHandle {
425 join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
426 }
427
428 impl Spawn for TokioHandle {
429 fn spawn_bg<F>(&mut self, future: F)
430 where
431 F: Future<Output = Result<(), ProtoError>> + Send + 'static,
432 {
433 let mut join_set = self.join_set.lock().unwrap();
434 join_set.spawn(future);
435 reap_tasks(&mut join_set);
436 }
437 }
438
439 #[derive(Clone, Default)]
441 pub struct TokioRuntimeProvider(TokioHandle);
442
443 impl TokioRuntimeProvider {
444 pub fn new() -> Self {
446 Self::default()
447 }
448 }
449
450 impl RuntimeProvider for TokioRuntimeProvider {
451 type Handle = TokioHandle;
452 type Timer = TokioTime;
453 type Udp = TokioUdpSocket;
454 type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
455
456 fn create_handle(&self) -> Self::Handle {
457 self.0.clone()
458 }
459
460 fn connect_tcp(
461 &self,
462 server_addr: SocketAddr,
463 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
464 Box::pin(async move {
465 TokioTcpStream::connect(server_addr)
466 .await
467 .map(AsyncIoTokioAsStd)
468 })
469 }
470
471 fn bind_udp(
472 &self,
473 local_addr: SocketAddr,
474 _server_addr: SocketAddr,
475 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
476 Box::pin(tokio::net::UdpSocket::bind(local_addr))
477 }
478 }
479
480 fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
482 while FutureExt::now_or_never(join_set.join_next())
483 .flatten()
484 .is_some()
485 {}
486 }
487
488 pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
490}