1#![allow(clippy::missing_errors_doc)]
2use crate::{Result, SerializedTask, Task};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7type TaskHandler = Arc<
9 dyn Fn(Vec<u8>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>
10 + Send
11 + Sync,
12>;
13
14pub struct TaskRegistry {
16 handlers: Arc<RwLock<HashMap<String, TaskHandler>>>,
17}
18
19impl TaskRegistry {
20 #[inline]
21 #[must_use]
22 pub fn new() -> Self {
23 Self {
24 handlers: Arc::new(RwLock::new(HashMap::new())),
25 }
26 }
27
28 pub async fn register<T>(&self, task: T)
30 where
31 T: Task + 'static,
32 {
33 let task_name = task.name().to_string();
34 let task = Arc::new(task);
35
36 let handler: TaskHandler = Arc::new(move |payload: Vec<u8>| {
37 let task = Arc::clone(&task);
38 Box::pin(async move {
39 let input: T::Input = serde_json::from_slice(&payload)
41 .map_err(|e| crate::CelersError::Deserialization(e.to_string()))?;
42
43 let output = task.execute(input).await?;
45
46 let output_bytes = serde_json::to_vec(&output)
48 .map_err(|e| crate::CelersError::Serialization(e.to_string()))?;
49
50 Ok(output_bytes)
51 })
52 });
53
54 self.handlers.write().await.insert(task_name, handler);
55 }
56
57 pub async fn execute(&self, task: &SerializedTask) -> Result<Vec<u8>> {
59 let handlers = self.handlers.read().await;
60
61 let handler = handlers.get(&task.metadata.name).ok_or_else(|| {
62 crate::CelersError::TaskExecution(format!(
63 "Task not found in registry: {}",
64 task.metadata.name
65 ))
66 })?;
67
68 handler(task.payload.clone()).await
69 }
70
71 pub async fn has_task(&self, name: &str) -> bool {
73 self.handlers.read().await.contains_key(name)
74 }
75
76 pub async fn list_tasks(&self) -> Vec<String> {
78 self.handlers.read().await.keys().cloned().collect()
79 }
80}
81
82impl Default for TaskRegistry {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::Task;
92 use serde::{Deserialize, Serialize};
93
94 #[derive(Serialize, Deserialize)]
95 struct AddInput {
96 a: i32,
97 b: i32,
98 }
99
100 #[derive(Serialize, Deserialize, PartialEq, Debug)]
101 struct AddOutput {
102 result: i32,
103 }
104
105 struct AddTask;
106
107 #[async_trait::async_trait]
108 impl Task for AddTask {
109 type Input = AddInput;
110 type Output = AddOutput;
111
112 async fn execute(&self, input: Self::Input) -> Result<Self::Output> {
113 Ok(AddOutput {
114 result: input.a + input.b,
115 })
116 }
117
118 fn name(&self) -> &'static str {
119 "add"
120 }
121 }
122
123 #[tokio::test]
124 async fn test_registry() {
125 let registry = TaskRegistry::new();
126
127 registry.register(AddTask).await;
129
130 assert!(registry.has_task("add").await);
132
133 let task = SerializedTask {
135 metadata: crate::TaskMetadata::new("add".to_string()),
136 payload: serde_json::to_vec(&AddInput { a: 2, b: 3 }).unwrap(),
137 };
138
139 let result = registry.execute(&task).await.unwrap();
141 let output: AddOutput = serde_json::from_slice(&result).unwrap();
142
143 assert_eq!(output.result, 5);
144 }
145}