use std::{cell::RefCell, marker::PhantomData, num, rc::Rc, task::Context};
use ntex_bytes::ByteString;
use ntex_service::cfg::{Cfg, SharedCfg};
use ntex_service::{self as service, Pipeline, Service, ServiceCtx, ServiceFactory};
use ntex_util::services::buffer::{BufferService, BufferServiceError};
use ntex_util::services::inflight::InFlightService;
use ntex_util::{HashMap, HashSet, future::join};
use crate::error::{
DecodeError, DispatcherError, MqttError, PayloadError, ProtocolError, SpecViolation,
};
use crate::payload::{Payload, PayloadStatus};
use crate::{MqttServiceConfig, types::QoS};
use super::codec::{self, Decoded, DisconnectReasonCode, Encoded, Packet};
use super::control::{Pkt, ProtocolMessage, ProtocolMessageAck};
use super::publish::{Publish, PublishAck};
use super::{Session, ToPublishAck, shared::Ack, shared::MqttShared};
pub(super) fn factory<St, T, P, E, InitErr>(
publish: T,
control: P,
) -> impl ServiceFactory<
Decoded,
(SharedCfg, Session<St>),
Response = Option<Encoded>,
Error = DispatcherError<E>,
InitError = MqttError<InitErr>,
>
where
St: 'static,
E: From<P::Error> + 'static,
T: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
T::Error: ToPublishAck<Error = E>,
P: ServiceFactory<ProtocolMessage, Session<St>, Response = ProtocolMessageAck> + 'static,
InitErr: From<T::InitError> + From<P::InitError>,
{
let factories = Rc::new((publish, control));
service::fn_factory_with_config(async move |(cfg, ses): (SharedCfg, Session<St>)| {
let cfg: Cfg<MqttServiceConfig> = cfg.get();
let sink = ses.sink().shared();
let (publish, control) =
join(factories.0.create(ses.clone()), factories.1.create(ses)).await;
let publish = publish.map_err(|e| MqttError::Service(InitErr::from(e)))?;
let control = control
.map_err(|e| MqttError::Service(InitErr::from(e)))?
.map_err(|e| DispatcherError::Service(e.into()));
let control = Pipeline::new(
BufferService::new(
16,
InFlightService::new(1, control),
)
.map_err(|err| match err {
BufferServiceError::Service(e) => e,
BufferServiceError::RequestCanceled => {
DispatcherError::Protocol(ProtocolError::ReadTimeout)
}
}),
);
Ok(Dispatcher::new(sink, publish, control, cfg))
})
}
impl crate::inflight::SizedRequest for Decoded {
fn size(&self) -> u32 {
if let Decoded::Packet(_, size) | Decoded::Publish(_, _, size) = self {
*size
} else {
0
}
}
fn is_publish(&self) -> bool {
matches!(self, Decoded::Publish(..))
}
fn is_chunk(&self) -> bool {
matches!(self, Decoded::PayloadChunk(..))
}
}
pub(crate) struct Dispatcher<T, C, E> {
publish: T,
inner: Rc<Inner<C>>,
cfg: Cfg<MqttServiceConfig>,
_t: PhantomData<E>,
}
struct Inner<C> {
control: Pipeline<C>,
sink: Rc<MqttShared>,
info: RefCell<PublishInfo>,
}
struct PublishInfo {
inflight: HashSet<num::NonZeroU16>,
aliases: HashMap<num::NonZeroU16, ByteString>,
}
impl<T, C, E> Dispatcher<T, C, E>
where
T: Service<Publish, Response = PublishAck>,
T::Error: ToPublishAck<Error = E>,
C: Service<ProtocolMessage, Response = ProtocolMessageAck, Error = DispatcherError<E>>,
{
fn new(
sink: Rc<MqttShared>,
publish: T,
control: Pipeline<C>,
cfg: Cfg<MqttServiceConfig>,
) -> Self {
Self {
cfg,
publish,
_t: PhantomData,
inner: Rc::new(Inner {
sink,
control,
info: RefCell::new(PublishInfo {
aliases: HashMap::default(),
inflight: HashSet::default(),
}),
}),
}
}
fn tag(&self) -> &'static str {
self.inner.sink.tag()
}
}
impl<T, C, E> Service<Decoded> for Dispatcher<T, C, E>
where
T: Service<Publish, Response = PublishAck> + 'static,
T::Error: ToPublishAck<Error = E>,
C: Service<ProtocolMessage, Response = ProtocolMessageAck, Error = DispatcherError<E>>
+ 'static,
{
type Response = Option<Encoded>;
type Error = DispatcherError<E>;
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
let (res1, res2) = join(ctx.ready(&self.publish), self.inner.control.ready()).await;
if (res1.is_err() || res2.is_err())
&& let Some(pl) = self.inner.sink.payload.take()
{
self.inner.sink.payload.set(Some(pl.clone()));
if pl.ready().await != PayloadStatus::Ready {
self.inner.sink.force_close();
}
}
res1.map_err(|e| DispatcherError::Service(e.into_error()))?;
res2?;
Ok(())
}
fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
self.publish.poll(cx).map_err(|e| DispatcherError::Service(e.into_error()))?;
self.inner.control.poll(cx)
}
async fn shutdown(&self) {
log::trace!("{}: Shutdown v5 dispatcher", self.tag());
self.inner.sink.drop_payload(&PayloadError::Disconnected);
self.inner.sink.drop_sink(true);
self.publish.shutdown().await;
self.inner.control.shutdown().await;
}
#[allow(clippy::too_many_lines, clippy::await_holding_refcell_ref)]
async fn call(
&self,
request: Decoded,
ctx: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
log::trace!("{}: Dispatch v5 packet: {:#?}", self.tag(), request);
match request {
Decoded::Publish(mut publish, payload, size) => {
let info = self.inner.as_ref();
let packet_id = publish.packet_id;
if publish.topic.contains(['#', '+']) {
return Err(SpecViolation::Pub_3_3_2_2.into());
}
{
let mut inner = info.info.borrow_mut();
let state = &self.inner.sink;
if let Some(pid) = packet_id {
let receive_max = state.receive_max();
if receive_max != 0 && inner.inflight.len() >= receive_max as usize {
log::trace!(
"{}: Receive maximum exceeded: max: {} in-flight: {}",
self.tag(),
receive_max,
inner.inflight.len()
);
return Err(SpecViolation::Pub_3_3_4_7.into());
}
if publish.qos > state.max_qos() {
log::trace!(
"{}: Max allowed QoS is violated, max {:?} provided {:?}",
self.tag(),
state.max_qos(),
publish.qos
);
return Err(SpecViolation::Connack_3_2_2_11.into());
}
if publish.retain && !state.codec.retain_available() {
log::trace!("{}: Retain is not available but is set", self.tag());
return Err(SpecViolation::Connack_3_2_2_14.into());
}
if !inner.inflight.insert(pid) {
let _ = self.inner.sink.encode_packet(codec::Packet::PublishAck(
codec::PublishAck {
packet_id: pid,
reason_code: codec::PublishAckReason::PacketIdentifierInUse,
..Default::default()
},
));
return Ok(None);
}
}
if let Some(alias) = publish.properties.topic_alias {
if publish.topic.is_empty() {
if let Some(aliased_topic) = inner.aliases.get(&alias) {
publish.topic = aliased_topic.clone();
} else {
return Err(ProtocolError::violation(
DisconnectReasonCode::TopicAliasInvalid,
"Unknown topic alias",
)
.into());
}
} else {
match inner.aliases.entry(alias) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if entry.get().as_str() != publish.topic.as_str() {
let mut topic = publish.topic.clone();
topic.trimdown();
entry.insert(topic);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
if alias.get() > state.topic_alias_max() {
return Err(SpecViolation::Connack_3_2_2_17.into());
}
let mut topic = publish.topic.clone();
topic.trimdown();
entry.insert(topic);
}
}
}
}
if state.is_closed()
&& self
.cfg
.handle_qos_after_disconnect
.is_none_or(|max_qos| publish.qos > max_qos)
{
return Ok(None);
}
}
let payload = if publish.payload_size == payload.len() as u32 {
Payload::from_bytes(payload)
} else {
let (pl, sender) =
Payload::from_stream(payload, self.cfg.max_payload_buffer_size);
self.inner.sink.payload.set(Some(sender));
pl
};
publish_fn(
&self.publish,
Publish::new(publish, payload, size),
packet_id.map_or(0, num::NonZero::get),
info,
ctx,
)
.await
}
Decoded::PayloadChunk(buf, eof) => {
if let Some(pl) = self.inner.sink.payload.take() {
pl.feed_data(buf);
if eof {
pl.feed_eof();
} else {
self.inner.sink.payload.set(Some(pl));
}
Ok(None)
} else {
Err(ProtocolError::Decode(DecodeError::UnexpectedPayload).into())
}
}
Decoded::Packet(Packet::PublishAck(packet), _) => {
self.inner.sink.pkt_ack(Ack::Publish(packet))?;
Ok(None)
}
Decoded::Packet(Packet::PublishReceived(pkt), _) => {
self.inner.sink.pkt_ack(Ack::Receive(pkt))?;
Ok(None)
}
Decoded::Packet(Packet::PublishRelease(ack), size) => {
if self.inner.info.borrow().inflight.contains(&ack.packet_id) {
self.inner.control(ProtocolMessage::pubrel(ack, size)).await
} else {
Ok(Some(Encoded::Packet(codec::Packet::PublishComplete(
codec::PublishAck2 {
packet_id: ack.packet_id,
reason_code: codec::PublishAck2Reason::PacketIdNotFound,
properties: codec::UserProperties::default(),
reason_string: None,
},
))))
}
}
Decoded::Packet(Packet::PublishComplete(pkt), _) => {
self.inner.sink.pkt_ack(Ack::Complete(pkt))?;
Ok(None)
}
Decoded::Packet(Packet::Auth(pkt), size) => {
if self.inner.sink.is_closed() {
Ok(None)
} else {
self.inner.control(ProtocolMessage::auth(pkt, size)).await
}
}
Decoded::Packet(Packet::PingRequest, _) => {
self.inner.control(ProtocolMessage::ping()).await
}
Decoded::Packet(Packet::Disconnect(pkt), size) => {
self.inner.sink.set_disconnect_recv();
if let Some(val) = pkt.session_expiry_interval_secs
&& val > 0
&& self.inner.sink.is_zero_session_expiry()
{
Err(SpecViolation::Disconnect_3_14_2_22.into())
} else {
self.inner.sink.is_disconnect_sent();
self.inner.sink.close(None);
self.inner.control(ProtocolMessage::remote_disconnect(pkt, size)).await
}
}
Decoded::Packet(Packet::Subscribe(pkt), size) => {
if self.inner.sink.is_closed() {
Ok(None)
} else if pkt.topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
Err(SpecViolation::Subs_4_7_1.into())
} else if pkt.id.is_some() && !self.inner.sink.codec.sub_ids_available() {
log::trace!(
"{}: Subscription Identifiers are not supported but was set",
self.tag()
);
Err(SpecViolation::Connack_3_2_2_3_12.into())
} else if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
let _ = self.inner.sink.encode_packet(codec::Packet::SubscribeAck(
codec::SubscribeAck {
packet_id: pkt.packet_id,
status: pkt
.topic_filters
.iter()
.map(|_| codec::SubscribeAckReason::PacketIdentifierInUse)
.collect(),
properties: codec::UserProperties::new(),
reason_string: None,
},
));
Ok(None)
} else {
let id = pkt.packet_id;
self.inner
.control_pkt(ProtocolMessage::subscribe(pkt, size), id.get())
.await
}
}
Decoded::Packet(Packet::Unsubscribe(pkt), size) => {
if self.inner.sink.is_closed() {
Ok(None)
} else if pkt.topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
Err(SpecViolation::Subs_4_7_1.into())
} else if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
let _ = self.inner.sink.encode_packet(codec::Packet::UnsubscribeAck(
codec::UnsubscribeAck {
packet_id: pkt.packet_id,
status: pkt
.topic_filters
.iter()
.map(|_| codec::UnsubscribeAckReason::PacketIdentifierInUse)
.collect(),
properties: codec::UserProperties::new(),
reason_string: None,
},
));
Ok(None)
} else {
let id = pkt.packet_id;
self.inner
.control_pkt(ProtocolMessage::unsubscribe(pkt, size), id.get())
.await
}
}
Decoded::Packet(_, _) => Ok(None),
}
}
}
impl<C> Inner<C> {
async fn control<E>(
&self,
pkt: ProtocolMessage,
) -> Result<Option<Encoded>, DispatcherError<E>>
where
C: Service<ProtocolMessage, Response = ProtocolMessageAck, Error = DispatcherError<E>>,
{
self.control_pkt(pkt, 0).await
}
async fn control_pkt<E>(
&self,
pkt: ProtocolMessage,
packet_id: u16,
) -> Result<Option<Encoded>, DispatcherError<E>>
where
C: Service<ProtocolMessage, Response = ProtocolMessageAck, Error = DispatcherError<E>>,
{
let result = match self.control.call(pkt).await {
Ok(result) => {
if let Some(id) = num::NonZeroU16::new(packet_id) {
self.info.borrow_mut().inflight.remove(&id);
}
result
}
Err(err) => {
self.sink.drop_sink(false);
self.sink.drop_payload(&PayloadError::Service);
return Err(err);
}
};
let response = match result.packet {
Pkt::Packet(pkt) => Ok(Some(Encoded::Packet(pkt))),
Pkt::Disconnect(pkt) => {
if self.sink.is_disconnect_sent() {
Ok(None)
} else {
Ok(Some(Encoded::Packet(codec::Packet::from(pkt))))
}
}
Pkt::None => Ok(None),
};
if result.disconnect {
self.sink.drop_sink(true);
self.sink.drop_payload(&PayloadError::Service);
}
response
}
}
async fn publish_fn<'f, T, C, E>(
publish: &T,
pkt: Publish,
packet_id: u16,
inner: &'f Inner<C>,
ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
) -> Result<Option<Encoded>, DispatcherError<E>>
where
T: Service<Publish, Response = PublishAck>,
T::Error: ToPublishAck<Error = E>,
C: Service<ProtocolMessage, Response = ProtocolMessageAck, Error = DispatcherError<E>>,
{
let qos2 = pkt.qos() == QoS::ExactlyOnce;
let ack = match ctx.call(publish, pkt).await {
Ok(ack) => ack,
Err(e) => {
if packet_id != 0 {
match e.try_ack() {
Ok(ack) => ack,
Err(e) => {
return Err(DispatcherError::Service(e));
}
}
} else {
return Err(DispatcherError::Service(e.into_error()));
}
}
};
if let Some(id) = num::NonZeroU16::new(packet_id) {
let ack = if qos2 {
codec::Packet::PublishReceived(codec::PublishAck {
packet_id: id,
reason_code: ack.reason_code,
reason_string: ack.reason_string,
properties: ack.properties,
})
} else {
inner.info.borrow_mut().inflight.remove(&id);
codec::Packet::PublishAck(codec::PublishAck {
packet_id: id,
reason_code: ack.reason_code,
reason_string: ack.reason_string,
properties: ack.properties,
})
};
Ok(Some(Encoded::Packet(ack)))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use std::num::{NonZeroU16, NonZeroU32};
use ntex_bytes::{ByteString, Bytes};
use ntex_io::{Io, testing::IoTest};
use ntex_service::{cfg::SharedCfg, fn_service};
use super::*;
use crate::{error, v5::codec};
#[derive(Debug)]
struct TestError;
impl From<()> for TestError {
fn from((): ()) -> Self {
TestError
}
}
impl TryFrom<TestError> for PublishAck {
type Error = TestError;
fn try_from(err: TestError) -> Result<Self, Self::Error> {
Err(err)
}
}
#[ntex::test]
async fn test_spec_violations() {
let cfg: SharedCfg = SharedCfg::new("DBG")
.add(MqttServiceConfig::new().set_max_qos(QoS::AtLeastOnce))
.into();
let io = Io::new(IoTest::create().0, cfg.clone());
let codec = codec::Codec::default();
codec.set_retain_available(false);
codec.set_sub_ids_available(false);
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, Rc::default()));
shared.set_topic_alias_max(1);
let disp = Pipeline::new(Dispatcher::new(
shared.clone(),
fn_service(async |msg: Publish| Ok::<_, TestError>(msg.ack())),
Pipeline::new(fn_service(async |msg: ProtocolMessage| {
Ok::<_, DispatcherError<TestError>>(msg.ack())
})),
cfg.get(),
));
let err = disp
.call(Decoded::Publish(
codec::Publish {
retain: true,
qos: QoS::AtLeastOnce,
packet_id: NonZeroU16::new(1),
..Default::default()
},
Bytes::new(),
999,
))
.await
.err()
.unwrap();
let DispatcherError::Protocol(ProtocolError::ProtocolViolation(err)) = err else {
panic!()
};
assert_eq!(
err.inner,
error::ViolationInner::Spec(error::SpecViolation::Connack_3_2_2_14)
);
let mut pkt = codec::Publish::default();
pkt.properties.topic_alias = NonZeroU16::new(1);
let err = disp.call(Decoded::Publish(pkt, Bytes::new(), 999)).await.err().unwrap();
let DispatcherError::Protocol(ProtocolError::ProtocolViolation(err)) = err else {
panic!()
};
assert_eq!(
err.inner,
error::ViolationInner::Common {
reason: DisconnectReasonCode::TopicAliasInvalid,
message: "Unknown topic alias"
}
);
let mut pkt = codec::Publish {
packet_id: NonZeroU16::new(1),
topic: ByteString::from_static("test"),
..Default::default()
};
pkt.properties.topic_alias = NonZeroU16::new(1);
let res = disp.call(Decoded::Publish(pkt, Bytes::new(), 999)).await;
assert!(res.is_ok());
let mut pkt = codec::Publish {
packet_id: NonZeroU16::new(2),
topic: ByteString::from_static("test2"),
..Default::default()
};
pkt.properties.topic_alias = NonZeroU16::new(2);
let err = disp.call(Decoded::Publish(pkt, Bytes::new(), 999)).await.err().unwrap();
let DispatcherError::Protocol(ProtocolError::ProtocolViolation(err)) = err else {
panic!()
};
assert_eq!(
err.inner,
error::ViolationInner::Spec(error::SpecViolation::Connack_3_2_2_17)
);
let pkt = disp
.call(Decoded::Packet(
Packet::PublishRelease(codec::PublishAck2 {
packet_id: NonZeroU16::new(100).unwrap(),
reason_code: codec::PublishAck2Reason::Success,
properties: codec::UserProperties::default(),
reason_string: None,
}),
999,
))
.await
.ok()
.unwrap()
.unwrap();
let Encoded::Packet(Packet::PublishComplete(pkt)) = pkt else { panic!() };
assert_eq!(pkt.reason_code, codec::PublishAck2Reason::PacketIdNotFound);
let err = disp
.call(Decoded::Packet(
Packet::Subscribe(codec::Subscribe {
packet_id: NonZeroU16::new(1).unwrap(),
id: None,
user_properties: codec::UserProperties::default(),
topic_filters: vec![(
ByteString::new(),
codec::SubscriptionOptions::default(),
)],
}),
999,
))
.await
.err()
.unwrap();
let DispatcherError::Protocol(ProtocolError::ProtocolViolation(err)) = err else {
panic!()
};
assert_eq!(err.inner, error::ViolationInner::Spec(error::SpecViolation::Subs_4_7_1));
let err = disp
.call(Decoded::Packet(
Packet::Subscribe(codec::Subscribe {
packet_id: NonZeroU16::new(1).unwrap(),
id: NonZeroU32::new(1),
user_properties: codec::UserProperties::default(),
topic_filters: vec![(
ByteString::from_static("test"),
codec::SubscriptionOptions::default(),
)],
}),
999,
))
.await
.err()
.unwrap();
let DispatcherError::Protocol(ProtocolError::ProtocolViolation(err)) = err else {
panic!()
};
assert_eq!(
err.inner,
error::ViolationInner::Spec(error::SpecViolation::Connack_3_2_2_3_12)
);
let err = disp
.call(Decoded::Packet(
Packet::Unsubscribe(codec::Unsubscribe {
packet_id: NonZeroU16::new(1).unwrap(),
user_properties: codec::UserProperties::default(),
topic_filters: vec![ByteString::new()],
}),
999,
))
.await
.err()
.unwrap();
let DispatcherError::Protocol(ProtocolError::ProtocolViolation(err)) = err else {
panic!()
};
assert_eq!(err.inner, error::ViolationInner::Spec(error::SpecViolation::Subs_4_7_1));
}
}