use std::{fmt, io, marker::PhantomData, ptr};
use ntex_bytes::ByteString;
use super::codec::{self, DisconnectReasonCode, QoS, UserProperties};
use crate::error;
#[derive(Debug)]
pub enum Control<E> {
Protocol(CtlFrame),
Flow(CtlFlow),
Stop(CtlReason<E>),
Shutdown(Shutdown),
}
#[derive(Debug)]
pub enum CtlFrame {
Auth(Auth),
PublishRelease(PublishRelease),
Subscribe(Subscribe),
Unsubscribe(Unsubscribe),
Disconnect(Disconnect),
}
#[derive(Debug)]
pub enum CtlFlow {
Ping(Ping),
WrBackpressure(WrBackpressure),
}
#[derive(Debug)]
pub enum CtlReason<E> {
Error(Error<E>),
ProtocolError(ProtocolError),
PeerGone(PeerGone),
}
#[derive(Debug)]
pub(crate) enum Pkt {
None,
Disconnect(codec::Disconnect),
Packet(codec::Packet),
}
#[derive(Debug)]
pub struct ControlAck {
pub(crate) packet: Pkt,
pub(crate) disconnect: bool,
}
impl<E> Control<E> {
pub(crate) fn is_protocol(&self) -> bool {
matches!(self, Control::Protocol(_))
}
#[doc(hidden)]
pub fn auth(pkt: codec::Auth, size: u32) -> Self {
Control::Protocol(CtlFrame::Auth(Auth { pkt, size }))
}
pub(crate) fn pubrel(pkt: codec::PublishAck2, size: u32) -> Self {
Control::Protocol(CtlFrame::PublishRelease(PublishRelease::new(pkt, size)))
}
#[doc(hidden)]
pub fn subscribe(pkt: codec::Subscribe, size: u32) -> Self {
Control::Protocol(CtlFrame::Subscribe(Subscribe::new(pkt, size)))
}
#[doc(hidden)]
pub fn unsubscribe(pkt: codec::Unsubscribe, size: u32) -> Self {
Control::Protocol(CtlFrame::Unsubscribe(Unsubscribe::new(pkt, size)))
}
#[doc(hidden)]
pub fn ping() -> Self {
Control::Flow(CtlFlow::Ping(Ping))
}
#[doc(hidden)]
pub fn remote_disconnect(pkt: codec::Disconnect, size: u32) -> Self {
Control::Protocol(CtlFrame::Disconnect(Disconnect(pkt, size)))
}
pub(super) const fn wr_backpressure(enabled: bool) -> Self {
Control::Flow(CtlFlow::WrBackpressure(WrBackpressure(enabled)))
}
pub(super) fn error(err: E) -> Self {
Control::Stop(CtlReason::Error(Error::new(err)))
}
pub(super) fn peer_gone(err: Option<io::Error>) -> Self {
Control::Stop(CtlReason::PeerGone(PeerGone(err)))
}
pub(super) fn spec(err: error::SpecViolation) -> Self {
Control::Stop(CtlReason::ProtocolError(ProtocolError::new(error::ProtocolError::spec(
err,
))))
}
pub(super) fn proto_error(err: error::ProtocolError) -> Self {
Control::Stop(CtlReason::ProtocolError(ProtocolError::new(err)))
}
pub(super) const fn shutdown() -> Self {
Control::Shutdown(Shutdown)
}
pub fn disconnect(&self) -> ControlAck {
let pkt = codec::Disconnect {
reason_code: codec::DisconnectReasonCode::NormalDisconnection,
session_expiry_interval_secs: None,
server_reference: None,
reason_string: None,
user_properties: Vec::default(),
};
ControlAck { packet: Pkt::Disconnect(pkt), disconnect: true }
}
pub fn disconnect_with(&self, pkt: codec::Disconnect) -> ControlAck {
ControlAck { packet: Pkt::Disconnect(pkt), disconnect: true }
}
pub fn ack(self) -> ControlAck {
match self {
Control::Protocol(CtlFrame::Auth(_)) => super::disconnect(error::ERR_AUTH_NOT_SUP),
Control::Protocol(CtlFrame::PublishRelease(msg)) => msg.ack(),
Control::Protocol(CtlFrame::Subscribe(msg)) => msg.ack(),
Control::Protocol(CtlFrame::Unsubscribe(msg)) => msg.ack(),
Control::Protocol(CtlFrame::Disconnect(msg)) => msg.ack(),
Control::Flow(CtlFlow::Ping(msg)) => msg.ack(),
Control::Flow(CtlFlow::WrBackpressure(msg)) => msg.ack(),
Control::Stop(CtlReason::Error(_)) => super::disconnect(error::ERR_CTL_NOT_SUP),
Control::Stop(CtlReason::ProtocolError(msg)) => msg.ack(),
Control::Stop(CtlReason::PeerGone(msg)) => msg.ack(),
Control::Shutdown(msg) => msg.ack(),
}
}
}
impl CtlFlow {
pub fn ack(self) -> ControlAck {
match self {
CtlFlow::Ping(msg) => msg.ack(),
CtlFlow::WrBackpressure(msg) => msg.ack(),
}
}
}
#[derive(Debug)]
pub struct Auth {
pkt: codec::Auth,
size: u32,
}
impl Auth {
pub fn packet(&self) -> &codec::Auth {
&self.pkt
}
pub fn packet_size(&self) -> u32 {
self.size
}
pub fn ack(self, response: codec::Auth) -> ControlAck {
ControlAck { packet: Pkt::Packet(codec::Packet::Auth(response)), disconnect: false }
}
}
#[derive(Debug, Clone)]
pub struct PublishRelease {
pkt: codec::PublishAck2,
result: codec::PublishAck2,
size: u32,
}
impl PublishRelease {
pub(crate) fn new(pkt: codec::PublishAck2, size: u32) -> Self {
let packet_id = pkt.packet_id;
Self {
pkt,
size,
result: codec::PublishAck2 {
packet_id,
reason_code: codec::PublishAck2Reason::Success,
properties: codec::UserProperties::default(),
reason_string: None,
},
}
}
pub fn packet(&self) -> &codec::PublishAck2 {
&self.pkt
}
pub fn packet_size(&self) -> u32 {
self.size
}
#[inline]
#[must_use]
pub fn properties<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut codec::UserProperties),
{
f(&mut self.result.properties);
self
}
#[inline]
#[must_use]
pub fn reason(mut self, reason: ByteString) -> Self {
self.result.reason_string = Some(reason);
self
}
pub fn ack(self) -> ControlAck {
ControlAck {
packet: Pkt::Packet(codec::Packet::PublishComplete(self.result)),
disconnect: false,
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct Ping;
impl Ping {
pub fn ack(self) -> ControlAck {
ControlAck { packet: Pkt::Packet(codec::Packet::PingResponse), disconnect: false }
}
}
#[derive(Debug, Clone)]
pub struct Disconnect(pub(crate) codec::Disconnect, pub(crate) u32);
impl Disconnect {
pub fn packet(&self) -> &codec::Disconnect {
&self.0
}
pub fn packet_size(&self) -> u32 {
self.1
}
pub fn ack(self) -> ControlAck {
ControlAck { packet: Pkt::None, disconnect: true }
}
}
#[derive(Debug, Clone)]
pub struct Subscribe {
packet: codec::Subscribe,
result: codec::SubscribeAck,
size: u32,
}
impl Subscribe {
pub fn new(packet: codec::Subscribe, size: u32) -> Self {
let mut status = Vec::with_capacity(packet.topic_filters.len());
(0..packet.topic_filters.len())
.for_each(|_| status.push(codec::SubscribeAckReason::UnspecifiedError));
let result = codec::SubscribeAck {
status,
packet_id: packet.packet_id,
properties: codec::UserProperties::default(),
reason_string: None,
};
Self { packet, result, size }
}
#[inline]
#[must_use]
pub fn iter_mut(&mut self) -> SubscribeIter<'_> {
SubscribeIter { subs: ptr::from_ref(self).cast_mut(), entry: 0, lt: PhantomData }
}
#[inline]
#[must_use]
pub fn ack_reason(mut self, reason: ByteString) -> Self {
self.result.reason_string = Some(reason);
self
}
#[inline]
#[must_use]
pub fn ack_properties<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut codec::UserProperties),
{
f(&mut self.result.properties);
self
}
#[inline]
pub fn ack(self) -> ControlAck {
ControlAck {
packet: Pkt::Packet(codec::Packet::SubscribeAck(self.result)),
disconnect: false,
}
}
#[inline]
pub fn packet(&self) -> &codec::Subscribe {
&self.packet
}
#[inline]
pub fn packet_size(&self) -> u32 {
self.size
}
}
impl<'a> IntoIterator for &'a mut Subscribe {
type Item = Subscription<'a>;
type IntoIter = SubscribeIter<'a>;
fn into_iter(self) -> SubscribeIter<'a> {
self.iter_mut()
}
}
pub struct SubscribeIter<'a> {
subs: *mut Subscribe,
entry: usize,
lt: PhantomData<&'a mut Subscribe>,
}
impl fmt::Debug for SubscribeIter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SubscribeIter").finish()
}
}
impl<'a> SubscribeIter<'a> {
fn next_unsafe(&mut self) -> Option<Subscription<'a>> {
let subs = unsafe { &mut *self.subs };
if self.entry < subs.packet.topic_filters.len() {
let s = Subscription {
topic: &subs.packet.topic_filters[self.entry].0,
options: &subs.packet.topic_filters[self.entry].1,
status: &mut subs.result.status[self.entry],
};
self.entry += 1;
Some(s)
} else {
None
}
}
}
impl<'a> Iterator for SubscribeIter<'a> {
type Item = Subscription<'a>;
#[inline]
fn next(&mut self) -> Option<Subscription<'a>> {
self.next_unsafe()
}
}
#[derive(Debug)]
pub struct Subscription<'a> {
topic: &'a ByteString,
options: &'a codec::SubscriptionOptions,
status: &'a mut codec::SubscribeAckReason,
}
impl<'a> Subscription<'a> {
#[inline]
pub fn topic(&self) -> &'a ByteString {
self.topic
}
#[inline]
pub fn options(&self) -> &codec::SubscriptionOptions {
self.options
}
#[inline]
pub fn fail(&mut self, status: codec::SubscribeAckReason) {
*self.status = status;
}
#[inline]
pub fn confirm(&mut self, qos: QoS) {
match qos {
QoS::AtMostOnce => *self.status = codec::SubscribeAckReason::GrantedQos0,
QoS::AtLeastOnce => *self.status = codec::SubscribeAckReason::GrantedQos1,
QoS::ExactlyOnce => *self.status = codec::SubscribeAckReason::GrantedQos2,
}
}
#[inline]
#[doc(hidden)]
pub fn subscribe(&mut self, qos: QoS) {
self.confirm(qos);
}
}
#[derive(Debug, Clone)]
pub struct Unsubscribe {
packet: codec::Unsubscribe,
result: codec::UnsubscribeAck,
size: u32,
}
impl Unsubscribe {
pub fn new(packet: codec::Unsubscribe, size: u32) -> Self {
let mut status = Vec::with_capacity(packet.topic_filters.len());
(0..packet.topic_filters.len())
.for_each(|_| status.push(codec::UnsubscribeAckReason::Success));
let result = codec::UnsubscribeAck {
status,
packet_id: packet.packet_id,
properties: codec::UserProperties::default(),
reason_string: None,
};
Self { packet, result, size }
}
#[inline]
pub fn properties(&self) -> &codec::UserProperties {
&self.packet.user_properties
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &ByteString> {
self.packet.topic_filters.iter()
}
#[inline]
pub fn iter_mut(&mut self) -> UnsubscribeIter<'_> {
UnsubscribeIter {
subs: ptr::from_ref::<Unsubscribe>(self).cast_mut(),
entry: 0,
lt: PhantomData,
}
}
#[inline]
#[must_use]
pub fn ack_reason(mut self, reason: ByteString) -> Self {
self.result.reason_string = Some(reason);
self
}
#[inline]
#[must_use]
pub fn ack_properties<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut codec::UserProperties),
{
f(&mut self.result.properties);
self
}
#[inline]
pub fn ack(self) -> ControlAck {
ControlAck {
packet: Pkt::Packet(codec::Packet::UnsubscribeAck(self.result)),
disconnect: false,
}
}
#[inline]
pub fn packet(&self) -> &codec::Unsubscribe {
&self.packet
}
#[inline]
pub fn packet_size(&self) -> u32 {
self.size
}
}
impl<'a> IntoIterator for &'a mut Unsubscribe {
type Item = UnsubscribeItem<'a>;
type IntoIter = UnsubscribeIter<'a>;
fn into_iter(self) -> UnsubscribeIter<'a> {
self.iter_mut()
}
}
pub struct UnsubscribeIter<'a> {
subs: *mut Unsubscribe,
entry: usize,
lt: PhantomData<&'a mut Unsubscribe>,
}
impl fmt::Debug for UnsubscribeIter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UnsubscribeIter").finish()
}
}
impl<'a> UnsubscribeIter<'a> {
fn next_unsafe(&mut self) -> Option<UnsubscribeItem<'a>> {
let subs = unsafe { &mut *self.subs };
if self.entry < subs.packet.topic_filters.len() {
let s = UnsubscribeItem {
topic: &subs.packet.topic_filters[self.entry],
status: &mut subs.result.status[self.entry],
};
self.entry += 1;
Some(s)
} else {
None
}
}
}
impl<'a> Iterator for UnsubscribeIter<'a> {
type Item = UnsubscribeItem<'a>;
#[inline]
fn next(&mut self) -> Option<UnsubscribeItem<'a>> {
self.next_unsafe()
}
}
#[derive(Debug)]
pub struct UnsubscribeItem<'a> {
topic: &'a ByteString,
status: &'a mut codec::UnsubscribeAckReason,
}
impl<'a> UnsubscribeItem<'a> {
#[inline]
pub fn topic(&self) -> &'a ByteString {
self.topic
}
#[inline]
pub fn fail(&mut self, status: codec::UnsubscribeAckReason) {
*self.status = status;
}
#[inline]
pub fn success(&mut self) {
*self.status = codec::UnsubscribeAckReason::Success;
}
}
#[derive(Debug, Copy, Clone)]
pub struct WrBackpressure(bool);
impl WrBackpressure {
#[inline]
pub fn enabled(&self) -> bool {
self.0
}
#[inline]
pub fn ack(self) -> ControlAck {
ControlAck { packet: Pkt::None, disconnect: false }
}
}
#[derive(Debug, Copy, Clone)]
pub struct Shutdown;
impl Shutdown {
#[inline]
pub fn ack(self) -> ControlAck {
ControlAck { packet: Pkt::None, disconnect: false }
}
}
#[derive(Debug, Clone)]
pub struct Error<E> {
err: E,
pkt: codec::Disconnect,
}
impl<E> Error<E> {
pub fn new(err: E) -> Self {
Self {
err,
pkt: codec::Disconnect {
session_expiry_interval_secs: None,
server_reference: None,
reason_string: None,
user_properties: UserProperties::default(),
reason_code: DisconnectReasonCode::ImplementationSpecificError,
},
}
}
#[inline]
pub fn get_ref(&self) -> &E {
&self.err
}
#[inline]
#[must_use]
pub fn reason_string(mut self, reason: ByteString) -> Self {
self.pkt.reason_string = Some(reason);
self
}
#[inline]
#[must_use]
pub fn server_reference(mut self, reference: ByteString) -> Self {
self.pkt.server_reference = Some(reference);
self
}
#[inline]
#[must_use]
pub fn properties<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut codec::UserProperties),
{
f(&mut self.pkt.user_properties);
self
}
#[inline]
pub fn ack(mut self, reason: DisconnectReasonCode) -> ControlAck {
self.pkt.reason_code = reason;
ControlAck { packet: Pkt::Disconnect(self.pkt), disconnect: true }
}
#[inline]
pub fn ack_with<F>(self, f: F) -> ControlAck
where
F: FnOnce(E, codec::Disconnect) -> codec::Disconnect,
{
let pkt = f(self.err, self.pkt);
ControlAck { packet: Pkt::Disconnect(pkt), disconnect: true }
}
}
#[derive(Debug, Clone)]
pub struct ProtocolError {
err: error::ProtocolError,
pkt: codec::Disconnect,
}
impl ProtocolError {
pub fn new(err: error::ProtocolError) -> Self {
Self {
pkt: codec::Disconnect {
session_expiry_interval_secs: None,
server_reference: None,
reason_string: None,
user_properties: UserProperties::default(),
reason_code: match err {
error::ProtocolError::Decode(error::DecodeError::InvalidLength) => {
DisconnectReasonCode::MalformedPacket
}
error::ProtocolError::Decode(error::DecodeError::MaxSizeExceeded) => {
DisconnectReasonCode::PacketTooLarge
}
error::ProtocolError::KeepAliveTimeout => {
DisconnectReasonCode::KeepAliveTimeout
}
error::ProtocolError::ProtocolViolation(ref e) => e.reason(),
error::ProtocolError::Encode(_) => {
DisconnectReasonCode::ImplementationSpecificError
}
_ => DisconnectReasonCode::ImplementationSpecificError,
},
},
err,
}
}
#[inline]
pub fn get_ref(&self) -> &error::ProtocolError {
&self.err
}
#[inline]
#[must_use]
pub fn reason_code(mut self, reason: DisconnectReasonCode) -> Self {
self.pkt.reason_code = reason;
self
}
#[inline]
#[must_use]
pub fn reason_string(mut self, reason: ByteString) -> Self {
self.pkt.reason_string = Some(reason);
self
}
#[inline]
#[must_use]
pub fn server_reference(mut self, reference: ByteString) -> Self {
self.pkt.server_reference = Some(reference);
self
}
#[inline]
#[must_use]
pub fn properties<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut codec::UserProperties),
{
f(&mut self.pkt.user_properties);
self
}
#[inline]
pub fn ack(self) -> ControlAck {
ControlAck { packet: Pkt::Disconnect(self.pkt), disconnect: true }
}
#[inline]
pub fn ack_and_error(self) -> (ControlAck, error::ProtocolError) {
(ControlAck { packet: Pkt::Disconnect(self.pkt), disconnect: true }, self.err)
}
}
#[derive(Debug)]
pub struct PeerGone(pub(crate) Option<io::Error>);
impl PeerGone {
#[inline]
pub fn err(&self) -> Option<&io::Error> {
self.0.as_ref()
}
#[inline]
pub fn take(&mut self) -> Option<io::Error> {
self.0.take()
}
#[inline]
pub fn ack(self) -> ControlAck {
ControlAck { packet: Pkt::None, disconnect: true }
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU16;
use ntex_bytes::ByteString;
use super::*;
use crate::types::QoS;
#[test]
fn test_debug() {
let mut sub = Subscribe::new(
codec::Subscribe {
packet_id: NonZeroU16::new(1).unwrap(),
id: None,
user_properties: Vec::new(),
topic_filters: vec![(
ByteString::from_static("a/b"),
codec::SubscriptionOptions {
qos: QoS::AtLeastOnce,
no_local: false,
retain_as_published: false,
retain_handling: codec::RetainHandling::AtSubscribe,
},
)],
},
0,
);
let iter = sub.iter_mut();
assert!(format!("{iter:?}").contains("SubscribeIter"));
let mut unsub = Unsubscribe::new(
codec::Unsubscribe {
packet_id: NonZeroU16::new(2).unwrap(),
user_properties: Vec::new(),
topic_filters: vec![ByteString::from_static("a/b")],
},
0,
);
let uiter = unsub.iter_mut();
assert!(format!("{uiter:?}").contains("UnsubscribeIter"));
assert!(format!("{Ping:?}").contains("Ping"));
assert!(format!("{:?}", WrBackpressure(false)).contains("WrBackpressure"));
assert!(format!("{Shutdown:?}").contains("Shutdown"));
assert!(format!("{:?}", PeerGone(None)).contains("PeerGone"));
}
}