use super::types::{ConnectionId, ListenerId};
use crate::TcpListenerTrait;
use crate::sim::state::CloseReason;
use crate::{Event, WeakSimWorld};
use async_trait::async_trait;
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::instrument;
fn sim_shutdown_error() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "simulation shutdown")
}
fn random_connection_failure_error() -> io::Error {
io::Error::new(
io::ErrorKind::ConnectionReset,
"Random connection failure (explicit)",
)
}
fn half_open_timeout_error() -> io::Error {
io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection reset (half-open timeout)",
)
}
fn connection_aborted_error() -> io::Error {
io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection was aborted (RST)",
)
}
pub struct SimTcpStream {
sim: WeakSimWorld,
connection_id: ConnectionId,
}
impl SimTcpStream {
pub(crate) fn new(sim: WeakSimWorld, connection_id: ConnectionId) -> Self {
Self { sim, connection_id }
}
pub fn connection_id(&self) -> ConnectionId {
self.connection_id
}
}
impl Drop for SimTcpStream {
fn drop(&mut self) {
if let Ok(sim) = self.sim.upgrade() {
tracing::debug!(
"SimTcpStream dropping, closing connection {}",
self.connection_id.0
);
sim.close_connection(self.connection_id);
}
}
}
impl AsyncRead for SimTcpStream {
#[instrument(skip(self, cx, buf))]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
tracing::trace!(
"SimTcpStream::poll_read called on connection_id={}",
self.connection_id.0
);
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some(true) = sim.roll_random_close(self.connection_id) {
return Poll::Ready(Err(random_connection_failure_error()));
}
if sim.is_recv_closed(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} recv side closed, returning EOF",
self.connection_id.0
);
return Poll::Ready(Ok(())); }
if sim.is_half_open(self.connection_id) && sim.should_half_open_error(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} half-open error time reached, returning ECONNRESET",
self.connection_id.0
);
return Poll::Ready(Err(half_open_timeout_error()));
}
if sim.is_read_clogged(self.connection_id) {
sim.register_read_clog_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
if sim.should_clog_read(self.connection_id) {
sim.clog_read(self.connection_id);
sim.register_read_clog_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
let mut temp_buf = vec![0u8; buf.remaining()];
let bytes_read = sim
.read_from_connection(self.connection_id, &mut temp_buf)
.map_err(|e| io::Error::other(format!("read error: {}", e)))?;
tracing::trace!(
"SimTcpStream::poll_read connection_id={} read {} bytes",
self.connection_id.0,
bytes_read
);
if bytes_read > 0 {
let data_preview = String::from_utf8_lossy(&temp_buf[..std::cmp::min(bytes_read, 20)]);
tracing::trace!(
"SimTcpStream::poll_read connection_id={} returning data: '{}'",
self.connection_id.0,
data_preview
);
buf.put_slice(&temp_buf[..bytes_read]);
Poll::Ready(Ok(()))
} else {
if sim.is_remote_fin_received(self.connection_id) {
tracing::info!(
"SimTcpStream::poll_read connection_id={} remote FIN received, returning EOF (0 bytes)",
self.connection_id.0
);
return Poll::Ready(Ok(()));
}
if sim.is_connection_closed(self.connection_id) {
match sim.close_reason(self.connection_id) {
CloseReason::Aborted => {
tracing::info!(
"SimTcpStream::poll_read connection_id={} was aborted (RST), returning ECONNRESET",
self.connection_id.0
);
return Poll::Ready(Err(connection_aborted_error()));
}
_ => {
tracing::info!(
"SimTcpStream::poll_read connection_id={} is closed gracefully (FIN), returning EOF (0 bytes)",
self.connection_id.0
);
return Poll::Ready(Ok(()));
}
}
}
if sim.is_connection_cut(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} is cut, registering cut waker",
self.connection_id.0
);
sim.register_cut_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
tracing::trace!(
"SimTcpStream::poll_read connection_id={} no data, registering waker",
self.connection_id.0
);
sim.register_read_waker(self.connection_id, cx.waker().clone())
.map_err(|e| io::Error::other(format!("waker registration error: {}", e)))?;
let mut temp_buf_recheck = vec![0u8; buf.remaining()];
let bytes_read_recheck = sim
.read_from_connection(self.connection_id, &mut temp_buf_recheck)
.map_err(|e| io::Error::other(format!("recheck read error: {}", e)))?;
if bytes_read_recheck > 0 {
let data_preview = String::from_utf8_lossy(
&temp_buf_recheck[..std::cmp::min(bytes_read_recheck, 20)],
);
tracing::trace!(
"SimTcpStream::poll_read connection_id={} found data on recheck: '{}' (race condition avoided)",
self.connection_id.0,
data_preview
);
buf.put_slice(&temp_buf_recheck[..bytes_read_recheck]);
Poll::Ready(Ok(()))
} else {
if sim.is_remote_fin_received(self.connection_id) {
tracing::info!(
"SimTcpStream::poll_read connection_id={} remote FIN received on recheck, returning EOF (0 bytes)",
self.connection_id.0
);
return Poll::Ready(Ok(()));
}
if sim.is_connection_closed(self.connection_id) {
match sim.close_reason(self.connection_id) {
CloseReason::Aborted => {
tracing::info!(
"SimTcpStream::poll_read connection_id={} was aborted on recheck (RST), returning ECONNRESET",
self.connection_id.0
);
Poll::Ready(Err(connection_aborted_error()))
}
_ => {
tracing::info!(
"SimTcpStream::poll_read connection_id={} is closed on recheck (FIN), returning EOF (0 bytes)",
self.connection_id.0
);
Poll::Ready(Ok(()))
}
}
} else if sim.is_connection_cut(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} is cut on recheck, waiting",
self.connection_id.0
);
Poll::Pending
} else {
Poll::Pending
}
}
}
}
}
impl AsyncWrite for SimTcpStream {
#[instrument(skip(self, cx, buf))]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some(true) = sim.roll_random_close(self.connection_id) {
return Poll::Ready(Err(random_connection_failure_error()));
}
if sim.is_send_closed(self.connection_id) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Connection send side closed",
)));
}
if sim.is_connection_closed(self.connection_id) {
return match sim.close_reason(self.connection_id) {
CloseReason::Aborted => Poll::Ready(Err(connection_aborted_error())),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Connection was closed (FIN)",
))),
};
}
if sim.is_connection_cut(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_write connection_id={} is cut, registering cut waker",
self.connection_id.0
);
sim.register_cut_waker(self.connection_id, cx.waker().clone());
tracing::debug!(
"SimTcpStream::poll_write connection_id={} registered waker for cut connection",
self.connection_id.0
);
return Poll::Pending;
}
if sim.is_half_open(self.connection_id) && sim.should_half_open_error(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_write connection_id={} half-open error time reached, returning ECONNRESET",
self.connection_id.0
);
return Poll::Ready(Err(half_open_timeout_error()));
}
let available_buffer = sim.available_send_buffer(self.connection_id);
if available_buffer < buf.len() {
tracing::debug!(
"SimTcpStream::poll_write connection_id={} buffer full (available={}, needed={}), waiting",
self.connection_id.0,
available_buffer,
buf.len()
);
sim.register_send_buffer_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
if sim.is_write_clogged(self.connection_id) {
sim.register_clog_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
if sim.should_clog_write(self.connection_id) {
sim.clog_write(self.connection_id);
sim.register_clog_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
let data_preview = String::from_utf8_lossy(&buf[..std::cmp::min(buf.len(), 20)]);
tracing::trace!(
"SimTcpStream::poll_write buffering {} bytes: '{}' for ordered delivery",
buf.len(),
data_preview
);
sim.buffer_send(self.connection_id, buf.to_vec())
.map_err(|e| io::Error::other(format!("buffer send error: {}", e)))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
tracing::debug!(
"SimTcpStream::poll_shutdown closing connection {}",
self.connection_id.0
);
sim.close_connection(self.connection_id);
Poll::Ready(Ok(()))
}
}
pub struct AcceptFuture {
sim: WeakSimWorld,
local_addr: String,
#[allow(dead_code)] listener_id: ListenerId,
}
impl Future for AcceptFuture {
type Output = io::Result<(SimTcpStream, String)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let sim = match self.sim.upgrade() {
Ok(sim) => sim,
Err(_) => return Poll::Ready(Err(sim_shutdown_error())),
};
match sim.pending_connection(&self.local_addr) {
Ok(Some(connection_id)) => {
let delay = sim.with_network_config(|config| {
crate::network::sample_duration(&config.accept_latency)
});
sim.schedule_event(
Event::Connection {
id: connection_id.0,
state: crate::ConnectionStateChange::ConnectionReady,
},
delay,
);
let peer_addr = sim
.connection_peer_address(connection_id)
.unwrap_or_else(|| "unknown:0".to_string());
let stream = SimTcpStream::new(self.sim.clone(), connection_id);
Poll::Ready(Ok((stream, peer_addr)))
}
Ok(None) => {
if let Err(e) = sim.register_accept_waker(&self.local_addr, cx.waker().clone()) {
Poll::Ready(Err(io::Error::other(format!(
"failed to register accept waker: {}",
e
))))
} else {
Poll::Pending
}
}
Err(e) => Poll::Ready(Err(io::Error::other(format!(
"failed to get pending connection: {}",
e
)))),
}
}
}
pub struct SimTcpListener {
sim: WeakSimWorld,
#[allow(dead_code)] listener_id: ListenerId,
local_addr: String,
}
impl SimTcpListener {
pub(crate) fn new(sim: WeakSimWorld, listener_id: ListenerId, local_addr: String) -> Self {
Self {
sim,
listener_id,
local_addr,
}
}
}
#[async_trait(?Send)]
impl TcpListenerTrait for SimTcpListener {
type TcpStream = SimTcpStream;
#[instrument(skip(self))]
async fn accept(&self) -> io::Result<(Self::TcpStream, String)> {
AcceptFuture {
sim: self.sim.clone(),
local_addr: self.local_addr.clone(),
listener_id: self.listener_id,
}
.await
}
fn local_addr(&self) -> io::Result<String> {
Ok(self.local_addr.clone())
}
}