use async_trait::async_trait;
use crate::{Error, Result, Task};
use std::collections::HashMap;
use std::sync::Arc;
pub mod context;
pub use context::HandlerContext;
#[async_trait]
pub trait Handler: Send + Sync {
async fn handle(&self, task: &Task) -> Result<()> {
Err(Error::Handler("Handler::handle not implemented".into()))
}
async fn handle_with_context(&self, task: &Task, _ctx: &HandlerContext) -> Result<()> {
self.handle(task).await
}
}
pub struct Mux {
handlers: HashMap<String, Arc<dyn Handler>>,
default_handler: Option<Arc<dyn Handler>>,
}
impl Mux {
#[must_use]
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
default_handler: None,
}
}
pub fn handle<H: Handler + 'static>(&mut self, task_type: &str, handler: H) {
self.handlers.insert(task_type.to_string(), Arc::new(handler));
}
pub fn default_handler<H: Handler + 'static>(&mut self, handler: H) {
self.default_handler = Some(Arc::new(handler));
}
pub async fn process(&self, task: &Task) -> Result<()> {
let handler = self
.handlers
.get(&task.task_type)
.or(self.default_handler.as_ref())
.ok_or_else(|| {
Error::Handler(format!("No handler found for task_type: {}", task.task_type))
})?;
handler.handle(task).await
}
pub async fn process_with_context(&self, task: &Task, ctx: &HandlerContext) -> Result<()> {
let handler = self
.handlers
.get(&task.task_type)
.or(self.default_handler.as_ref())
.ok_or_else(|| {
Error::Handler(format!("No handler found for task_type: {}", task.task_type))
})?;
handler.handle_with_context(task, ctx).await
}
pub fn has_handler(&self, task_type: &str) -> bool {
self.handlers.contains_key(task_type) || self.default_handler.is_some()
}
pub fn handler_count(&self) -> usize {
self.handlers.len()
}
pub fn registered_types(&self) -> Vec<String> {
self.handlers.keys().cloned().collect()
}
pub fn remove(&mut self, task_type: &str) -> bool {
self.handlers.remove(task_type).is_some()
}
pub fn clear(&mut self) {
self.handlers.clear();
self.default_handler = None;
}
}
impl Default for Mux {
fn default() -> Self {
Self::new()
}
}
pub struct LogHandler {
pub prefix: String,
}
impl Default for LogHandler {
fn default() -> Self {
Self {
prefix: "Task".to_string(),
}
}
}
#[async_trait]
impl Handler for LogHandler {
async fn handle(&self, task: &Task) -> Result<()> {
tracing::info!("{}: type={}, queue={}, id={}",
self.prefix, task.task_type, task.queue, task.id);
tracing::debug!("Task payload size: {} bytes", task.payload.len());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHandler {
should_fail: bool,
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, task: &Task) -> Result<()> {
if self.should_fail {
Err(Error::Handler("Test handler failed".to_string()))
} else {
tracing::info!("Handled task: {}", task.task_type);
Ok(())
}
}
}
#[tokio::test]
async fn test_mux_registration() {
let mut mux = Mux::new();
mux.handle("test:type", TestHandler { should_fail: false });
assert!(mux.has_handler("test:type"));
assert!(!mux.has_handler("unknown:type"));
assert_eq!(mux.handler_count(), 1);
}
#[tokio::test]
async fn test_mux_routing() {
let mut mux = Mux::new();
mux.handle("test:type", TestHandler { should_fail: false });
let task = Task {
id: "test-id".to_string(),
task_type: "test:type".to_string(),
queue: "default".to_string(),
payload: vec![1, 2, 3],
options: Default::default(),
status: Default::default(),
created_at: 0,
enqueued_at: None,
processed_at: None,
retry_cnt: 0,
last_error: None,
};
let result = mux.process(&task).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_mux_no_handler() {
let mux = Mux::new();
let task = Task {
id: "test-id".to_string(),
task_type: "unknown:type".to_string(),
queue: "default".to_string(),
payload: vec![1, 2, 3],
options: Default::default(),
status: Default::default(),
created_at: 0,
enqueued_at: None,
processed_at: None,
retry_cnt: 0,
last_error: None,
};
let result = mux.process(&task).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mux_default_handler() {
let mut mux = Mux::new();
mux.default_handler(TestHandler { should_fail: false });
let task = Task {
id: "test-id".to_string(),
task_type: "unknown:type".to_string(),
queue: "default".to_string(),
payload: vec![1, 2, 3],
options: Default::default(),
status: Default::default(),
created_at: 0,
enqueued_at: None,
processed_at: None,
retry_cnt: 0,
last_error: None,
};
let result = mux.process(&task).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_log_handler() {
let handler = LogHandler {
prefix: "Test".to_string(),
};
let task = Task {
id: "test-id".to_string(),
task_type: "test:type".to_string(),
queue: "default".to_string(),
payload: vec![1, 2, 3],
options: Default::default(),
status: Default::default(),
created_at: 0,
enqueued_at: None,
processed_at: None,
retry_cnt: 0,
last_error: None,
};
let result = handler.handle(&task).await;
assert!(result.is_ok());
}
}