use core::future::Future;
use core::marker::PhantomData;
use core::pin::Pin;
use core::task::{Context, Poll};
use futures_channel::mpsc;
use futures_util::stream::{Peekable, Stream, StreamExt};
use tracing::debug;
use crate::error::NetError;
use crate::proto::op::{DnsRequest, DnsResponse};
use crate::runtime::RuntimeProvider;
use crate::runtime::Time;
use crate::xfer::dns_handle::DnsHandle;
use crate::xfer::{
BufDnsRequestStreamHandle, DEFAULT_STREAM_BUFFER_SIZE, DnsRequestSender, DnsResponseReceiver,
OneshotDnsRequest,
};
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchange<P> {
sender: BufDnsRequestStreamHandle<P>,
}
impl<P: RuntimeProvider> DnsExchange<P> {
pub fn from_stream<S: DnsRequestSender>(
stream: S,
) -> (Self, DnsExchangeBackground<S, P::Timer>) {
let (sender, outbound_messages) = mpsc::channel(DEFAULT_STREAM_BUFFER_SIZE);
(
Self {
sender: BufDnsRequestStreamHandle {
sender,
_phantom: PhantomData,
},
},
DnsExchangeBackground {
io_stream: stream,
outbound_messages: outbound_messages.peekable(),
marker: PhantomData,
},
)
}
}
impl<P: Clone> Clone for DnsExchange<P> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
}
}
}
impl<P: RuntimeProvider> DnsHandle for DnsExchange<P> {
type Response = DnsExchangeSend<P>;
type Runtime = P;
fn send(&self, request: DnsRequest) -> Self::Response {
DnsExchangeSend {
result: self.sender.send(request),
_sender: self.sender.clone(), }
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchangeSend<P> {
result: DnsResponseReceiver,
_sender: BufDnsRequestStreamHandle<P>,
}
impl<P: Unpin> Stream for DnsExchangeSend<P> {
type Item = Result<DnsResponse, NetError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.result.poll_next_unpin(cx)
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchangeBackground<S, TE> {
io_stream: S,
outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
marker: PhantomData<TE>,
}
impl<S, TE> DnsExchangeBackground<S, TE> {
fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
(&mut self.io_stream, &mut self.outbound_messages)
}
}
impl<S: DnsRequestSender, TE: Time> Future for DnsExchangeBackground<S, TE> {
type Output = ();
#[allow(clippy::unused_unit)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (io_stream, outbound_messages) = self.pollable_split();
let mut io_stream = Pin::new(io_stream);
let mut outbound_messages = Pin::new(outbound_messages);
loop {
match io_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(()))) => (),
Poll::Pending => {
if io_stream.is_shutdown() {
return Poll::Pending;
}
()
} Poll::Ready(None) => {
debug!("io_stream is done, shutting down");
return Poll::Ready(());
}
Poll::Ready(Some(Err(error))) => {
debug!(
%error,
"io_stream hit an error, shutting down"
);
return Poll::Ready(());
}
}
match outbound_messages.as_mut().poll_next(cx) {
Poll::Ready(Some(dns_request)) => {
let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
match serial_response.send_response(io_stream.send_message(dns_request)) {
Ok(()) => (),
Err(_) => {
debug!("failed to associate send_message response to the sender");
}
}
}
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
io_stream.shutdown();
}
}
}
}
}