use std::{cell::Cell, cell::RefCell, collections::VecDeque, num::NonZeroU16, rc::Rc};
use ntex::codec::{Decoder, Encoder};
use ntex::util::{BytesMut, HashSet, PoolId, PoolRef};
use ntex::{channel::pool, io::IoRef};
use crate::{error, error::SendPacketError, types::packet_type, v5::codec, QoS};
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Flags: u8 {
const WRB_ENABLED = 0b0100_0000; const ON_PUBLISH_ACK = 0b0010_0000; }
}
pub struct MqttShared {
io: IoRef,
cap: Cell<usize>,
max_qos: Cell<QoS>,
receive_max: Cell<u16>,
topic_alias_max: Cell<u16>,
inflight_idx: Cell<u16>,
queues: RefCell<MqttSharedQueues>,
flags: Cell<Flags>,
pool: Rc<MqttSinkPool>,
on_publish_ack: Cell<Option<Box<dyn Fn(codec::PublishAck, bool)>>>,
pub(super) codec: codec::Codec,
}
pub(super) struct MqttSharedQueues {
inflight: VecDeque<(NonZeroU16, Option<pool::Sender<Ack>>, AckType)>,
inflight_ids: HashSet<NonZeroU16>,
waiters: VecDeque<pool::Sender<()>>,
}
pub(super) struct MqttSinkPool {
queue: pool::Pool<Ack>,
waiters: pool::Pool<()>,
pub(super) pool: Cell<PoolRef>,
}
impl Default for MqttSinkPool {
fn default() -> Self {
Self {
queue: pool::new(),
waiters: pool::new(),
pool: Cell::new(PoolId::P5.pool_ref()),
}
}
}
impl MqttShared {
pub(super) fn new(io: IoRef, codec: codec::Codec, pool: Rc<MqttSinkPool>) -> Self {
Self {
io,
pool,
codec,
cap: Cell::new(0),
queues: RefCell::new(MqttSharedQueues {
inflight: VecDeque::with_capacity(8),
inflight_ids: HashSet::default(),
waiters: VecDeque::new(),
}),
receive_max: Cell::new(0),
topic_alias_max: Cell::new(0),
max_qos: Cell::new(QoS::AtLeastOnce),
inflight_idx: Cell::new(0),
flags: Cell::new(Flags::empty()),
on_publish_ack: Cell::new(None),
}
}
pub(super) fn receive_max(&self) -> u16 {
self.receive_max.get()
}
pub(super) fn topic_alias_max(&self) -> u16 {
self.topic_alias_max.get()
}
pub(super) fn max_qos(&self) -> QoS {
self.max_qos.get()
}
pub(super) fn set_receive_max(&self, val: u16) {
self.receive_max.set(val);
}
pub(super) fn set_topic_alias_max(&self, val: u16) {
self.topic_alias_max.set(val);
}
pub(super) fn set_max_qos(&self, val: QoS) {
self.max_qos.set(val);
}
pub(super) fn close(&self, pkt: codec::Disconnect) {
if !self.is_closed() {
let _ = self.io.encode(codec::Packet::Disconnect(pkt), &self.codec);
self.io.close();
}
self.clear_queues();
}
pub(super) fn force_close(&self) {
self.io.force_close();
self.clear_queues();
}
pub(super) fn is_closed(&self) -> bool {
self.io.is_closed()
}
pub(super) fn credit(&self) -> usize {
self.cap.get().saturating_sub(self.queues.borrow().inflight.len())
}
pub(super) fn is_ready(&self) -> bool {
self.credit() > 0 && !self.flags.get().contains(Flags::WRB_ENABLED)
}
pub(super) fn next_id(&self) -> NonZeroU16 {
let idx = self.inflight_idx.get() + 1;
self.inflight_idx.set(idx);
let idx = if idx == u16::max_value() {
self.inflight_idx.set(0);
u16::max_value()
} else {
self.inflight_idx.set(idx);
idx
};
NonZeroU16::new(idx).unwrap()
}
pub(super) fn set_cap(&self, cap: usize) {
let mut queues = self.queues.borrow_mut();
'outer: for _ in 0..cap {
while let Some(tx) = queues.waiters.pop_front() {
if tx.send(()).is_ok() {
continue 'outer;
}
}
break;
}
self.cap.set(cap);
}
pub(super) fn set_publish_ack(&self, f: Box<dyn Fn(codec::PublishAck, bool)>) {
let mut flags = self.flags.get();
flags.insert(Flags::ON_PUBLISH_ACK);
self.flags.set(flags);
self.on_publish_ack.set(Some(f));
}
pub(super) fn encode_packet(&self, pkt: codec::Packet) -> Result<(), error::EncodeError> {
self.io.encode(pkt, &self.codec)
}
pub(super) fn drop_sink(&self) {
self.clear_queues();
self.io.close();
}
fn clear_queues(&self) {
let mut queues = self.queues.borrow_mut();
queues.waiters.clear();
if let Some(cb) = self.on_publish_ack.take() {
for (idx, tx, _) in queues.inflight.drain(..) {
if tx.is_none() {
(*cb)(codec::PublishAck { packet_id: idx, ..Default::default() }, true);
}
}
} else {
queues.inflight.clear()
}
}
pub(super) fn enable_wr_backpressure(&self) {
let mut flags = self.flags.get();
flags.insert(Flags::WRB_ENABLED);
self.flags.set(flags);
}
pub(super) fn disable_wr_backpressure(&self) {
let mut flags = self.flags.get();
flags.remove(Flags::WRB_ENABLED);
self.flags.set(flags);
let mut queues = self.queues.borrow_mut();
if queues.inflight.len() < self.cap.get() {
let mut num = self.cap.get() - queues.inflight.len();
while num > 0 {
if let Some(tx) = queues.waiters.pop_front() {
if tx.send(()).is_ok() {
num -= 1;
}
} else {
break;
}
}
}
}
pub(super) fn pkt_ack(&self, ack: Ack) -> Result<(), error::ProtocolError> {
self.pkt_ack_inner(ack).map_err(|e| {
self.close(codec::Disconnect {
reason_code: codec::DisconnectReasonCode::ImplementationSpecificError,
..Default::default()
});
e
})
}
fn pkt_ack_inner(&self, pkt: Ack) -> Result<(), error::ProtocolError> {
let mut queues = self.queues.borrow_mut();
if let Some((idx, tx, tp)) = queues.inflight.pop_front() {
if idx != pkt.packet_id() {
log::trace!(
"MQTT protocol error, packet_id order does not match, expected {}, got: {}",
idx,
pkt.packet_id()
);
Err(error::ProtocolError::packet_id_mismatch())
} else {
log::trace!("Ack packet with id: {}", pkt.packet_id());
queues.inflight_ids.remove(&pkt.packet_id());
if pkt.is_match(tp) {
if let Some(tx) = tx {
let _ = tx.send(pkt);
} else {
let cb = self.on_publish_ack.take().unwrap();
(*cb)(pkt.publish(), false);
self.on_publish_ack.set(Some(cb));
}
while let Some(tx) = queues.waiters.pop_front() {
if tx.send(()).is_ok() {
break;
}
}
Ok(())
} else {
log::trace!("MQTT protocol error, unexpeted packet");
Err(error::ProtocolError::unexpected_packet(
pkt.packet_type(),
tp.expected_str(),
))
}
}
} else {
log::trace!("Unexpected PublishAck packet");
Err(error::ProtocolError::generic_violation(
"Received PUBACK packet while there are no unacknowledged PUBLISH packets",
))
}
}
pub(super) fn wait_response(
&self,
id: NonZeroU16,
ack: AckType,
) -> Result<pool::Receiver<Ack>, SendPacketError> {
let mut queues = self.queues.borrow_mut();
if queues.inflight_ids.contains(&id) {
Err(SendPacketError::PacketIdInUse(id))
} else {
let (tx, rx) = self.pool.queue.channel();
queues.inflight.push_back((id, Some(tx), ack));
queues.inflight_ids.insert(id);
Ok(rx)
}
}
pub(super) fn wait_packet_response(
&self,
id: NonZeroU16,
ack: AckType,
pkt: codec::Packet,
) -> Result<pool::Receiver<Ack>, SendPacketError> {
let mut queues = self.queues.borrow_mut();
if queues.inflight_ids.contains(&id) {
Err(SendPacketError::PacketIdInUse(id))
} else {
match self.io.encode(pkt, &self.codec) {
Ok(_) => {
let (tx, rx) = self.pool.queue.channel();
queues.inflight.push_back((id, Some(tx), ack));
queues.inflight_ids.insert(id);
Ok(rx)
}
Err(e) => Err(SendPacketError::Encode(e)),
}
}
}
pub(super) fn wait_packet_response_no_block(
&self,
id: NonZeroU16,
ack: AckType,
pkt: codec::Packet,
) -> Result<(), SendPacketError> {
let mut queues = self.queues.borrow_mut();
if queues.inflight_ids.contains(&id) {
Err(SendPacketError::PacketIdInUse(id))
} else {
match self.io.encode(pkt, &self.codec) {
Ok(_) => {
queues.inflight.push_back((id, None, ack));
queues.inflight_ids.insert(id);
Ok(())
}
Err(e) => Err(SendPacketError::Encode(e)),
}
}
}
pub(super) fn wait_readiness(&self) -> Option<pool::Receiver<()>> {
let mut queues = self.queues.borrow_mut();
if queues.inflight.len() >= self.cap.get()
|| self.flags.get().contains(Flags::WRB_ENABLED)
{
let (tx, rx) = self.pool.waiters.channel();
queues.waiters.push_back(tx);
Some(rx)
} else {
None
}
}
}
impl Encoder for MqttShared {
type Item = codec::Packet;
type Error = error::EncodeError;
#[inline]
fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.codec.encode(item, dst)
}
}
impl Decoder for MqttShared {
type Item = (codec::Packet, u32);
type Error = error::DecodeError;
#[inline]
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.codec.decode(src)
}
}
#[derive(Copy, Clone)]
pub(super) enum AckType {
Publish,
Subscribe,
Unsubscribe,
}
pub(super) enum Ack {
Publish(codec::PublishAck),
Subscribe(codec::SubscribeAck),
Unsubscribe(codec::UnsubscribeAck),
}
impl Ack {
pub(super) fn packet_type(&self) -> u8 {
match self {
Ack::Publish(_) => packet_type::PUBACK,
Ack::Subscribe(_) => packet_type::SUBACK,
Ack::Unsubscribe(_) => packet_type::UNSUBACK,
}
}
pub(super) fn packet_id(&self) -> NonZeroU16 {
match self {
Ack::Publish(ref pkt) => pkt.packet_id,
Ack::Subscribe(ref pkt) => pkt.packet_id,
Ack::Unsubscribe(ref pkt) => pkt.packet_id,
}
}
pub(super) fn publish(self) -> codec::PublishAck {
if let Ack::Publish(pkt) = self {
pkt
} else {
panic!()
}
}
pub(super) fn subscribe(self) -> codec::SubscribeAck {
if let Ack::Subscribe(pkt) = self {
pkt
} else {
panic!()
}
}
pub(super) fn unsubscribe(self) -> codec::UnsubscribeAck {
if let Ack::Unsubscribe(pkt) = self {
pkt
} else {
panic!()
}
}
pub(super) fn is_match(&self, tp: AckType) -> bool {
match (self, tp) {
(Ack::Publish(_), AckType::Publish) => true,
(Ack::Subscribe(_), AckType::Subscribe) => true,
(Ack::Unsubscribe(_), AckType::Unsubscribe) => true,
(_, _) => false,
}
}
}
impl AckType {
pub(super) fn expected_str(&self) -> &'static str {
match self {
AckType::Publish => "Expected PUBACK packet",
AckType::Subscribe => "Expected SUBACK packet",
AckType::Unsubscribe => "Expected UNSUBACK packet",
}
}
}