use core::mem;
use core::net::SocketAddr;
use core::pin::Pin;
use core::task::{Context, Poll};
use core::time::Duration;
use std::io;
use futures_io::IoSlice;
use futures_util::stream::Stream;
use futures_util::{self, future::Future, ready};
use tracing::{debug, trace};
use crate::proto::op::SerialMessage;
use crate::runtime::{DnsTcpStream, Time};
use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
enum WriteTcpState {
LenBytes {
pos: usize,
length: [u8; 2],
bytes: Vec<u8>,
},
Bytes {
pos: usize,
bytes: Vec<u8>,
},
Flushing,
}
pub(crate) enum ReadTcpState {
LenBytes {
pos: usize,
bytes: [u8; 2],
},
Bytes {
pos: usize,
bytes: Vec<u8>,
},
}
#[must_use = "futures do nothing unless polled"]
pub struct TcpStream<S: DnsTcpStream> {
socket: S,
outbound_messages: StreamReceiver,
send_state: Option<WriteTcpState>,
read_state: ReadTcpState,
peer_addr: SocketAddr,
}
impl<S: DnsTcpStream> TcpStream<S> {
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
fn pollable_split(
&mut self,
) -> (
&mut S,
&mut StreamReceiver,
&mut Option<WriteTcpState>,
&mut ReadTcpState,
) {
(
&mut self.socket,
&mut self.outbound_messages,
&mut self.send_state,
&mut self.read_state,
)
}
pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
pub fn from_stream_with_buffer_size(
stream: S,
peer_addr: SocketAddr,
buffer_size: usize,
) -> (Self, BufDnsStreamHandle) {
let (message_sender, outbound_messages) =
BufDnsStreamHandle::with_buffer_size(peer_addr, buffer_size);
let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
pub fn from_stream_with_receiver(
socket: S,
peer_addr: SocketAddr,
outbound_messages: StreamReceiver,
) -> Self {
Self {
socket,
outbound_messages,
send_state: None,
read_state: ReadTcpState::LenBytes {
pos: 0,
bytes: [0u8; 2],
},
peer_addr,
}
}
pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
future: F,
name_server: SocketAddr,
timeout: Duration,
) -> (
impl Future<Output = Result<Self, io::Error>> + Send,
BufDnsStreamHandle,
) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
(stream_fut, message_sender)
}
async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
future: F,
name_server: SocketAddr,
timeout: Duration,
outbound_messages: StreamReceiver,
) -> Result<Self, io::Error> {
let socket = S::Time::timeout(timeout, future).await??;
debug!("TCP connection established to: {}", name_server);
Ok(Self {
socket,
outbound_messages,
send_state: None,
read_state: ReadTcpState::LenBytes {
pos: 0,
bytes: [0u8; 2],
},
peer_addr: name_server,
})
}
}
impl<S: DnsTcpStream> Stream for TcpStream<S> {
type Item = io::Result<SerialMessage>;
#[allow(clippy::cognitive_complexity)]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let peer = self.peer_addr;
let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
let mut socket = Pin::new(socket);
let mut outbound_messages = Pin::new(outbound_messages);
loop {
if send_state.is_some() {
match send_state {
Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
let wrote = ready!(socket.as_mut().poll_write_vectored(
cx,
&[IoSlice::new(&length[*pos..]), IoSlice::new(bytes)]
))?;
*pos += wrote;
}
Some(WriteTcpState::Bytes { pos, bytes }) => {
let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
*pos += wrote;
}
Some(WriteTcpState::Flushing) => {
ready!(socket.as_mut().poll_flush(cx))?;
}
_ => (),
}
let current_state = send_state.take();
match current_state {
Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
if pos < length.len() {
*send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
} else if pos < length.len() + bytes.len() {
*send_state = Some(WriteTcpState::Bytes {
pos: pos - length.len(),
bytes,
});
} else {
*send_state = Some(WriteTcpState::Flushing);
}
}
Some(WriteTcpState::Bytes { pos, bytes }) => {
if pos < bytes.len() {
*send_state = Some(WriteTcpState::Bytes { pos, bytes });
} else {
*send_state = Some(WriteTcpState::Flushing);
}
}
Some(WriteTcpState::Flushing) => {
send_state.take();
}
None => (),
};
} else {
match outbound_messages.as_mut().poll_next(cx)
{
Poll::Ready(Some(message)) => {
let (buffer, dst) = message.into();
if peer != dst {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("mismatched peer: {peer} and dst: {dst}"),
))));
}
let len = u16::to_be_bytes(buffer.len() as u16);
debug!("sending message len: {} to: {}", buffer.len(), dst);
*send_state = Some(WriteTcpState::LenBytes {
pos: 0,
length: len,
bytes: buffer,
});
}
Poll::Pending => break,
Poll::Ready(None) => {
debug!("no messages to send");
break;
}
}
}
}
let mut ret_buf = None;
while ret_buf.is_none() {
let new_state: Option<ReadTcpState> = match read_state {
ReadTcpState::LenBytes { pos, bytes } => {
let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
if read == 0 {
debug!("zero bytes read, stream closed?");
if *pos == 0 {
return Poll::Ready(None);
} else {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"closed while reading length",
))));
}
}
trace!("in ReadTcpState::LenBytes: {}", pos);
*pos += read;
if *pos < bytes.len() {
trace!("remain ReadTcpState::LenBytes: {}", pos);
None
} else {
let length = u16::from_be_bytes(*bytes);
trace!("got length: {}", length);
let mut bytes = vec![0; length as usize];
bytes.resize(length as usize, 0);
trace!("move ReadTcpState::Bytes: {}", bytes.len());
Some(ReadTcpState::Bytes { pos: 0, bytes })
}
}
ReadTcpState::Bytes { pos, bytes } => {
let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
if read == 0 {
trace!("zero bytes read for message, stream closed?");
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"closed while reading message",
))));
}
trace!("in ReadTcpState::Bytes: {}", bytes.len());
*pos += read;
if *pos < bytes.len() {
trace!("remain ReadTcpState::Bytes: {}", bytes.len());
None
} else {
trace!("reset ReadTcpState::LenBytes: {}", 0);
Some(ReadTcpState::LenBytes {
pos: 0,
bytes: [0u8; 2],
})
}
}
};
if let Some(state) = new_state {
if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
assert_eq!(pos, bytes.len());
ret_buf = Some(bytes);
}
}
}
if let Some(buffer) = ret_buf {
let src_addr = self.peer_addr;
Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
} else {
debug!("bottomed out");
Poll::Pending
}
}
}
#[cfg(test)]
#[cfg(feature = "tokio")]
mod tests {
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use test_support::subscribe;
use crate::runtime::TokioRuntimeProvider;
use crate::tcp::tests::tcp_stream_test;
#[tokio::test]
async fn test_tcp_stream_ipv4() {
subscribe();
tcp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
}
#[tokio::test]
async fn test_tcp_stream_ipv6() {
subscribe();
tcp_stream_test(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
TokioRuntimeProvider::new(),
)
.await;
}
}