use bytes::{Buf, Bytes};
use std::future::poll_fn;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::task::Waker;
use std::{
fmt::Debug,
io::{self, Error, Result},
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
runtime::Handle,
sync::{mpsc, oneshot},
time::sleep,
};
use crate::{
envelope::{Envelope, Protocol, Segment, Syn},
host::{is_same, SequencedSegment},
net::SocketPair,
world::World,
ToSocketAddrs, TRACING_TARGET,
};
use super::split_owned::{OwnedReadHalf, OwnedWriteHalf};
#[derive(Debug)]
pub struct TcpStream {
read_half: ReadHalf,
write_half: WriteHalf,
}
impl TcpStream {
pub(crate) fn new(
pair: SocketPair,
receiver: mpsc::Receiver<SequencedSegment>,
flow_control: BidiFlowControl,
) -> Self {
let pair = Arc::new(pair);
let read_half = ReadHalf {
pair: pair.clone(),
rx: Rx {
recv: receiver,
buffer: None,
},
is_closed: false,
flow_control: flow_control.read,
};
let write_half = WriteHalf {
pair,
is_shutdown: false,
flow_control: flow_control.write,
};
Self {
read_half,
write_half,
}
}
pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
let (ack, syn_ack) = oneshot::channel();
let (pair, rx, bidi) = World::current(|world| {
let dst = addr.to_socket_addr(&world.dns)?;
let (pair, rx, bidi) = {
let host = world.current_host_mut();
let mut local_addr = SocketAddr::new(host.addr, host.assign_ephemeral_port());
if dst.ip().is_loopback() {
local_addr.set_ip(dst.ip());
}
let pair = SocketPair::new(local_addr, dst);
let (rx, bidi) = host.tcp.new_stream(pair);
(pair, rx, bidi)
};
let syn = Protocol::Tcp(Segment::Syn(Syn { ack }));
if !is_same(pair.local, pair.remote) {
world.send_message(pair.local, pair.remote, syn)?;
} else {
send_loopback(pair.local, pair.remote, syn);
};
Ok::<_, Error>((pair, rx, bidi))
})?;
syn_ack.await.map_err(|_| {
io::Error::new(io::ErrorKind::ConnectionRefused, pair.remote.to_string())
})?;
tracing::trace!(target: TRACING_TARGET, src = ?pair.remote, dst = ?pair.local, protocol = %"TCP SYN-ACK", "Recv");
Ok(TcpStream::new(pair, rx, bidi))
}
pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
self.write_half.try_write(buf)
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.read_half.pair.local)
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.read_half.pair.remote)
}
pub(crate) fn reunite(read_half: ReadHalf, write_half: WriteHalf) -> Self {
Self {
read_half,
write_half,
}
}
pub async fn writable(&self) -> Result<()> {
poll_fn(|cx| self.write_half.poll_writable(cx)).await
}
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
(
OwnedReadHalf {
inner: self.read_half,
},
OwnedWriteHalf {
inner: self.write_half,
},
)
}
pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
Ok(())
}
pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
self.read_half.peek(buf).await
}
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
self.read_half.poll_peek(cx, buf)
}
}
pub(crate) struct ReadHalf {
pub(crate) pair: Arc<SocketPair>,
rx: Rx,
is_closed: bool,
flow_control: Arc<FlowControl>,
}
struct Rx {
recv: mpsc::Receiver<SequencedSegment>,
buffer: Option<Bytes>,
}
impl ReadHalf {
fn poll_read_priv(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
if self.is_closed || buf.capacity() == 0 {
return Poll::Ready(Ok(()));
}
if let Some(bytes) = self.rx.buffer.take() {
self.rx.buffer = Self::put_slice(bytes, buf);
return Poll::Ready(Ok(()));
}
match ready!(self.rx.recv.poll_recv(cx)) {
Some(seg) => {
tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Recv");
match seg {
SequencedSegment::Data(bytes) => {
self.flow_control.release();
self.rx.buffer = Self::put_slice(bytes, buf);
}
SequencedSegment::Fin => {
self.is_closed = true;
}
}
Poll::Ready(Ok(()))
}
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection reset",
))),
}
}
fn put_slice(mut avail: Bytes, buf: &mut ReadBuf) -> Option<Bytes> {
let amt = std::cmp::min(avail.len(), buf.remaining());
buf.put_slice(&avail[..amt]);
avail.advance(amt);
if avail.is_empty() {
None
} else {
Some(avail)
}
}
pub(crate) fn poll_peek(
&mut self,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<usize>> {
if self.is_closed || buf.capacity() == 0 {
return Poll::Ready(Ok(0));
}
if let Some(bytes) = &self.rx.buffer {
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
return Poll::Ready(Ok(len));
}
match ready!(self.rx.recv.poll_recv(cx)) {
Some(seg) => {
tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");
match seg {
SequencedSegment::Data(bytes) => {
self.flow_control.release();
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.rx.buffer = Some(bytes);
Poll::Ready(Ok(len))
}
SequencedSegment::Fin => {
self.is_closed = true;
Poll::Ready(Ok(0))
}
}
}
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection reset",
))),
}
}
pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut buf = ReadBuf::new(buf);
poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
}
}
impl Debug for ReadHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReadHalf")
.field("pair", &self.pair)
.field("is_closed", &self.is_closed)
.finish()
}
}
pub(crate) struct WriteHalf {
pub(crate) pair: Arc<SocketPair>,
is_shutdown: bool,
flow_control: Arc<FlowControl>,
}
impl WriteHalf {
fn try_write(&self, buf: &[u8]) -> Result<usize> {
if buf.remaining() == 0 {
return Ok(0);
}
if self.is_shutdown {
return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
}
if !self.flow_control.try_acquire() {
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"send buffer full",
));
}
World::current(|world| {
let bytes = Bytes::copy_from_slice(buf);
let len = bytes.len();
let seq = self.seq(world)?;
self.send(world, Segment::Data(seq, bytes))?;
Ok(len)
})
}
fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
if self.is_shutdown {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Broken pipe",
)));
}
if self.flow_control.has_credits() {
return Poll::Ready(Ok(()));
}
self.flow_control.register_waker(cx.waker().clone());
Poll::Pending
}
fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
if self.is_shutdown {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
match self.try_write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.flow_control.register_waker(cx.waker().clone());
Poll::Pending
}
result => Poll::Ready(result),
}
}
fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
if self.is_shutdown {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"Socket is not connected",
)));
}
let res = World::current(|world| {
let seq = self.seq(world)?;
self.send(world, Segment::Fin(seq))?;
self.is_shutdown = true;
Ok(())
});
Poll::Ready(res)
}
fn seq(&self, world: &mut World) -> Result<u64> {
world
.current_host_mut()
.tcp
.assign_send_seq(*self.pair)
.ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"))
}
fn send(&self, world: &mut World, segment: Segment) -> Result<()> {
let message = Protocol::Tcp(segment);
if is_same(self.pair.local, self.pair.remote) {
send_loopback(self.pair.local, self.pair.remote, message);
} else {
world.send_message(self.pair.local, self.pair.remote, message)?;
}
Ok(())
}
}
fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
if Handle::try_current().is_err() {
return;
}
tokio::spawn(async move {
let tick_duration = World::current(|world| world.tick_duration);
sleep(tick_duration).await;
World::current(|world| {
if let Err(rst) =
world
.current_host_mut()
.receive_from_network(Envelope { src, dst, message })
{
_ = world.current_host_mut().receive_from_network(Envelope {
src: dst,
dst: src,
message: rst,
});
}
})
});
}
#[derive(Clone, Debug)]
pub(crate) struct BidiFlowControl {
write: Arc<FlowControl>,
read: Arc<FlowControl>,
}
impl BidiFlowControl {
pub(crate) fn new(capacity: usize) -> Self {
Self {
write: Arc::new(FlowControl::new(capacity)),
read: Arc::new(FlowControl::new(capacity)),
}
}
pub(crate) fn invert(self) -> Self {
Self {
write: self.read,
read: self.write,
}
}
}
pub(crate) struct FlowControl {
credits: AtomicUsize,
waker: Mutex<Option<Waker>>,
}
impl FlowControl {
fn new(capacity: usize) -> Self {
Self {
credits: AtomicUsize::new(capacity),
waker: Mutex::new(None),
}
}
fn try_acquire(&self) -> bool {
self.credits
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
.is_ok()
}
fn release(&self) {
self.credits.fetch_add(1, Ordering::Release);
if let Some(waker) = self.waker.lock().unwrap().take() {
waker.wake();
}
}
fn register_waker(&self, waker: Waker) {
*self.waker.lock().unwrap() = Some(waker);
}
fn has_credits(&self) -> bool {
self.credits.load(Ordering::Acquire) > 0
}
}
impl Debug for FlowControl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlowControl")
.field("credits", &self.credits.load(Ordering::Relaxed))
.finish()
}
}
impl Debug for WriteHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteHalf")
.field("pair", &self.pair)
.field("is_shutdown", &self.is_shutdown)
.finish()
}
}
impl AsyncRead for ReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<()>> {
self.poll_read_priv(cx, buf)
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<()>> {
Pin::new(&mut self.read_half).poll_read(cx, buf)
}
}
impl AsyncWrite for WriteHalf {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
self.poll_write_priv(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
self.poll_shutdown_priv()
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.write_half).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.write_half).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.write_half).poll_shutdown(cx)
}
}
impl Drop for ReadHalf {
fn drop(&mut self) {
World::current_if_set(|world| {
let has_unread = !self.is_closed
&& (self.rx.buffer.is_some()
|| matches!(self.rx.recv.try_recv(), Ok(SequencedSegment::Data(_)))
|| world.current_host_mut().tcp.has_buffered_data(*self.pair));
if has_unread {
let pair = *self.pair;
let message = Protocol::Tcp(Segment::Rst);
if is_same(pair.local, pair.remote) {
send_loopback(pair.local, pair.remote, message);
} else {
let _ = world.send_message(pair.local, pair.remote, message);
}
world.current_host_mut().tcp.reset_stream(pair);
return;
}
world.current_host_mut().tcp.close_stream_half(*self.pair);
})
}
}
impl Drop for WriteHalf {
fn drop(&mut self) {
World::current_if_set(|world| {
if !self.is_shutdown {
if let Ok(seq) = self.seq(world) {
let _ = self.send(world, Segment::Fin(seq));
}
}
world.current_host_mut().tcp.close_stream_half(*self.pair);
})
}
}