use std::task::{Context, Poll};
use std::{
cell::Cell, cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc, time,
};
use ntex::codec::{Decoder, Encoder};
use ntex::io::{DispatchItem, IoBoxed, IoRef, IoStatusUpdate, RecvError};
use ntex::service::{IntoService, Service};
use ntex::time::Seconds;
use ntex::util::{ready, Pool};
type Response<U> = <U as Encoder>::Item;
pin_project_lite::pin_project! {
pub(crate) struct Dispatcher<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder,
U: Decoder,
<U as Encoder>::Item: 'static,
{
codec: U,
service: S,
state: Rc<RefCell<DispatcherState<S, U>>>,
inner: DispatcherInner,
st: IoDispatcherState,
flags: Cell<Flags>,
pool: Pool,
#[pin]
response: Option<S::Future>,
response_idx: usize,
}
}
bitflags::bitflags! {
struct Flags: u8 {
const READY_ERR = 0b0001;
const IO_ERR = 0b0010;
}
}
struct DispatcherInner {
io: IoBoxed,
keepalive_timeout: Cell<time::Duration>,
}
struct DispatcherState<S: Service<DispatchItem<U>>, U: Encoder + Decoder> {
error: Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>,
base: usize,
queue: VecDeque<ServiceResult<Result<S::Response, S::Error>>>,
}
enum ServiceResult<T> {
Pending,
Ready(T),
}
impl<T> ServiceResult<T> {
fn take(&mut self) -> Option<T> {
let slf = std::mem::replace(self, ServiceResult::Pending);
match slf {
ServiceResult::Pending => None,
ServiceResult::Ready(result) => Some(result),
}
}
}
#[derive(Copy, Clone, Debug)]
enum IoDispatcherState {
Processing,
Stop,
Shutdown,
}
pub(crate) enum IoDispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
impl<S, U> From<S> for IoDispatcherError<S, U> {
fn from(err: S) -> Self {
IoDispatcherError::Service(err)
}
}
fn insert_flags(flags: &mut Cell<Flags>, f: Flags) {
let mut val = flags.get();
val.insert(f);
flags.set(val);
}
impl<S, U> Dispatcher<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + Clone + 'static,
<U as Encoder>::Item: 'static,
{
pub(crate) fn new<F: IntoService<S, DispatchItem<U>>>(
io: IoBoxed,
codec: U,
service: F,
) -> Self {
let keepalive_timeout = Cell::new(Seconds(30).into());
io.start_keepalive_timer(keepalive_timeout.get());
let state = Rc::new(RefCell::new(DispatcherState {
error: None,
base: 0,
queue: VecDeque::new(),
}));
let pool = io.memory_pool().pool();
Dispatcher {
state,
codec,
pool,
st: IoDispatcherState::Processing,
service: service.into_service(),
response: None,
response_idx: 0,
flags: Cell::new(Flags::empty()),
inner: DispatcherInner { io, keepalive_timeout },
}
}
pub(crate) fn keepalive_timeout(self, timeout: Seconds) -> Self {
let timeout = timeout.into();
self.inner.io.start_keepalive_timer(timeout);
self.inner.keepalive_timeout.set(timeout);
self
}
pub(crate) fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.io.set_disconnect_timeout(val.into());
self
}
}
impl DispatcherInner {
fn update_keepalive(&self) {
self.io.start_keepalive_timer(self.keepalive_timeout.get());
}
fn unregister_keepalive(&self) {
self.io.remove_keepalive_timer();
self.keepalive_timeout.set(time::Duration::ZERO);
}
}
impl<S, U> DispatcherState<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
fn handle_result(
&mut self,
item: Result<S::Response, S::Error>,
response_idx: usize,
write: &IoRef,
codec: &U,
wake: bool,
) {
let idx = response_idx.wrapping_sub(self.base);
if idx == 0 {
let _ = self.queue.pop_front();
self.base = self.base.wrapping_add(1);
match item {
Err(err) => {
self.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = write.encode(item, codec) {
self.error = Some(IoDispatcherError::Encoder(err));
}
}
Ok(None) => (),
}
while let Some(item) = self.queue.front_mut().and_then(|v| v.take()) {
let _ = self.queue.pop_front();
self.base = self.base.wrapping_add(1);
match item {
Err(err) => {
self.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = write.encode(item, codec) {
self.error = Some(IoDispatcherError::Encoder(err));
}
}
Ok(None) => (),
}
}
if wake && self.queue.is_empty() {
write.wake()
}
} else {
self.queue[idx] = ServiceResult::Ready(item);
}
}
}
impl<S, U> Future for Dispatcher<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + Clone + 'static,
<U as Encoder>::Item: 'static,
{
type Output = Result<(), S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
let io = &this.inner.io;
if let Some(fut) = this.response.as_mut().as_pin_mut() {
match fut.poll(cx) {
Poll::Pending => (),
Poll::Ready(item) => {
this.state.borrow_mut().handle_result(
item,
*this.response_idx,
io.as_ref(),
this.codec,
false,
);
this.response.set(None);
}
}
}
if this.pool.poll_ready(cx).is_pending() {
io.pause();
return Poll::Pending;
}
loop {
match this.st {
IoDispatcherState::Processing => {
match this.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let item = match ready!(io.poll_recv(this.codec, cx)) {
Ok(el) => {
this.inner.update_keepalive();
Some(DispatchItem::Item(el))
}
Err(RecvError::Stop) => {
log::trace!("dispatcher is instructed to stop");
*this.st = IoDispatcherState::Stop;
None
}
Err(RecvError::KeepAlive) => {
log::trace!("keepalive timeout");
*this.st = IoDispatcherState::Stop;
let mut inner = this.state.borrow_mut();
if inner.error.is_none() {
inner.error = Some(IoDispatcherError::KeepAlive);
}
Some(DispatchItem::KeepAliveTimeout)
}
Err(RecvError::WriteBackpressure) => {
if let Err(err) = ready!(io.poll_flush(cx, false)) {
*this.st = IoDispatcherState::Stop;
Some(DispatchItem::Disconnect(Some(err)))
} else {
continue;
}
}
Err(RecvError::Decoder(err)) => {
*this.st = IoDispatcherState::Stop;
Some(DispatchItem::DecoderError(err))
}
Err(RecvError::PeerGone(err)) => {
*this.st = IoDispatcherState::Stop;
Some(DispatchItem::Disconnect(err))
}
};
if let Some(item) = item {
if this.response.is_none() {
this.response.set(Some(this.service.call(item)));
let res =
this.response.as_mut().as_pin_mut().unwrap().poll(cx);
let mut inner = this.state.borrow_mut();
let response_idx =
inner.base.wrapping_add(inner.queue.len() as usize);
if let Poll::Ready(res) = res {
if inner.queue.is_empty() {
match res {
Err(err) => {
inner.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) =
io.encode(item, this.codec)
{
inner.error = Some(
IoDispatcherError::Encoder(err),
);
}
}
Ok(None) => (),
}
} else {
*this.response_idx = response_idx;
inner.queue.push_back(ServiceResult::Ready(res));
}
this.response.set(None);
} else {
*this.response_idx = response_idx;
inner.queue.push_back(ServiceResult::Pending);
}
} else {
let mut inner = this.state.borrow_mut();
let response_idx =
inner.base.wrapping_add(inner.queue.len() as usize);
inner.queue.push_back(ServiceResult::Pending);
let st = io.get_ref();
let codec = this.codec.clone();
let inner = this.state.clone();
let fut = this.service.call(item);
ntex::rt::spawn(async move {
let item = fut.await;
inner.borrow_mut().handle_result(
item,
response_idx,
&st,
&codec,
true,
);
});
}
}
}
Poll::Pending => {
log::trace!("service is not ready, pause read task");
io.pause();
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
log::trace!("service readiness check failed, stopping");
*this.st = IoDispatcherState::Stop;
this.state.borrow_mut().error =
Some(IoDispatcherError::Service(err));
insert_flags(&mut *this.flags, Flags::READY_ERR);
}
}
}
IoDispatcherState::Stop => {
this.inner.unregister_keepalive();
if !this.flags.get().contains(Flags::READY_ERR) {
let _ = this.service.poll_ready(cx);
}
if this.state.borrow().queue.is_empty() {
if io.poll_shutdown(cx).is_ready() {
log::trace!("io shutdown completed");
*this.st = IoDispatcherState::Shutdown;
continue;
}
} else if !this.flags.get().contains(Flags::IO_ERR) {
match ready!(this.inner.io.poll_status_update(cx)) {
IoStatusUpdate::PeerGone(_)
| IoStatusUpdate::Stop
| IoStatusUpdate::KeepAlive => {
insert_flags(&mut *this.flags, Flags::IO_ERR);
continue;
}
IoStatusUpdate::WriteBackpressure => {
if ready!(this.inner.io.poll_flush(cx, true)).is_err() {
insert_flags(&mut *this.flags, Flags::IO_ERR);
}
continue;
}
}
}
return Poll::Pending;
}
IoDispatcherState::Shutdown => {
let is_err = this.state.borrow().error.is_some();
return if this.service.poll_shutdown(cx, is_err).is_ready() {
log::trace!("service shutdown is completed, stop");
Poll::Ready(
if let Some(IoDispatcherError::Service(err)) =
this.state.borrow_mut().error.take()
{
Err(err)
} else {
Ok(())
},
)
} else {
Poll::Pending
};
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::cell::Cell;
use ntex::channel::condition::Condition;
use ntex::codec::BytesCodec;
use ntex::io as nio;
use ntex::testing::Io;
use ntex::time::{sleep, Millis};
use ntex::util::{Bytes, Ready};
use super::*;
impl<S, U> Dispatcher<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
pub(crate) fn new_debug<F: IntoService<S, DispatchItem<U>>>(
io: Io,
codec: U,
service: F,
) -> (Self, nio::IoRef) {
let keepalive_timeout = Cell::new(Seconds(30).into());
let io = nio::Io::new(io);
io.start_keepalive_timer(keepalive_timeout.get());
let rio = io.get_ref();
let state = Rc::new(RefCell::new(DispatcherState {
error: None,
base: 0,
queue: VecDeque::new(),
}));
(
Dispatcher {
state,
codec,
service: service.into_service(),
st: IoDispatcherState::Processing,
response: None,
response_idx: 0,
pool: io.memory_pool().pool(),
flags: Cell::new(Flags::empty()),
inner: DispatcherInner { keepalive_timeout, io: IoBoxed::from(io) },
},
rio,
)
}
}
#[ntex::test]
async fn test_basic() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, _) = Dispatcher::new_debug(
server,
BytesCodec,
ntex::service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
sleep(Millis(50)).await;
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}),
);
ntex::rt::spawn(async move {
let _ = disp.await;
});
sleep(Millis(25)).await;
client.write("GET /test HTTP/1\r\n\r\n");
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_ordering() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("test");
let condition = Condition::new();
let waiter = condition.wait();
let (disp, _) = Dispatcher::new_debug(
server,
BytesCodec,
ntex::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
let waiter = waiter.clone();
async move {
waiter.await;
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}
}),
);
ntex::rt::spawn(async move {
let _ = disp.await;
});
sleep(Millis(50)).await;
client.write("test");
sleep(Millis(50)).await;
client.write("test");
sleep(Millis(50)).await;
condition.notify();
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"testtesttest"));
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, io) = Dispatcher::new_debug(
server,
BytesCodec,
ntex::service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}),
);
ntex::rt::spawn(async move {
let _ = disp.disconnect_timeout(Seconds(1)).await;
});
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(io.encode(Bytes::from_static(b"test"), &BytesCodec).is_ok());
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
io.close();
sleep(Millis(1200)).await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_err_in_service() {
let (client, server) = Io::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, io) = Dispatcher::new_debug(
server,
BytesCodec,
ntex::service::fn_service(|_: DispatchItem<BytesCodec>| async move {
Err::<Option<Bytes>, _>(())
}),
);
ntex::rt::spawn(async move {
let _ = disp.await;
});
io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap();
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(!client.is_closed());
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_err_in_service_ready() {
let (client, server) = Io::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let counter = Rc::new(Cell::new(0));
struct Srv(Rc<Cell<usize>>);
impl Service<DispatchItem<BytesCodec>> for Srv {
type Response = Option<Response<BytesCodec>>;
type Error = ();
type Future = Ready<Option<Response<BytesCodec>>, ()>;
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.0.set(self.0.get() + 1);
Poll::Ready(Err(()))
}
fn call(&self, _: DispatchItem<BytesCodec>) -> Self::Future {
Ready::Ok(None)
}
}
let (disp, io) = Dispatcher::new_debug(server, BytesCodec, Srv(counter.clone()));
io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &mut BytesCodec).unwrap();
ntex::rt::spawn(async move {
let _ = disp.await;
});
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(client.is_closed());
client.close().await;
assert!(client.is_server_dropped());
assert_eq!(counter.get(), 1);
}
}