use core::future::Future;
use core::marker::PhantomData;
use core::pin::Pin;
use core::task::{Context, Poll};
use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::stream::{Peekable, Stream, StreamExt};
use tracing::debug;
use crate::error::*;
#[cfg(feature = "std")]
use crate::runtime::Time;
use crate::xfer::DnsResponseReceiver;
#[cfg(any(feature = "std", feature = "no-std-rand"))]
use crate::xfer::dns_handle::DnsHandle;
use crate::xfer::{
BufDnsRequestStreamHandle, CHANNEL_BUFFER_SIZE, DnsRequest, DnsRequestSender, DnsResponse,
OneshotDnsRequest,
};
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchange {
sender: BufDnsRequestStreamHandle,
}
impl DnsExchange {
pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
where
S: DnsRequestSender + 'static + Send + Unpin,
{
let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let message_sender = BufDnsRequestStreamHandle { sender };
Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
}
pub fn from_stream_with_receiver<S, TE>(
stream: S,
receiver: mpsc::Receiver<OneshotDnsRequest>,
sender: BufDnsRequestStreamHandle,
) -> (Self, DnsExchangeBackground<S, TE>)
where
S: DnsRequestSender + 'static + Send + Unpin,
{
let background = DnsExchangeBackground {
io_stream: stream,
outbound_messages: receiver.peekable(),
marker: PhantomData,
};
(Self { sender }, background)
}
#[cfg(feature = "std")]
pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender + 'static + Send + Unpin,
TE: Time + Unpin,
{
let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let message_sender = BufDnsRequestStreamHandle { sender };
DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
}
pub fn error<F, S, TE>(error: ProtoError) -> DnsExchangeConnect<F, S, TE>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender + 'static + Send + Unpin,
TE: Time + Unpin,
{
DnsExchangeConnect(DnsExchangeConnectInner::Error(error))
}
}
impl Clone for DnsExchange {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
}
}
}
#[cfg(any(feature = "std", feature = "no-std-rand"))]
impl DnsHandle for DnsExchange {
type Response = DnsExchangeSend;
fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
DnsExchangeSend {
result: self.sender.send(request),
_sender: self.sender.clone(), }
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchangeSend {
result: DnsResponseReceiver,
_sender: BufDnsRequestStreamHandle,
}
impl Stream for DnsExchangeSend {
type Item = Result<DnsResponse, ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.result.poll_next_unpin(cx)
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchangeBackground<S, TE>
where
S: DnsRequestSender + 'static + Send + Unpin,
{
io_stream: S,
outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
marker: PhantomData<TE>,
}
impl<S, TE> DnsExchangeBackground<S, TE>
where
S: DnsRequestSender + 'static + Send + Unpin,
{
fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
(&mut self.io_stream, &mut self.outbound_messages)
}
}
impl<S, TE> Future for DnsExchangeBackground<S, TE>
where
S: DnsRequestSender + 'static + Send + Unpin,
TE: Time + Unpin,
{
type Output = Result<(), ProtoError>;
#[allow(clippy::unused_unit)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (io_stream, outbound_messages) = self.pollable_split();
let mut io_stream = Pin::new(io_stream);
let mut outbound_messages = Pin::new(outbound_messages);
loop {
match io_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(()))) => (),
Poll::Pending => {
if io_stream.is_shutdown() {
return Poll::Pending;
}
()
} Poll::Ready(None) => {
debug!("io_stream is done, shutting down");
return Poll::Ready(Ok(()));
}
Poll::Ready(Some(Err(err))) => {
debug!(
error = err.as_dyn(),
"io_stream hit an error, shutting down"
);
return Poll::Ready(Err(err));
}
}
match outbound_messages.as_mut().poll_next(cx) {
Poll::Ready(Some(dns_request)) => {
let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
match serial_response.send_response(io_stream.send_message(dns_request)) {
Ok(()) => (),
Err(_) => {
debug!("failed to associate send_message response to the sender");
}
}
}
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
io_stream.shutdown();
}
}
}
}
}
pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender + 'static,
TE: Time + Unpin;
impl<F, S, TE> DnsExchangeConnect<F, S, TE>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender + 'static,
TE: Time + Unpin,
{
fn connect(
connect_future: F,
outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
sender: BufDnsRequestStreamHandle,
) -> Self {
Self(DnsExchangeConnectInner::Connecting {
connect_future,
outbound_messages: Some(outbound_messages),
sender: Some(sender),
})
}
}
#[allow(clippy::type_complexity)]
impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender + 'static + Send + Unpin,
TE: Time + Unpin,
{
type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
#[allow(clippy::large_enum_variant)]
enum DnsExchangeConnectInner<F, S, TE>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send,
S: DnsRequestSender + 'static + Send,
TE: Time + Unpin,
{
Connecting {
connect_future: F,
outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
sender: Option<BufDnsRequestStreamHandle>,
},
Connected {
exchange: DnsExchange,
background: Option<DnsExchangeBackground<S, TE>>,
},
FailAll {
error: ProtoError,
outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
},
Error(ProtoError),
}
#[allow(clippy::type_complexity)]
impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender + 'static + Send + Unpin,
TE: Time + Unpin,
{
type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let next;
match &mut *self {
Self::Connecting {
connect_future,
outbound_messages,
sender,
} => {
let connect_future = Pin::new(connect_future);
match connect_future.poll(cx) {
Poll::Ready(Ok(stream)) => {
let (exchange, background) = DnsExchange::from_stream_with_receiver(
stream,
outbound_messages
.take()
.expect("cannot poll after complete"),
sender.take().expect("cannot poll after complete"),
);
next = Self::Connected {
exchange,
background: Some(background),
};
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(error)) => {
debug!(error = error.as_dyn(), "stream errored while connecting");
next = Self::FailAll {
error,
outbound_messages: outbound_messages
.take()
.expect("cannot poll after complete"),
}
}
};
}
Self::Connected {
exchange,
background,
} => {
let exchange = exchange.clone();
let background = background.take().expect("cannot poll after complete");
return Poll::Ready(Ok((exchange, background)));
}
Self::FailAll {
error,
outbound_messages,
} => {
while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
Poll::Ready(opt) => opt,
Poll::Pending => return Poll::Pending,
} {
outbound_message
.into_parts()
.1
.send_response(error.clone().into())
.ok();
}
return Poll::Ready(Err(error.clone()));
}
Self::Error(error) => return Poll::Ready(Err(error.clone())),
}
*self = next;
}
}
}