use crate::{
net::{IpProtocol::Tcp, *},
plugin,
};
use bytes::{Buf, Bytes, BytesMut};
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
use std::{
fmt,
io::Result,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::*;
pub struct TcpStream {
pub(super) guard: Option<Arc<BindGuard>>,
pub(super) addr: SocketAddr,
pub(super) peer: SocketAddr,
pub(super) write_buf: BytesMut,
pub(super) read_buf: Bytes,
pub(super) tx: PayloadSender,
pub(super) rx: PayloadReceiver,
}
impl fmt::Debug for TcpStream {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("TcpStream")
.field("addr", &self.addr)
.field("peer", &self.peer)
.finish()
}
}
impl TcpStream {
#[instrument]
pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
let mut last_err = None;
for addr in lookup_host(addr).await? {
match Self::connect_one(addr).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
}
#[instrument]
async fn connect_one(addr: SocketAddr) -> Result<TcpStream> {
let net = plugin::simulator::<NetSim>();
net.rand_delay().await?;
let guard = BindGuard::bind("0.0.0.0:0", Tcp, Arc::new(TcpStreamSocket)).await?;
let (tx, rx, local_addr) = net
.connect1(plugin::node(), guard.addr.port(), addr, Tcp)
.await?;
let stream = TcpStream {
guard: Some(Arc::new(guard)),
addr: local_addr,
peer: addr,
write_buf: Default::default(),
read_buf: Default::default(),
tx,
rx,
};
Ok(stream)
}
pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
Ok(())
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.addr)
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.peer)
}
}
#[cfg(unix)]
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
todo!("TcpStream::as_raw_fd");
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
if !self.read_buf.is_empty() {
let len = self.read_buf.len().min(buf.remaining());
buf.put_slice(&self.read_buf[..len]);
self.read_buf.advance(len);
return Poll::Ready(Ok(()));
}
match self.rx.poll_recv(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(data)) => {
self.read_buf = *data.downcast::<Bytes>().unwrap();
self.poll_read(cx, buf)
}
Poll::Ready(None) => Poll::Ready(Ok(())),
}
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
let data = self.write_buf.split().freeze();
self.tx
.send(Box::new(data))
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?;
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}
struct TcpStreamSocket;
impl Socket for TcpStreamSocket {}