use std::task::{Context, Poll, ready};
use std::{any, cell::Cell, cmp, future::poll_fn, io, pin::Pin, rc::Rc, rc::Weak};
use ntex_bytes::{BufMut, BytesMut};
use ntex_io::{
Filter, Handle, Io, IoBoxed, IoContext, IoStream, IoTaskStatus, Readiness, types,
};
use ntex_util::time::Millis;
use tok_io::io::{AsyncRead, AsyncWrite, ReadBuf};
use tok_io::net::TcpStream;
impl IoStream for super::TcpStream {
fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
let io = Rc::new(Cell::new(Some(self.0)));
tok_io::task::spawn_local(run_rd(io.clone(), ctx.clone()));
tok_io::task::spawn_local(run_wrt(io.clone(), ctx));
Some(Box::new(HandleWrapper(io)))
}
}
#[cfg(unix)]
impl IoStream for super::UnixStream {
fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
let io = Rc::new(Cell::new(Some(self.0)));
tok_io::task::spawn_local(run_rd(io.clone(), ctx.clone()));
tok_io::task::spawn_local(run_wrt(io, ctx));
None
}
}
struct HandleWrapper(Rc<Cell<Option<TcpStream>>>);
impl Handle for HandleWrapper {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if id == any::TypeId::of::<types::PeerAddr>() {
let inner = self.0.take().unwrap();
let result = inner.peer_addr();
self.0.set(Some(inner));
if let Ok(addr) = result {
return Some(Box::new(types::PeerAddr(addr)));
}
} else if id == any::TypeId::of::<SocketOptions>() {
return Some(Box::new(SocketOptions(Rc::downgrade(&self.0))));
}
None
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum Status {
Shutdown,
Terminate,
}
async fn run_rd<T>(io: Rc<Cell<Option<T>>>, ctx: IoContext)
where
T: AsyncRead + AsyncWrite + Unpin,
{
let st = poll_fn(|cx| {
let mut inner = io.take().unwrap();
let result = match ctx.poll_read_ready(cx) {
Poll::Ready(Readiness::Ready) => read(&mut inner, &ctx, cx),
Poll::Ready(Readiness::Shutdown | Readiness::Terminate) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
};
io.set(Some(inner));
result
})
.await;
}
async fn run_wrt<T>(io: Rc<Cell<Option<T>>>, ctx: IoContext)
where
T: AsyncRead + AsyncWrite + Unpin,
{
let st = poll_fn(|cx| {
let mut inner = io.take().unwrap();
let result = match ctx.poll_write_ready(cx) {
Poll::Ready(Readiness::Ready) => write(&mut inner, &ctx, cx),
Poll::Ready(Readiness::Shutdown) => Poll::Ready(Status::Shutdown),
Poll::Ready(Readiness::Terminate) => Poll::Ready(Status::Terminate),
Poll::Pending => Poll::Pending,
};
io.set(Some(inner));
result
})
.await;
log::trace!("{}: Shuting down io {:?}", ctx.tag(), ctx.is_stopped());
if !ctx.is_stopped() {
let flush = st == Status::Shutdown;
poll_fn(|cx| {
let mut inner = io.take().unwrap();
let result = if write(&mut inner, &ctx, cx) == Poll::Ready(Status::Terminate) {
Poll::Ready(())
} else {
ctx.shutdown(flush, cx)
};
io.set(Some(inner));
result
})
.await;
}
let result = poll_fn(|cx| {
let mut inner = io.take().unwrap();
let result = Pin::new(&mut inner).poll_shutdown(cx);
io.set(Some(inner));
result
})
.await;
log::trace!("{}: Shutdown complete, result {result:?}", ctx.tag());
if !ctx.is_stopped() {
ctx.stop(None);
}
}
fn write<T>(io: &mut T, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status>
where
T: AsyncRead + AsyncWrite + Unpin,
{
if let Some(mut buf) = ctx.get_write_buf() {
let result = write_io(io, &mut buf, cx);
if ctx.release_write_buf(buf, result) == IoTaskStatus::Stop {
Poll::Ready(Status::Terminate)
} else {
Poll::Pending
}
} else {
Poll::Pending
}
}
fn read<T: AsyncRead + Unpin>(
io: &mut T,
ctx: &IoContext,
cx: &mut Context<'_>,
) -> Poll<()> {
let mut buf = ctx.get_read_buf();
let mut n = 0;
loop {
ctx.resize_read_buf(&mut buf);
let result = match read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
if n > 0 {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
Poll::Ready(Ok(0)) => Poll::Ready(Err(None)),
Poll::Ready(Ok(size)) => {
n += size;
continue;
}
Poll::Ready(Err(err)) => Poll::Ready(Err(Some(err))),
};
return if matches!(ctx.release_read_buf(n, buf, result), IoTaskStatus::Stop) {
Poll::Ready(())
} else {
Poll::Pending
};
}
}
fn read_buf<T: AsyncRead>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut BytesMut,
) -> Poll<io::Result<usize>> {
let n = {
let dst = buf.chunk_mut().as_mut();
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
if io.poll_read(cx, &mut buf)?.is_pending() {
return Poll::Pending;
}
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}
fn write_io<T: AsyncRead + AsyncWrite + Unpin>(
io: &mut T,
buf: &mut BytesMut,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
let len = buf.len();
if len != 0 {
let mut written = 0;
while let Poll::Ready(n) = Pin::new(&mut *io).poll_write(cx, &buf[written..])? {
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
}
written += n;
if written == len {
break;
}
}
if written > 0 {
let _ = Pin::new(&mut *io).poll_flush(cx)?;
Poll::Ready(Ok(written))
} else {
Poll::Pending
}
} else {
Poll::Pending
}
}
#[derive(Debug)]
pub struct TokioIoBoxed(IoBoxed);
impl std::ops::Deref for TokioIoBoxed {
type Target = IoBoxed;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<IoBoxed> for TokioIoBoxed {
fn from(io: IoBoxed) -> TokioIoBoxed {
TokioIoBoxed(io)
}
}
impl<F: Filter> From<Io<F>> for TokioIoBoxed {
fn from(io: Io<F>) -> TokioIoBoxed {
TokioIoBoxed(IoBoxed::from(io))
}
}
impl AsyncRead for TokioIoBoxed {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let len = self.0.with_read_buf(|src| {
let len = cmp::min(src.len(), buf.remaining());
buf.put_slice(&src.split_to(len));
len
});
if len == 0 {
match ready!(self.0.poll_read_ready(cx)) {
Ok(Some(())) => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
Ok(None) => Poll::Ready(Ok(())),
}
} else {
Poll::Ready(Ok(()))
}
}
}
impl AsyncWrite for TokioIoBoxed {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.write(buf).map(|()| buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.as_ref().0.poll_flush(cx, false)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.as_ref().0.poll_shutdown(cx)
}
}
#[derive(Debug)]
pub struct SocketOptions(Weak<Cell<Option<TcpStream>>>);
impl SocketOptions {
#[deprecated = "`SO_LINGER` causes the socket to block the thread on drop"]
pub fn set_linger(&self, dur: Option<Millis>) -> io::Result<()> {
#[allow(deprecated)]
{
let inner = self.try_self()?;
let io = inner.take().unwrap();
io.set_linger(dur.map(Into::into))?;
inner.set(Some(io));
Ok(())
}
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
let inner = self.try_self()?;
let io = inner.take().unwrap();
io.set_ttl(ttl)?;
inner.set(Some(io));
Ok(())
}
fn try_self(&self) -> io::Result<Rc<Cell<Option<TcpStream>>>> {
self.0
.upgrade()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "socket is gone"))
}
}