use crate::packet::publish::PublishPacket;
use crate::topic_matching::matches as topic_matches;
use crate::QoS;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct RetainedMessageStore {
messages: Arc<RwLock<HashMap<String, RetainedMessage>>>,
}
#[derive(Debug, Clone)]
pub struct RetainedMessage {
pub topic: String,
pub payload: Vec<u8>,
pub qos: QoS,
pub properties: crate::protocol::v5::properties::Properties,
}
impl RetainedMessageStore {
#[must_use]
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn store(&self, topic: impl Into<String>, message: Option<RetainedMessage>) {
let mut messages = self.messages.write().await;
let topic = topic.into();
if let Some(msg) = message {
messages.insert(topic, msg);
} else {
messages.remove(&topic);
}
}
pub async fn get_matching(&self, topic_filter: &str) -> Vec<RetainedMessage> {
let messages = self.messages.read().await;
let mut matching = Vec::new();
for (topic, message) in messages.iter() {
if topic_matches(topic, topic_filter) {
matching.push(message.clone());
}
}
matching
}
pub async fn get(&self, topic: &str) -> Option<RetainedMessage> {
let messages = self.messages.read().await;
messages.get(topic).cloned()
}
pub async fn clear_all(&self) {
let mut messages = self.messages.write().await;
messages.clear();
}
pub async fn count(&self) -> usize {
let messages = self.messages.read().await;
messages.len()
}
}
impl From<&PublishPacket> for RetainedMessage {
fn from(packet: &PublishPacket) -> Self {
Self {
topic: packet.topic_name.clone(),
payload: packet.payload.to_vec(),
qos: packet.qos,
properties: packet.properties.clone(),
}
}
}
impl Default for RetainedMessageStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::TestMessageBuilder;
use crate::Properties;
#[tokio::test]
async fn test_store_and_retrieve() {
let store = RetainedMessageStore::new();
let msg = RetainedMessage {
topic: "test/topic".to_string(),
payload: b"test payload".to_vec(),
qos: QoS::AtLeastOnce,
properties: Properties::default(),
};
store
.store("test/topic".to_string(), Some(msg.clone()))
.await;
let retrieved = store.get("test/topic").await;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.topic, "test/topic");
assert_eq!(&retrieved.payload[..], b"test payload");
assert_eq!(retrieved.qos, QoS::AtLeastOnce);
}
#[tokio::test]
async fn test_clear_retained_message() {
let store = RetainedMessageStore::new();
let msg = RetainedMessage {
topic: "test/topic".to_string(),
payload: b"test payload".to_vec(),
qos: QoS::AtMostOnce,
properties: Properties::default(),
};
store.store("test/topic".to_string(), Some(msg)).await;
assert_eq!(store.count().await, 1);
store.store("test/topic".to_string(), None).await;
assert_eq!(store.count().await, 0);
let retrieved = store.get("test/topic").await;
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_topic_matching() {
let store = RetainedMessageStore::new();
let topics = vec![
"home/room1/temp",
"home/room1/humidity",
"home/room2/temp",
"office/room1/temp",
];
for topic in topics {
let msg = RetainedMessage {
topic: topic.to_string(),
payload: format!("data for {topic}").into_bytes(),
qos: QoS::AtMostOnce,
properties: Properties::default(),
};
store.store(topic.to_string(), Some(msg)).await;
}
let matching = store.get_matching("home/room1/temp").await;
assert_eq!(matching.len(), 1);
assert_eq!(matching[0].topic, "home/room1/temp");
let matching = store.get_matching("home/+/temp").await;
assert_eq!(matching.len(), 2);
let matching = store.get_matching("home/#").await;
assert_eq!(matching.len(), 3);
let matching = store.get_matching("garage/+/temp").await;
assert_eq!(matching.len(), 0);
}
#[tokio::test]
async fn test_clear_all() {
let store = RetainedMessageStore::new();
let messages = TestMessageBuilder::new()
.with_topic_prefix("topic")
.build_retained_batch(5);
for (i, msg) in messages.into_iter().enumerate() {
store.store(format!("topic/{i}"), Some(msg)).await;
}
assert_eq!(store.count().await, 5);
store.clear_all().await;
assert_eq!(store.count().await, 0);
}
}