use crate::error::Result;
use crate::packet::publish::PublishPacket;
use crate::validation::strip_shared_subscription_prefix;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub type PublishCallback = Arc<dyn Fn(PublishPacket) + Send + Sync>;
pub type CallbackId = u64;
#[derive(Clone)]
pub(crate) struct CallbackEntry {
id: CallbackId,
callback: PublishCallback,
topic_filter: String,
}
pub struct CallbackManager {
exact_callbacks: Arc<Mutex<HashMap<String, Vec<CallbackEntry>>>>,
wildcard_callbacks: Arc<Mutex<Vec<CallbackEntry>>>,
callback_registry: Arc<Mutex<HashMap<CallbackId, CallbackEntry>>>,
next_id: Arc<AtomicU64>,
}
impl CallbackManager {
#[must_use]
pub fn new() -> Self {
Self {
exact_callbacks: Arc::new(Mutex::new(HashMap::new())),
wildcard_callbacks: Arc::new(Mutex::new(Vec::new())),
callback_registry: Arc::new(Mutex::new(HashMap::new())),
next_id: Arc::new(AtomicU64::new(1)),
}
}
pub fn register_with_id(
&self,
topic_filter: &str,
callback: PublishCallback,
) -> Result<CallbackId> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let entry = CallbackEntry {
id,
callback,
topic_filter: topic_filter.to_string(),
};
self.callback_registry.lock().insert(id, entry.clone());
self.register_internal(topic_filter, entry);
Ok(id)
}
pub fn register(&self, topic_filter: &str, callback: PublishCallback) -> Result<CallbackId> {
self.register_with_id(topic_filter, callback)
}
fn register_internal(&self, topic_filter: &str, entry: CallbackEntry) {
let actual_filter = strip_shared_subscription_prefix(topic_filter).to_string();
if actual_filter.contains('+') || actual_filter.contains('#') {
let mut wildcards = self.wildcard_callbacks.lock();
wildcards.push(entry);
} else {
let mut exact = self.exact_callbacks.lock();
exact.entry(actual_filter).or_default().push(entry);
}
}
fn get_callback(&self, id: CallbackId) -> Option<CallbackEntry> {
self.callback_registry.lock().get(&id).cloned()
}
#[must_use]
pub fn restore_callback(&self, id: CallbackId) -> bool {
if let Some(entry) = self.get_callback(id) {
let topic_filter = &entry.topic_filter;
let actual_filter = strip_shared_subscription_prefix(topic_filter).to_string();
let already_registered = if actual_filter.contains('+') || actual_filter.contains('#') {
let wildcards = self.wildcard_callbacks.lock();
wildcards.iter().any(|e| e.id == id)
} else {
let exact = self.exact_callbacks.lock();
exact
.get(&actual_filter)
.is_some_and(|entries| entries.iter().any(|e| e.id == id))
};
if !already_registered {
let topic_filter = entry.topic_filter.clone();
self.register_internal(&topic_filter, entry);
}
true
} else {
false
}
}
#[must_use]
pub fn unregister(&self, topic_filter: &str) -> bool {
let actual_filter = strip_shared_subscription_prefix(topic_filter);
let mut registry = self.callback_registry.lock();
let registry_count_before = registry.len();
registry.retain(|_, entry| entry.topic_filter != topic_filter);
let removed_from_registry = registry.len() < registry_count_before;
drop(registry);
let removed_from_callbacks = if actual_filter.contains('+') || actual_filter.contains('#') {
let mut wildcards = self.wildcard_callbacks.lock();
let count_before = wildcards.len();
wildcards.retain(|entry| entry.topic_filter != topic_filter);
wildcards.len() < count_before
} else {
let mut exact = self.exact_callbacks.lock();
exact.remove(actual_filter).is_some()
};
removed_from_registry || removed_from_callbacks
}
pub fn dispatch(&self, message: &PublishPacket) -> Result<()> {
let mut callbacks_to_call = Vec::new();
{
let exact = self.exact_callbacks.lock();
if let Some(entries) = exact.get(&message.topic_name) {
for entry in entries {
callbacks_to_call.push(entry.callback.clone());
}
}
}
{
let wildcards = self.wildcard_callbacks.lock();
for entry in wildcards.iter() {
let match_filter = strip_shared_subscription_prefix(&entry.topic_filter);
if crate::topic_matching::matches(&message.topic_name, match_filter) {
callbacks_to_call.push(entry.callback.clone());
}
}
}
for callback in callbacks_to_call {
let message = message.clone();
#[cfg(feature = "opentelemetry")]
{
use crate::telemetry::propagation;
let user_props = propagation::extract_user_properties(&message.properties);
tokio::spawn(async move {
propagation::with_remote_context(&user_props, || {
let span = tracing::info_span!(
"message_received",
topic = %message.topic_name,
qos = ?message.qos,
payload_size = message.payload.len(),
retain = message.retain,
);
let _enter = span.enter();
callback(message);
});
});
}
#[cfg(not(feature = "opentelemetry"))]
{
tokio::spawn(async move {
callback(message);
});
}
}
Ok(())
}
#[must_use]
pub fn callback_count(&self) -> usize {
self.callback_registry.lock().len()
}
pub fn clear(&self) {
self.exact_callbacks.lock().clear();
self.wildcard_callbacks.lock().clear();
self.callback_registry.lock().clear();
}
}
impl Default for CallbackManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Properties;
use crate::QoS;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn test_exact_match_callback() {
let manager = CallbackManager::new();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let callback: PublishCallback = Arc::new(move |_msg| {
counter_clone.fetch_add(1, Ordering::Relaxed);
});
manager.register("test/topic", callback).unwrap();
let message = PublishPacket {
topic_name: "test/topic".to_string(),
packet_id: None,
payload: vec![1, 2, 3].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
manager.dispatch(&message).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 1);
let message2 = PublishPacket {
topic_name: "test/other".to_string(),
..message.clone()
};
manager.dispatch(&message2).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_wildcard_callback() {
let manager = CallbackManager::new();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let callback: PublishCallback = Arc::new(move |_msg| {
counter_clone.fetch_add(1, Ordering::Relaxed);
});
manager.register("test/+/topic", callback).unwrap();
let message1 = PublishPacket {
topic_name: "test/foo/topic".to_string(),
packet_id: None,
payload: vec![].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
manager.dispatch(&message1).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 1);
let message2 = PublishPacket {
topic_name: "test/bar/topic".to_string(),
..message1.clone()
};
manager.dispatch(&message2).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 2);
let message3 = PublishPacket {
topic_name: "test/topic".to_string(),
..message1.clone()
};
manager.dispatch(&message3).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_multiple_callbacks() {
let manager = CallbackManager::new();
let counter1 = Arc::new(AtomicU32::new(0));
let counter2 = Arc::new(AtomicU32::new(0));
let counter1_clone = Arc::clone(&counter1);
let callback1: PublishCallback = Arc::new(move |_msg| {
counter1_clone.fetch_add(1, Ordering::Relaxed);
});
let counter2_clone = Arc::clone(&counter2);
let callback2: PublishCallback = Arc::new(move |_msg| {
counter2_clone.fetch_add(2, Ordering::Relaxed);
});
manager.register("test/topic", callback1).unwrap();
manager.register("test/topic", callback2).unwrap();
let message = PublishPacket {
topic_name: "test/topic".to_string(),
packet_id: None,
payload: vec![].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
manager.dispatch(&message).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter1.load(Ordering::Relaxed), 1);
assert_eq!(counter2.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_unregister() {
let manager = CallbackManager::new();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let callback: PublishCallback = Arc::new(move |_msg| {
counter_clone.fetch_add(1, Ordering::Relaxed);
});
manager.register("test/topic", callback).unwrap();
let message = PublishPacket {
topic_name: "test/topic".to_string(),
packet_id: None,
payload: vec![].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
manager.dispatch(&message).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 1);
let _ = manager.unregister("test/topic");
manager.dispatch(&message).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_callback_count() {
let manager = CallbackManager::new();
assert_eq!(manager.callback_count(), 0);
let callback: PublishCallback = Arc::new(|_msg| {});
manager.register("test/exact", callback.clone()).unwrap();
assert_eq!(manager.callback_count(), 1);
manager
.register("test/+/wildcard", callback.clone())
.unwrap();
assert_eq!(manager.callback_count(), 2);
manager.register("test/exact", callback).unwrap();
assert_eq!(manager.callback_count(), 3);
manager.clear();
assert_eq!(manager.callback_count(), 0);
}
#[tokio::test]
async fn test_shared_subscription_callback() {
let manager = CallbackManager::new();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let callback: PublishCallback = Arc::new(move |_msg| {
counter_clone.fetch_add(1, Ordering::Relaxed);
});
manager
.register("$share/workers/tasks/#", callback)
.unwrap();
let message = PublishPacket {
topic_name: "tasks/job1".to_string(),
packet_id: None,
payload: vec![1, 2, 3].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
manager.dispatch(&message).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 1);
let message2 = PublishPacket {
topic_name: "tasks/job2".to_string(),
..message.clone()
};
manager.dispatch(&message2).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 2);
let message3 = PublishPacket {
topic_name: "other/topic".to_string(),
..message.clone()
};
manager.dispatch(&message3).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_dispatch_does_not_block_on_slow_callback() {
let manager = CallbackManager::new();
let started = Arc::new(AtomicU32::new(0));
let started_clone = Arc::clone(&started);
let callback: PublishCallback = Arc::new(move |_msg| {
started_clone.fetch_add(1, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(100));
});
manager.register("test/topic", callback).unwrap();
let message = PublishPacket {
topic_name: "test/topic".to_string(),
packet_id: None,
payload: vec![].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
let start = std::time::Instant::now();
manager.dispatch(&message).unwrap();
let dispatch_time = start.elapsed();
assert!(
dispatch_time < std::time::Duration::from_millis(50),
"dispatch should return immediately, took {dispatch_time:?}"
);
tokio::time::sleep(std::time::Duration::from_millis(150)).await;
assert_eq!(started.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_dispatch_handles_multiple_concurrent_callbacks() {
let manager = CallbackManager::new();
let counter = Arc::new(AtomicU32::new(0));
for i in 0..5 {
let counter_clone = Arc::clone(&counter);
let callback: PublishCallback = Arc::new(move |_msg| {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
let topic = format!("test/topic{i}");
manager.register(&topic, callback).unwrap();
}
let wildcard_counter = Arc::clone(&counter);
let wildcard_callback: PublishCallback = Arc::new(move |_msg| {
wildcard_counter.fetch_add(10, Ordering::SeqCst);
});
manager.register("test/#", wildcard_callback).unwrap();
let message = PublishPacket {
topic_name: "test/topic0".to_string(),
packet_id: None,
payload: vec![].into(),
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
manager.dispatch(&message).unwrap();
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::SeqCst), 11);
}
}