use super::types::ConnectionId;
use crate::TcpListenerTrait;
use crate::sim::state::CloseReason;
use crate::{Event, WeakSimWorld};
use futures::io::{AsyncRead, AsyncWrite};
use std::{
future::Future,
io::{self, IoSlice},
pin::Pin,
task::{Context, Poll},
};
use tracing::instrument;
fn sim_shutdown_error() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "simulation shutdown")
}
fn random_connection_failure_error() -> io::Error {
io::Error::new(
io::ErrorKind::ConnectionReset,
"Random connection failure (explicit)",
)
}
fn half_open_timeout_error() -> io::Error {
io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection reset (half-open timeout)",
)
}
fn connection_aborted_error() -> io::Error {
io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection was aborted (RST)",
)
}
pub struct SimTcpStream {
sim: WeakSimWorld,
connection_id: ConnectionId,
}
impl SimTcpStream {
pub(crate) fn new(sim: WeakSimWorld, connection_id: ConnectionId) -> Self {
Self { sim, connection_id }
}
#[must_use]
pub fn connection_id(&self) -> ConnectionId {
self.connection_id
}
#[must_use]
pub fn is_write_vectored(&self) -> bool {
true
}
fn write_guard_pre_backpressure(
&self,
sim: &crate::sim::SimWorld,
cx: &mut Context<'_>,
) -> Option<Poll<Result<usize, io::Error>>> {
if let Some(true) = sim.roll_random_close(self.connection_id) {
return Some(Poll::Ready(Err(random_connection_failure_error())));
}
if sim.is_send_closed(self.connection_id) {
return Some(Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Connection send side closed",
))));
}
if sim.is_connection_closed(self.connection_id) {
return Some(match sim.close_reason(self.connection_id) {
CloseReason::Aborted => Poll::Ready(Err(connection_aborted_error())),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Connection was closed (FIN)",
))),
});
}
if sim.is_connection_cut(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_write connection_id={} is cut, registering cut waker",
self.connection_id.0
);
sim.register_cut_waker(self.connection_id, cx.waker().clone());
tracing::debug!(
"SimTcpStream::poll_write connection_id={} registered waker for cut connection",
self.connection_id.0
);
return Some(Poll::Pending);
}
if sim.is_half_open(self.connection_id) && sim.should_half_open_error(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_write connection_id={} half-open error time reached, returning ECONNRESET",
self.connection_id.0
);
return Some(Poll::Ready(Err(half_open_timeout_error())));
}
None
}
fn write_guard_clog(
&self,
sim: &crate::sim::SimWorld,
cx: &mut Context<'_>,
) -> Option<Poll<Result<usize, io::Error>>> {
if sim.is_write_clogged(self.connection_id) {
sim.register_clog_waker(self.connection_id, cx.waker().clone());
return Some(Poll::Pending);
}
if sim.should_clog_write(self.connection_id) {
sim.clog_write(self.connection_id);
sim.register_clog_waker(self.connection_id, cx.waker().clone());
return Some(Poll::Pending);
}
None
}
}
impl Drop for SimTcpStream {
fn drop(&mut self) {
if let Ok(sim) = self.sim.upgrade() {
tracing::debug!(
"SimTcpStream dropping, closing connection {}",
self.connection_id.0
);
sim.close_connection(self.connection_id);
}
}
}
impl AsyncRead for SimTcpStream {
#[instrument(skip(self, cx, buf))]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
tracing::trace!(
"SimTcpStream::poll_read called on connection_id={}",
self.connection_id.0
);
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some(true) = sim.roll_random_close(self.connection_id) {
return Poll::Ready(Err(random_connection_failure_error()));
}
if sim.is_recv_closed(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} recv side closed, returning EOF",
self.connection_id.0
);
return Poll::Ready(Ok(0)); }
if sim.is_half_open(self.connection_id) && sim.should_half_open_error(self.connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} half-open error time reached, returning ECONNRESET",
self.connection_id.0
);
return Poll::Ready(Err(half_open_timeout_error()));
}
if sim.is_read_clogged(self.connection_id) {
sim.register_read_clog_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
if sim.should_clog_read(self.connection_id) {
sim.clog_read(self.connection_id);
sim.register_read_clog_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
let mut temp_buf = vec![0u8; buf.len()];
let bytes_read = sim
.read_from_connection(self.connection_id, &mut temp_buf)
.map_err(|e| io::Error::other(format!("read error: {e}")))?;
tracing::trace!(
"SimTcpStream::poll_read connection_id={} read {} bytes",
self.connection_id.0,
bytes_read
);
if bytes_read > 0 {
let data_preview = String::from_utf8_lossy(&temp_buf[..std::cmp::min(bytes_read, 20)]);
tracing::trace!(
"SimTcpStream::poll_read connection_id={} returning data: '{}'",
self.connection_id.0,
data_preview
);
buf[..bytes_read].copy_from_slice(&temp_buf[..bytes_read]);
return Poll::Ready(Ok(bytes_read));
}
Self::poll_read_no_data(&sim, self.connection_id, cx, buf)
}
}
impl SimTcpStream {
fn poll_read_no_data(
sim: &crate::sim::SimWorld,
connection_id: crate::network::sim::ConnectionId,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if sim.is_remote_fin_received(connection_id) {
tracing::info!(
"SimTcpStream::poll_read connection_id={} remote FIN received, returning EOF",
connection_id.0
);
return Poll::Ready(Ok(0));
}
if sim.is_connection_closed(connection_id) {
if sim.close_reason(connection_id) == CloseReason::Aborted {
tracing::info!(
"SimTcpStream::poll_read connection_id={} was aborted (RST)",
connection_id.0
);
return Poll::Ready(Err(connection_aborted_error()));
}
tracing::info!(
"SimTcpStream::poll_read connection_id={} closed gracefully (FIN)",
connection_id.0
);
return Poll::Ready(Ok(0));
}
if sim.is_connection_cut(connection_id) {
tracing::debug!(
"SimTcpStream::poll_read connection_id={} is cut, registering cut waker",
connection_id.0
);
sim.register_cut_waker(connection_id, cx.waker().clone());
return Poll::Pending;
}
tracing::trace!(
"SimTcpStream::poll_read connection_id={} no data, registering waker",
connection_id.0
);
sim.register_read_waker(connection_id, cx.waker().clone());
let mut temp_buf_recheck = vec![0u8; buf.len()];
let bytes_read_recheck = sim
.read_from_connection(connection_id, &mut temp_buf_recheck)
.map_err(|e| io::Error::other(format!("recheck read error: {e}")))?;
if bytes_read_recheck > 0 {
buf[..bytes_read_recheck].copy_from_slice(&temp_buf_recheck[..bytes_read_recheck]);
return Poll::Ready(Ok(bytes_read_recheck));
}
if sim.is_remote_fin_received(connection_id) {
return Poll::Ready(Ok(0));
}
if sim.is_connection_closed(connection_id) {
if sim.close_reason(connection_id) == CloseReason::Aborted {
return Poll::Ready(Err(connection_aborted_error()));
}
return Poll::Ready(Ok(0));
}
Poll::Pending
}
}
impl AsyncWrite for SimTcpStream {
#[instrument(skip(self, cx, buf))]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some(poll) = self.write_guard_pre_backpressure(&sim, cx) {
return poll;
}
let available_buffer = sim.available_send_buffer(self.connection_id);
if available_buffer < buf.len() {
tracing::debug!(
"SimTcpStream::poll_write connection_id={} buffer full (available={}, needed={}), waiting",
self.connection_id.0,
available_buffer,
buf.len()
);
sim.register_send_buffer_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
if let Some(poll) = self.write_guard_clog(&sim, cx) {
return poll;
}
let data_preview = String::from_utf8_lossy(&buf[..std::cmp::min(buf.len(), 20)]);
tracing::trace!(
"SimTcpStream::poll_write buffering {} bytes: '{}' for ordered delivery",
buf.len(),
data_preview
);
sim.buffer_send(self.connection_id, buf.to_vec())
.map_err(|e| io::Error::other(format!("buffer send error: {e}")))?;
Poll::Ready(Ok(buf.len()))
}
#[instrument(skip(self, cx, bufs))]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some(poll) = self.write_guard_pre_backpressure(&sim, cx) {
return poll;
}
let total: usize = bufs.iter().map(|slice| slice.len()).sum();
if total == 0 {
return Poll::Ready(Ok(0));
}
let available = sim.available_send_buffer(self.connection_id);
if available == 0 {
sim.register_send_buffer_waker(self.connection_id, cx.waker().clone());
return Poll::Pending;
}
if let Some(poll) = self.write_guard_clog(&sim, cx) {
return poll;
}
let accepted = total.min(available);
let mut remaining = accepted;
for slice in bufs {
if remaining == 0 {
break;
}
if slice.is_empty() {
continue;
}
let take = remaining.min(slice.len());
sim.buffer_send(self.connection_id, slice[..take].to_vec())
.map_err(|e| io::Error::other(format!("buffer send error: {e}")))?;
remaining -= take;
}
Poll::Ready(Ok(accepted))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
tracing::debug!(
"SimTcpStream::poll_close closing connection {}",
self.connection_id.0
);
sim.close_connection(self.connection_id);
Poll::Ready(Ok(()))
}
}
pub struct AcceptFuture {
sim: WeakSimWorld,
local_addr: String,
}
impl Future for AcceptFuture {
type Output = io::Result<(SimTcpStream, String)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Ok(sim) = self.sim.upgrade() else {
return Poll::Ready(Err(sim_shutdown_error()));
};
let Some(connection_id) = sim.pending_connection(&self.local_addr) else {
sim.register_accept_waker(&self.local_addr, cx.waker().clone());
return Poll::Pending;
};
let delay = sim
.with_network_config(|config| crate::network::sample_duration(&config.accept_latency));
sim.schedule_event(
Event::Connection {
id: connection_id.0,
state: crate::ConnectionStateChange::ConnectionReady,
},
delay,
);
let peer_addr = sim
.connection_peer_address(connection_id)
.unwrap_or_else(|| "unknown:0".to_string());
let stream = SimTcpStream::new(self.sim.clone(), connection_id);
Poll::Ready(Ok((stream, peer_addr)))
}
}
pub struct SimTcpListener {
sim: WeakSimWorld,
local_addr: String,
}
impl SimTcpListener {
pub(crate) fn new(sim: WeakSimWorld, local_addr: String) -> Self {
Self { sim, local_addr }
}
}
impl TcpListenerTrait for SimTcpListener {
type TcpStream = SimTcpStream;
#[instrument(skip(self))]
async fn accept(&self) -> io::Result<(Self::TcpStream, String)> {
AcceptFuture {
sim: self.sim.clone(),
local_addr: self.local_addr.clone(),
}
.await
}
fn local_addr(&self) -> io::Result<String> {
Ok(self.local_addr.clone())
}
}