use std::collections::HashMap;
use std::sync::Arc;
use super::command_message::CommandMessage;
#[derive(Debug, Clone, Default)]
pub struct RouteContext {
pub matched_pattern: Option<String>,
pub captures: Vec<String>,
pub user_data: serde_json::Value,
}
impl RouteContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_pattern(pattern: &str) -> Self {
Self {
matched_pattern: Some(pattern.to_string()),
..Default::default()
}
}
pub fn with_captures(mut self, captures: Vec<String>) -> Self {
self.captures = captures;
self
}
pub fn with_user_data(mut self, data: serde_json::Value) -> Self {
self.user_data = data;
self
}
}
#[derive(Debug, Clone)]
pub struct RouteMatch {
pub route: String,
pub context: RouteContext,
}
pub trait RouteHandler: Send + Sync {
fn handle(&self, msg: CommandMessage, ctx: &RouteContext) -> CommandMessage;
}
impl<F> RouteHandler for F
where
F: Fn(CommandMessage, &RouteContext) -> CommandMessage + Send + Sync,
{
fn handle(&self, msg: CommandMessage, ctx: &RouteContext) -> CommandMessage {
self(msg, ctx)
}
}
struct Route {
pattern: String,
is_pattern: bool,
handler: Arc<dyn RouteHandler>,
}
impl Route {
fn new<H: RouteHandler + 'static>(pattern: &str, is_pattern: bool, handler: H) -> Self {
Self {
pattern: pattern.to_string(),
is_pattern,
handler: Arc::new(handler),
}
}
fn matches(&self, topic: &str) -> Option<RouteContext> {
if self.is_pattern {
self.match_pattern(topic)
} else if self.pattern == topic || self.pattern == "*" {
Some(RouteContext::with_pattern(&self.pattern))
} else {
None
}
}
fn match_pattern(&self, topic: &str) -> Option<RouteContext> {
let pattern_parts: Vec<&str> = self.pattern.split('.').collect();
let topic_parts: Vec<&str> = topic.split('.').collect();
let mut captures = Vec::new();
let mut pi = 0; let mut ti = 0;
while pi < pattern_parts.len() {
let pattern_part = pattern_parts[pi];
if pattern_part == "**" {
let remaining: Vec<&str> = topic_parts[ti..].to_vec();
captures.push(remaining.join("."));
return Some(
RouteContext::with_pattern(&self.pattern).with_captures(captures),
);
} else if pattern_part == "*" {
if ti >= topic_parts.len() {
return None;
}
captures.push(topic_parts[ti].to_string());
pi += 1;
ti += 1;
} else {
if ti >= topic_parts.len() || pattern_part != topic_parts[ti] {
return None;
}
pi += 1;
ti += 1;
}
}
if ti == topic_parts.len() {
Some(RouteContext::with_pattern(&self.pattern).with_captures(captures))
} else {
None
}
}
}
pub struct TopicRouter {
exact_routes: HashMap<String, Arc<dyn RouteHandler>>,
pattern_routes: Vec<Route>,
default_handler: Option<Arc<dyn RouteHandler>>,
match_subtopic: bool,
}
impl TopicRouter {
pub fn new() -> Self {
Self {
exact_routes: HashMap::new(),
pattern_routes: Vec::new(),
default_handler: None,
match_subtopic: true,
}
}
pub fn new_full_topic() -> Self {
Self {
exact_routes: HashMap::new(),
pattern_routes: Vec::new(),
default_handler: None,
match_subtopic: false,
}
}
pub fn register<H: RouteHandler + 'static>(&mut self, topic: &str, handler: H) {
self.exact_routes
.insert(topic.to_string(), Arc::new(handler));
}
pub fn register_pattern<H: RouteHandler + 'static>(&mut self, pattern: &str, handler: H) {
self.pattern_routes
.push(Route::new(pattern, true, handler));
}
pub fn set_default<H: RouteHandler + 'static>(&mut self, handler: H) {
self.default_handler = Some(Arc::new(handler));
}
fn get_routing_key(&self, msg: &CommandMessage) -> String {
if self.match_subtopic {
let subtopic = msg.subtopic();
if subtopic.is_empty() {
msg.topic.clone()
} else {
subtopic
}
} else {
msg.topic.clone()
}
}
pub fn route(&self, msg: CommandMessage) -> CommandMessage {
let routing_key = self.get_routing_key(&msg);
if let Some(handler) = self.exact_routes.get(&routing_key) {
let ctx = RouteContext::new();
return handler.handle(msg, &ctx);
}
for route in &self.pattern_routes {
if let Some(ctx) = route.matches(&routing_key) {
return route.handler.handle(msg, &ctx);
}
}
if let Some(handler) = &self.default_handler {
let ctx = RouteContext::new();
return handler.handle(msg, &ctx);
}
let error_msg = format!("No handler for topic: {}", routing_key);
msg.into_error_response(&error_msg)
}
pub fn has_handler(&self, topic: &str) -> bool {
if self.exact_routes.contains_key(topic) {
return true;
}
for route in &self.pattern_routes {
if route.matches(topic).is_some() {
return true;
}
}
self.default_handler.is_some()
}
pub fn registered_topics(&self) -> Vec<String> {
self.exact_routes.keys().cloned().collect()
}
pub fn registered_patterns(&self) -> Vec<String> {
self.pattern_routes.iter().map(|r| r.pattern.clone()).collect()
}
}
impl Default for TopicRouter {
fn default() -> Self {
Self::new()
}
}
pub struct TopicRouterBuilder {
router: TopicRouter,
}
impl TopicRouterBuilder {
pub fn new() -> Self {
Self {
router: TopicRouter::new(),
}
}
pub fn route<H: RouteHandler + 'static>(mut self, topic: &str, handler: H) -> Self {
self.router.register(topic, handler);
self
}
pub fn pattern<H: RouteHandler + 'static>(mut self, pattern: &str, handler: H) -> Self {
self.router.register_pattern(pattern, handler);
self
}
pub fn default<H: RouteHandler + 'static>(mut self, handler: H) -> Self {
self.router.set_default(handler);
self
}
pub fn build(self) -> TopicRouter {
self.router
}
}
impl Default for TopicRouterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_msg(subtopic: &str) -> CommandMessage {
CommandMessage::read(&format!("TEST.{}", subtopic))
}
#[test]
fn test_exact_match() {
let mut router = TopicRouter::new();
router.register("test_topic", |msg: CommandMessage, _ctx: &RouteContext| {
msg.into_response(serde_json::json!({"matched": "exact"}))
});
let msg = make_msg("test_topic");
let result = router.route(msg);
assert!(result.success);
assert_eq!(result.data["matched"], "exact");
}
#[test]
fn test_pattern_match_single_wildcard() {
let mut router = TopicRouter::new();
router.register_pattern("status.*", |msg: CommandMessage, ctx: &RouteContext| {
msg.into_response(serde_json::json!({
"matched": "pattern",
"captures": ctx.captures
}))
});
let msg = make_msg("status.ok");
let result = router.route(msg);
assert!(result.success);
assert_eq!(result.data["matched"], "pattern");
assert_eq!(result.data["captures"][0], "ok");
let msg2 = make_msg("status.sub.topic");
let result2 = router.route(msg2);
assert!(!result2.success);
}
#[test]
fn test_pattern_match_double_wildcard() {
let mut router = TopicRouter::new();
router.register_pattern("data.**", |msg: CommandMessage, ctx: &RouteContext| {
msg.into_response(serde_json::json!({
"matched": "double_wildcard",
"captures": ctx.captures
}))
});
let msg = make_msg("data.sensors.temperature.value");
let result = router.route(msg);
assert!(result.success);
assert_eq!(result.data["captures"][0], "sensors.temperature.value");
}
#[test]
fn test_default_handler() {
let mut router = TopicRouter::new();
router.set_default(|msg: CommandMessage, _ctx: &RouteContext| {
msg.into_response(serde_json::json!({"matched": "default"}))
});
let msg = make_msg("unknown_topic");
let result = router.route(msg);
assert!(result.success);
assert_eq!(result.data["matched"], "default");
}
#[test]
fn test_no_handler() {
let router = TopicRouter::new();
let msg = make_msg("unhandled");
let result = router.route(msg);
assert!(!result.success);
assert!(result.error_message.contains("No handler"));
}
#[test]
fn test_exact_takes_priority() {
let mut router = TopicRouter::new();
router.register_pattern("test.*", |msg: CommandMessage, _ctx: &RouteContext| {
msg.into_response(serde_json::json!({"matched": "pattern"}))
});
router.register("test.specific", |msg: CommandMessage, _ctx: &RouteContext| {
msg.into_response(serde_json::json!({"matched": "exact"}))
});
let msg = make_msg("test.specific");
let result = router.route(msg);
assert_eq!(result.data["matched"], "exact");
let msg2 = make_msg("test.other");
let result2 = router.route(msg2);
assert_eq!(result2.data["matched"], "pattern");
}
#[test]
fn test_builder() {
let router = TopicRouterBuilder::new()
.route("exact", |msg: CommandMessage, _: &RouteContext| {
msg.into_response(serde_json::Value::Null)
})
.pattern("pattern.*", |msg: CommandMessage, _: &RouteContext| {
msg.into_response(serde_json::Value::Null)
})
.default(|msg: CommandMessage, _: &RouteContext| {
msg.into_error_response("default handler")
})
.build();
assert!(router.has_handler("exact"));
assert!(router.has_handler("pattern.test"));
assert!(router.has_handler("anything_else"));
}
#[test]
fn test_full_topic_router() {
let mut router = TopicRouter::new_full_topic();
router.register("domain.subtopic", |msg: CommandMessage, _ctx: &RouteContext| {
msg.into_response(serde_json::json!({"matched": "full"}))
});
let msg = CommandMessage::read("domain.subtopic");
let result = router.route(msg);
assert!(result.success);
assert_eq!(result.data["matched"], "full");
}
}