use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::error::Result;
use crate::hlc::HlcTimestamp;
use crate::message::MessageEnvelope;
use crate::runtime::KernelId;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Topic(pub String);
impl Topic {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn name(&self) -> &str {
&self.0
}
pub fn is_pattern(&self) -> bool {
self.0.contains('*') || self.0.contains('#')
}
pub fn matches(&self, other: &Topic) -> bool {
if !self.is_pattern() {
return self.0 == other.0;
}
let pattern_parts: Vec<&str> = self.0.split('/').collect();
let topic_parts: Vec<&str> = other.0.split('/').collect();
let mut p_idx = 0;
let mut t_idx = 0;
while p_idx < pattern_parts.len() && t_idx < topic_parts.len() {
match pattern_parts[p_idx] {
"#" => return true, "*" => {
p_idx += 1;
t_idx += 1;
}
part if part == topic_parts[t_idx] => {
p_idx += 1;
t_idx += 1;
}
_ => return false,
}
}
p_idx == pattern_parts.len() && t_idx == topic_parts.len()
}
}
impl std::fmt::Display for Topic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for Topic {
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl From<String> for Topic {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Debug, Clone)]
pub struct PubSubConfig {
pub max_subscribers_per_topic: usize,
pub channel_buffer_size: usize,
pub max_retained_messages: usize,
pub enable_persistence: bool,
pub default_qos: QoS,
}
impl Default for PubSubConfig {
fn default() -> Self {
Self {
max_subscribers_per_topic: 1000,
channel_buffer_size: 256,
max_retained_messages: 100,
enable_persistence: false,
default_qos: QoS::AtMostOnce,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum QoS {
#[default]
AtMostOnce,
AtLeastOnce,
ExactlyOnce,
}
#[derive(Debug, Clone)]
pub struct Publication {
pub topic: Topic,
pub publisher: KernelId,
pub envelope: MessageEnvelope,
pub timestamp: HlcTimestamp,
pub qos: QoS,
pub sequence: u64,
pub retained: bool,
}
impl Publication {
pub fn new(
topic: Topic,
publisher: KernelId,
envelope: MessageEnvelope,
timestamp: HlcTimestamp,
) -> Self {
Self {
topic,
publisher,
envelope,
timestamp,
qos: QoS::default(),
sequence: 0,
retained: false,
}
}
pub fn with_qos(mut self, qos: QoS) -> Self {
self.qos = qos;
self
}
pub fn with_retained(mut self, retained: bool) -> Self {
self.retained = retained;
self
}
}
pub struct Subscription {
pub id: u64,
pub pattern: Topic,
pub subscriber: KernelId,
receiver: broadcast::Receiver<Publication>,
broker: Arc<PubSubBroker>,
}
impl Subscription {
pub async fn receive(&mut self) -> Option<Publication> {
loop {
match self.receiver.recv().await {
Ok(pub_msg) => {
if self.pattern.matches(&pub_msg.topic) {
return Some(pub_msg);
}
}
Err(broadcast::error::RecvError::Closed) => return None,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
}
pub fn try_receive(&mut self) -> Option<Publication> {
loop {
match self.receiver.try_recv() {
Ok(pub_msg) => {
if self.pattern.matches(&pub_msg.topic) {
return Some(pub_msg);
}
}
Err(_) => return None,
}
}
}
pub fn unsubscribe(self) {
self.broker.unsubscribe(self.id);
}
}
#[derive(Debug, Clone)]
pub struct TopicInfo {
pub topic: Topic,
pub subscriber_count: usize,
pub messages_published: u64,
pub retained_count: usize,
}
pub struct PubSubBroker {
config: PubSubConfig,
sender: broadcast::Sender<Publication>,
subscriptions: RwLock<HashMap<u64, SubscriptionInfo>>,
subscription_counter: AtomicU64,
topic_stats: RwLock<HashMap<Topic, TopicStats>>,
retained: RwLock<HashMap<Topic, Vec<Publication>>>,
sequence: AtomicU64,
}
struct SubscriptionInfo {
pattern: Topic,
#[allow(dead_code)]
subscriber: KernelId,
}
#[derive(Debug, Clone, Default)]
struct TopicStats {
subscribers: HashSet<u64>,
messages_published: u64,
}
impl PubSubBroker {
pub fn new(config: PubSubConfig) -> Arc<Self> {
let (sender, _) = broadcast::channel(config.channel_buffer_size);
Arc::new(Self {
config,
sender,
subscriptions: RwLock::new(HashMap::new()),
subscription_counter: AtomicU64::new(0),
topic_stats: RwLock::new(HashMap::new()),
retained: RwLock::new(HashMap::new()),
sequence: AtomicU64::new(0),
})
}
pub fn subscribe(self: &Arc<Self>, subscriber: KernelId, pattern: Topic) -> Subscription {
let id = self.subscription_counter.fetch_add(1, Ordering::Relaxed);
self.subscriptions.write().insert(
id,
SubscriptionInfo {
pattern: pattern.clone(),
subscriber: subscriber.clone(),
},
);
let mut stats = self.topic_stats.write();
stats
.entry(pattern.clone())
.or_default()
.subscribers
.insert(id);
Subscription {
id,
pattern,
subscriber,
receiver: self.sender.subscribe(),
broker: Arc::clone(self),
}
}
pub fn unsubscribe(&self, subscription_id: u64) {
let info = self.subscriptions.write().remove(&subscription_id);
if let Some(info) = info {
let mut stats = self.topic_stats.write();
if let Some(topic_stats) = stats.get_mut(&info.pattern) {
topic_stats.subscribers.remove(&subscription_id);
}
}
}
pub fn publish(
&self,
topic: Topic,
publisher: KernelId,
envelope: MessageEnvelope,
timestamp: HlcTimestamp,
) -> Result<u64> {
let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
let mut publication = Publication::new(topic.clone(), publisher, envelope, timestamp);
publication.sequence = sequence;
{
let mut stats = self.topic_stats.write();
let topic_stats = stats.entry(topic.clone()).or_default();
topic_stats.messages_published += 1;
}
if publication.retained {
let mut retained = self.retained.write();
let retained_list = retained.entry(topic).or_default();
retained_list.push(publication.clone());
if retained_list.len() > self.config.max_retained_messages {
retained_list.remove(0);
}
}
let _ = self.sender.send(publication);
Ok(sequence)
}
pub fn publish_qos(
&self,
topic: Topic,
publisher: KernelId,
envelope: MessageEnvelope,
timestamp: HlcTimestamp,
qos: QoS,
) -> Result<u64> {
let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
let mut publication = Publication::new(topic.clone(), publisher, envelope, timestamp);
publication.sequence = sequence;
publication.qos = qos;
{
let mut stats = self.topic_stats.write();
let topic_stats = stats.entry(topic).or_default();
topic_stats.messages_published += 1;
}
let _ = self.sender.send(publication);
Ok(sequence)
}
pub fn publish_retained(
&self,
topic: Topic,
publisher: KernelId,
envelope: MessageEnvelope,
timestamp: HlcTimestamp,
) -> Result<u64> {
let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
let mut publication = Publication::new(topic.clone(), publisher, envelope, timestamp);
publication.sequence = sequence;
publication.retained = true;
{
let mut retained = self.retained.write();
let retained_list = retained.entry(topic.clone()).or_default();
retained_list.push(publication.clone());
if retained_list.len() > self.config.max_retained_messages {
retained_list.remove(0);
}
}
{
let mut stats = self.topic_stats.write();
let topic_stats = stats.entry(topic).or_default();
topic_stats.messages_published += 1;
}
let _ = self.sender.send(publication);
Ok(sequence)
}
pub fn get_retained(&self, topic: &Topic) -> Vec<Publication> {
self.retained.read().get(topic).cloned().unwrap_or_default()
}
pub fn clear_retained(&self, topic: &Topic) {
self.retained.write().remove(topic);
}
pub fn topic_info(&self, topic: &Topic) -> Option<TopicInfo> {
let stats = self.topic_stats.read();
let topic_stats = stats.get(topic)?;
let retained_count = self
.retained
.read()
.get(topic)
.map(|v| v.len())
.unwrap_or(0);
Some(TopicInfo {
topic: topic.clone(),
subscriber_count: topic_stats.subscribers.len(),
messages_published: topic_stats.messages_published,
retained_count,
})
}
pub fn list_topics(&self) -> Vec<Topic> {
self.topic_stats
.read()
.iter()
.filter(|(_, stats)| !stats.subscribers.is_empty())
.map(|(topic, _)| topic.clone())
.collect()
}
pub fn stats(&self) -> PubSubStats {
let stats = self.topic_stats.read();
let total_subscribers: usize = stats.values().map(|s| s.subscribers.len()).sum();
let total_messages: u64 = stats.values().map(|s| s.messages_published).sum();
let retained_count: usize = self.retained.read().values().map(|v| v.len()).sum();
PubSubStats {
topic_count: stats.len(),
total_subscribers,
total_messages_published: total_messages,
retained_message_count: retained_count,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PubSubStats {
pub topic_count: usize,
pub total_subscribers: usize,
pub total_messages_published: u64,
pub retained_message_count: usize,
}
pub struct PubSubBuilder {
config: PubSubConfig,
}
impl PubSubBuilder {
pub fn new() -> Self {
Self {
config: PubSubConfig::default(),
}
}
pub fn max_subscribers_per_topic(mut self, count: usize) -> Self {
self.config.max_subscribers_per_topic = count;
self
}
pub fn channel_buffer_size(mut self, size: usize) -> Self {
self.config.channel_buffer_size = size;
self
}
pub fn max_retained_messages(mut self, count: usize) -> Self {
self.config.max_retained_messages = count;
self
}
pub fn enable_persistence(mut self, enable: bool) -> Self {
self.config.enable_persistence = enable;
self
}
pub fn default_qos(mut self, qos: QoS) -> Self {
self.config.default_qos = qos;
self
}
pub fn build(self) -> Arc<PubSubBroker> {
PubSubBroker::new(self.config)
}
}
impl Default for PubSubBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topic_matching() {
let pattern = Topic::new("sensors/*/temperature");
let topic1 = Topic::new("sensors/kitchen/temperature");
let topic2 = Topic::new("sensors/living_room/temperature");
let topic3 = Topic::new("sensors/kitchen/humidity");
assert!(pattern.matches(&topic1));
assert!(pattern.matches(&topic2));
assert!(!pattern.matches(&topic3));
}
#[test]
fn test_topic_wildcard_hash() {
let pattern = Topic::new("sensors/#");
let topic1 = Topic::new("sensors/kitchen/temperature");
let topic2 = Topic::new("sensors/a/b/c/d");
assert!(pattern.matches(&topic1));
assert!(pattern.matches(&topic2));
}
#[test]
fn test_topic_exact_match() {
let pattern = Topic::new("sensors/kitchen/temperature");
let topic1 = Topic::new("sensors/kitchen/temperature");
let topic2 = Topic::new("sensors/kitchen/humidity");
assert!(pattern.matches(&topic1));
assert!(!pattern.matches(&topic2));
}
#[tokio::test]
async fn test_pubsub_broker() {
let broker = PubSubBuilder::new().build();
let publisher = KernelId::new("publisher");
let subscriber = KernelId::new("subscriber");
let topic = Topic::new("test/topic");
let mut subscription = broker.subscribe(subscriber, topic.clone());
let envelope = MessageEnvelope::empty(1, 2, HlcTimestamp::now(1));
let timestamp = HlcTimestamp::now(1);
broker
.publish(topic.clone(), publisher.clone(), envelope, timestamp)
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let received = subscription.try_receive();
assert!(received.is_some());
assert_eq!(received.unwrap().publisher, publisher);
}
#[test]
fn test_pubsub_stats() {
let broker = PubSubBuilder::new().build();
let topic = Topic::new("test");
let kernel = KernelId::new("kernel");
let _sub = broker.subscribe(kernel.clone(), topic.clone());
let stats = broker.stats();
assert_eq!(stats.topic_count, 1);
assert_eq!(stats.total_subscribers, 1);
}
}