use super::traits::{TcpListenerApi, TcpStreamApi};
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::io;
use std::net::{Shutdown, SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::task::{Context, Poll, Waker};
const VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES: usize = 1024;
const VIRTUAL_TCP_ACCEPT_QUEUE_CAPACITY: usize = 16;
#[derive(Debug, Default)]
struct AcceptWaiters {
waiters: Mutex<Vec<Waker>>,
}
impl AcceptWaiters {
fn register(&self, waker: &Waker) {
let mut waiters = self.waiters.lock();
if waiters.iter().any(|existing| existing.will_wake(waker)) {
return;
}
if waiters.len() >= 32 {
let evicted = waiters.remove(0);
evicted.wake();
}
waiters.push(waker.clone());
}
fn wake_all(&self) {
let waiters = {
let mut guard = self.waiters.lock();
std::mem::take(&mut *guard)
};
for waiter in waiters {
waiter.wake();
}
}
fn wake_others(&self, current: &Waker) {
let waiters = {
let mut guard = self.waiters.lock();
std::mem::take(&mut *guard)
};
for waiter in waiters {
if !waiter.will_wake(current) {
waiter.wake();
}
}
}
}
use std::task::Wake;
impl Wake for AcceptWaiters {
fn wake(self: Arc<Self>) {
self.wake_all();
}
fn wake_by_ref(self: &Arc<Self>) {
self.wake_all();
}
}
struct ChannelHalf {
buf: VecDeque<u8>,
waker: Option<Waker>,
write_waker: Option<Waker>,
closed: bool,
read_shutdown: bool,
}
impl ChannelHalf {
fn new() -> Self {
Self {
buf: VecDeque::with_capacity(VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES),
waker: None,
write_waker: None,
closed: false,
read_shutdown: false,
}
}
fn take_wakers(&mut self) -> (Option<Waker>, Option<Waker>) {
(self.waker.take(), self.write_waker.take())
}
}
pub struct VirtualTcpStream {
local_addr: SocketAddr,
peer_addr: SocketAddr,
read_half: Arc<Mutex<ChannelHalf>>,
write_half: Arc<Mutex<ChannelHalf>>,
nodelay: AtomicBool,
ttl: AtomicU32,
write_shutdown: AtomicBool,
}
impl std::fmt::Debug for VirtualTcpStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VirtualTcpStream")
.field("local_addr", &self.local_addr)
.field("peer_addr", &self.peer_addr)
.finish_non_exhaustive()
}
}
impl VirtualTcpStream {
#[must_use]
pub fn pair(a_addr: SocketAddr, b_addr: SocketAddr) -> (Self, Self) {
let a_to_b = Arc::new(Mutex::new(ChannelHalf::new()));
let b_to_a = Arc::new(Mutex::new(ChannelHalf::new()));
let stream_a = Self {
local_addr: a_addr,
peer_addr: b_addr,
read_half: Arc::clone(&b_to_a),
write_half: Arc::clone(&a_to_b),
nodelay: AtomicBool::new(false),
ttl: AtomicU32::new(64),
write_shutdown: AtomicBool::new(false),
};
let stream_b = Self {
local_addr: b_addr,
peer_addr: a_addr,
read_half: a_to_b,
write_half: b_to_a,
nodelay: AtomicBool::new(false),
ttl: AtomicU32::new(64),
write_shutdown: AtomicBool::new(false),
};
(stream_a, stream_b)
}
}
impl AsyncRead for VirtualTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut half = this.read_half.lock();
if half.read_shutdown {
return Poll::Ready(Ok(()));
}
if half.buf.is_empty() {
if half.closed {
return Poll::Ready(Ok(()));
}
if !half.waker.as_ref().is_some_and(|w| w.will_wake(cx.waker())) {
half.waker = Some(cx.waker().clone());
}
return Poll::Pending;
}
let unfilled = buf.unfilled();
let to_read = unfilled.len().min(half.buf.len());
let (front, back) = half.buf.as_slices();
let front_copy = front.len().min(to_read);
unfilled[..front_copy].copy_from_slice(&front[..front_copy]);
if front_copy < to_read {
unfilled[front_copy..to_read].copy_from_slice(&back[..to_read - front_copy]);
}
half.buf.drain(..to_read);
let write_wake = half.write_waker.take();
drop(half);
buf.advance(to_read);
if let Some(waker) = write_wake {
waker.wake();
}
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for VirtualTcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.write_shutdown.load(Ordering::Relaxed) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"write half shutdown",
)));
}
let mut half = this.write_half.lock();
if half.closed {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"peer closed",
)));
}
if half.read_shutdown {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"peer shut down read half",
)));
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if half.buf.len() >= VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES {
if !half
.write_waker
.as_ref()
.is_some_and(|w| w.will_wake(cx.waker()))
{
half.write_waker = Some(cx.waker().clone());
}
return Poll::Pending;
}
let available = VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES.saturating_sub(half.buf.len());
let to_write = buf.len().min(available);
half.buf.extend(&buf[..to_write]);
let wake = half.waker.take();
drop(half);
if let Some(waker) = wake {
waker.wake();
}
Poll::Ready(Ok(to_write))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
this.write_shutdown.store(true, Ordering::Relaxed);
let mut half = this.write_half.lock();
half.closed = true;
let wake = half.waker.take();
drop(half);
if let Some(waker) = wake {
waker.wake();
}
Poll::Ready(Ok(()))
}
}
impl Unpin for VirtualTcpStream {}
impl Drop for VirtualTcpStream {
fn drop(&mut self) {
*self.write_shutdown.get_mut() = true;
let read_wake = {
let mut half = self.read_half.lock();
half.read_shutdown = true;
half.buf.clear();
half.take_wakers()
};
if let Some(waker) = read_wake.0 {
waker.wake();
}
if let Some(waker) = read_wake.1 {
waker.wake();
}
let write_wake = {
let mut half = self.write_half.lock();
half.closed = true;
half.take_wakers()
};
if let Some(waker) = write_wake.0 {
waker.wake();
}
if let Some(waker) = write_wake.1 {
waker.wake();
}
}
}
#[allow(clippy::manual_async_fn)] impl TcpStreamApi for VirtualTcpStream {
fn connect<A: ToSocketAddrs + Send + 'static>(
_addr: A,
) -> impl std::future::Future<Output = io::Result<Self>> + Send {
async {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"VirtualTcpStream::connect not supported; use VirtualTcpStream::pair()",
))
}
}
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(self.peer_addr)
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match how {
Shutdown::Read | Shutdown::Both => {
let mut half = self.read_half.lock();
half.read_shutdown = true;
half.buf.clear();
let wake = half.take_wakers();
drop(half);
if let Some(waker) = wake.0 {
waker.wake();
}
if let Some(waker) = wake.1 {
waker.wake();
}
}
Shutdown::Write => {}
}
match how {
Shutdown::Write | Shutdown::Both => {
self.write_shutdown.store(true, Ordering::Relaxed);
let mut half = self.write_half.lock();
half.closed = true;
let wake = half.take_wakers();
drop(half);
if let Some(waker) = wake.0 {
waker.wake();
}
if let Some(waker) = wake.1 {
waker.wake();
}
}
Shutdown::Read => {}
}
Ok(())
}
fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.nodelay.store(nodelay, Ordering::Relaxed);
Ok(())
}
fn nodelay(&self) -> io::Result<bool> {
Ok(self.nodelay.load(Ordering::Relaxed))
}
fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.ttl.store(ttl, Ordering::Relaxed);
Ok(())
}
fn ttl(&self) -> io::Result<u32> {
Ok(self.ttl.load(Ordering::Relaxed))
}
}
struct VirtualListenerState {
connections: VecDeque<(VirtualTcpStream, SocketAddr)>,
closed: bool,
}
pub struct VirtualTcpListener {
addr: SocketAddr,
state: Arc<Mutex<VirtualListenerState>>,
accept_waiters: Arc<AcceptWaiters>,
}
impl Drop for VirtualTcpListener {
fn drop(&mut self) {
self.close();
}
}
impl std::fmt::Debug for VirtualTcpListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VirtualTcpListener")
.field("addr", &self.addr)
.finish_non_exhaustive()
}
}
impl VirtualTcpListener {
#[must_use]
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
state: Arc::new(Mutex::new(VirtualListenerState {
connections: VecDeque::with_capacity(VIRTUAL_TCP_ACCEPT_QUEUE_CAPACITY),
closed: false,
})),
accept_waiters: Arc::new(AcceptWaiters::default()),
}
}
pub fn inject_connection(&self, stream: VirtualTcpStream, remote_addr: SocketAddr) {
{
let mut state = self.state.lock();
if state.closed {
return;
}
state.connections.push_back((stream, remote_addr));
}
self.accept_waiters.wake_all();
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.state.lock().connections.len()
}
pub fn close(&self) {
{
let mut state = self.state.lock();
state.closed = true;
state.connections.clear();
}
self.accept_waiters.wake_all();
}
#[must_use]
pub fn injector(&self) -> VirtualConnectionInjector {
VirtualConnectionInjector {
state: Arc::clone(&self.state),
accept_waiters: Arc::clone(&self.accept_waiters),
}
}
}
#[derive(Clone)]
pub struct VirtualConnectionInjector {
state: Arc<Mutex<VirtualListenerState>>,
accept_waiters: Arc<AcceptWaiters>,
}
impl VirtualConnectionInjector {
pub fn inject(&self, stream: VirtualTcpStream, remote_addr: SocketAddr) {
{
let mut state = self.state.lock();
if state.closed {
return;
}
state.connections.push_back((stream, remote_addr));
}
self.accept_waiters.wake_all();
}
}
#[allow(clippy::manual_async_fn)] impl TcpListenerApi for VirtualTcpListener {
type Stream = VirtualTcpStream;
fn bind<A: ToSocketAddrs + Send + 'static>(
addr: A,
) -> impl std::future::Future<Output = io::Result<Self>> + Send {
async move {
let socket_addr = addr.to_socket_addrs()?.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no socket addresses found")
})?;
Ok(Self::new(socket_addr))
}
}
fn accept(
&self,
) -> impl std::future::Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send {
let state = Arc::clone(&self.state);
let accept_waiters = Arc::clone(&self.accept_waiters);
async move {
std::future::poll_fn(|cx| {
let mut guard = state.lock();
if guard.closed {
drop(guard);
accept_waiters.wake_others(cx.waker());
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"virtual listener closed",
)));
}
if let Some(conn) = guard.connections.pop_front() {
drop(guard);
accept_waiters.wake_others(cx.waker());
return Poll::Ready(Ok(conn));
}
drop(guard);
accept_waiters.register(cx.waker());
Poll::Pending
})
.await
}
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(Self::Stream, SocketAddr)>> {
let mut state = self.state.lock();
if state.closed {
drop(state);
self.accept_waiters.wake_others(cx.waker());
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"virtual listener closed",
)));
}
if let Some(conn) = state.connections.pop_front() {
drop(state);
self.accept_waiters.wake_others(cx.waker());
return Poll::Ready(Ok(conn));
}
drop(state);
self.accept_waiters.register(cx.waker());
Poll::Pending
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.addr)
}
fn pending_connections(&self) -> Option<usize> {
Some(self.pending_count())
}
fn set_ttl(&self, _ttl: u32) -> io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
struct CountWaker(std::sync::atomic::AtomicUsize);
impl Wake for CountWaker {
fn wake(self: Arc<Self>) {
self.0.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
fn count_waker() -> (Arc<CountWaker>, Waker) {
let inner = Arc::new(CountWaker(std::sync::atomic::AtomicUsize::new(0)));
(Arc::clone(&inner), Waker::from(inner))
}
fn addr(s: &str) -> SocketAddr {
s.parse().unwrap()
}
#[test]
fn virtual_stream_pair_read_write() {
let (mut a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let data = b"hello";
let result = Pin::new(&mut a).poll_write(&mut cx, data);
assert!(matches!(result, Poll::Ready(Ok(5))));
let mut buf = [0u8; 16];
let mut read_buf = ReadBuf::new(&mut buf);
let result = Pin::new(&mut b).poll_read(&mut cx, &mut read_buf);
assert!(matches!(result, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled(), b"hello");
}
#[test]
fn virtual_stream_addresses() {
let (a, b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
assert_eq!(a.local_addr().unwrap(), addr("127.0.0.1:1000"));
assert_eq!(a.peer_addr().unwrap(), addr("127.0.0.1:2000"));
assert_eq!(b.local_addr().unwrap(), addr("127.0.0.1:2000"));
assert_eq!(b.peer_addr().unwrap(), addr("127.0.0.1:1000"));
}
#[test]
fn virtual_stream_eof_on_write_shutdown() {
let (mut a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut a).poll_shutdown(&mut cx);
assert!(matches!(result, Poll::Ready(Ok(()))));
let mut buf = [0u8; 16];
let mut read_buf = ReadBuf::new(&mut buf);
let result = Pin::new(&mut b).poll_read(&mut cx, &mut read_buf);
assert!(matches!(result, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 0); }
#[test]
fn virtual_stream_eof_on_drop() {
let (a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
drop(a);
let mut buf = [0u8; 16];
let mut read_buf = ReadBuf::new(&mut buf);
let result = Pin::new(&mut b).poll_read(&mut cx, &mut read_buf);
assert!(matches!(result, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 0); }
#[test]
fn virtual_stream_write_after_shutdown_errors() {
let (mut a, _b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(
Pin::new(&mut a).poll_shutdown(&mut cx),
Poll::Ready(Ok(()))
));
let result = Pin::new(&mut a).poll_write(&mut cx, b"data");
assert!(matches!(result, Poll::Ready(Err(_))));
}
#[test]
fn virtual_stream_pending_when_empty() {
let (_a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut buf = [0u8; 16];
let mut read_buf = ReadBuf::new(&mut buf);
let result = Pin::new(&mut b).poll_read(&mut cx, &mut read_buf);
assert!(matches!(result, Poll::Pending));
}
#[test]
fn virtual_listener_bind_and_accept() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
assert_eq!(listener.local_addr().unwrap(), addr("127.0.0.1:8080"));
assert_eq!(listener.pending_count(), 0);
assert_eq!(listener.pending_connections(), Some(0));
let (client, server) =
VirtualTcpStream::pair(addr("127.0.0.1:9000"), addr("127.0.0.1:8080"));
let _ = client; listener.inject_connection(server, addr("127.0.0.1:9000"));
assert_eq!(listener.pending_count(), 1);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = listener.poll_accept(&mut cx);
match &result {
Poll::Ready(Ok((stream, remote_addr))) => {
assert_eq!(*remote_addr, addr("127.0.0.1:9000"));
assert_eq!(stream.local_addr().unwrap(), addr("127.0.0.1:8080"));
}
other => {
assert!(
matches!(other, Poll::Ready(Ok(_))),
"expected Ready(Ok(...)), got: {other:?}",
);
}
}
}
#[test]
fn virtual_listener_pending_when_no_connections() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = listener.poll_accept(&mut cx);
assert!(matches!(result, Poll::Pending));
}
#[test]
fn virtual_listener_wakes_all_pending_accept_waiters() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
let (hits1, waker1) = count_waker();
let (hits2, waker2) = count_waker();
let mut cx1 = Context::from_waker(&waker1);
let mut cx2 = Context::from_waker(&waker2);
assert!(matches!(listener.poll_accept(&mut cx1), Poll::Pending));
assert!(matches!(listener.poll_accept(&mut cx2), Poll::Pending));
let (_client1, server1) =
VirtualTcpStream::pair(addr("127.0.0.1:9000"), addr("127.0.0.1:8080"));
listener.inject_connection(server1, addr("127.0.0.1:9000"));
assert_eq!(hits1.0.load(Ordering::Relaxed), 1);
assert_eq!(hits2.0.load(Ordering::Relaxed), 1);
assert!(matches!(listener.poll_accept(&mut cx2), Poll::Ready(Ok(_))));
assert!(matches!(listener.poll_accept(&mut cx1), Poll::Pending));
let (_client2, server2) =
VirtualTcpStream::pair(addr("127.0.0.1:9001"), addr("127.0.0.1:8080"));
listener.inject_connection(server2, addr("127.0.0.1:9001"));
assert_eq!(hits1.0.load(Ordering::Relaxed), 2);
}
#[test]
fn virtual_listener_closed_returns_error() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
listener.close();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = listener.poll_accept(&mut cx);
assert!(matches!(result, Poll::Ready(Err(_))));
}
#[test]
fn virtual_listener_close_drops_pending_connections() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
let (_client, server) =
VirtualTcpStream::pair(addr("127.0.0.1:9000"), addr("127.0.0.1:8080"));
listener.inject_connection(server, addr("127.0.0.1:9000"));
assert_eq!(listener.pending_count(), 1);
listener.close();
assert_eq!(listener.pending_count(), 0);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = listener.poll_accept(&mut cx);
assert!(matches!(result, Poll::Ready(Err(_))));
}
#[test]
fn virtual_listener_inject_after_close_is_ignored() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
listener.close();
let (_client, server) =
VirtualTcpStream::pair(addr("127.0.0.1:9001"), addr("127.0.0.1:8080"));
listener.inject_connection(server, addr("127.0.0.1:9001"));
assert_eq!(listener.pending_count(), 0);
let injector = listener.injector();
let (_client2, server2) =
VirtualTcpStream::pair(addr("127.0.0.1:9002"), addr("127.0.0.1:8080"));
injector.inject(server2, addr("127.0.0.1:9002"));
assert_eq!(listener.pending_count(), 0);
}
#[test]
fn virtual_listener_injector_works() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
let injector = listener.injector();
let (_client, server) =
VirtualTcpStream::pair(addr("127.0.0.1:9000"), addr("127.0.0.1:8080"));
injector.inject(server, addr("127.0.0.1:9000"));
assert_eq!(listener.pending_count(), 1);
}
#[test]
fn virtual_listener_drop_marks_closed() {
let listener = VirtualTcpListener::new(addr("127.0.0.1:8080"));
let state = Arc::clone(&listener.state);
drop(listener);
let closed = state.lock().closed;
assert!(closed);
}
#[test]
fn virtual_listener_bind_via_trait() {
futures_lite::future::block_on(async {
let listener = VirtualTcpListener::bind("127.0.0.1:0").await.expect("bind");
assert!(listener.local_addr().is_ok());
});
}
#[test]
fn virtual_stream_bidirectional() {
let (mut a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, b"from-a"),
Poll::Ready(Ok(6))
));
let mut buf = [0u8; 16];
let mut read_buf = ReadBuf::new(&mut buf);
assert!(matches!(
Pin::new(&mut b).poll_read(&mut cx, &mut read_buf),
Poll::Ready(Ok(()))
));
assert_eq!(read_buf.filled(), b"from-a");
assert!(matches!(
Pin::new(&mut b).poll_write(&mut cx, b"from-b"),
Poll::Ready(Ok(6))
));
let mut buf2 = [0u8; 16];
let mut read_buf2 = ReadBuf::new(&mut buf2);
assert!(matches!(
Pin::new(&mut a).poll_read(&mut cx, &mut read_buf2),
Poll::Ready(Ok(()))
));
assert_eq!(read_buf2.filled(), b"from-b");
}
#[test]
fn virtual_stream_nodelay_and_ttl() {
let (a, b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
assert!(!a.nodelay().unwrap());
assert_eq!(a.ttl().unwrap(), 64);
assert!(!b.nodelay().unwrap());
assert_eq!(b.ttl().unwrap(), 64);
assert!(a.set_nodelay(true).is_ok());
assert!(a.set_ttl(128).is_ok());
assert!(a.nodelay().unwrap());
assert_eq!(a.ttl().unwrap(), 128);
assert!(!b.nodelay().unwrap());
assert_eq!(b.ttl().unwrap(), 64);
}
#[test]
fn virtual_stream_shutdown_both() {
let (a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
a.shutdown(Shutdown::Both).unwrap();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut buf = [0u8; 8];
let mut read_buf = ReadBuf::new(&mut buf);
let result = Pin::new(&mut b).poll_read(&mut cx, &mut read_buf);
assert!(matches!(result, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 0);
}
#[test]
fn virtual_stream_shutdown_read_rejects_peer_write() {
let (mut a, mut b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
a.shutdown(Shutdown::Read).unwrap();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut b).poll_write(&mut cx, b"discarded");
assert!(matches!(
result,
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::BrokenPipe
));
let mut buf = [0u8; 8];
let mut read_buf = ReadBuf::new(&mut buf);
let result = Pin::new(&mut a).poll_read(&mut cx, &mut read_buf);
assert!(matches!(result, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 0);
}
#[test]
fn virtual_stream_shutdown_read_wakes_blocked_writer() {
let (mut a, b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let full = vec![7u8; VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES];
let (wake_counter, waker) = count_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, &full),
Poll::Ready(Ok(VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES))
));
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, b"x"),
Poll::Pending
));
b.shutdown(Shutdown::Read).unwrap();
assert_eq!(wake_counter.0.load(Ordering::Relaxed), 1);
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, b"x"),
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::BrokenPipe
));
}
#[test]
fn virtual_stream_zero_len_write_on_full_buffer_is_ready() {
let (mut a, _b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let full = vec![5u8; VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES];
let (wake_counter, waker) = count_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, &full),
Poll::Ready(Ok(VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES))
));
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, b""),
Poll::Ready(Ok(0))
));
assert_eq!(
wake_counter.0.load(Ordering::Relaxed),
0,
"zero-length writes must not register a blocked-writer wake"
);
}
#[test]
fn virtual_stream_drop_wakes_blocked_writer() {
let (mut a, b) = VirtualTcpStream::pair(addr("127.0.0.1:1000"), addr("127.0.0.1:2000"));
let full = vec![9u8; VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES];
let (wake_counter, waker) = count_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, &full),
Poll::Ready(Ok(VIRTUAL_TCP_CHANNEL_CAPACITY_BYTES))
));
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, b"y"),
Poll::Pending
));
drop(b);
assert_eq!(wake_counter.0.load(Ordering::Relaxed), 1);
assert!(matches!(
Pin::new(&mut a).poll_write(&mut cx, b"y"),
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::BrokenPipe
));
}
#[test]
fn virtual_stream_connect_returns_unsupported() {
futures_lite::future::block_on(async {
let result = VirtualTcpStream::connect("127.0.0.1:8080").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::Unsupported);
});
}
}