1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::context::TaskContext;
7use crate::error::{KojinError, TaskResult};
8use crate::task::Task;
9
10pub type TaskHandler = Arc<
12 dyn Fn(
13 serde_json::Value,
14 Arc<TaskContext>,
15 ) -> Pin<Box<dyn Future<Output = TaskResult<serde_json::Value>> + Send>>
16 + Send
17 + Sync,
18>;
19
20#[derive(Clone)]
22pub struct TaskRegistry {
23 handlers: HashMap<String, TaskHandler>,
24}
25
26impl TaskRegistry {
27 pub fn new() -> Self {
28 Self {
29 handlers: HashMap::new(),
30 }
31 }
32
33 pub fn register<T: Task>(&mut self) {
35 let handler: TaskHandler = Arc::new(|payload, ctx| {
36 Box::pin(async move {
37 let task: T = serde_json::from_value(payload)?;
38 let result = task.run(&ctx).await?;
39 Ok(serde_json::to_value(result)?)
40 })
41 });
42 self.handlers.insert(T::NAME.to_string(), handler);
43 }
44
45 pub fn get(&self, name: &str) -> Option<&TaskHandler> {
47 self.handlers.get(name)
48 }
49
50 pub fn contains(&self, name: &str) -> bool {
52 self.handlers.contains_key(name)
53 }
54
55 pub async fn dispatch(
57 &self,
58 name: &str,
59 payload: serde_json::Value,
60 ctx: Arc<TaskContext>,
61 ) -> TaskResult<serde_json::Value> {
62 let handler = self
63 .get(name)
64 .ok_or_else(|| KojinError::TaskNotFound(name.to_string()))?;
65 handler(payload, ctx).await
66 }
67}
68
69impl Default for TaskRegistry {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use async_trait::async_trait;
79 use serde::{Deserialize, Serialize};
80
81 #[derive(Debug, Serialize, Deserialize)]
82 struct AddTask {
83 a: i32,
84 b: i32,
85 }
86
87 #[async_trait]
88 impl Task for AddTask {
89 const NAME: &'static str = "add";
90 type Output = i32;
91
92 async fn run(&self, _ctx: &TaskContext) -> TaskResult<Self::Output> {
93 Ok(self.a + self.b)
94 }
95 }
96
97 #[tokio::test]
98 async fn register_and_dispatch() {
99 let mut registry = TaskRegistry::new();
100 registry.register::<AddTask>();
101
102 assert!(registry.contains("add"));
103 assert!(!registry.contains("unknown"));
104
105 let ctx = Arc::new(TaskContext::new());
106 let result = registry
107 .dispatch("add", serde_json::json!({"a": 3, "b": 4}), ctx)
108 .await
109 .unwrap();
110 assert_eq!(result, serde_json::json!(7));
111 }
112
113 #[tokio::test]
114 async fn dispatch_not_found() {
115 let registry = TaskRegistry::new();
116 let ctx = Arc::new(TaskContext::new());
117 let result = registry
118 .dispatch("unknown", serde_json::Value::Null, ctx)
119 .await;
120 assert!(matches!(result, Err(KojinError::TaskNotFound(_))));
121 }
122
123 #[tokio::test]
124 async fn dispatch_with_context() {
125 #[derive(Debug, Serialize, Deserialize)]
126 struct CtxTask;
127
128 #[async_trait]
129 impl Task for CtxTask {
130 const NAME: &'static str = "ctx_task";
131 type Output = String;
132
133 async fn run(&self, ctx: &TaskContext) -> TaskResult<Self::Output> {
134 let prefix = ctx.data::<String>().cloned().unwrap_or_default();
135 Ok(format!("{prefix}done"))
136 }
137 }
138
139 let mut registry = TaskRegistry::new();
140 registry.register::<CtxTask>();
141
142 let mut ctx = TaskContext::new();
143 ctx.insert("prefix:".to_string());
144 let ctx = Arc::new(ctx);
145
146 let result = registry
147 .dispatch("ctx_task", serde_json::json!(null), ctx)
148 .await
149 .unwrap();
150 assert_eq!(result, serde_json::json!("prefix:done"));
151 }
152}