use std::{mem, sync::Arc, time::Instant};
use log::debug;
use mqttbytes::{v5::Publish, QoS};
use parking_lot::Mutex;
use tokio::sync::Semaphore;
use crate::client::mqtt::publish_store::store::Store;
use crate::client::Message;
use std::time::Duration;
pub(crate) struct PublishStore {
store: Mutex<Store>,
semaphore: Semaphore,
}
impl PublishStore {
pub fn new() -> Self {
Self {
store: Mutex::new(Store::new()),
semaphore: Semaphore::new(1),
}
}
pub fn insert_to_send(&self, message: Message) {
{
let mut store = self.store.lock();
store.insert_to_send(message);
}
self.semaphore.add_permits(1);
}
pub fn tx_pending(&self) -> usize {
self.store.lock().len()
}
pub(super) async fn notified(&self) {
let x = self
.semaphore
.acquire_many(1)
.await
.expect("PublishStore closed");
x.forget();
}
pub(super) fn next_publish(&self, pkid: u16) -> Option<Arc<Publish>> {
let mut store = self.store.lock();
store.next_message_to_send(pkid)
}
pub(super) fn remove_waiting_for_ack(&self, pkid: u16) {
let mut store = self.store.lock();
store.remove_waiting_for_ack(pkid);
}
pub(super) fn on_connect_cleanup(&self) {
let is_empty = {
let mut store = self.store.lock();
store.on_connect_cleanup();
store.len() == 0
};
if !is_empty {
self.semaphore.add_permits(1);
}
}
pub(super) fn remove_expired(&self) {
let mut store = self.store.lock();
store.remove_expired();
}
}
mod store {
use super::*;
struct NotSendMessage {
message: Message,
send_time: Instant,
}
struct WaitingForAck {
publish: Arc<Publish>,
send_time: Instant,
}
enum PublishStoreEntry {
NotSend(NotSendMessage),
WaitingForAck(WaitingForAck),
}
impl PublishStoreEntry {
fn message_expiry_interval(&self) -> Option<u32> {
match self {
PublishStoreEntry::NotSend(ns) => ns.message.properties.message_expiry_interval,
PublishStoreEntry::WaitingForAck(wfa) => wfa
.publish
.as_ref()
.properties
.as_ref()
.map(|p| p.message_expiry_interval)
.flatten(),
}
}
fn send_time(&self) -> Instant {
match self {
PublishStoreEntry::NotSend(ns) => ns.send_time,
PublishStoreEntry::WaitingForAck(wfa) => wfa.send_time,
}
}
fn topic(&self) -> &str {
match self {
PublishStoreEntry::NotSend(ns) => ns.message.topic(),
PublishStoreEntry::WaitingForAck(wfa) => wfa.publish.topic.as_str(),
}
}
fn retain(&self) -> bool {
match self {
PublishStoreEntry::NotSend(ns) => ns.message.retain,
PublishStoreEntry::WaitingForAck(wfa) => wfa.publish.retain,
}
}
}
pub(super) struct Store {
vec: Vec<Option<PublishStoreEntry>>,
last_autovacuum_check: Instant,
}
impl Store {
pub fn new() -> Store {
Store {
vec: vec![],
last_autovacuum_check: Instant::now(),
}
}
pub fn insert_to_send(&mut self, message: Message) {
let not_send = NotSendMessage {
message,
send_time: Instant::now(),
};
if not_send.message.retain {
let topic = not_send.message.topic.as_str();
if let Some(entry) = self
.vec
.iter_mut()
.filter_map(|x| x.as_mut())
.find(|e| e.retain() && e.topic() == topic)
{
*entry = PublishStoreEntry::NotSend(not_send);
return;
}
}
self.vec.push(Some(PublishStoreEntry::NotSend(not_send)));
}
fn autovacuum(&mut self) {
if self.last_autovacuum_check.elapsed() < Duration::from_millis(50) {
return;
}
self.last_autovacuum_check = Instant::now();
let store_len = self.vec.len();
if store_len < 64 {
return;
}
let empty_entries_num = self.vec.iter().filter(|e| e.is_none()).count();
if empty_entries_num < store_len / 10 {
return;
}
self.vec.retain(|e| e.is_some());
}
pub fn next_message_to_send(&mut self, pkid: u16) -> Option<Arc<Publish>> {
let entry = self
.vec
.iter_mut()
.find(|e| matches!(e, Some(PublishStoreEntry::NotSend(_))))?;
let NotSendMessage {
mut message,
send_time,
} = match mem::take(entry) {
Some(PublishStoreEntry::NotSend(nsm)) => nsm,
_ => return None,
};
if let Some(message_expiry_interval) = message.properties.message_expiry_interval {
let delta_seconds = send_time.elapsed().as_secs() as u32;
let new_mei = message_expiry_interval.saturating_sub(delta_seconds);
message.properties.message_expiry_interval = Some(new_mei);
}
let publish = Arc::new(Publish {
dup: false,
qos: message.qos,
retain: message.retain,
topic: message.topic,
pkid,
properties: Some(message.properties),
payload: message.payload,
});
match message.qos {
QoS::AtMostOnce => {
self.autovacuum();
}
QoS::AtLeastOnce => {
let waiting_for_ack = WaitingForAck {
publish: publish.clone(),
send_time,
};
*entry = Some(PublishStoreEntry::WaitingForAck(waiting_for_ack));
}
QoS::ExactlyOnce => unreachable!(),
};
Some(publish)
}
pub fn remove_waiting_for_ack(&mut self, pkid: u16) {
for e in self.vec.iter_mut() {
match e {
Some(PublishStoreEntry::WaitingForAck(wfa)) if wfa.publish.pkid == pkid => {
debug!("Received acknowledge for packet id {}", pkid);
}
_ => continue,
};
*e = None;
break;
}
self.autovacuum();
}
pub fn remove_expired(&mut self) {
let now = Instant::now();
let mut removed = 0;
#[allow(clippy::manual_flatten)]
for entity in &mut self.vec {
if let Some(e) = entity {
if let Some(message_expiry_interval) = e.message_expiry_interval() {
let delta_seconds = (now - e.send_time()).as_secs() as u32;
if delta_seconds >= message_expiry_interval {
log::debug!("timeout for not sent {}", e.topic());
removed += 1;
*entity = None;
}
}
}
}
if removed > 0 {
log::info!("mqtt timeouted {} messages", removed);
self.autovacuum();
}
}
pub fn on_connect_cleanup(&mut self) {
self.remove_expired();
let mut new_vec = Vec::with_capacity(self.vec.len());
let old_vec = mem::take(&mut self.vec);
for e in old_vec {
match e {
None => {}
Some(PublishStoreEntry::NotSend(ns)) => {
if ns.message.qos > QoS::AtMostOnce {
new_vec.push(Some(PublishStoreEntry::NotSend(ns)));
}
}
Some(PublishStoreEntry::WaitingForAck(wfa)) => {
if wfa.publish.qos > QoS::AtMostOnce {
let publish: Publish = match Arc::try_unwrap(wfa.publish) {
Ok(p) => p,
Err(p) => p.as_ref().clone(),
};
new_vec.push(Some(PublishStoreEntry::NotSend(NotSendMessage {
message: Message::from(publish),
send_time: wfa.send_time,
})));
}
}
}
}
self.vec = new_vec;
}
pub fn len(&self) -> usize {
self.vec.iter().filter(|e| e.is_some()).count()
}
}
}