1use core::fmt::Display;
8use core::fmt::{self, Debug};
9use core::future::Future;
10use core::marker::PhantomData;
11use core::net::SocketAddr;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::io;
16
17use futures_channel::mpsc;
18use futures_channel::oneshot;
19use futures_util::future::BoxFuture;
20use futures_util::ready;
21use futures_util::stream::{Fuse, Peekable};
22use futures_util::stream::{Stream, StreamExt};
23#[cfg(feature = "serde")]
24use serde::{Deserialize, Serialize};
25use tracing::{debug, warn};
26
27use crate::error::NetError;
28use crate::proto::ProtoError;
29use crate::proto::op::{DnsRequest, DnsResponse, SerialMessage};
30use crate::runtime::{RuntimeProvider, Time};
31
32mod dns_exchange;
33pub use dns_exchange::{DnsExchange, DnsExchangeBackground, DnsExchangeSend};
34
35pub mod dns_handle;
36pub use dns_handle::{DnsHandle, DnsStreamHandle};
37
38pub mod dns_multiplexer;
39pub use dns_multiplexer::DnsMultiplexer;
40
41pub mod retry_dns_handle;
42pub use retry_dns_handle::RetryDnsHandle;
43
44pub struct DnsResponseStream {
46 inner: DnsResponseStreamInner,
47 done: bool,
48}
49
50impl DnsResponseStream {
51 fn new(inner: DnsResponseStreamInner) -> Self {
52 Self { inner, done: false }
53 }
54}
55
56impl Stream for DnsResponseStream {
57 type Item = Result<DnsResponse, NetError>;
58
59 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60 use DnsResponseStreamInner::*;
61
62 if self.done {
64 return Poll::Ready(None);
65 }
66
67 let Self { inner, done } = self.get_mut();
69
70 let result = match inner {
71 Timeout(fut) => {
72 let x = match ready!(fut.as_mut().poll(cx)) {
73 Ok(x) => x,
74 Err(e) => Err(e.into()),
75 };
76 *done = true;
77 x
78 }
79 Receiver(fut) => match ready!(Pin::new(fut).poll_next(cx)) {
80 Some(Ok(x)) => Ok(x),
81 Some(Err(e)) => Err(e),
82 None => return Poll::Ready(None),
83 },
84 Error(err) => {
85 *done = true;
86 Err(err.take().expect("cannot poll after complete"))
87 }
88 Boxed(fut) => {
89 let x = ready!(fut.as_mut().poll(cx));
90 *done = true;
91 x
92 }
93 };
94
95 match result {
96 Err(NetError::Timeout) => Poll::Ready(None),
97 r => Poll::Ready(Some(r)),
98 }
99 }
100}
101
102impl From<TimeoutFuture> for DnsResponseStream {
103 fn from(f: TimeoutFuture) -> Self {
104 Self::new(DnsResponseStreamInner::Timeout(f))
105 }
106}
107
108impl From<mpsc::Receiver<Result<DnsResponse, NetError>>> for DnsResponseStream {
109 fn from(receiver: mpsc::Receiver<Result<DnsResponse, NetError>>) -> Self {
110 Self::new(DnsResponseStreamInner::Receiver(receiver))
111 }
112}
113
114impl From<NetError> for DnsResponseStream {
115 fn from(e: NetError) -> Self {
116 Self::new(DnsResponseStreamInner::Error(Some(e)))
117 }
118}
119
120impl<F> From<Pin<Box<F>>> for DnsResponseStream
121where
122 F: Future<Output = Result<DnsResponse, NetError>> + Send + 'static,
123{
124 fn from(f: Pin<Box<F>>) -> Self {
125 Self::new(DnsResponseStreamInner::Boxed(f))
126 }
127}
128
129enum DnsResponseStreamInner {
130 Timeout(TimeoutFuture),
131 Receiver(mpsc::Receiver<Result<DnsResponse, NetError>>),
132 Error(Option<NetError>),
133 Boxed(BoxFuture<'static, Result<DnsResponse, NetError>>),
134}
135
136type TimeoutFuture = BoxFuture<'static, Result<Result<DnsResponse, NetError>, io::Error>>;
137
138fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) {
140 if let Err(error) = result {
141 if error.is_disconnected() {
142 debug!("ignoring send error on disconnected stream");
143 return;
144 }
145
146 warn!("error notifying wait, possible future leak: {:?}", error);
147 }
148}
149
150pub trait DnsClientStream:
152 Stream<Item = Result<SerialMessage, NetError>> + Unpin + Send + 'static
153{
154 type Time: Time;
156
157 fn name_server_addr(&self) -> SocketAddr;
159}
160
161pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>;
163
164#[derive(Clone)]
170pub struct BufDnsStreamHandle {
171 remote_addr: SocketAddr,
172 sender: mpsc::Sender<SerialMessage>,
173}
174
175impl BufDnsStreamHandle {
176 pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) {
182 Self::with_buffer_size(remote_addr, DEFAULT_STREAM_BUFFER_SIZE)
183 }
184
185 pub fn with_buffer_size(remote_addr: SocketAddr, buffer_size: usize) -> (Self, StreamReceiver) {
195 let (sender, receiver) = mpsc::channel(buffer_size);
196 let receiver = receiver.fuse().peekable();
197
198 let this = Self {
199 remote_addr,
200 sender,
201 };
202
203 (this, receiver)
204 }
205
206 pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self {
210 Self {
211 remote_addr,
212 sender: self.sender.clone(),
213 }
214 }
215}
216
217impl DnsStreamHandle for BufDnsStreamHandle {
218 fn send(&mut self, buffer: SerialMessage) -> Result<(), NetError> {
219 let sender: &mut _ = &mut self.sender;
220 sender
221 .try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr))
222 .map_err(|e| NetError::from(format!("mpsc::SendError {e}")))
223 }
224}
225
226pub trait DnsRequestSender: Stream<Item = Result<(), NetError>> + Send + Unpin + 'static {
232 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream;
238
239 fn shutdown(&mut self);
243
244 fn is_shutdown(&self) -> bool;
246}
247
248#[derive(Clone)]
250
251pub struct BufDnsRequestStreamHandle<P> {
252 sender: mpsc::Sender<OneshotDnsRequest>,
253 _phantom: PhantomData<P>,
254}
255
256impl<P: RuntimeProvider> DnsHandle for BufDnsRequestStreamHandle<P> {
257 type Response = DnsResponseReceiver;
258 type Runtime = P;
259
260 fn send(&self, request: DnsRequest) -> Self::Response {
261 debug!(
262 "enqueueing message:{}:{:?}",
263 request.op_code, request.queries
264 );
265
266 let (request, oneshot) = OneshotDnsRequest::oneshot(request);
267 let mut sender = self.sender.clone();
268 let try_send = sender.try_send(request).map_err(|_| {
269 debug!("unable to enqueue message");
270 NetError::Busy
271 });
272
273 match try_send {
274 Ok(val) => val,
275 Err(err) => return DnsResponseReceiver::Err(Some(err)),
276 }
277
278 DnsResponseReceiver::Receiver(oneshot)
279 }
280}
281
282pub struct OneshotDnsRequest {
285 dns_request: DnsRequest,
286 sender_for_response: oneshot::Sender<DnsResponseStream>,
287}
288
289impl OneshotDnsRequest {
290 fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) {
291 let (sender_for_response, receiver) = oneshot::channel();
292
293 (
294 Self {
295 dns_request,
296 sender_for_response,
297 },
298 receiver,
299 )
300 }
301
302 fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) {
303 (
304 self.dns_request,
305 OneshotDnsResponse(self.sender_for_response),
306 )
307 }
308}
309
310struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>);
311
312impl OneshotDnsResponse {
313 fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> {
314 self.0.send(serial_response)
315 }
316}
317
318pub enum DnsResponseReceiver {
320 Receiver(oneshot::Receiver<DnsResponseStream>),
322 Received(DnsResponseStream),
324 Err(Option<NetError>),
326}
327
328impl Stream for DnsResponseReceiver {
329 type Item = Result<DnsResponse, NetError>;
330
331 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
332 loop {
333 *self = match &mut *self {
334 Self::Receiver(receiver) => {
335 let receiver = Pin::new(receiver);
336 let future = ready!(
337 receiver
338 .poll(cx)
339 .map_err(|_| NetError::from("receiver was canceled"))
340 )?;
341 Self::Received(future)
342 }
343 Self::Received(stream) => {
344 return stream.poll_next_unpin(cx);
345 }
346 Self::Err(err) => return Poll::Ready(err.take().map(Err)),
347 };
348 }
349 }
350}
351
352pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized {
354 fn first_answer(self) -> FirstAnswerFuture<Self> {
357 FirstAnswerFuture { stream: Some(self) }
358 }
359}
360
361impl<E, S, T> FirstAnswer<T, E> for S
362where
363 S: Stream<Item = Result<T, E>> + Unpin + Sized,
364 E: From<ProtoError>,
365{
366}
367
368#[derive(Debug)]
370#[must_use = "futures do nothing unless you `.await` or poll them"]
371pub struct FirstAnswerFuture<S> {
372 stream: Option<S>,
373}
374
375impl<S: Stream<Item = Result<T, NetError>> + Unpin, T> Future for FirstAnswerFuture<S>
376where
377 S: Stream<Item = Result<T, NetError>> + Unpin + Sized,
378{
379 type Output = S::Item;
380
381 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
382 let s = self
383 .stream
384 .as_mut()
385 .expect("polling FirstAnswerFuture twice");
386 let item = match ready!(s.poll_next_unpin(cx)) {
387 Some(r) => r,
388 None => Err(NetError::Timeout),
389 };
390 self.stream.take();
391 Poll::Ready(item)
392 }
393}
394
395#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
397#[cfg_attr(
398 feature = "serde",
399 derive(Serialize, Deserialize),
400 serde(rename_all = "lowercase")
401)]
402#[non_exhaustive]
403pub enum Protocol {
404 Udp,
406 Tcp,
408 #[cfg(feature = "__tls")]
410 Tls,
411 #[cfg(feature = "__https")]
413 Https,
414 #[cfg(feature = "__quic")]
416 Quic,
417 #[cfg(feature = "__h3")]
419 H3,
420}
421
422impl Protocol {
423 pub fn is_datagram(self) -> bool {
425 matches!(self, Self::Udp)
426 }
427
428 pub fn is_stream(self) -> bool {
430 !self.is_datagram()
431 }
432
433 pub fn is_encrypted(self) -> bool {
435 match self {
436 Self::Udp => false,
437 Self::Tcp => false,
438 #[cfg(feature = "__tls")]
439 Self::Tls => true,
440 #[cfg(feature = "__https")]
441 Self::Https => true,
442 #[cfg(feature = "__quic")]
443 Self::Quic => true,
444 #[cfg(feature = "__h3")]
445 Self::H3 => true,
446 }
447 }
448}
449
450impl Display for Protocol {
451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 f.write_str(match self {
453 Self::Udp => "udp",
454 Self::Tcp => "tcp",
455 #[cfg(feature = "__tls")]
456 Self::Tls => "tls",
457 #[cfg(feature = "__https")]
458 Self::Https => "https",
459 #[cfg(feature = "__quic")]
460 Self::Quic => "quic",
461 #[cfg(feature = "__h3")]
462 Self::H3 => "h3",
463 })
464 }
465}
466
467#[allow(unused)] pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
469
470const DEFAULT_STREAM_BUFFER_SIZE: usize = 32;