use std::{
collections::VecDeque,
sync::{atomic::AtomicU16, Arc},
};
use arc_swap::{ArcSwap, ArcSwapOption};
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use mqtt_format::v3::{identifier::MPacketIdentifier, packet::MPublish, qos::MQualityOfService};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace};
use crate::error::MqttError;
use super::{message::MqttMessage, ClientConnection};
#[derive(Debug)]
enum StoredMessage {
Sent(MqttMessage),
Acknowledged,
}
#[derive(Debug, Default)]
pub struct ClientState {
connection_token: ArcSwap<CancellationToken>,
conn: Arc<ArcSwapOption<ClientConnection>>,
current_id: AtomicU16,
sent_stored_messages: DashMap<MPacketIdentifier, StoredMessage>,
received_messages: DashSet<MPacketIdentifier>,
}
impl ClientState {
pub(crate) async fn set_new_connection(&self, conn: Arc<ClientConnection>) {
let token = self
.connection_token
.swap(Arc::new(CancellationToken::new()));
token.cancel();
let old_conn = self.conn.swap(Some(conn));
if let Some(old_conn) = old_conn {
tokio::spawn(async move {
let mut lock = old_conn.writer.lock().await;
let _ = lock.shutdown().await;
});
}
let mut messages: VecDeque<_> = self.sent_stored_messages.iter().collect();
let last_message_id = {
let current_id = self.current_id.load(std::sync::atomic::Ordering::SeqCst);
messages
.iter()
.enumerate()
.skip_while(|(_, msg)| msg.key().0 < current_id)
.map(|(idx, _)| idx)
.next()
};
if let Some(last_message_id) = last_message_id {
messages.rotate_left(last_message_id);
}
for stored_message in &self.sent_stored_messages {
let conn = self.conn.clone();
let token = self.connection_token.load().clone();
let msg = match stored_message.value() {
StoredMessage::Sent(msg) => msg.clone(),
StoredMessage::Acknowledged => continue,
};
let key = *stored_message.key();
tokio::spawn(Self::with_connection(token, conn, move |conn| async move {
let packet = MPublish {
dup: false,
qos: msg.qos(),
retain: msg.retain(),
topic_name: mqtt_format::v3::strings::MString { value: msg.topic() },
id: Some(key),
payload: msg.payload(),
};
let mut conn = conn.writer.lock().await;
crate::write_packet(&mut *conn, packet).await?;
Ok(())
}));
}
}
async fn with_connection<'a, F, D>(
token: Arc<CancellationToken>,
conn: Arc<ArcSwapOption<ClientConnection>>,
fun: F,
) where
F: FnOnce(Arc<ClientConnection>) -> D,
F: 'a,
D: std::future::Future<Output = Result<(), MqttError>> + Send + 'a,
{
let conn = conn.load_full();
if let Some(conn) = conn {
{
tokio::select! {
_ = fun(conn) => {}
_ = token.cancelled() => {}
}
}
}
}
pub async fn send_message(&self, msg: MqttMessage) {
let entry = self
.get_next_packet_entry()
.expect("Exhausted available identifier slots");
match msg.qos() {
MQualityOfService::AtMostOnce => {
let conn = self.conn.clone();
let token = self.connection_token.load().clone();
tokio::spawn(Self::with_connection(token, conn, move |conn| async move {
let msg = msg;
let packet = MPublish {
dup: false,
qos: msg.qos(),
retain: msg.retain(),
topic_name: mqtt_format::v3::strings::MString { value: msg.topic() },
id: None,
payload: msg.payload(),
};
let mut conn = conn.writer.lock().await;
crate::write_packet(&mut *conn, packet).await?;
Ok(())
}));
}
MQualityOfService::AtLeastOnce | MQualityOfService::ExactlyOnce => {
let conn = self.conn.clone();
let token = self.connection_token.load().clone();
let key = *entry.key();
entry.or_insert(StoredMessage::Sent(msg.clone()));
tokio::spawn(Self::with_connection(token, conn, move |conn| async move {
let packet = MPublish {
dup: false,
qos: msg.qos(),
retain: msg.retain(),
topic_name: mqtt_format::v3::strings::MString { value: msg.topic() },
id: Some(key),
payload: msg.payload(),
};
let mut conn = conn.writer.lock().await;
crate::write_packet(&mut *conn, packet).await?;
Ok(())
}));
}
}
}
pub fn receive_puback(&self, id: MPacketIdentifier) -> Result<(), ()> {
if let Some(msg) = self.sent_stored_messages.get(&id) {
match msg.value() {
StoredMessage::Sent(msg) => {
if msg.qos() != MQualityOfService::AtLeastOnce {
debug!(
?id,
"Received a PUBACK for a non-QoS 1 message (QoS was {:?})",
msg.qos()
);
return Err(());
}
}
StoredMessage::Acknowledged => {
debug!(
?id,
"Received a PUBACK for an already acknowledged QoS 2 message"
);
return Err(());
}
}
drop(msg);
self.sent_stored_messages.remove(&id);
trace!(?id, "Removed QoS 1 message after acknowledging it");
} else {
debug!(?id, "Received a PUBACK for a nonexistent message");
return Err(());
}
Ok(())
}
pub fn receive_pubrec(&self, id: MPacketIdentifier) -> Result<(), ()> {
if let Some(msg) = self.sent_stored_messages.get(&id) {
match msg.value() {
StoredMessage::Sent(msg) => {
if msg.qos() != MQualityOfService::ExactlyOnce {
debug!(
?id,
"Received a PUBREC for a non-QoS 1 message (QoS was {:?})",
msg.qos()
);
return Err(());
}
}
StoredMessage::Acknowledged => {
debug!(
?id,
"Received a PUBREC for an already acknowledged QoS 2 message"
);
return Err(());
}
}
drop(msg);
self.sent_stored_messages
.insert(id, StoredMessage::Acknowledged);
trace!(?id, "Acknowledged QoS 2 message, storing identifier");
} else {
debug!(?id, "Received a PUBREC for a nonexistent message");
return Err(());
}
Ok(())
}
pub fn receive_pubcomp(&self, id: MPacketIdentifier) -> Result<(), ()> {
if let Some(msg) = self.sent_stored_messages.get(&id) {
match msg.value() {
StoredMessage::Sent(_msg) => {
debug!(
?id,
"Received a PUBCOMP for an already acknowledged QoS 2 message"
);
return Err(());
}
StoredMessage::Acknowledged => {
}
}
drop(msg);
self.sent_stored_messages.remove(&id);
trace!(?id, "Removed QoS 2 message after completing it");
} else {
debug!(?id, "Received a PUBCOMP for a nonexistent message");
return Err(());
}
Ok(())
}
pub fn receive_pubrel(&self, id: MPacketIdentifier) -> Result<(), ()> {
if self.received_messages.contains(&id) {
self.received_messages.remove(&id);
} else {
debug!(
?id,
"Received a pubrel, but no corresponding message exists"
);
return Err(());
}
Ok(())
}
pub fn save_qos_exactly_once(&self, id: MPacketIdentifier) -> Result<(), ()> {
if self.received_messages.contains(&id) {
debug!(?id, "Tried to save message id, but it already exists");
return Err(());
} else {
self.received_messages.insert(id);
}
Ok(())
}
fn get_next_packet_entry(&self) -> Option<Entry<MPacketIdentifier, StoredMessage>> {
if self.sent_stored_messages.len() == u16::MAX as usize {
debug!("We are storing u16::MAX messages, cannot allocate more messages");
return None;
}
let next_id = match self
.current_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
{
0 => self
.current_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
other => other,
};
let entry = self.sent_stored_messages.entry(MPacketIdentifier(next_id));
if matches!(entry, Entry::Occupied(_)) {
self.get_next_packet_entry()
} else {
Some(entry)
}
}
}