use crate::error::{Error, Result};
use crate::server::Handler;
use crate::task::Task;
use async_trait::async_trait;
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
enum HandlerWrapper {
Sync(Arc<dyn Fn(Task) -> Result<()> + Send + Sync>),
Async(Arc<dyn Fn(Task) -> BoxFuture<Result<()>> + Send + Sync>),
}
type BoxFuture<T> = std::pin::Pin<Box<dyn Future<Output = T> + Send>>;
fn pattern_matches(pattern: &str, task_type: &str) -> bool {
if pattern == "*" {
return true;
}
if !pattern.contains('*') {
return pattern == task_type;
}
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 {
let (prefix, suffix) = (parts[0], parts[1]);
if prefix.is_empty() {
return task_type.ends_with(suffix);
} else if suffix.is_empty() {
return task_type.starts_with(prefix);
} else {
return task_type.starts_with(prefix) && task_type.ends_with(suffix);
}
}
if let (Some(first), Some(last)) = (parts.first(), parts.last()) {
if task_type.starts_with(first) && task_type.ends_with(last) {
let mut search_start = first.len();
for part in &parts[1..parts.len() - 1] {
if let Some(pos) = task_type[search_start..].find(part) {
search_start += pos + part.len();
} else {
return false;
}
}
return true;
}
}
false
}
pub struct ServeMux {
handlers: HashMap<String, HandlerWrapper>,
}
impl ServeMux {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
}
}
pub fn handle_func<F>(&mut self, pattern: &str, func: F)
where
F: Fn(Task) -> Result<()> + Send + Sync + 'static,
{
self
.handlers
.insert(pattern.to_string(), HandlerWrapper::Sync(Arc::new(func)));
}
pub fn handle_async_func<F, Fut>(&mut self, pattern: &str, func: F)
where
F: Fn(Task) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
let func = Arc::new(func);
self.handlers.insert(
pattern.to_string(),
HandlerWrapper::Async(Arc::new(move |task: Task| {
let func = Arc::clone(&func);
Box::pin(async move { func(task).await })
})),
);
}
fn find_handler(&self, task_type: &str) -> Option<&HandlerWrapper> {
if let Some(handler) = self.handlers.get(task_type) {
return Some(handler);
}
for (pattern, handler) in &self.handlers {
if pattern_matches(pattern, task_type) {
return Some(handler);
}
}
None
}
}
impl Default for ServeMux {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Handler for ServeMux {
async fn process_task(&self, task: Task) -> Result<()> {
let task_type = task.get_type();
match self.find_handler(task_type) {
Some(HandlerWrapper::Sync(func)) => func(task),
Some(HandlerWrapper::Async(func)) => func(task).await,
None => Err(Error::other(format!(
"No handler registered for task type: {task_type}"
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::Task;
#[tokio::test]
async fn test_serve_mux_sync_handler() {
let mut mux = ServeMux::new();
mux.handle_func("test:task", |task: Task| {
assert_eq!(task.get_type(), "test:task");
Ok(())
});
let task = Task::new("test:task", b"test payload").unwrap();
let result = mux.process_task(task).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_serve_mux_async_handler() {
let mut mux = ServeMux::new();
mux.handle_async_func("async:task", |task: Task| async move {
assert_eq!(task.get_type(), "async:task");
Ok(())
});
let task = Task::new("async:task", b"test payload").unwrap();
let result = mux.process_task(task).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_serve_mux_no_handler() {
let mux = ServeMux::new();
let task = Task::new("unknown:task", b"test payload").unwrap();
let result = mux.process_task(task).await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("No handler registered"));
}
}
#[tokio::test]
async fn test_serve_mux_multiple_handlers() {
let mut mux = ServeMux::new();
mux.handle_func("email:send", |_task: Task| Ok(()));
mux.handle_async_func("image:resize", |_task: Task| async move { Ok(()) });
let task1 = Task::new("email:send", b"test").unwrap();
assert!(mux.process_task(task1).await.is_ok());
let task2 = Task::new("image:resize", b"test").unwrap();
assert!(mux.process_task(task2).await.is_ok());
}
#[test]
fn test_pattern_matches() {
assert!(pattern_matches("email:send", "email:send"));
assert!(!pattern_matches("email:send", "email:deliver"));
assert!(pattern_matches("*", "any:task"));
assert!(pattern_matches("*", "anything"));
assert!(pattern_matches("email:*", "email:send"));
assert!(pattern_matches("email:*", "email:deliver"));
assert!(pattern_matches("email:*", "email:process:complex"));
assert!(!pattern_matches("email:*", "sms:send"));
assert!(pattern_matches("*:send", "email:send"));
assert!(pattern_matches("*:send", "sms:send"));
assert!(!pattern_matches("*:send", "email:deliver"));
assert!(pattern_matches("email:*:done", "email:send:done"));
assert!(pattern_matches("email:*:done", "email:process:task:done"));
assert!(!pattern_matches("email:*:done", "email:send:failed"));
assert!(!pattern_matches("email:*:done", "sms:send:done"));
}
#[tokio::test]
async fn test_serve_mux_wildcard_patterns() {
let mut mux = ServeMux::new();
mux.handle_func("email:*", |task: Task| {
assert!(task.get_type().starts_with("email:"));
Ok(())
});
mux.handle_async_func("*:send", |task: Task| async move {
assert!(task.get_type().ends_with(":send"));
Ok(())
});
let task1 = Task::new("email:send", b"test").unwrap();
assert!(mux.process_task(task1).await.is_ok());
let task2 = Task::new("email:deliver", b"test").unwrap();
assert!(mux.process_task(task2).await.is_ok());
let task3 = Task::new("sms:send", b"test").unwrap();
assert!(mux.process_task(task3).await.is_ok());
let task4 = Task::new("report:generate", b"test").unwrap();
assert!(mux.process_task(task4).await.is_err());
}
#[tokio::test]
async fn test_serve_mux_catch_all_pattern() {
let mut mux = ServeMux::new();
mux.handle_func("*", |_task: Task| Ok(()));
let task1 = Task::new("any:task:type", b"test").unwrap();
assert!(mux.process_task(task1).await.is_ok());
let task2 = Task::new("another:completely:different:task", b"test").unwrap();
assert!(mux.process_task(task2).await.is_ok());
}
}