#![allow(clippy::missing_panics_doc)]
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc};
use ntex_bytes::{BufMut, Bytes, BytesMut};
use ntex_util::time::{Millis, sleep};
use crate::{Handle, IoContext, IoStream, IoTaskStatus, Readiness, types};
#[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
impl AtomicWaker {
fn wake(&self) {
if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
waker.wake();
}
}
}
impl fmt::Debug for AtomicWaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AtomicWaker")
}
}
#[derive(Debug)]
pub struct IoTest {
tp: Type,
peer_addr: Option<net::SocketAddr>,
state: Arc<Mutex<RefCell<State>>>,
local: Arc<Mutex<RefCell<Channel>>>,
remote: Arc<Mutex<RefCell<Channel>>>,
}
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct IoTestFlags: u8 {
const FLUSHED = 0b0000_0001;
const CLOSED = 0b0000_0010;
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
enum Type {
Client,
Server,
ClientClone,
ServerClone,
}
#[derive(Copy, Clone, Default, Debug)]
struct State {
client_dropped: bool,
server_dropped: bool,
}
#[derive(Default, Debug)]
struct Channel {
buf: BytesMut,
buf_cap: usize,
flags: IoTestFlags,
waker: AtomicWaker,
read: IoTestState,
write: IoTestState,
}
unsafe impl Sync for Channel {}
unsafe impl Send for Channel {}
impl Channel {
fn is_closed(&self) -> bool {
self.flags.contains(IoTestFlags::CLOSED)
}
}
impl Default for IoTestFlags {
fn default() -> Self {
IoTestFlags::empty()
}
}
#[derive(Debug, Default)]
enum IoTestState {
#[default]
Ok,
Pending,
Close,
Err(io::Error),
}
impl IoTest {
pub fn create() -> (IoTest, IoTest) {
let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let state = Arc::new(Mutex::new(RefCell::new(State::default())));
(
IoTest {
tp: Type::Client,
peer_addr: None,
local: local.clone(),
remote: remote.clone(),
state: state.clone(),
},
IoTest {
state,
peer_addr: None,
tp: Type::Server,
local: remote,
remote: local,
},
)
}
pub fn is_client_dropped(&self) -> bool {
self.state.lock().unwrap().borrow().client_dropped
}
pub fn is_server_dropped(&self) -> bool {
self.state.lock().unwrap().borrow().server_dropped
}
pub fn is_closed(&self) -> bool {
self.remote.lock().unwrap().borrow().is_closed()
}
#[must_use]
pub fn set_peer_addr(mut self, addr: net::SocketAddr) -> Self {
self.peer_addr = Some(addr);
self
}
pub fn read_pending(&self) {
self.remote.lock().unwrap().borrow_mut().read = IoTestState::Pending;
}
pub fn read_error(&self, err: io::Error) {
let channel = self.remote.lock().unwrap();
channel.borrow_mut().read = IoTestState::Err(err);
channel.borrow().waker.wake();
}
pub fn write_error(&self, err: io::Error) {
self.local.lock().unwrap().borrow_mut().write = IoTestState::Err(err);
self.remote.lock().unwrap().borrow().waker.wake();
}
pub fn local_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
pub fn remote_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
pub async fn close(&self) {
{
let guard = self.remote.lock().unwrap();
let mut remote = guard.borrow_mut();
remote.read = IoTestState::Close;
remote.waker.wake();
log::debug!("close remote socket");
}
sleep(Millis(35)).await;
}
pub fn write<T: AsRef<[u8]>>(&self, data: T) {
let guard = self.remote.lock().unwrap();
let mut write = guard.borrow_mut();
write.buf.extend_from_slice(data.as_ref());
write.waker.wake();
}
pub fn remote_buffer_cap(&self, cap: usize) {
self.local.lock().unwrap().borrow_mut().buf_cap = cap;
self.remote.lock().unwrap().borrow().waker.wake();
}
pub fn read_any(&self) -> Bytes {
self.local.lock().unwrap().borrow_mut().buf.take()
}
pub async fn read(&self) -> Result<Bytes, io::Error> {
if self.local.lock().unwrap().borrow().buf.is_empty() {
poll_fn(|cx| {
let guard = self.local.lock().unwrap();
let read = guard.borrow_mut();
if read.buf.is_empty() {
let closed = match self.tp {
Type::Client | Type::ClientClone => {
self.is_server_dropped() || read.is_closed()
}
Type::Server | Type::ServerClone => self.is_client_dropped(),
};
if closed {
Poll::Ready(())
} else {
*read.waker.0.lock().unwrap().borrow_mut() =
Some(cx.waker().clone());
drop(read);
drop(guard);
Poll::Pending
}
} else {
Poll::Ready(())
}
})
.await;
}
Ok(self.local.lock().unwrap().borrow_mut().buf.take())
}
pub fn poll_read_buf(
&self,
cx: &mut Context<'_>,
buf: &mut BytesMut,
) -> Poll<io::Result<usize>> {
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
*ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
if !ch.buf.is_empty() {
let size = std::cmp::min(ch.buf.len(), buf.remaining_mut());
let b = ch.buf.split_to(size);
buf.put_slice(&b);
return Poll::Ready(Ok(size));
}
match mem::take(&mut ch.read) {
IoTestState::Ok | IoTestState::Pending => Poll::Pending,
IoTestState::Close => {
ch.read = IoTestState::Close;
Poll::Ready(Ok(0))
}
IoTestState::Err(e) => Poll::Ready(Err(e)),
}
}
pub fn poll_write_buf(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
match mem::take(&mut ch.write) {
IoTestState::Ok => {
let cap = cmp::min(buf.len(), ch.buf_cap);
if cap > 0 {
ch.buf.extend(&buf[..cap]);
ch.buf_cap -= cap;
ch.flags.remove(IoTestFlags::FLUSHED);
ch.waker.wake();
Poll::Ready(Ok(cap))
} else {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
}
IoTestState::Close => Poll::Ready(Ok(0)),
IoTestState::Pending => {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
IoTestState::Err(e) => Poll::Ready(Err(e)),
}
}
}
impl Clone for IoTest {
fn clone(&self) -> Self {
let tp = match self.tp {
Type::Server => Type::ServerClone,
Type::Client => Type::ClientClone,
val => val,
};
IoTest {
tp,
local: self.local.clone(),
remote: self.remote.clone(),
state: self.state.clone(),
peer_addr: self.peer_addr,
}
}
}
impl Drop for IoTest {
fn drop(&mut self) {
let mut state = *self.state.lock().unwrap().borrow();
match self.tp {
Type::Server => state.server_dropped = true,
Type::Client => state.client_dropped = true,
_ => (),
}
*self.state.lock().unwrap().borrow_mut() = state;
let guard = self.remote.lock().unwrap();
let mut remote = guard.borrow_mut();
remote.read = IoTestState::Close;
remote.waker.wake();
log::debug!("drop remote socket");
}
}
impl IoStream for IoTest {
fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
let io = Rc::new(self);
ntex_util::spawn(run(io.clone(), ctx));
Some(Box::new(io))
}
}
impl Handle for Rc<IoTest> {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if id == any::TypeId::of::<types::PeerAddr>()
&& let Some(addr) = self.peer_addr
{
return Some(Box::new(types::PeerAddr(addr)));
}
None
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum Status {
Shutdown,
Terminate,
}
async fn run(io: Rc<IoTest>, ctx: IoContext) {
let st = poll_fn(|cx| turn(&io, &ctx, cx)).await;
log::debug!("{}: Shuting down io", ctx.tag());
if !ctx.is_stopped() {
let flush = st == Status::Shutdown;
poll_fn(|cx| {
if write(&io, &ctx, cx) == Poll::Ready(Status::Terminate) {
Poll::Ready(())
} else {
ctx.shutdown(flush, cx)
}
})
.await;
}
io.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(IoTestFlags::CLOSED);
log::debug!("{}: Shutdown complete", ctx.tag());
if !ctx.is_stopped() {
ctx.stop(None);
}
}
fn turn(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status> {
let read = match ctx.poll_read_ready(cx) {
Poll::Ready(Readiness::Ready) => read(io, ctx, cx),
Poll::Ready(Readiness::Shutdown | Readiness::Terminate) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
};
let write = match ctx.poll_write_ready(cx) {
Poll::Ready(Readiness::Ready) => write(io, ctx, cx),
Poll::Ready(Readiness::Shutdown) => Poll::Ready(Status::Shutdown),
Poll::Ready(Readiness::Terminate) => Poll::Ready(Status::Terminate),
Poll::Pending => Poll::Pending,
};
if read.is_pending() && write.is_pending() {
Poll::Pending
} else if write.is_ready() {
write
} else {
Poll::Ready(Status::Terminate)
}
}
fn write(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status> {
if let Some(mut buf) = ctx.get_write_buf() {
let result = write_io(io, &mut buf, cx, ctx.tag());
if ctx.release_write_buf(buf, result) == IoTaskStatus::Stop {
Poll::Ready(Status::Terminate)
} else {
Poll::Pending
}
} else {
Poll::Pending
}
}
fn read(io: &IoTest, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<()> {
let mut buf = ctx.get_read_buf();
let mut n = 0;
loop {
ctx.resize_read_buf(&mut buf);
let result = match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
if n > 0 {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
Poll::Ready(Ok(size)) => {
n += size;
if size > 0 && buf.remaining_mut() > 0 {
continue;
}
if size == 0 {
Poll::Ready(Err(None))
} else {
Poll::Ready(Ok(()))
}
}
Poll::Ready(Err(err)) => Poll::Ready(Err(Some(err))),
};
return if matches!(ctx.release_read_buf(n, buf, result), IoTaskStatus::Stop) {
Poll::Ready(())
} else {
Poll::Pending
};
}
}
pub(super) fn write_io(
io: &IoTest,
buf: &mut BytesMut,
cx: &mut Context<'_>,
tag: &'static str,
) -> Poll<io::Result<usize>> {
let len = buf.len();
if len != 0 {
log::debug!("{tag}: flushing framed transport: {len}");
let mut written = 0;
while let Poll::Ready(n) = io.poll_write_buf(cx, &buf[written..])? {
if n == 0 {
log::trace!("{tag}: disconnected during flush, written {written}");
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
}
written += n;
if written == len {
break;
}
}
log::debug!("{tag}: flushed {written} bytes");
if written > 0 {
Poll::Ready(Ok(written))
} else {
Poll::Pending
}
} else {
Poll::Pending
}
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use super::*;
use ntex_util::future::lazy;
#[ntex::test]
async fn basic() {
let (client, server) = IoTest::create();
assert_eq!(client.tp, Type::Client);
assert_eq!(client.clone().tp, Type::ClientClone);
assert_eq!(server.tp, Type::Server);
assert_eq!(server.clone().tp, Type::ServerClone);
assert!(format!("{server:?}").contains("IoTest"));
assert!(format!("{:?}", AtomicWaker::default()).contains("AtomicWaker"));
server.read_pending();
let mut buf = BytesMut::new();
let res = lazy(|cx| client.poll_read_buf(cx, &mut buf)).await;
assert!(res.is_pending());
server.read_pending();
let res = lazy(|cx| server.poll_write_buf(cx, b"123")).await;
assert!(res.is_pending());
assert!(!server.is_client_dropped());
drop(client);
assert!(server.is_client_dropped());
let server2 = server.clone();
assert!(!server2.is_server_dropped());
drop(server);
assert!(server2.is_server_dropped());
let res = lazy(|cx| server2.poll_write_buf(cx, b"123")).await;
assert!(res.is_pending());
let (client, _) = IoTest::create();
let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
let client = crate::Io::from(client.set_peer_addr(addr));
let item = client.query::<crate::types::PeerAddr>();
assert!(format!("{item:?}").contains("QueryItem(127.0.0.1:8080)"));
}
}