Skip to main content

kojin_core/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::context::TaskContext;
7use crate::error::{KojinError, TaskResult};
8use crate::task::Task;
9
10/// Type-erased task handler function.
11pub type TaskHandler = Arc<
12    dyn Fn(
13            serde_json::Value,
14            Arc<TaskContext>,
15        ) -> Pin<Box<dyn Future<Output = TaskResult<serde_json::Value>> + Send>>
16        + Send
17        + Sync,
18>;
19
20/// Registry mapping task names to type-erased handlers.
21#[derive(Clone)]
22pub struct TaskRegistry {
23    handlers: HashMap<String, TaskHandler>,
24}
25
26impl TaskRegistry {
27    pub fn new() -> Self {
28        Self {
29            handlers: HashMap::new(),
30        }
31    }
32
33    /// Register a task type.
34    pub fn register<T: Task>(&mut self) {
35        let handler: TaskHandler = Arc::new(|payload, ctx| {
36            Box::pin(async move {
37                let task: T = serde_json::from_value(payload)?;
38                let result = task.run(&ctx).await?;
39                Ok(serde_json::to_value(result)?)
40            })
41        });
42        self.handlers.insert(T::NAME.to_string(), handler);
43    }
44
45    /// Look up a handler by task name.
46    pub fn get(&self, name: &str) -> Option<&TaskHandler> {
47        self.handlers.get(name)
48    }
49
50    /// Check if a task is registered.
51    pub fn contains(&self, name: &str) -> bool {
52        self.handlers.contains_key(name)
53    }
54
55    /// Execute a task by name with the given payload.
56    pub async fn dispatch(
57        &self,
58        name: &str,
59        payload: serde_json::Value,
60        ctx: Arc<TaskContext>,
61    ) -> TaskResult<serde_json::Value> {
62        let handler = self
63            .get(name)
64            .ok_or_else(|| KojinError::TaskNotFound(name.to_string()))?;
65        handler(payload, ctx).await
66    }
67}
68
69impl Default for TaskRegistry {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use async_trait::async_trait;
79    use serde::{Deserialize, Serialize};
80
81    #[derive(Debug, Serialize, Deserialize)]
82    struct AddTask {
83        a: i32,
84        b: i32,
85    }
86
87    #[async_trait]
88    impl Task for AddTask {
89        const NAME: &'static str = "add";
90        type Output = i32;
91
92        async fn run(&self, _ctx: &TaskContext) -> TaskResult<Self::Output> {
93            Ok(self.a + self.b)
94        }
95    }
96
97    #[tokio::test]
98    async fn register_and_dispatch() {
99        let mut registry = TaskRegistry::new();
100        registry.register::<AddTask>();
101
102        assert!(registry.contains("add"));
103        assert!(!registry.contains("unknown"));
104
105        let ctx = Arc::new(TaskContext::new());
106        let result = registry
107            .dispatch("add", serde_json::json!({"a": 3, "b": 4}), ctx)
108            .await
109            .unwrap();
110        assert_eq!(result, serde_json::json!(7));
111    }
112
113    #[tokio::test]
114    async fn dispatch_not_found() {
115        let registry = TaskRegistry::new();
116        let ctx = Arc::new(TaskContext::new());
117        let result = registry
118            .dispatch("unknown", serde_json::Value::Null, ctx)
119            .await;
120        assert!(matches!(result, Err(KojinError::TaskNotFound(_))));
121    }
122
123    #[tokio::test]
124    async fn dispatch_with_context() {
125        #[derive(Debug, Serialize, Deserialize)]
126        struct CtxTask;
127
128        #[async_trait]
129        impl Task for CtxTask {
130            const NAME: &'static str = "ctx_task";
131            type Output = String;
132
133            async fn run(&self, ctx: &TaskContext) -> TaskResult<Self::Output> {
134                let prefix = ctx.data::<String>().cloned().unwrap_or_default();
135                Ok(format!("{prefix}done"))
136            }
137        }
138
139        let mut registry = TaskRegistry::new();
140        registry.register::<CtxTask>();
141
142        let mut ctx = TaskContext::new();
143        ctx.insert("prefix:".to_string());
144        let ctx = Arc::new(ctx);
145
146        let result = registry
147            .dispatch("ctx_task", serde_json::json!(null), ctx)
148            .await
149            .unwrap();
150        assert_eq!(result, serde_json::json!("prefix:done"));
151    }
152}