use std::cell::RefCell;
use std::task::{Context, Poll};
use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc};
use ntex::io::DispatchItem;
use ntex::service::{Pipeline, Service, ServiceCall, ServiceCtx};
use ntex::util::{inflight::InFlightService, BoxFuture, Either, HashSet, Ready};
use crate::error::{HandshakeError, MqttError, ProtocolError};
use crate::v3::shared::{Ack, MqttShared};
use crate::v3::{codec, control::ControlResultKind, publish::Publish};
use super::control::{ControlMessage, ControlResult};
pub(super) fn create_dispatcher<T, C, E>(
sink: Rc<MqttShared>,
inflight: usize,
publish: T,
control: C,
) -> impl Service<DispatchItem<Rc<MqttShared>>, Response = Option<codec::Packet>, Error = MqttError<E>>
where
E: 'static,
T: Service<Publish, Response = Either<(), Publish>, Error = E> + 'static,
C: Service<ControlMessage<E>, Response = ControlResult, Error = E> + 'static,
{
InFlightService::new(
inflight,
Dispatcher::new(sink, publish, control.map_err(MqttError::Service)),
)
}
pub(crate) struct Dispatcher<T, C: Service<ControlMessage<E>>, E> {
publish: T,
shutdown: RefCell<Option<BoxFuture<'static, ()>>>,
inner: Rc<Inner<C>>,
_t: PhantomData<E>,
}
struct Inner<C> {
control: C,
sink: Rc<MqttShared>,
inflight: RefCell<HashSet<NonZeroU16>>,
}
impl<T, C, E> Dispatcher<T, C, E>
where
T: Service<Publish, Response = Either<(), Publish>, Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
pub(crate) fn new(sink: Rc<MqttShared>, publish: T, control: C) -> Self {
Self {
publish,
shutdown: RefCell::new(None),
inner: Rc::new(Inner { sink, control, inflight: RefCell::new(HashSet::default()) }),
_t: PhantomData,
}
}
}
impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
where
T: Service<Publish, Response = Either<(), Publish>, Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>> + 'static,
E: 'static,
{
type Response = Option<codec::Packet>;
type Error = MqttError<E>;
type Future<'f> = Either<
PublishResponse<'f, T, C, E>,
Either<Ready<Self::Response, MqttError<E>>, ControlResponse<'f, C, E>>,
> where Self: 'f;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx)?;
if res1.is_pending() || res2.is_pending() {
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
let mut shutdown = self.shutdown.borrow_mut();
if !shutdown.is_some() {
self.inner.sink.close();
let inner = self.inner.clone();
*shutdown = Some(Box::pin(async move {
let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await;
}));
}
let res0 = shutdown.as_mut().expect("guard above").as_mut().poll(cx);
let res1 = self.publish.poll_shutdown(cx);
let res2 = self.inner.control.poll_shutdown(cx);
if res0.is_pending() || res1.is_pending() || res2.is_pending() {
Poll::Pending
} else {
Poll::Ready(())
}
}
fn call<'a>(
&'a self,
packet: DispatchItem<Rc<MqttShared>>,
ctx: ServiceCtx<'a, Self>,
) -> Self::Future<'a> {
log::trace!("Dispatch packet: {:#?}", packet);
match packet {
DispatchItem::Item((codec::Packet::Publish(publish), size)) => {
let inner = self.inner.as_ref();
let packet_id = publish.packet_id;
if let Some(pid) = packet_id {
if !inner.inflight.borrow_mut().insert(pid) {
log::trace!("Duplicated packet id for publish packet: {:?}", pid);
return Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(
ProtocolError::generic_violation("PUBLISH received with packet id that is already in use [MQTT-2.2.1-3]"))
))));
}
}
Either::Left(PublishResponse {
fut: ctx.call(&self.publish, Publish::new(publish, size)),
fut_c: None,
packet_id,
inner,
ctx,
})
}
DispatchItem::Item((codec::Packet::PublishAck { packet_id }, _)) => {
if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) {
Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(e),
))))
} else {
Either::Right(Either::Left(Ready::Ok(None)))
}
}
DispatchItem::Item((codec::Packet::SubscribeAck { packet_id, status }, _)) => {
if let Err(e) = self.inner.sink.pkt_ack(Ack::Subscribe { packet_id, status }) {
Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(e),
))))
} else {
Either::Right(Either::Left(Ready::Ok(None)))
}
}
DispatchItem::Item((codec::Packet::UnsubscribeAck { packet_id }, _)) => {
if let Err(e) = self.inner.sink.pkt_ack(Ack::Unsubscribe(packet_id)) {
Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(e),
))))
} else {
Either::Right(Either::Left(Ready::Ok(None)))
}
}
DispatchItem::Item((
pkt @ (codec::Packet::PingRequest
| codec::Packet::Disconnect
| codec::Packet::Subscribe { .. }
| codec::Packet::Unsubscribe { .. }),
_,
)) => Either::Right(Either::Left(Ready::Err(
HandshakeError::Protocol(ProtocolError::unexpected_packet(
pkt.packet_type(),
"Packet of the type is not expected from server",
))
.into(),
))),
DispatchItem::Item((pkt, _)) => {
log::debug!("Unsupported packet: {:?}", pkt);
Either::Right(Either::Left(Ready::Ok(None)))
}
DispatchItem::EncoderError(err) => {
Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::Encode(err)),
&self.inner,
ctx,
)))
}
DispatchItem::DecoderError(err) => {
Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::Decode(err)),
&self.inner,
ctx,
)))
}
DispatchItem::Disconnect(err) => Either::Right(Either::Right(
ControlResponse::new(ControlMessage::peer_gone(err), &self.inner, ctx),
)),
DispatchItem::KeepAliveTimeout => {
Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::KeepAliveTimeout),
&self.inner,
ctx,
)))
}
DispatchItem::ReadTimeout => Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::ReadTimeout),
&self.inner,
ctx,
))),
DispatchItem::WBackPressureEnabled => {
self.inner.sink.enable_wr_backpressure();
Either::Right(Either::Left(Ready::Ok(None)))
}
DispatchItem::WBackPressureDisabled => {
self.inner.sink.disable_wr_backpressure();
Either::Right(Either::Left(Ready::Ok(None)))
}
}
}
}
pin_project_lite::pin_project! {
pub(crate) struct PublishResponse<'f, T: Service<Publish>, C: Service<ControlMessage<E>>, E>
where T: 'f
{
#[pin]
fut: ServiceCall<'f, T, Publish>,
#[pin]
fut_c: Option<ControlResponse<'f, C, E>>,
packet_id: Option<NonZeroU16>,
inner: &'f Inner<C>,
ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
}
}
impl<'f, T, C, E> Future for PublishResponse<'f, T, C, E>
where
T: Service<Publish, Response = Either<(), Publish>, Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(fut) = self.as_mut().project().fut_c.as_pin_mut() {
return fut.poll(cx);
}
let mut this = self.as_mut().project();
let res = match this.fut.poll(cx) {
Poll::Ready(Ok(item)) => item,
Poll::Ready(Err(e)) => {
this.fut_c.set(Some(ControlResponse::new(
ControlMessage::error(e),
*this.inner,
*this.ctx,
)));
return self.poll(cx);
}
Poll::Pending => return Poll::Pending,
};
match res {
Either::Left(_) => {
log::trace!("Publish result for packet {:?} is ready", this.packet_id);
if let Some(packet_id) = this.packet_id {
this.inner.inflight.borrow_mut().remove(packet_id);
Poll::Ready(Ok(Some(codec::Packet::PublishAck { packet_id: *packet_id })))
} else {
Poll::Ready(Ok(None))
}
}
Either::Right(pkt) => {
this.fut_c.set(Some(ControlResponse::new(
ControlMessage::publish(pkt.into_inner()),
*this.inner,
*this.ctx,
)));
self.poll(cx)
}
}
}
}
pin_project_lite::pin_project! {
pub(crate) struct ControlResponse<'f, C: Service<ControlMessage<E>>, E>
where C: 'f, E: 'f
{
#[pin]
fut: ServiceCall<'f, C, ControlMessage<E>>,
inner: &'f Inner<C>,
}
}
impl<'f, C, E> ControlResponse<'f, C, E>
where
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
fn new<T>(
msg: ControlMessage<E>,
inner: &'f Inner<C>,
ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
) -> Self {
Self { fut: ctx.call(&inner.control, msg), inner }
}
}
impl<'f, C, E> Future for ControlResponse<'f, C, E>
where
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let packet = match this.fut.poll(cx)? {
Poll::Ready(item) => match item.result {
ControlResultKind::Ping => Some(codec::Packet::PingResponse),
ControlResultKind::PublishAck(id) => {
this.inner.inflight.borrow_mut().remove(&id);
Some(codec::Packet::PublishAck { packet_id: id })
}
ControlResultKind::Subscribe(_) => unreachable!(),
ControlResultKind::Unsubscribe(_) => unreachable!(),
ControlResultKind::Disconnect => {
this.inner.sink.close();
None
}
ControlResultKind::Closed | ControlResultKind::Nothing => None,
},
Poll::Pending => return Poll::Pending,
};
Poll::Ready(Ok(packet))
}
}
#[cfg(test)]
mod tests {
use ntex::time::{sleep, Seconds};
use ntex::util::{lazy, ByteString, Bytes};
use ntex::{io::Io, service::fn_service, testing::IoTest};
use std::rc::Rc;
use super::*;
use crate::v3::{codec::Codec, MqttSink, QoS};
#[ntex::test]
async fn test_dup_packet_id() {
let io = Io::new(IoTest::create().0);
let codec = codec::Codec::default();
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
shared.clone(),
fn_service(|_| async {
sleep(Seconds(10)).await;
Ok(Either::Left(()))
}),
fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })),
));
let mut f = Box::pin(disp.call(DispatchItem::Item((
codec::Packet::Publish(codec::Publish {
dup: false,
retain: false,
qos: QoS::AtLeastOnce,
topic: ByteString::new(),
packet_id: NonZeroU16::new(1),
payload: Bytes::new(),
}),
999,
))));
let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;
let f = Box::pin(disp.call(DispatchItem::Item((
codec::Packet::Publish(codec::Publish {
dup: false,
retain: false,
qos: QoS::AtLeastOnce,
topic: ByteString::new(),
packet_id: NonZeroU16::new(1),
payload: Bytes::new(),
}),
999,
))));
let err = f.await.err().unwrap();
match err {
MqttError::Handshake(HandshakeError::Protocol(msg)) => {
assert!(format!("{}", msg)
.contains("PUBLISH received with packet id that is already in use"))
}
_ => panic!(),
}
}
#[ntex::test]
async fn test_wr_backpressure() {
let io = Io::new(IoTest::create().0);
let codec = Codec::default();
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
shared.clone(),
fn_service(|_| Ready::Ok(Either::Left(()))),
fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })),
));
let sink = MqttSink::new(shared.clone());
assert!(!sink.is_ready());
shared.set_cap(1);
assert!(sink.is_ready());
assert!(shared.wait_readiness().is_none());
disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
assert!(!sink.is_ready());
let rx = shared.wait_readiness();
let rx2 = shared.wait_readiness().unwrap();
assert!(rx.is_some());
let rx = rx.unwrap();
disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
}
}