use crate::NorthwardResult;
use std::{future::Future, pin::Pin, sync::Arc};
use tokio::sync::RwLock;
pub type HandlerResult = NorthwardResult<()>;
pub type MessageHandler = Arc<
dyn Fn(&str, &[u8]) -> Pin<Box<dyn Future<Output = HandlerResult> + Send + 'static>>
+ Send
+ Sync,
>;
pub type ConnectionHandler =
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = HandlerResult> + Send + 'static>> + Send + Sync>;
pub type DisconnectionHandler = Arc<
dyn Fn(Option<String>) -> Pin<Box<dyn Future<Output = HandlerResult> + Send + 'static>>
+ Send
+ Sync,
>;
#[derive(Debug, Clone)]
pub struct Route {
pub pattern: String,
pub priority: u8,
}
impl Route {
pub fn new(pattern: impl Into<String>) -> Self {
Self {
pattern: pattern.into(),
priority: 0,
}
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority;
self
}
}
pub struct MessageRouter {
routes: Arc<RwLock<Vec<(Route, MessageHandler)>>>,
default_handler: Arc<RwLock<Option<MessageHandler>>>,
connection_handler: Arc<RwLock<Option<ConnectionHandler>>>,
disconnection_handler: Arc<RwLock<Option<DisconnectionHandler>>>,
}
impl MessageRouter {
pub fn new() -> Self {
Self {
routes: Arc::new(RwLock::new(Vec::new())),
default_handler: Arc::new(RwLock::new(None)),
connection_handler: Arc::new(RwLock::new(None)),
disconnection_handler: Arc::new(RwLock::new(None)),
}
}
pub async fn register(&self, pattern: impl Into<String>, handler: MessageHandler) {
let route = Route::new(pattern);
self.register_with_route(route, handler).await;
}
pub async fn register_with_route(&self, route: Route, handler: MessageHandler) {
let mut routes = self.routes.write().await;
routes.push((route, handler));
routes.sort_by(|a, b| b.0.priority.cmp(&a.0.priority));
}
pub async fn set_default_handler(&self, handler: MessageHandler) {
let mut default_handler = self.default_handler.write().await;
*default_handler = Some(handler);
}
pub async fn set_connection_handler(&self, handler: ConnectionHandler) {
let mut connection_handler = self.connection_handler.write().await;
*connection_handler = Some(handler);
}
pub async fn set_disconnection_handler(&self, handler: DisconnectionHandler) {
let mut disconnection_handler = self.disconnection_handler.write().await;
*disconnection_handler = Some(handler);
}
pub async fn route_message(&self, topic: &str, payload: &[u8]) -> HandlerResult {
let matched: Option<MessageHandler> = {
let routes = self.routes.read().await;
routes
.iter()
.find(|(route, _)| mqtt_pattern_matches(&route.pattern, topic))
.map(|(_, handler)| Arc::clone(handler))
};
if let Some(handler) = matched {
return handler(topic, payload).await;
}
let default_handler = self.default_handler.read().await.clone();
if let Some(handler) = default_handler {
tracing::debug!("No specific route found for topic '{topic}', using default handler");
return handler(topic, payload).await;
}
tracing::warn!("No handler found for topic: {topic}");
Ok(())
}
pub async fn handle_connected(&self) -> HandlerResult {
let handler = self.connection_handler.read().await.clone();
if let Some(handler) = handler {
handler().await
} else {
tracing::debug!("Connection established but no connection handler registered");
Ok(())
}
}
pub async fn handle_disconnected(&self, reason: Option<String>) -> HandlerResult {
let handler = self.disconnection_handler.read().await.clone();
if let Some(handler) = handler {
handler(reason).await
} else {
tracing::debug!("Connection lost but no disconnection handler registered");
Ok(())
}
}
pub async fn route_count(&self) -> usize {
self.routes.read().await.len()
}
}
impl Default for MessageRouter {
fn default() -> Self {
Self::new()
}
}
pub fn mqtt_pattern_matches(pattern: &str, topic: &str) -> bool {
if pattern == topic {
return true;
}
if pattern == "#" {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/#") {
if prefix.contains('+') {
return matches_mixed_wildcards(pattern, topic);
}
if topic.starts_with(prefix) {
return topic.len() == prefix.len() || topic.chars().nth(prefix.len()) == Some('/');
}
return false;
}
if pattern.contains('+') {
return matches_with_single_level_wildcards(pattern, topic);
}
false
}
fn matches_pattern_parts(pattern_parts: &[&str], topic_parts: &[&str], count: usize) -> bool {
if topic_parts.len() < count {
return false;
}
for i in 0..count {
if pattern_parts[i] != "+" && pattern_parts[i] != topic_parts[i] {
return false;
}
}
true
}
fn matches_with_single_level_wildcards(pattern: &str, topic: &str) -> bool {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let topic_parts: Vec<&str> = topic.split('/').collect();
if pattern_parts.len() != topic_parts.len() {
return false;
}
matches_pattern_parts(&pattern_parts, &topic_parts, pattern_parts.len())
}
fn matches_mixed_wildcards(pattern: &str, topic: &str) -> bool {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let topic_parts: Vec<&str> = topic.split('/').collect();
if pattern_parts.last() != Some(&"#") {
return false;
}
let prefix_len = pattern_parts.len() - 1; matches_pattern_parts(&pattern_parts, &topic_parts, prefix_len)
}
#[derive(Debug, Clone)]
pub struct CompiledPattern {
parts: Vec<PatternPart>,
has_multi_level_wildcard: bool,
}
#[derive(Debug, Clone, PartialEq)]
enum PatternPart {
Literal(String),
SingleWildcard,
MultiWildcard,
}
impl CompiledPattern {
pub fn new(pattern: &str) -> Result<Self, PatternError> {
if pattern.is_empty() {
return Err(PatternError::EmptyPattern);
}
let mut parts = Vec::new();
let mut has_multi_level_wildcard = false;
for (index, part) in pattern.split('/').enumerate() {
match part {
"+" => parts.push(PatternPart::SingleWildcard),
"#" => {
if index != pattern.split('/').count() - 1 {
return Err(PatternError::MultiLevelWildcardNotAtEnd);
}
parts.push(PatternPart::MultiWildcard);
has_multi_level_wildcard = true;
}
literal => {
if literal.contains('+') || literal.contains('#') {
return Err(PatternError::InvalidWildcardUsage);
}
parts.push(PatternPart::Literal(literal.to_string()));
}
}
}
Ok(Self {
parts,
has_multi_level_wildcard,
})
}
pub fn matches(&self, topic: &str) -> bool {
let topic_parts: Vec<&str> = topic.split('/').collect();
if self.has_multi_level_wildcard {
return self.matches_with_multi_level(&topic_parts);
}
if self.parts.len() != topic_parts.len() {
return false;
}
for (pattern_part, topic_part) in self.parts.iter().zip(topic_parts.iter()) {
match pattern_part {
PatternPart::Literal(literal) => {
if literal != topic_part {
return false;
}
}
PatternPart::SingleWildcard => {
continue;
}
PatternPart::MultiWildcard => {
return true;
}
}
}
true
}
fn matches_with_multi_level(&self, topic_parts: &[&str]) -> bool {
let pattern_parts_without_wildcard = &self.parts[..self.parts.len() - 1];
if topic_parts.len() < pattern_parts_without_wildcard.len() {
return false;
}
for (pattern_part, topic_part) in pattern_parts_without_wildcard
.iter()
.zip(topic_parts.iter())
{
match pattern_part {
PatternPart::Literal(literal) => {
if literal != topic_part {
return false;
}
}
PatternPart::SingleWildcard => {
continue;
}
PatternPart::MultiWildcard => {
return true;
}
}
}
true
}
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum PatternError {
#[error("Pattern cannot be empty")]
EmptyPattern,
#[error("Multi-level wildcard (#) must be at the end of the pattern")]
MultiLevelWildcardNotAtEnd,
#[error("Wildcards (+, #) cannot be mixed with literal text in the same level")]
InvalidWildcardUsage,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
assert!(mqtt_pattern_matches(
"sensor/temperature",
"sensor/temperature"
));
assert!(!mqtt_pattern_matches(
"sensor/temperature",
"sensor/humidity"
));
}
#[test]
fn test_single_level_wildcard() {
assert!(mqtt_pattern_matches(
"sensor/+/temperature",
"sensor/device1/temperature"
));
assert!(mqtt_pattern_matches(
"sensor/+/temperature",
"sensor/device2/temperature"
));
assert!(!mqtt_pattern_matches(
"sensor/+/temperature",
"sensor/device1/humidity"
));
assert!(!mqtt_pattern_matches(
"sensor/+/temperature",
"sensor/device1/data/temperature"
));
}
#[test]
fn test_multi_level_wildcard() {
assert!(mqtt_pattern_matches(
"sensor/#",
"sensor/device1/temperature"
));
assert!(mqtt_pattern_matches(
"sensor/#",
"sensor/device1/data/temperature"
));
assert!(mqtt_pattern_matches("sensor/#", "sensor"));
assert!(!mqtt_pattern_matches("sensor/#", "device/sensor"));
}
#[test]
fn test_compiled_pattern() {
let pattern = CompiledPattern::new("sensor/+/temperature").unwrap();
assert!(pattern.matches("sensor/device1/temperature"));
assert!(!pattern.matches("sensor/device1/humidity"));
let multi_pattern = CompiledPattern::new("sensor/#").unwrap();
assert!(multi_pattern.matches("sensor/device1/temperature"));
assert!(multi_pattern.matches("sensor/device1/data/temperature"));
}
#[test]
fn test_pattern_errors() {
assert!(matches!(
CompiledPattern::new(""),
Err(PatternError::EmptyPattern)
));
assert!(matches!(
CompiledPattern::new("sensor/#/temperature"),
Err(PatternError::MultiLevelWildcardNotAtEnd)
));
assert!(matches!(
CompiledPattern::new("sensor/device+/temperature"),
Err(PatternError::InvalidWildcardUsage)
));
}
#[test]
fn test_complex_patterns() {
assert!(mqtt_pattern_matches(
"building/+/floor/+/temperature",
"building/A/floor/1/temperature"
));
assert!(!mqtt_pattern_matches(
"building/+/floor/+/temperature",
"building/A/floor/1/humidity"
));
assert!(mqtt_pattern_matches(
"sensor/+/#",
"sensor/device1/temperature/current"
));
assert!(!mqtt_pattern_matches(
"sensor/+/#",
"device/sensor/temperature"
));
}
#[test]
fn test_edge_cases() {
assert!(mqtt_pattern_matches("+", "test"));
assert!(mqtt_pattern_matches("#", ""));
assert!(mqtt_pattern_matches("#", "any/topic/here"));
assert!(mqtt_pattern_matches("+", "root"));
assert!(!mqtt_pattern_matches("+", "root/sub"));
}
}