use core::fmt::Display;
use core::fmt::{self, Debug};
use core::future::Future;
use core::marker::PhantomData;
use core::net::SocketAddr;
use core::pin::Pin;
use core::task::{Context, Poll};
use core::time::Duration;
use std::io;
use futures_channel::mpsc;
use futures_channel::oneshot;
use futures_util::future::BoxFuture;
use futures_util::ready;
use futures_util::stream::{Fuse, Peekable};
use futures_util::stream::{Stream, StreamExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use crate::error::NetError;
use crate::proto::ProtoError;
use crate::proto::op::{DnsRequest, DnsResponse, SerialMessage};
use crate::runtime::{RuntimeProvider, Time};
mod dns_exchange;
pub use dns_exchange::{DnsExchange, DnsExchangeBackground, DnsExchangeSend};
pub mod dns_handle;
pub use dns_handle::{DnsHandle, DnsStreamHandle};
pub mod dns_multiplexer;
pub use dns_multiplexer::DnsMultiplexer;
pub mod retry_dns_handle;
pub use retry_dns_handle::RetryDnsHandle;
pub struct DnsResponseStream {
inner: DnsResponseStreamInner,
done: bool,
}
impl DnsResponseStream {
fn new(inner: DnsResponseStreamInner) -> Self {
Self { inner, done: false }
}
}
impl Stream for DnsResponseStream {
type Item = Result<DnsResponse, NetError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use DnsResponseStreamInner::*;
if self.done {
return Poll::Ready(None);
}
let Self { inner, done } = self.get_mut();
let result = match inner {
Timeout(fut) => {
let x = match ready!(fut.as_mut().poll(cx)) {
Ok(x) => x,
Err(e) => Err(e.into()),
};
*done = true;
x
}
Receiver(fut) => match ready!(Pin::new(fut).poll_next(cx)) {
Some(Ok(x)) => Ok(x),
Some(Err(e)) => Err(e),
None => return Poll::Ready(None),
},
Error(err) => {
*done = true;
Err(err.take().expect("cannot poll after complete"))
}
Boxed(fut) => {
let x = ready!(fut.as_mut().poll(cx));
*done = true;
x
}
};
match result {
Err(NetError::Timeout) => Poll::Ready(None),
r => Poll::Ready(Some(r)),
}
}
}
impl From<TimeoutFuture> for DnsResponseStream {
fn from(f: TimeoutFuture) -> Self {
Self::new(DnsResponseStreamInner::Timeout(f))
}
}
impl From<mpsc::Receiver<Result<DnsResponse, NetError>>> for DnsResponseStream {
fn from(receiver: mpsc::Receiver<Result<DnsResponse, NetError>>) -> Self {
Self::new(DnsResponseStreamInner::Receiver(receiver))
}
}
impl From<NetError> for DnsResponseStream {
fn from(e: NetError) -> Self {
Self::new(DnsResponseStreamInner::Error(Some(e)))
}
}
impl<F> From<Pin<Box<F>>> for DnsResponseStream
where
F: Future<Output = Result<DnsResponse, NetError>> + Send + 'static,
{
fn from(f: Pin<Box<F>>) -> Self {
Self::new(DnsResponseStreamInner::Boxed(f))
}
}
enum DnsResponseStreamInner {
Timeout(TimeoutFuture),
Receiver(mpsc::Receiver<Result<DnsResponse, NetError>>),
Error(Option<NetError>),
Boxed(BoxFuture<'static, Result<DnsResponse, NetError>>),
}
type TimeoutFuture = BoxFuture<'static, Result<Result<DnsResponse, NetError>, io::Error>>;
fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) {
if let Err(error) = result {
if error.is_disconnected() {
debug!("ignoring send error on disconnected stream");
return;
}
warn!("error notifying wait, possible future leak: {:?}", error);
}
}
pub trait DnsClientStream:
Stream<Item = Result<SerialMessage, NetError>> + Unpin + Send + 'static
{
type Time: Time;
fn name_server_addr(&self) -> SocketAddr;
}
pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>;
#[derive(Clone)]
pub struct BufDnsStreamHandle {
remote_addr: SocketAddr,
sender: mpsc::Sender<SerialMessage>,
}
impl BufDnsStreamHandle {
pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) {
Self::with_buffer_size(remote_addr, DEFAULT_STREAM_BUFFER_SIZE)
}
pub fn with_buffer_size(remote_addr: SocketAddr, buffer_size: usize) -> (Self, StreamReceiver) {
let (sender, receiver) = mpsc::channel(buffer_size);
let receiver = receiver.fuse().peekable();
let this = Self {
remote_addr,
sender,
};
(this, receiver)
}
pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self {
Self {
remote_addr,
sender: self.sender.clone(),
}
}
}
impl DnsStreamHandle for BufDnsStreamHandle {
fn send(&mut self, buffer: SerialMessage) -> Result<(), NetError> {
let sender: &mut _ = &mut self.sender;
sender
.try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr))
.map_err(|e| NetError::from(format!("mpsc::SendError {e}")))
}
}
pub trait DnsRequestSender: Stream<Item = Result<(), NetError>> + Send + Unpin + 'static {
fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream;
fn shutdown(&mut self);
fn is_shutdown(&self) -> bool;
}
#[derive(Clone)]
pub struct BufDnsRequestStreamHandle<P> {
sender: mpsc::Sender<OneshotDnsRequest>,
_phantom: PhantomData<P>,
}
impl<P: RuntimeProvider> DnsHandle for BufDnsRequestStreamHandle<P> {
type Response = DnsResponseReceiver;
type Runtime = P;
fn send(&self, request: DnsRequest) -> Self::Response {
debug!(
"enqueueing message:{}:{:?}",
request.op_code, request.queries
);
let (request, oneshot) = OneshotDnsRequest::oneshot(request);
let mut sender = self.sender.clone();
let try_send = sender.try_send(request).map_err(|_| {
debug!("unable to enqueue message");
NetError::Busy
});
match try_send {
Ok(val) => val,
Err(err) => return DnsResponseReceiver::Err(Some(err)),
}
DnsResponseReceiver::Receiver(oneshot)
}
}
pub struct OneshotDnsRequest {
dns_request: DnsRequest,
sender_for_response: oneshot::Sender<DnsResponseStream>,
}
impl OneshotDnsRequest {
fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) {
let (sender_for_response, receiver) = oneshot::channel();
(
Self {
dns_request,
sender_for_response,
},
receiver,
)
}
fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) {
(
self.dns_request,
OneshotDnsResponse(self.sender_for_response),
)
}
}
struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>);
impl OneshotDnsResponse {
fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> {
self.0.send(serial_response)
}
}
pub enum DnsResponseReceiver {
Receiver(oneshot::Receiver<DnsResponseStream>),
Received(DnsResponseStream),
Err(Option<NetError>),
}
impl Stream for DnsResponseReceiver {
type Item = Result<DnsResponse, NetError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
*self = match &mut *self {
Self::Receiver(receiver) => {
let receiver = Pin::new(receiver);
let future = ready!(
receiver
.poll(cx)
.map_err(|_| NetError::from("receiver was canceled"))
)?;
Self::Received(future)
}
Self::Received(stream) => {
return stream.poll_next_unpin(cx);
}
Self::Err(err) => return Poll::Ready(err.take().map(Err)),
};
}
}
}
pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized {
fn first_answer(self) -> FirstAnswerFuture<Self> {
FirstAnswerFuture { stream: Some(self) }
}
}
impl<E, S, T> FirstAnswer<T, E> for S
where
S: Stream<Item = Result<T, E>> + Unpin + Sized,
E: From<ProtoError>,
{
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct FirstAnswerFuture<S> {
stream: Option<S>,
}
impl<S: Stream<Item = Result<T, NetError>> + Unpin, T> Future for FirstAnswerFuture<S>
where
S: Stream<Item = Result<T, NetError>> + Unpin + Sized,
{
type Output = S::Item;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let s = self
.stream
.as_mut()
.expect("polling FirstAnswerFuture twice");
let item = match ready!(s.poll_next_unpin(cx)) {
Some(r) => r,
None => Err(NetError::Timeout),
};
self.stream.take();
Poll::Ready(item)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(rename_all = "lowercase")
)]
#[non_exhaustive]
pub enum Protocol {
Udp,
Tcp,
#[cfg(feature = "__tls")]
Tls,
#[cfg(feature = "__https")]
Https,
#[cfg(feature = "__quic")]
Quic,
#[cfg(feature = "__h3")]
H3,
}
impl Protocol {
pub fn is_datagram(self) -> bool {
matches!(self, Self::Udp)
}
pub fn is_stream(self) -> bool {
!self.is_datagram()
}
pub fn is_encrypted(self) -> bool {
match self {
Self::Udp => false,
Self::Tcp => false,
#[cfg(feature = "__tls")]
Self::Tls => true,
#[cfg(feature = "__https")]
Self::Https => true,
#[cfg(feature = "__quic")]
Self::Quic => true,
#[cfg(feature = "__h3")]
Self::H3 => true,
}
}
}
impl Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Udp => "udp",
Self::Tcp => "tcp",
#[cfg(feature = "__tls")]
Self::Tls => "tls",
#[cfg(feature = "__https")]
Self::Https => "https",
#[cfg(feature = "__quic")]
Self::Quic => "quic",
#[cfg(feature = "__h3")]
Self::H3 => "h3",
})
}
}
#[allow(unused)] pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_STREAM_BUFFER_SIZE: usize = 32;